├── config ├── data │ ├── __init__.py │ ├── megadepth_test_1500.py │ ├── scannet_test_1500.py │ └── base.py └── defaultmf.py ├── matchformer.png ├── requirements.txt ├── model ├── backbone │ ├── __init__.py │ ├── fine_preprocess.py │ ├── fine_matching.py │ ├── coarse_matching.py │ ├── match_LA_large.py │ ├── match_LA_lite.py │ ├── match_SEA_large.py │ └── match_SEA_lite.py ├── utils │ ├── dataloader.py │ ├── profiler.py │ ├── augment.py │ ├── misc.py │ ├── metrics.py │ └── comm.py ├── matchformer.py ├── datasets │ ├── sampler.py │ ├── scannet.py │ ├── megadepth.py │ └── dataset.py ├── lightning_loftr.py └── data.py ├── test.py ├── README.md └── LICENSE /config/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /matchformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InSAI-Lab/MatchFormer/HEAD/matchformer.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.2.0.32 2 | tqdm==4.36.1 3 | h5py==2.10.0 4 | matplotlib==3.1.2 5 | numpy==1.17.4 6 | torch==1.3.1 7 | einops==0.3.0 8 | kornia==0.4.1 9 | loguru==0.5.3 10 | pillow==7.1.1 11 | scipy==1.4.1 12 | pytorch-lightning==1.3.5 -------------------------------------------------------------------------------- /config/data/megadepth_test_1500.py: -------------------------------------------------------------------------------- 1 | from config.data.base import cfg 2 | 3 | TEST_BASE_PATH = "data/megadepth/index" 4 | 5 | cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" 6 | cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" 7 | cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" 8 | cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" 9 | 10 | cfg.DATASET.MGDPT_IMG_RESIZE = 840 11 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 12 | -------------------------------------------------------------------------------- /config/data/scannet_test_1500.py: -------------------------------------------------------------------------------- 1 | from config.data.base import cfg 2 | 3 | TEST_BASE_PATH = "data/scannet/index" 4 | 5 | cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" 6 | cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" 7 | cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" 8 | cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" 9 | cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" 10 | 11 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .match_LA_lite import Matchformer_LA_lite 2 | from .match_LA_large import Matchformer_LA_large 3 | from .match_SEA_lite import Matchformer_SEA_lite 4 | from .match_SEA_large import Matchformer_SEA_large 5 | 6 | 7 | def build_backbone(config): 8 | if config['backbone_type'] == 'litela': 9 | return Matchformer_LA_lite() 10 | elif config['backbone_type'] == 'largela': 11 | return Matchformer_LA_large() 12 | elif config['backbone_type'] == 'litesea': 13 | return Matchformer_SEA_lite() 14 | elif config['backbone_type'] == 'largesea': 15 | return Matchformer_SEA_large() 16 | else: 17 | raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") -------------------------------------------------------------------------------- /model/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # --- PL-DATAMODULE --- 5 | 6 | def get_local_split(items: list, world_size: int, rank: int, seed: int): 7 | """ The local rank only loads a split of the dataset. """ 8 | n_items = len(items) 9 | items_permute = np.random.RandomState(seed).permutation(items) 10 | if n_items % world_size == 0: 11 | padded_items = items_permute 12 | else: 13 | padding = np.random.RandomState(seed).choice( 14 | items, 15 | world_size - (n_items % world_size), 16 | replace=True) 17 | padded_items = np.concatenate([items_permute, padding]) 18 | assert len(padded_items) % world_size == 0, \ 19 | f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' 20 | n_per_rank = len(padded_items) // world_size 21 | local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] 22 | 23 | return local_items 24 | -------------------------------------------------------------------------------- /config/data/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data config will be the last one merged into the main config. 3 | Setups in data configs will override all existed setups! 4 | """ 5 | 6 | from yacs.config import CfgNode as CN 7 | _CN = CN() 8 | _CN.DATASET = CN() 9 | _CN.TRAINER = CN() 10 | 11 | # training data config 12 | _CN.DATASET.TRAIN_DATA_ROOT = None 13 | _CN.DATASET.TRAIN_POSE_ROOT = None 14 | _CN.DATASET.TRAIN_NPZ_ROOT = None 15 | _CN.DATASET.TRAIN_LIST_PATH = None 16 | _CN.DATASET.TRAIN_INTRINSIC_PATH = None 17 | # validation set config 18 | _CN.DATASET.VAL_DATA_ROOT = None 19 | _CN.DATASET.VAL_POSE_ROOT = None 20 | _CN.DATASET.VAL_NPZ_ROOT = None 21 | _CN.DATASET.VAL_LIST_PATH = None 22 | _CN.DATASET.VAL_INTRINSIC_PATH = None 23 | 24 | # testing data config 25 | _CN.DATASET.TEST_DATA_ROOT = None 26 | _CN.DATASET.TEST_POSE_ROOT = None 27 | _CN.DATASET.TEST_NPZ_ROOT = None 28 | _CN.DATASET.TEST_LIST_PATH = None 29 | _CN.DATASET.TEST_INTRINSIC_PATH = None 30 | 31 | # dataset config 32 | _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 33 | _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val 34 | 35 | cfg = _CN 36 | -------------------------------------------------------------------------------- /model/utils/profiler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler 3 | from contextlib import contextmanager 4 | from pytorch_lightning.utilities import rank_zero_only 5 | 6 | 7 | class InferenceProfiler(SimpleProfiler): 8 | """ 9 | This profiler records duration of actions with cuda.synchronize() 10 | Use this in test time. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.start = rank_zero_only(self.start) 16 | self.stop = rank_zero_only(self.stop) 17 | self.summary = rank_zero_only(self.summary) 18 | 19 | @contextmanager 20 | def profile(self, action_name: str) -> None: 21 | try: 22 | torch.cuda.synchronize() 23 | self.start(action_name) 24 | yield action_name 25 | finally: 26 | torch.cuda.synchronize() 27 | self.stop(action_name) 28 | 29 | 30 | def build_profiler(name): 31 | if name == 'inference': 32 | return InferenceProfiler() 33 | elif name == 'pytorch': 34 | from pytorch_lightning.profiler import PyTorchProfiler 35 | return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) 36 | elif name is None: 37 | return PassThroughProfiler() 38 | else: 39 | raise ValueError(f'Invalid profiler: {name}') 40 | -------------------------------------------------------------------------------- /model/utils/augment.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | 3 | 4 | class DarkAug(object): 5 | """ 6 | Extreme dark augmentation aiming at Aachen Day-Night 7 | """ 8 | 9 | def __init__(self) -> None: 10 | self.augmentor = A.Compose([ 11 | A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), 12 | A.Blur(p=0.1, blur_limit=(3, 9)), 13 | A.MotionBlur(p=0.2, blur_limit=(3, 25)), 14 | A.RandomGamma(p=0.1, gamma_limit=(15, 65)), 15 | A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) 16 | ], p=0.75) 17 | 18 | def __call__(self, x): 19 | return self.augmentor(image=x)['image'] 20 | 21 | 22 | class MobileAug(object): 23 | """ 24 | Random augmentations aiming at images of mobile/handhold devices. 25 | """ 26 | 27 | def __init__(self): 28 | self.augmentor = A.Compose([ 29 | A.MotionBlur(p=0.25), 30 | A.ColorJitter(p=0.5), 31 | A.RandomRain(p=0.1), # random occlusion 32 | A.RandomSunFlare(p=0.1), 33 | A.JpegCompression(p=0.25), 34 | A.ISONoise(p=0.25) 35 | ], p=1.0) 36 | 37 | def __call__(self, x): 38 | return self.augmentor(image=x)['image'] 39 | 40 | 41 | def build_augmentor(method=None, **kwargs): 42 | if method is not None: 43 | raise NotImplementedError('Using of augmentation functions are not supported yet!') 44 | if method == 'dark': 45 | return DarkAug() 46 | elif method == 'mobile': 47 | return MobileAug() 48 | elif method is None: 49 | return None 50 | else: 51 | raise ValueError(f'Invalid augmentation method: {method}') 52 | 53 | 54 | if __name__ == '__main__': 55 | augmentor = build_augmentor('FDA') 56 | -------------------------------------------------------------------------------- /model/matchformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .backbone import build_backbone 4 | from .backbone.fine_preprocess import FinePreprocess 5 | from .backbone.coarse_matching import CoarseMatching 6 | from .backbone.fine_matching import FineMatching 7 | from einops.einops import rearrange 8 | 9 | 10 | class Matchformer(nn.Module): 11 | 12 | def __init__(self, config): 13 | super().__init__() 14 | # Misc 15 | self.config = config 16 | self.backbone = build_backbone(config) 17 | self.coarse_matching = CoarseMatching(config['match_coarse']) 18 | self.fine_preprocess = FinePreprocess(config) 19 | self.fine_matching = FineMatching() 20 | 21 | 22 | def forward(self, data): 23 | data.update({ 24 | 'bs': data['image0'].size(0), 25 | 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] 26 | }) 27 | 28 | mask_c0 = mask_c1 = None # mask is useful in training 29 | if 'mask0' in data: 30 | mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) 31 | 32 | if data['hw0_i'] == data['hw1_i']: 33 | feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) 34 | (feat_c0, feat_c1),(feat_f0, feat_f1) = feats_c.split(data['bs']),feats_f.split(data['bs']) 35 | else: 36 | (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) 37 | 38 | data.update({ 39 | 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], 40 | 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] 41 | }) 42 | 43 | feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') 44 | feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') 45 | 46 | # match coarse-level 47 | self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) 48 | 49 | # fine-level refinement 50 | feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) 51 | 52 | # match fine-level 53 | self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) 54 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import argparse 3 | import pprint 4 | from loguru import logger as loguru_logger 5 | from config.defaultmf import get_cfg_defaults 6 | from model.data import MultiSceneDataModule 7 | from model.lightning_loftr import PL_LoFTR 8 | 9 | 10 | def parse_args(): 11 | # init a costum parser which will be added into pl.Trainer parser 12 | # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags 13 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 14 | parser.add_argument( 15 | 'data_cfg_path', type=str, help='data config path') 16 | parser.add_argument( 17 | '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') 18 | parser.add_argument( 19 | '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") 20 | parser.add_argument( 21 | '--profiler_name', type=str, default='inference', help='options: [inference, pytorch], or leave it unset') 22 | parser.add_argument( 23 | '--batch_size', type=int, default=1, help='batch_size per gpu') 24 | parser.add_argument( 25 | '--num_workers', type=int, default=2) 26 | parser.add_argument( 27 | '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') 28 | 29 | parser = pl.Trainer.add_argparse_args(parser) 30 | return parser.parse_args() 31 | 32 | 33 | if __name__ == '__main__': 34 | # parse arguments 35 | args = parse_args() 36 | # init default-cfg and merge it with the main- and data-cfg 37 | config = get_cfg_defaults() 38 | config.merge_from_file(args.data_cfg_path) 39 | pl.seed_everything(config.TRAINER.SEED) # reproducibility 40 | 41 | # tune when testing 42 | if args.thr is not None: 43 | config.LOFTR.MATCH_COARSE.THR = args.thr 44 | 45 | # lightning module 46 | model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, dump_dir=args.dump_dir) 47 | 48 | # lightning data 49 | data_module = MultiSceneDataModule(args, config) 50 | 51 | # lightning trainer 52 | trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) 53 | 54 | loguru_logger.info(f"Start testing!") 55 | trainer.test(model, datamodule=data_module, verbose=False) 56 | -------------------------------------------------------------------------------- /model/backbone/fine_preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange, repeat 5 | 6 | 7 | class FinePreprocess(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | 11 | self.config = config 12 | self.cat_c_feat = config['fine_concat_coarse_feat'] 13 | self.W = self.config['fine_window_size'] 14 | 15 | d_model_c = self.config['coarse']['d_model'] 16 | d_model_f = self.config['fine']['d_model'] 17 | self.d_model_f = d_model_f 18 | if self.cat_c_feat: 19 | self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) 20 | self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) 21 | 22 | self._reset_parameters() 23 | 24 | def _reset_parameters(self): 25 | for p in self.parameters(): 26 | if p.dim() > 1: 27 | nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") 28 | 29 | def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): 30 | W = self.W 31 | stride = data['hw0_f'][0] // data['hw0_c'][0] 32 | 33 | data.update({'W': W}) 34 | if data['b_ids'].shape[0] == 0: 35 | feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 36 | feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 37 | return feat0, feat1 38 | 39 | # 1. unfold(crop) all local windows 40 | feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) 41 | feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 42 | feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) 43 | feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 44 | 45 | # 2. select only the predicted matches 46 | feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] 47 | feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] 48 | 49 | # option: use coarse-level loftr feature as context: concat and linear 50 | if self.cat_c_feat: 51 | feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], 52 | feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] 53 | feat_cf_win = self.merge_feat(torch.cat([ 54 | torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] 55 | repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] 56 | ], -1)) 57 | feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) 58 | 59 | return feat_f0_unfold, feat_f1_unfold 60 | -------------------------------------------------------------------------------- /model/backbone/fine_matching.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from kornia.geometry.subpix import dsnt 6 | from kornia.utils.grid import create_meshgrid 7 | 8 | 9 | class FineMatching(nn.Module): 10 | """FineMatching with s2d paradigm""" 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, feat_f0, feat_f1, data): 16 | """ 17 | Args: 18 | feat0 (torch.Tensor): [M, WW, C] 19 | feat1 (torch.Tensor): [M, WW, C] 20 | data (dict) 21 | Update: 22 | data (dict):{ 23 | 'expec_f' (torch.Tensor): [M, 3], 24 | 'mkpts0_f' (torch.Tensor): [M, 2], 25 | 'mkpts1_f' (torch.Tensor): [M, 2]} 26 | """ 27 | M, WW, C = feat_f0.shape 28 | W = int(math.sqrt(WW)) 29 | scale = data['hw0_i'][0] / data['hw0_f'][0] 30 | self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale 31 | 32 | # corner case: if no coarse matches found 33 | if M == 0: 34 | assert self.training == False, "M is always >0, when training, see coarse_matching.py" 35 | # logger.warning('No matches found in coarse-level.') 36 | data.update({ 37 | 'expec_f': torch.empty(0, 3, device=feat_f0.device), 38 | 'mkpts0_f': data['mkpts0_c'], 39 | 'mkpts1_f': data['mkpts1_c'], 40 | }) 41 | return 42 | 43 | feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] 44 | sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) 45 | softmax_temp = 1. / C**.5 46 | heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) 47 | 48 | # compute coordinates from heatmap 49 | coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] 50 | grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] 51 | 52 | # compute std over 53 | var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] 54 | std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability 55 | 56 | # for fine-level supervision 57 | data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) 58 | 59 | # compute absolute kpt coords 60 | self.get_fine_match(coords_normalized, data) 61 | 62 | @torch.no_grad() 63 | def get_fine_match(self, coords_normed, data): 64 | W, WW, C, scale = self.W, self.WW, self.C, self.scale 65 | 66 | # mkpts0_f and mkpts1_f 67 | mkpts0_f = data['mkpts0_c'] 68 | scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale 69 | mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] 70 | 71 | data.update({ 72 | "mkpts0_f": mkpts0_f, 73 | "mkpts1_f": mkpts1_f 74 | }) 75 | -------------------------------------------------------------------------------- /model/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import contextlib 3 | import joblib 4 | from typing import Union 5 | from loguru import _Logger, logger 6 | from itertools import chain 7 | 8 | import torch 9 | from yacs.config import CfgNode as CN 10 | from pytorch_lightning.utilities import rank_zero_only 11 | 12 | 13 | def lower_config(yacs_cfg): 14 | if not isinstance(yacs_cfg, CN): 15 | return yacs_cfg 16 | return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} 17 | 18 | 19 | def upper_config(dict_cfg): 20 | if not isinstance(dict_cfg, dict): 21 | return dict_cfg 22 | return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} 23 | 24 | 25 | def log_on(condition, message, level): 26 | if condition: 27 | assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] 28 | logger.log(level, message) 29 | 30 | 31 | def get_rank_zero_only_logger(logger: _Logger): 32 | if rank_zero_only.rank == 0: 33 | return logger 34 | else: 35 | for _level in logger._core.levels.keys(): 36 | level = _level.lower() 37 | setattr(logger, level, 38 | lambda x: None) 39 | logger._log = lambda x: None 40 | return logger 41 | 42 | 43 | def setup_gpus(gpus: Union[str, int]) -> int: 44 | """ A temporary fix for pytorch-lighting 1.3.x """ 45 | gpus = str(gpus) 46 | gpu_ids = [] 47 | 48 | if ',' not in gpus: 49 | n_gpus = int(gpus) 50 | return n_gpus if n_gpus != -1 else torch.cuda.device_count() 51 | else: 52 | gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] 53 | 54 | # setup environment variables 55 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') 56 | if visible_devices is None: 57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 58 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) 59 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') 60 | logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') 61 | else: 62 | logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') 63 | return len(gpu_ids) 64 | 65 | 66 | def flattenList(x): 67 | return list(chain(*x)) 68 | 69 | 70 | @contextlib.contextmanager 71 | def tqdm_joblib(tqdm_object): 72 | """Context manager to patch joblib to report into tqdm progress bar given as argument 73 | 74 | Usage: 75 | with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: 76 | Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) 77 | 78 | When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) 79 | ret_vals = Parallel(n_jobs=args.world_size)( 80 | delayed(lambda x: _compute_cov_score(pid, *x))(param) 81 | for param in tqdm(combinations(image_ids, 2), 82 | desc=f'Computing cov_score of [{pid}]', 83 | total=len(image_ids)*(len(image_ids)-1)/2)) 84 | Src: https://stackoverflow.com/a/58936697 85 | """ 86 | class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): 87 | def __init__(self, *args, **kwargs): 88 | super().__init__(*args, **kwargs) 89 | 90 | def __call__(self, *args, **kwargs): 91 | tqdm_object.update(n=self.batch_size) 92 | return super().__call__(*args, **kwargs) 93 | 94 | old_batch_callback = joblib.parallel.BatchCompletionCallBack 95 | joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback 96 | try: 97 | yield tqdm_object 98 | finally: 99 | joblib.parallel.BatchCompletionCallBack = old_batch_callback 100 | tqdm_object.close() 101 | 102 | -------------------------------------------------------------------------------- /config/defaultmf.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.MATCHFORMER = CN() 5 | _CN.MATCHFORMER.BACKBONE_TYPE = 'largela'# litela,largela,litesea,largesea 6 | _CN.MATCHFORMER.SCENS = 'indoor' # indoor, outdoor 7 | _CN.MATCHFORMER.RESOLUTION = (8,2) #(8,2),(8,4) 8 | _CN.MATCHFORMER.FINE_WINDOW_SIZE = 5 9 | _CN.MATCHFORMER.FINE_CONCAT_COARSE_FEAT = True 10 | 11 | _CN.MATCHFORMER.COARSE = CN() 12 | _CN.MATCHFORMER.COARSE.D_MODEL = 256 13 | _CN.MATCHFORMER.COARSE.D_FFN = 256 14 | 15 | _CN.MATCHFORMER.MATCH_COARSE = CN() 16 | _CN.MATCHFORMER.MATCH_COARSE.THR = 0.2 17 | _CN.MATCHFORMER.MATCH_COARSE.BORDER_RM = 0 18 | _CN.MATCHFORMER.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 19 | _CN.MATCHFORMER.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 20 | _CN.MATCHFORMER.MATCH_COARSE.SKH_ITERS = 3 21 | _CN.MATCHFORMER.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 22 | _CN.MATCHFORMER.MATCH_COARSE.SKH_PREFILTER = False 23 | _CN.MATCHFORMER.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 24 | _CN.MATCHFORMER.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 25 | _CN.MATCHFORMER.MATCH_COARSE.SPARSE_SPVS = True 26 | 27 | _CN.MATCHFORMER.FINE = CN() 28 | _CN.MATCHFORMER.FINE.D_MODEL = 128 29 | _CN.MATCHFORMER.FINE.D_FFN = 128 30 | 31 | ############## Dataset ############## 32 | _CN.DATASET = CN() 33 | # 1. data config 34 | # training and validating 35 | _CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] 36 | _CN.DATASET.TRAIN_DATA_ROOT = None 37 | _CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) 38 | _CN.DATASET.TRAIN_NPZ_ROOT = None 39 | _CN.DATASET.TRAIN_LIST_PATH = None 40 | _CN.DATASET.TRAIN_INTRINSIC_PATH = None 41 | _CN.DATASET.VAL_DATA_ROOT = None 42 | _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) 43 | _CN.DATASET.VAL_NPZ_ROOT = None 44 | _CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file 45 | _CN.DATASET.VAL_INTRINSIC_PATH = None 46 | # testing 47 | _CN.DATASET.TEST_DATA_SOURCE = None 48 | _CN.DATASET.TEST_DATA_ROOT = None 49 | _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) 50 | _CN.DATASET.TEST_NPZ_ROOT = None 51 | _CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file 52 | _CN.DATASET.TEST_INTRINSIC_PATH = None 53 | 54 | # 2. dataset config 55 | # general options 56 | _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score 57 | _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 58 | _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] 59 | 60 | # MegaDepth options 61 | _CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. 62 | _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE 63 | _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 64 | _CN.DATASET.MGDPT_DF = 8 65 | 66 | # geometric metrics and pose solver 67 | _CN.TRAINER = CN() 68 | _CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) 69 | _CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] 70 | _CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] 71 | _CN.TRAINER.RANSAC_PIXEL_THR = 0.5 72 | _CN.TRAINER.RANSAC_CONF = 0.99999 73 | _CN.TRAINER.RANSAC_MAX_ITERS = 10000 74 | _CN.TRAINER.USE_MAGSACPP = False 75 | 76 | # data sampler 77 | _CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] 78 | _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 79 | _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not 80 | _CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not 81 | _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data 82 | _CN.TRAINER.SEED = 66 83 | 84 | def get_cfg_defaults(): 85 | """Get a yacs CfgNode object with default values for my_project.""" 86 | # Return a clone so that the defaults will not be altered 87 | # This is for the "local variable" use pattern 88 | return _CN.clone() 89 | -------------------------------------------------------------------------------- /model/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler, ConcatDataset 3 | 4 | 5 | class RandomConcatSampler(Sampler): 6 | """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset 7 | in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. 8 | However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. 9 | 10 | For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. 11 | Args: 12 | shuffle (bool): shuffle the random sampled indices across all sub-datsets. 13 | repeat (int): repeatedly use the sampled indices multiple times for training. 14 | [arXiv:1902.05509, arXiv:1901.09335] 15 | NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) 16 | NOTE: This sampler behaves differently with DistributedSampler. 17 | It assume the dataset is splitted across ranks instead of replicated. 18 | TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. 19 | ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 20 | """ 21 | def __init__(self, 22 | data_source: ConcatDataset, 23 | n_samples_per_subset: int, 24 | subset_replacement: bool=True, 25 | shuffle: bool=True, 26 | repeat: int=1, 27 | seed: int=None): 28 | if not isinstance(data_source, ConcatDataset): 29 | raise TypeError("data_source should be torch.utils.data.ConcatDataset") 30 | 31 | self.data_source = data_source 32 | self.n_subset = len(self.data_source.datasets) 33 | self.n_samples_per_subset = n_samples_per_subset 34 | self.n_samples = self.n_subset * self.n_samples_per_subset * repeat 35 | self.subset_replacement = subset_replacement 36 | self.repeat = repeat 37 | self.shuffle = shuffle 38 | self.generator = torch.manual_seed(seed) 39 | assert self.repeat >= 1 40 | 41 | def __len__(self): 42 | return self.n_samples 43 | 44 | def __iter__(self): 45 | indices = [] 46 | # sample from each sub-dataset 47 | for d_idx in range(self.n_subset): 48 | low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] 49 | high = self.data_source.cumulative_sizes[d_idx] 50 | if self.subset_replacement: 51 | rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), 52 | generator=self.generator, dtype=torch.int64) 53 | else: # sample without replacement 54 | len_subset = len(self.data_source.datasets[d_idx]) 55 | rand_tensor = torch.randperm(len_subset, generator=self.generator) + low 56 | if len_subset >= self.n_samples_per_subset: 57 | rand_tensor = rand_tensor[:self.n_samples_per_subset] 58 | else: # padding with replacement 59 | rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), 60 | generator=self.generator, dtype=torch.int64) 61 | rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) 62 | indices.append(rand_tensor) 63 | indices = torch.cat(indices) 64 | if self.shuffle: # shuffle the sampled dataset (from multiple subsets) 65 | rand_tensor = torch.randperm(len(indices), generator=self.generator) 66 | indices = indices[rand_tensor] 67 | 68 | # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) 69 | if self.repeat > 1: 70 | repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] 71 | if self.shuffle: 72 | _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] 73 | repeat_indices = map(_choice, repeat_indices) 74 | indices = torch.cat([indices, *repeat_indices], 0) 75 | 76 | assert indices.shape[0] == self.n_samples 77 | return iter(indices.tolist()) 78 | -------------------------------------------------------------------------------- /model/lightning_loftr.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from loguru import logger 3 | from pathlib import Path 4 | 5 | import torch 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | 9 | from .matchformer import Matchformer 10 | from .utils.metrics import ( 11 | compute_symmetrical_epipolar_errors, 12 | compute_pose_errors, 13 | aggregate_metrics 14 | ) 15 | from .utils.comm import gather 16 | from .utils.misc import lower_config, flattenList 17 | from .utils.profiler import PassThroughProfiler 18 | 19 | 20 | class PL_LoFTR(pl.LightningModule): 21 | def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): 22 | super().__init__() 23 | # Misc 24 | self.config = config # full config 25 | _config = lower_config(self.config) 26 | self.profiler = profiler or PassThroughProfiler() 27 | 28 | # Matcher: LoFTR 29 | self.matcher = Matchformer(config=_config['matchformer']) 30 | 31 | # Pretrained weights 32 | if pretrained_ckpt: 33 | self.matcher.load_state_dict({k.replace('matcher.',''):v for k,v in torch.load(pretrained_ckpt, map_location='cpu').items()}) 34 | logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") 35 | 36 | # Testing 37 | self.dump_dir = dump_dir 38 | 39 | 40 | def _compute_metrics(self, batch): 41 | with self.profiler.profile("Copmute metrics"): 42 | compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match 43 | compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair 44 | 45 | rel_pair_names = list(zip(*batch['pair_names'])) 46 | bs = batch['image0'].size(0) 47 | metrics = { 48 | # to filter duplicate pairs caused by DistributedSampler 49 | 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], 50 | 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], 51 | 'R_errs': batch['R_errs'], 52 | 't_errs': batch['t_errs'], 53 | 'inliers': batch['inliers']} 54 | ret_dict = {'metrics': metrics} 55 | return ret_dict, rel_pair_names 56 | 57 | 58 | def test_step(self, batch, batch_idx): 59 | with self.profiler.profile("LoFTR"): 60 | self.matcher(batch) 61 | 62 | ret_dict, rel_pair_names = self._compute_metrics(batch) 63 | 64 | with self.profiler.profile("dump_results"): 65 | if self.dump_dir is not None: 66 | # dump results for further analysis 67 | keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'} 68 | pair_names = list(zip(*batch['pair_names'])) 69 | bs = batch['image0'].shape[0] 70 | dumps = [] 71 | for b_id in range(bs): 72 | item = {} 73 | mask = batch['m_bids'] == b_id 74 | item['pair_names'] = pair_names[b_id] 75 | item['identifier'] = '#'.join(rel_pair_names[b_id]) 76 | for key in keys_to_save: 77 | item[key] = batch[key][mask].cpu().numpy() 78 | for key in ['R_errs', 't_errs', 'inliers']: 79 | item[key] = batch[key][b_id] 80 | dumps.append(item) 81 | ret_dict['dumps'] = dumps 82 | 83 | return ret_dict 84 | 85 | def test_epoch_end(self, outputs): 86 | # metrics: dict of list, numpy 87 | _metrics = [o['metrics'] for o in outputs] 88 | metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} 89 | 90 | # [{key: [{...}, *#bs]}, *#batch] 91 | if self.dump_dir is not None: 92 | Path(self.dump_dir).mkdir(parents=True, exist_ok=True) 93 | _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] 94 | dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] 95 | logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') 96 | 97 | if self.trainer.global_rank == 0: 98 | print(self.profiler.summary()) 99 | val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) 100 | logger.info('\n' + pprint.pformat(val_metrics_4tb)) 101 | if self.dump_dir is not None: 102 | np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MatchFormer 2 | 3 | ### MatchFormer: Interleaving Attention in Transformers for Feature Matching 4 | 5 | Qing Wang∗, [Jiaming Zhang](https://jamycheung.github.io/)∗, [Kailun Yang](https://yangkailun.com/)†, Kunyu Peng, [Rainer Stiefelhagen](https://cvhci.anthropomatik.kit.edu/people_596.php) 6 | 7 | ∗ denotes equal contribution and † denotes corresponding author 8 | 9 | ### News 10 | - [09/2022] **MatchFormer** [[**PDF**](https://arxiv.org/pdf/2203.09645.pdf)] is accepted to **ACCV2022**. 11 | 12 | ![matchformer](matchformer.png) 13 | 14 | ### Introduction 15 | 16 | In this work, we propose a novel hierarchical extract-and-match transformer, termed as **MatchFormer**. Inside each stage of the hierarchical encoder, we interleave self-attention for feature extraction and cross-attention for feature matching, enabling a human-intuitive **extract-and-match** scheme. 17 | 18 | More detailed can be found in our [arxiv](https://arxiv.org/pdf/2203.09645.pdf) paper. 19 | 20 | ### Installation 21 | 22 | The requirements are listed in the `requirement.txt` file. To create your own environment, an example is: 23 | 24 | ```bash 25 | conda create -n matchformer python=3.7 26 | conda activate matchformer 27 | cd /path/to/matchformer 28 | pip install -r requirement.txt 29 | ``` 30 | 31 | ### Datasets 32 | 33 | You can prepare the test dataset in the same way as [LoFTR](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md), place the dataset and index in the data directory. 34 | 35 | A structure of dataset should be: 36 | 37 | ``` 38 | data 39 | ├── scannet 40 | │   ├── index 41 | │   │   ├── intrinsics.npz 42 | │   │ ├── scannet_test.txt 43 | │ │ └── test.npz 44 | │   └── test 45 | │   ├── scene0707_00 46 | │ ├── ... 47 | │ └── scene0806_00 48 | └── megadepth 49 | ├── index 50 | │ ├── 0015_0.1_0.3.npz 51 |    │ ├── ... 52 | │ ├── 0022_0.5_0.7.npz 53 | │ └── megadepth_test_1500.txt 54 | └── test 55 | ├── Undistorted_SfM 56 | └── phoenix 57 | ``` 58 | 59 | 60 | 61 | ### Evaluation 62 | 63 | The evaluation configurations can be adjusted at `/config/defaultmf.py` 64 | 65 | The weights can be downloaded in [Google Drive](https://drive.google.com/drive/folders/1JSnoQMfr32eoIXwJ1gpwUaPKv4kjdqJ7?usp=sharing). 66 | 67 | Put the weight at `model/weights`. 68 | 69 | #### Indoor: 70 | 71 | ``` 72 | # adjust large SEA model config: 73 | MATCHFORMER.BACKBONE_TYPE = 'largesea' 74 | MATCHFORMER.SCENS = 'indoor' 75 | MATCHFORMER.RESOLUTION = (8,2) 76 | MATCHFORMER.COARSE.D_MODEL = 256 77 | MATCHFORMER.COARSE.D_FFN = 256 78 | 79 | python test.py /config/data/scannet_test_1500.py --ckpt_path /model/weights/indoor-large-SEA.ckpt --gpus=1 --accelerator="ddp" 80 | ``` 81 | 82 | ``` 83 | # adjust lite LA model config: 84 | MATCHFORMER.BACKBONE_TYPE = 'litela' 85 | MATCHFORMER.SCENS = 'indoor' 86 | MATCHFORMER.RESOLUTION = (8,4) 87 | MATCHFORMER.COARSE.D_MODEL = 192 88 | MATCHFORMER.COARSE.D_FFN = 192 89 | 90 | python test.py /config/data/scannet_test_1500.py --ckpt_path /model/weights/indoor-lite-LA.ckpt --gpus=1 --accelerator="ddp" 91 | ``` 92 | 93 | #### Outdoor: 94 | 95 | ``` 96 | # adjust large LA model config: 97 | MATCHFORMER.BACKBONE_TYPE = 'largela' 98 | MATCHFORMER.SCENS = 'outdoor' 99 | MATCHFORMER.RESOLUTION = (8,2) 100 | MATCHFORMER.COARSE.D_MODEL = 256 101 | MATCHFORMER.COARSE.D_FFN = 256 102 | 103 | python test.py /config/data/megadepth_test_1500.py --ckpt_path /model/weights/outdoor-large-LA.ckpt --gpus=1 --accelerator="ddp" 104 | ``` 105 | 106 | ``` 107 | # adjust lite SEA model config: 108 | MATCHFORMER.BACKBONE_TYPE = 'litesea' 109 | MATCHFORMER.SCENS = 'outdoor' 110 | MATCHFORMER.RESOLUTION = (8,4) 111 | MATCHFORMER.COARSE.D_MODEL = 192 112 | MATCHFORMER.COARSE.D_FFN = 192 113 | 114 | python test.py /config/data/megadepth_test_1500.py --ckpt_path /model/weights/indoor-large-SEA.ckpt --gpus=1 --accelerator="ddp" 115 | ``` 116 | 117 | ### Training 118 | 119 | Based on the LOFTER code to train MatchFormer, replace LoFTR/src/loftr/backbone/ with model/backbone/match_**.py to train. 120 | 121 | ### Citation 122 | 123 | If you are interested in this work, please cite the following work: 124 | 125 | ``` 126 | @inproceedings{wang2022matchformer, 127 | title={MatchFormer: Interleaving Attention in Transformers for Feature Matching}, 128 | author={Wang, Qing and Zhang, Jiaming and Yang, Kailun and Peng, Kunyu and Stiefelhagen, Rainer}, 129 | booktitle={Asian Conference on Computer Vision}, 130 | year={2022} 131 | } 132 | ``` 133 | 134 | ### Acknowledgments 135 | 136 | Our work is based on [LoFTR](https://github.com/zju3dv/LoFTR) and we use their code. We appreciate the previous open-source repository [LoFTR](https://github.com/zju3dv/LoFTR). 137 | -------------------------------------------------------------------------------- /model/datasets/scannet.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from typing import Dict 3 | from unicodedata import name 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils as utils 8 | from numpy.linalg import inv 9 | from .dataset import ( 10 | read_scannet_gray, 11 | read_scannet_depth, 12 | read_scannet_pose 13 | ) 14 | 15 | 16 | class ScanNetDataset(utils.data.Dataset): 17 | def __init__(self, 18 | root_dir, 19 | npz_path, 20 | intrinsic_path, 21 | mode='train', 22 | min_overlap_score=0.4, 23 | augment_fn=None, 24 | pose_dir=None, 25 | **kwargs): 26 | """Manage one scene of ScanNet Dataset. 27 | Args: 28 | root_dir (str): ScanNet root directory that contains scene folders. 29 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 30 | intrinsic_path (str): path to depth-camera intrinsic file. 31 | mode (str): options are ['train', 'val', 'test']. 32 | augment_fn (callable, optional): augments images with pre-defined visual effects. 33 | pose_dir (str): ScanNet root directory that contains all poses. 34 | (we use a separate (optional) pose_dir since we store images and poses separately.) 35 | """ 36 | super().__init__() 37 | self.root_dir = root_dir 38 | self.pose_dir = pose_dir if pose_dir is not None else root_dir 39 | self.mode = mode 40 | 41 | # prepare data_names, intrinsics and extrinsics(T) 42 | with np.load(npz_path) as data: 43 | self.data_names = data['name'] 44 | if 'score' in data.keys() and mode not in ['val' or 'test']: 45 | kept_mask = data['score'] > min_overlap_score 46 | self.data_names = self.data_names[kept_mask] 47 | self.intrinsics = dict(np.load(intrinsic_path)) 48 | 49 | # for training LoFTR 50 | self.augment_fn = augment_fn if mode == 'train' else None 51 | 52 | def __len__(self): 53 | return len(self.data_names) 54 | 55 | def _read_abs_pose(self, scene_name, name): 56 | pth = osp.join(self.pose_dir, 57 | scene_name, 58 | 'pose', f'{name}.txt') 59 | return read_scannet_pose(pth) 60 | 61 | def _compute_rel_pose(self, scene_name, name0, name1): 62 | pose0 = self._read_abs_pose(scene_name, name0) 63 | pose1 = self._read_abs_pose(scene_name, name1) 64 | 65 | return np.matmul(pose1, inv(pose0)) # (4, 4) 66 | 67 | def __getitem__(self, idx): 68 | data_name = self.data_names[idx] 69 | scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name 70 | scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' 71 | 72 | # read the grayscale image which will be resized to (1, 480, 640) 73 | img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') 74 | img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') 75 | 76 | # TODO: Support augmentation & handle seeds for each worker correctly. 77 | image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None) 78 | # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 79 | image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None) 80 | # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 81 | 82 | # read the depthmap which is stored as (480, 640) 83 | if self.mode in ['train', 'val']: 84 | depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) 85 | depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) 86 | else: 87 | depth0 = depth1 = torch.tensor([]) 88 | 89 | # read the intrinsic of depthmap 90 | K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) 91 | 92 | # read and compute relative poses 93 | T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), 94 | dtype=torch.float32) 95 | T_1to0 = T_0to1.inverse() 96 | 97 | data = { 98 | 'image0': image0, # (1, h, w) 99 | 'depth0': depth0, # (h, w) 100 | 'image1': image1, 101 | 'depth1': depth1, 102 | 'T_0to1': T_0to1, # (4, 4) 103 | 'T_1to0': T_1to0, 104 | 'K0': K_0, # (3, 3) 105 | 'K1': K_1, 106 | 'dataset_name': 'ScanNet', 107 | 'scene_id': scene_name, 108 | 'pair_id': idx, 109 | 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), 110 | osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) 111 | } 112 | 113 | return data 114 | -------------------------------------------------------------------------------- /model/datasets/megadepth.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | from loguru import logger 7 | from .dataset import read_megadepth_gray, read_megadepth_depth 8 | 9 | 10 | class MegaDepthDataset(Dataset): 11 | def __init__(self, 12 | root_dir, 13 | npz_path, 14 | mode='train', 15 | min_overlap_score=0.4, 16 | img_resize=None, 17 | df=None, 18 | img_padding=False, 19 | depth_padding=False, 20 | augment_fn=None, 21 | **kwargs): 22 | """ 23 | Manage one scene(npz_path) of MegaDepth dataset. 24 | 25 | Args: 26 | root_dir (str): megadepth root directory that has `phoenix`. 27 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 28 | mode (str): options are ['train', 'val', 'test'] 29 | min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. 30 | img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. 31 | This is useful during training with batches and testing with memory intensive algorithms. 32 | df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. 33 | img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. 34 | depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. 35 | augment_fn (callable, optional): augments images with pre-defined visual effects. 36 | """ 37 | super().__init__() 38 | self.root_dir = root_dir 39 | self.mode = mode 40 | self.scene_id = npz_path.split('.')[0] 41 | 42 | # prepare scene_info and pair_info 43 | if mode == 'test' and min_overlap_score != 0: 44 | logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") 45 | min_overlap_score = 0 46 | self.scene_info = np.load(npz_path, allow_pickle=True) 47 | self.pair_infos = self.scene_info['pair_infos'].copy() 48 | del self.scene_info['pair_infos'] 49 | self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] 50 | 51 | # parameters for image resizing, padding and depthmap padding 52 | if mode == 'train': 53 | assert img_resize is not None and img_padding and depth_padding 54 | self.img_resize = img_resize 55 | self.df = df 56 | self.img_padding = img_padding 57 | self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. 58 | 59 | # for training LoFTR 60 | self.augment_fn = augment_fn if mode == 'train' else None 61 | self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) 62 | 63 | def __len__(self): 64 | return len(self.pair_infos) 65 | 66 | def __getitem__(self, idx): 67 | (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] 68 | 69 | # read grayscale image and mask. (1, h, w) and (h, w) 70 | img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) 71 | img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) 72 | 73 | # TODO: Support augmentation & handle seeds for each worker correctly. 74 | image0, mask0, scale0 = read_megadepth_gray( 75 | img_name0, self.img_resize, self.df, self.img_padding, None) 76 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 77 | image1, mask1, scale1 = read_megadepth_gray( 78 | img_name1, self.img_resize, self.df, self.img_padding, None) 79 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 80 | 81 | # read depth. shape: (h, w) 82 | if self.mode in ['train', 'val']: 83 | depth0 = read_megadepth_depth( 84 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) 85 | depth1 = read_megadepth_depth( 86 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) 87 | else: 88 | depth0 = depth1 = torch.tensor([]) 89 | 90 | # read intrinsics of original size 91 | K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) 92 | K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) 93 | 94 | # read and compute relative poses 95 | T0 = self.scene_info['poses'][idx0] 96 | T1 = self.scene_info['poses'][idx1] 97 | T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) 98 | T_1to0 = T_0to1.inverse() 99 | 100 | data = { 101 | 'image0': image0, # (1, h, w) 102 | 'depth0': depth0, # (h, w) 103 | 'image1': image1, 104 | 'depth1': depth1, 105 | 'T_0to1': T_0to1, # (4, 4) 106 | 'T_1to0': T_1to0, 107 | 'K0': K_0, # (3, 3) 108 | 'K1': K_1, 109 | 'scale0': scale0, # [scale_w, scale_h] 110 | 'scale1': scale1, 111 | 'dataset_name': 'MegaDepth', 112 | 'scene_id': self.scene_id, 113 | 'pair_id': idx, 114 | 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), 115 | } 116 | 117 | # for LoFTR training 118 | if mask0 is not None: # img_padding is True 119 | if self.coarse_scale: 120 | [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), 121 | scale_factor=self.coarse_scale, 122 | mode='nearest', 123 | recompute_scale_factor=False)[0].bool() 124 | data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) 125 | 126 | return data 127 | -------------------------------------------------------------------------------- /model/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from collections import OrderedDict 5 | from loguru import logger 6 | from kornia.geometry.epipolar import numeric 7 | from kornia.geometry.conversions import convert_points_to_homogeneous 8 | 9 | 10 | # --- METRICS --- 11 | 12 | def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): 13 | # angle error between 2 vectors 14 | t_gt = T_0to1[:3, 3] 15 | n = np.linalg.norm(t) * np.linalg.norm(t_gt) 16 | t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) 17 | t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity 18 | if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging 19 | t_err = 0 20 | 21 | # angle error between 2 rotation matrices 22 | R_gt = T_0to1[:3, :3] 23 | cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 24 | cos = np.clip(cos, -1., 1.) # handle numercial errors 25 | R_err = np.rad2deg(np.abs(np.arccos(cos))) 26 | 27 | return t_err, R_err 28 | 29 | 30 | def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): 31 | """Squared symmetric epipolar distance. 32 | This can be seen as a biased estimation of the reprojection error. 33 | Args: 34 | pts0 (torch.Tensor): [N, 2] 35 | E (torch.Tensor): [3, 3] 36 | """ 37 | pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 38 | pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 39 | pts0 = convert_points_to_homogeneous(pts0) 40 | pts1 = convert_points_to_homogeneous(pts1) 41 | 42 | Ep0 = pts0 @ E.T # [N, 3] 43 | p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] 44 | Etp1 = pts1 @ E # [N, 3] 45 | 46 | d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N 47 | return d 48 | 49 | 50 | def compute_symmetrical_epipolar_errors(data): 51 | """ 52 | Update: 53 | data (dict):{"epi_errs": [M]} 54 | """ 55 | Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) 56 | E_mat = Tx @ data['T_0to1'][:, :3, :3] 57 | 58 | m_bids = data['m_bids'] 59 | pts0 = data['mkpts0_f'] 60 | pts1 = data['mkpts1_f'] 61 | 62 | epi_errs = [] 63 | for bs in range(Tx.size(0)): 64 | mask = m_bids == bs 65 | epi_errs.append( 66 | symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) 67 | epi_errs = torch.cat(epi_errs, dim=0) 68 | 69 | data.update({'epi_errs': epi_errs}) 70 | 71 | 72 | def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): 73 | if len(kpts0) < 5: 74 | return None 75 | # normalize keypoints 76 | kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 77 | kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 78 | 79 | # normalize ransac threshold 80 | ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) 81 | 82 | # compute pose with cv2 83 | E, mask = cv2.findEssentialMat( 84 | kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) 85 | if E is None: 86 | print("\nE is None while trying to recover pose.\n") 87 | return None 88 | 89 | # recover pose from E 90 | best_num_inliers = 0 91 | ret = None 92 | for _E in np.split(E, len(E) / 3): 93 | n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) 94 | if n > best_num_inliers: 95 | ret = (R, t[:, 0], mask.ravel() > 0) 96 | best_num_inliers = n 97 | 98 | return ret 99 | 100 | 101 | def compute_pose_errors(data, config): 102 | """ 103 | Update: 104 | data (dict):{ 105 | "R_errs" List[float]: [N] 106 | "t_errs" List[float]: [N] 107 | "inliers" List[np.ndarray]: [N] 108 | } 109 | """ 110 | pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 111 | conf = config.TRAINER.RANSAC_CONF # 0.99999 112 | data.update({'R_errs': [], 't_errs': [], 'inliers': []}) 113 | 114 | m_bids = data['m_bids'].cpu().numpy() 115 | pts0 = data['mkpts0_f'].cpu().numpy() 116 | pts1 = data['mkpts1_f'].cpu().numpy() 117 | K0 = data['K0'].cpu().numpy() 118 | K1 = data['K1'].cpu().numpy() 119 | T_0to1 = data['T_0to1'].cpu().numpy() 120 | 121 | for bs in range(K0.shape[0]): 122 | mask = m_bids == bs 123 | ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) 124 | 125 | if ret is None: 126 | data['R_errs'].append(np.inf) 127 | data['t_errs'].append(np.inf) 128 | data['inliers'].append(np.array([]).astype(np.bool)) 129 | else: 130 | R, t, inliers = ret 131 | t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) 132 | data['R_errs'].append(R_err) 133 | data['t_errs'].append(t_err) 134 | data['inliers'].append(inliers) 135 | 136 | 137 | # --- METRIC AGGREGATION --- 138 | 139 | def error_auc(errors, thresholds): 140 | """ 141 | Args: 142 | errors (list): [N,] 143 | thresholds (list) 144 | """ 145 | errors = [0] + sorted(list(errors)) 146 | recall = list(np.linspace(0, 1, len(errors))) 147 | 148 | aucs = [] 149 | thresholds = [5, 10, 20] 150 | for thr in thresholds: 151 | last_index = np.searchsorted(errors, thr) 152 | y = recall[:last_index] + [recall[last_index-1]] 153 | x = errors[:last_index] + [thr] 154 | aucs.append(np.trapz(y, x) / thr) 155 | 156 | return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} 157 | 158 | 159 | def epidist_prec(errors, thresholds, ret_dict=False): 160 | precs = [] 161 | for thr in thresholds: 162 | prec_ = [] 163 | for errs in errors: 164 | correct_mask = errs < thr 165 | prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) 166 | precs.append(np.mean(prec_) if len(prec_) > 0 else 0) 167 | if ret_dict: 168 | return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} 169 | else: 170 | return precs 171 | 172 | 173 | def aggregate_metrics(metrics, epi_err_thr=5e-4): 174 | """ Aggregate metrics for the whole dataset: 175 | (This method should be called once per dataset) 176 | 1. AUC of the pose error (angular) at the threshold [5, 10, 20] 177 | 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) 178 | """ 179 | # filter duplicates 180 | unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) 181 | unq_ids = list(unq_ids.values()) 182 | logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') 183 | 184 | # pose auc 185 | angular_thresholds = [5, 10, 20] 186 | pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] 187 | aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) 188 | 189 | # matching precision 190 | dist_thresholds = [epi_err_thr] 191 | precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) 192 | 193 | return {**aucs, **precs} 194 | -------------------------------------------------------------------------------- /model/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | from loguru import logger 3 | 4 | import cv2 5 | import numpy as np 6 | import h5py 7 | import torch 8 | from numpy.linalg import inv 9 | 10 | 11 | try: 12 | # for internel use only 13 | from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT 14 | except Exception: 15 | MEGADEPTH_CLIENT = SCANNET_CLIENT = None 16 | 17 | # --- DATA IO --- 18 | 19 | def load_array_from_s3( 20 | path, client, cv_type, 21 | use_h5py=False, 22 | ): 23 | byte_str = client.Get(path) 24 | try: 25 | if not use_h5py: 26 | raw_array = np.fromstring(byte_str, np.uint8) 27 | data = cv2.imdecode(raw_array, cv_type) 28 | else: 29 | f = io.BytesIO(byte_str) 30 | data = np.array(h5py.File(f, 'r')['/depth']) 31 | except Exception as ex: 32 | print(f"==> Data loading failure: {path}") 33 | raise ex 34 | 35 | assert data is not None 36 | return data 37 | 38 | 39 | def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): 40 | cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ 41 | else cv2.IMREAD_COLOR 42 | if str(path).startswith('s3://'): 43 | image = load_array_from_s3(str(path), client, cv_type) 44 | else: 45 | image = cv2.imread(str(path), cv_type) 46 | 47 | if augment_fn is not None: 48 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 50 | image = augment_fn(image) 51 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 52 | return image # (h, w) 53 | 54 | def imread_rgb(path, augment_fn=None, client=SCANNET_CLIENT): 55 | cv_type = cv2.IMREAD_COLOR 56 | if str(path).startswith('s3://'): 57 | image = load_array_from_s3(str(path), client, cv_type) 58 | else: 59 | image = cv2.imread(str(path), cv_type) 60 | 61 | if augment_fn is not None: 62 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 63 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 64 | image = augment_fn(image) 65 | 66 | return image # (3, h, w) 67 | 68 | def get_resized_wh(w, h, resize=None): 69 | if resize is not None: # resize the longer edge 70 | scale = resize / max(h, w) 71 | w_new, h_new = int(round(w*scale)), int(round(h*scale)) 72 | else: 73 | w_new, h_new = w, h 74 | return w_new, h_new 75 | 76 | 77 | def get_divisible_wh(w, h, df=None): 78 | if df is not None: 79 | w_new, h_new = map(lambda x: int(x // df * df), [w, h]) 80 | else: 81 | w_new, h_new = w, h 82 | return w_new, h_new 83 | 84 | 85 | def pad_bottom_right(inp, pad_size, ret_mask=False): 86 | assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" 87 | mask = None 88 | if inp.ndim == 2: 89 | padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) 90 | padded[:inp.shape[0], :inp.shape[1]] = inp 91 | if ret_mask: 92 | mask = np.zeros((pad_size, pad_size), dtype=bool) 93 | mask[:inp.shape[0], :inp.shape[1]] = True 94 | elif inp.ndim == 3: 95 | padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) 96 | padded[:, :inp.shape[1], :inp.shape[2]] = inp 97 | if ret_mask: 98 | mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) 99 | mask[:, :inp.shape[1], :inp.shape[2]] = True 100 | else: 101 | raise NotImplementedError() 102 | return padded, mask 103 | 104 | # --- MEGADEPTH --- 105 | def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): 106 | """ 107 | Args: 108 | resize (int, optional): the longer edge of resized images. None for no resize. 109 | padding (bool): If set to 'True', zero-pad resized images to squared size. 110 | augment_fn (callable, optional): augments images with pre-defined visual effects 111 | Returns: 112 | image (torch.tensor): (1, h, w) 113 | mask (torch.tensor): (h, w) 114 | scale (torch.tensor): [w/w_new, h/h_new] 115 | """ 116 | # read image 117 | image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) 118 | 119 | # resize image 120 | w, h = image.shape[1], image.shape[0] 121 | w_new, h_new = get_resized_wh(w, h, resize) 122 | w_new, h_new = get_divisible_wh(w_new, h_new, df) 123 | 124 | image = cv2.resize(image, (w_new, h_new)) 125 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) 126 | 127 | if padding: # padding 128 | pad_to = max(h_new, w_new) 129 | image, mask = pad_bottom_right(image, pad_to, ret_mask=True) 130 | else: 131 | mask = None 132 | image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized 133 | mask = torch.from_numpy(mask) 134 | 135 | return image, mask, scale 136 | 137 | 138 | def read_megadepth_rgb(path, resize=None, df=None, padding=False, augment_fn=None): 139 | """ 140 | Args: 141 | resize (int, optional): the longer edge of resized images. None for no resize. 142 | padding (bool): If set to 'True', zero-pad resized images to squared size. 143 | augment_fn (callable, optional): augments images with pre-defined visual effects 144 | Returns: 145 | image (torch.tensor): (3, h, w) 146 | mask (torch.tensor): (h, w) 147 | scale (torch.tensor): [w/w_new, h/h_new] 148 | """ 149 | # read image 150 | image_gray = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) 151 | image = imread_rgb(path, augment_fn, client=MEGADEPTH_CLIENT) 152 | 153 | # resize image 154 | w, h = image.shape[1], image.shape[0] 155 | w_new, h_new = get_resized_wh(w, h, resize) 156 | w_new, h_new = get_divisible_wh(w_new, h_new, df) 157 | 158 | image = cv2.resize(image, (w_new, h_new)) 159 | image = image.transpose(2,0,1) 160 | image_gray = cv2.resize(image_gray, (w_new, h_new)) 161 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) 162 | 163 | if padding: # padding 164 | pad_to = max(h_new, w_new) 165 | image, _ = pad_bottom_right(image, pad_to, ret_mask=True) 166 | _, mask = pad_bottom_right(image_gray, pad_to, ret_mask=True) 167 | else: 168 | mask = None 169 | 170 | image = torch.from_numpy(image).float() / 255 # (3, h, w) -> (3, h, w) normalized 171 | mask = torch.from_numpy(mask) 172 | return image, mask, scale 173 | 174 | def read_megadepth_depth(path, pad_to=None): 175 | if str(path).startswith('s3://'): 176 | depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) 177 | else: 178 | depth = np.array(h5py.File(path, 'r')['depth']) 179 | if pad_to is not None: 180 | depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) 181 | depth = torch.from_numpy(depth).float() # (h, w) 182 | return depth 183 | 184 | 185 | # --- ScanNet --- 186 | 187 | def read_scannet_gray(path, resize=(640, 480), augment_fn=None): 188 | """ 189 | Args: 190 | resize (tuple): align image to depthmap, in (w, h). 191 | augment_fn (callable, optional): augments images with pre-defined visual effects 192 | Returns: 193 | image (torch.tensor): (1, h, w) 194 | mask (torch.tensor): (h, w) 195 | scale (torch.tensor): [w/w_new, h/h_new] 196 | """ 197 | # read and resize image 198 | image = imread_gray(path, augment_fn) 199 | image = cv2.resize(image, resize) 200 | 201 | # (h, w) -> (1, h, w) and normalized 202 | image = torch.from_numpy(image).float()[None] / 255 203 | return image 204 | 205 | 206 | def read_scannet_depth(path): 207 | if str(path).startswith('s3://'): 208 | depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) 209 | else: 210 | depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) 211 | depth = depth / 1000 212 | depth = torch.from_numpy(depth).float() # (h, w) 213 | return depth 214 | 215 | 216 | def read_scannet_pose(path): 217 | """ Read ScanNet's Camera2World pose and transform it to World2Camera. 218 | 219 | Returns: 220 | pose_w2c (np.ndarray): (4, 4) 221 | """ 222 | cam2world = np.loadtxt(path, delimiter=' ') 223 | world2cam = inv(cam2world) 224 | return world2cam 225 | 226 | 227 | def read_scannet_intrinsic(path): 228 | """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. 229 | """ 230 | intrinsic = np.loadtxt(path, delimiter=' ') 231 | return intrinsic[:-1, :-1] 232 | -------------------------------------------------------------------------------- /model/utils/comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | [Copied from detectron2] 4 | This file contains primitives for multi-gpu communication. 5 | This is useful when doing distributed training. 6 | """ 7 | 8 | import functools 9 | import logging 10 | import numpy as np 11 | import pickle 12 | import torch 13 | import torch.distributed as dist 14 | 15 | _LOCAL_PROCESS_GROUP = None 16 | """ 17 | A torch process group which only includes processes that on the same machine as the current process. 18 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 19 | """ 20 | 21 | 22 | def get_world_size() -> int: 23 | if not dist.is_available(): 24 | return 1 25 | if not dist.is_initialized(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def get_rank() -> int: 31 | if not dist.is_available(): 32 | return 0 33 | if not dist.is_initialized(): 34 | return 0 35 | return dist.get_rank() 36 | 37 | 38 | def get_local_rank() -> int: 39 | """ 40 | Returns: 41 | The rank of the current process within the local (per-machine) process group. 42 | """ 43 | if not dist.is_available(): 44 | return 0 45 | if not dist.is_initialized(): 46 | return 0 47 | assert _LOCAL_PROCESS_GROUP is not None 48 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 49 | 50 | 51 | def get_local_size() -> int: 52 | """ 53 | Returns: 54 | The size of the per-machine process group, 55 | i.e. the number of processes per machine. 56 | """ 57 | if not dist.is_available(): 58 | return 1 59 | if not dist.is_initialized(): 60 | return 1 61 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 62 | 63 | 64 | def is_main_process() -> bool: 65 | return get_rank() == 0 66 | 67 | 68 | def synchronize(): 69 | """ 70 | Helper function to synchronize (barrier) among all processes when 71 | using distributed training 72 | """ 73 | if not dist.is_available(): 74 | return 75 | if not dist.is_initialized(): 76 | return 77 | world_size = dist.get_world_size() 78 | if world_size == 1: 79 | return 80 | dist.barrier() 81 | 82 | 83 | @functools.lru_cache() 84 | def _get_global_gloo_group(): 85 | """ 86 | Return a process group based on gloo backend, containing all the ranks 87 | The result is cached. 88 | """ 89 | if dist.get_backend() == "nccl": 90 | return dist.new_group(backend="gloo") 91 | else: 92 | return dist.group.WORLD 93 | 94 | 95 | def _serialize_to_tensor(data, group): 96 | backend = dist.get_backend(group) 97 | assert backend in ["gloo", "nccl"] 98 | device = torch.device("cpu" if backend == "gloo" else "cuda") 99 | 100 | buffer = pickle.dumps(data) 101 | if len(buffer) > 1024 ** 3: 102 | logger = logging.getLogger(__name__) 103 | logger.warning( 104 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 105 | get_rank(), len(buffer) / (1024 ** 3), device 106 | ) 107 | ) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to(device=device) 110 | return tensor 111 | 112 | 113 | def _pad_to_largest_tensor(tensor, group): 114 | """ 115 | Returns: 116 | list[int]: size of the tensor, on each rank 117 | Tensor: padded tensor that has the max size 118 | """ 119 | world_size = dist.get_world_size(group=group) 120 | assert ( 121 | world_size >= 1 122 | ), "comm.gather/all_gather must be called from ranks within the given group!" 123 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 124 | size_list = [ 125 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 126 | ] 127 | dist.all_gather(size_list, local_size, group=group) 128 | 129 | size_list = [int(size.item()) for size in size_list] 130 | 131 | max_size = max(size_list) 132 | 133 | # we pad the tensor because torch all_gather does not support 134 | # gathering tensors of different shapes 135 | if local_size != max_size: 136 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 137 | tensor = torch.cat((tensor, padding), dim=0) 138 | return size_list, tensor 139 | 140 | 141 | def all_gather(data, group=None): 142 | """ 143 | Run all_gather on arbitrary picklable data (not necessarily tensors). 144 | 145 | Args: 146 | data: any picklable object 147 | group: a torch process group. By default, will use a group which 148 | contains all ranks on gloo backend. 149 | 150 | Returns: 151 | list[data]: list of data gathered from each rank 152 | """ 153 | if get_world_size() == 1: 154 | return [data] 155 | if group is None: 156 | group = _get_global_gloo_group() 157 | if dist.get_world_size(group) == 1: 158 | return [data] 159 | 160 | tensor = _serialize_to_tensor(data, group) 161 | 162 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 163 | max_size = max(size_list) 164 | 165 | # receiving Tensor from all ranks 166 | tensor_list = [ 167 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 168 | ] 169 | dist.all_gather(tensor_list, tensor, group=group) 170 | 171 | data_list = [] 172 | for size, tensor in zip(size_list, tensor_list): 173 | buffer = tensor.cpu().numpy().tobytes()[:size] 174 | data_list.append(pickle.loads(buffer)) 175 | 176 | return data_list 177 | 178 | 179 | def gather(data, dst=0, group=None): 180 | """ 181 | Run gather on arbitrary picklable data (not necessarily tensors). 182 | 183 | Args: 184 | data: any picklable object 185 | dst (int): destination rank 186 | group: a torch process group. By default, will use a group which 187 | contains all ranks on gloo backend. 188 | 189 | Returns: 190 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 191 | an empty list. 192 | """ 193 | if get_world_size() == 1: 194 | return [data] 195 | if group is None: 196 | group = _get_global_gloo_group() 197 | if dist.get_world_size(group=group) == 1: 198 | return [data] 199 | rank = dist.get_rank(group=group) 200 | 201 | tensor = _serialize_to_tensor(data, group) 202 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 203 | 204 | # receiving Tensor from all ranks 205 | if rank == dst: 206 | max_size = max(size_list) 207 | tensor_list = [ 208 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 209 | ] 210 | dist.gather(tensor, tensor_list, dst=dst, group=group) 211 | 212 | data_list = [] 213 | for size, tensor in zip(size_list, tensor_list): 214 | buffer = tensor.cpu().numpy().tobytes()[:size] 215 | data_list.append(pickle.loads(buffer)) 216 | return data_list 217 | else: 218 | dist.gather(tensor, [], dst=dst, group=group) 219 | return [] 220 | 221 | 222 | def shared_random_seed(): 223 | """ 224 | Returns: 225 | int: a random number that is the same across all workers. 226 | If workers need a shared RNG, they can use this shared seed to 227 | create one. 228 | 229 | All workers must call this function, otherwise it will deadlock. 230 | """ 231 | ints = np.random.randint(2 ** 31) 232 | all_ints = all_gather(ints) 233 | return all_ints[0] 234 | 235 | 236 | def reduce_dict(input_dict, average=True): 237 | """ 238 | Reduce the values in the dictionary from all processes so that process with rank 239 | 0 has the reduced results. 240 | 241 | Args: 242 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 243 | average (bool): whether to do average or sum 244 | 245 | Returns: 246 | a dict with the same keys as input_dict, after reduction. 247 | """ 248 | world_size = get_world_size() 249 | if world_size < 2: 250 | return input_dict 251 | with torch.no_grad(): 252 | names = [] 253 | values = [] 254 | # sort the keys so that they are consistent across processes 255 | for k in sorted(input_dict.keys()): 256 | names.append(k) 257 | values.append(input_dict[k]) 258 | values = torch.stack(values, dim=0) 259 | dist.reduce(values, dst=0) 260 | if dist.get_rank() == 0 and average: 261 | # only main process gets accumulated, so only divide by 262 | # world_size in this case 263 | values /= world_size 264 | reduced_dict = {k: v for k, v in zip(names, values)} 265 | return reduced_dict 266 | -------------------------------------------------------------------------------- /model/backbone/coarse_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange 5 | 6 | INF = 1e9 7 | 8 | def mask_border(m, b: int, v): 9 | """ Mask borders with value 10 | Args: 11 | m (torch.Tensor): [N, H0, W0, H1, W1] 12 | b (int) 13 | v (m.dtype) 14 | """ 15 | if b <= 0: 16 | return 17 | 18 | m[:, :b] = v 19 | m[:, :, :b] = v 20 | m[:, :, :, :b] = v 21 | m[:, :, :, :, :b] = v 22 | m[:, -b:] = v 23 | m[:, :, -b:] = v 24 | m[:, :, :, -b:] = v 25 | m[:, :, :, :, -b:] = v 26 | 27 | 28 | def mask_border_with_padding(m, bd, v, p_m0, p_m1): 29 | if bd <= 0: 30 | return 31 | 32 | m[:, :bd] = v 33 | m[:, :, :bd] = v 34 | m[:, :, :, :bd] = v 35 | m[:, :, :, :, :bd] = v 36 | 37 | h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() 38 | h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() 39 | for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): 40 | m[b_idx, h0 - bd:] = v 41 | m[b_idx, :, w0 - bd:] = v 42 | m[b_idx, :, :, h1 - bd:] = v 43 | m[b_idx, :, :, :, w1 - bd:] = v 44 | 45 | 46 | def compute_max_candidates(p_m0, p_m1): 47 | """Compute the max candidates of all pairs within a batch 48 | 49 | Args: 50 | p_m0, p_m1 (torch.Tensor): padded masks 51 | """ 52 | h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] 53 | h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] 54 | max_cand = torch.sum( 55 | torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) 56 | return max_cand 57 | 58 | 59 | class CoarseMatching(nn.Module): 60 | def __init__(self, config): 61 | super().__init__() 62 | self.config = config 63 | # general config 64 | self.thr = config['thr'] 65 | self.border_rm = config['border_rm'] 66 | # -- # for trainig fine-level LoFTR 67 | self.train_coarse_percent = config['train_coarse_percent'] 68 | self.train_pad_num_gt_min = config['train_pad_num_gt_min'] 69 | 70 | # we provide 2 options for differentiable matching 71 | self.match_type = config['match_type'] 72 | if self.match_type == 'dual_softmax': 73 | self.temperature = config['dsmax_temperature'] 74 | else: 75 | raise NotImplementedError() 76 | 77 | def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): 78 | """ 79 | Args: 80 | feat0 (torch.Tensor): [N, L, C] 81 | feat1 (torch.Tensor): [N, S, C] 82 | data (dict) 83 | mask_c0 (torch.Tensor): [N, L] (optional) 84 | mask_c1 (torch.Tensor): [N, S] (optional) 85 | Update: 86 | data (dict): { 87 | 'b_ids' (torch.Tensor): [M'], 88 | 'i_ids' (torch.Tensor): [M'], 89 | 'j_ids' (torch.Tensor): [M'], 90 | 'gt_mask' (torch.Tensor): [M'], 91 | 'mkpts0_c' (torch.Tensor): [M, 2], 92 | 'mkpts1_c' (torch.Tensor): [M, 2], 93 | 'mconf' (torch.Tensor): [M]} 94 | NOTE: M' != M during training. 95 | """ 96 | N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) 97 | 98 | # normalize 99 | feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, 100 | [feat_c0, feat_c1]) 101 | 102 | if self.match_type == 'dual_softmax': 103 | sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, 104 | feat_c1) / self.temperature 105 | if mask_c0 is not None: 106 | sim_matrix.masked_fill_( 107 | ~(mask_c0[..., None] * mask_c1[:, None]).bool(), 108 | -INF) 109 | conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) 110 | 111 | 112 | data.update({'conf_matrix': conf_matrix}) 113 | 114 | # predict coarse matches from conf_matrix 115 | data.update(**self.get_coarse_match(conf_matrix, data)) 116 | 117 | @torch.no_grad() 118 | def get_coarse_match(self, conf_matrix, data): 119 | """ 120 | Args: 121 | conf_matrix (torch.Tensor): [N, L, S] 122 | data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] 123 | Returns: 124 | coarse_matches (dict): { 125 | 'b_ids' (torch.Tensor): [M'], 126 | 'i_ids' (torch.Tensor): [M'], 127 | 'j_ids' (torch.Tensor): [M'], 128 | 'gt_mask' (torch.Tensor): [M'], 129 | 'm_bids' (torch.Tensor): [M], 130 | 'mkpts0_c' (torch.Tensor): [M, 2], 131 | 'mkpts1_c' (torch.Tensor): [M, 2], 132 | 'mconf' (torch.Tensor): [M]} 133 | """ 134 | axes_lengths = { 135 | 'h0c': data['hw0_c'][0], 136 | 'w0c': data['hw0_c'][1], 137 | 'h1c': data['hw1_c'][0], 138 | 'w1c': data['hw1_c'][1] 139 | } 140 | _device = conf_matrix.device 141 | # 1. confidence thresholding 142 | mask = conf_matrix > self.thr 143 | mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', 144 | **axes_lengths) 145 | if 'mask0' not in data: 146 | mask_border(mask, self.border_rm, False) 147 | else: 148 | mask_border_with_padding(mask, self.border_rm, False, 149 | data['mask0'], data['mask1']) 150 | mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', 151 | **axes_lengths) 152 | 153 | # 2. mutual nearest 154 | mask = mask \ 155 | * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ 156 | * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) 157 | 158 | # 3. find all valid coarse matches 159 | # this only works when at most one `True` in each row 160 | mask_v, all_j_ids = mask.max(dim=2) 161 | b_ids, i_ids = torch.where(mask_v) 162 | j_ids = all_j_ids[b_ids, i_ids] 163 | mconf = conf_matrix[b_ids, i_ids, j_ids] 164 | 165 | # 4. Random sampling of training samples for fine-level LoFTR 166 | # (optional) pad samples with gt coarse-level matches 167 | if self.training: 168 | # NOTE: 169 | # The sampling is performed across all pairs in a batch without manually balancing 170 | # #samples for fine-level increases w.r.t. batch_size 171 | if 'mask0' not in data: 172 | num_candidates_max = mask.size(0) * max( 173 | mask.size(1), mask.size(2)) 174 | else: 175 | num_candidates_max = compute_max_candidates( 176 | data['mask0'], data['mask1']) 177 | num_matches_train = int(num_candidates_max * 178 | self.train_coarse_percent) 179 | num_matches_pred = len(b_ids) 180 | assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" 181 | 182 | # pred_indices is to select from prediction 183 | if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: 184 | pred_indices = torch.arange(num_matches_pred, device=_device) 185 | else: 186 | pred_indices = torch.randint( 187 | num_matches_pred, 188 | (num_matches_train - self.train_pad_num_gt_min, ), 189 | device=_device) 190 | 191 | # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) 192 | gt_pad_indices = torch.randint( 193 | len(data['spv_b_ids']), 194 | (max(num_matches_train - num_matches_pred, 195 | self.train_pad_num_gt_min), ), 196 | device=_device) 197 | mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero 198 | 199 | b_ids, i_ids, j_ids, mconf = map( 200 | lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], 201 | dim=0), 202 | *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], 203 | [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) 204 | 205 | # These matches select patches that feed into fine-level network 206 | coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} 207 | 208 | # 4. Update with matches in original image resolution 209 | scale = data['hw0_i'][0] / data['hw0_c'][0] 210 | scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale 211 | scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale 212 | mkpts0_c = torch.stack( 213 | [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], 214 | dim=1) * scale0 215 | mkpts1_c = torch.stack( 216 | [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], 217 | dim=1) * scale1 218 | 219 | # These matches is the current prediction (for visualization) 220 | coarse_matches.update({ 221 | 'gt_mask': mconf == 0, 222 | 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches 223 | 'mkpts0_c': mkpts0_c[mconf != 0], 224 | 'mkpts1_c': mkpts1_c[mconf != 0], 225 | 'mconf': mconf[mconf != 0] 226 | }) 227 | 228 | return coarse_matches 229 | -------------------------------------------------------------------------------- /model/backbone/match_LA_large.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import math 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class DWConv(nn.Module): 16 | def __init__(self, dim=768): 17 | super(DWConv, self).__init__() 18 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 19 | 20 | def forward(self, x, H, W): 21 | B, N, C = x.shape 22 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 23 | x = self.dwconv(x) 24 | x = x.flatten(2).transpose(1, 2) 25 | 26 | return x 27 | 28 | class Mlp(nn.Module): 29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 30 | super().__init__() 31 | out_features = out_features or in_features 32 | hidden_features = hidden_features or in_features 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.dwconv = DWConv(hidden_features) 35 | self.act = act_layer() 36 | self.fc2 = nn.Linear(hidden_features, out_features) 37 | self.drop = nn.Dropout(drop) 38 | 39 | def forward(self, x,H,W): 40 | x = self.fc1(x) 41 | x = self.dwconv(x, H, W) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | def elu_feature_map(x): 49 | return torch.nn.functional.elu(x) + 1 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, eps=1e-6 , cross = False): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | 56 | self.cross = cross 57 | self.feature_map = elu_feature_map 58 | self.eps = eps 59 | self.dim = dim 60 | self.num_heads = num_heads 61 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 62 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 63 | 64 | def forward(self, x): 65 | x_q, x_kv = x, x 66 | B, N, C = x_q.shape 67 | MiniB = B // 2 68 | query = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 1, 2, 3) 69 | kv = self.kv(x_kv).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) 70 | 71 | if self.cross == True: 72 | k1, k2 = kv[0].split(MiniB) 73 | v1, v2 = kv[1].split(MiniB) 74 | key = torch.cat([k2, k1], dim=0) 75 | value = torch.cat([v2, v1], dim=0) 76 | else: 77 | key, value = kv[0], kv[1] 78 | 79 | Q = self.feature_map(query) 80 | K = self.feature_map(key) 81 | v_length = value.size(1) 82 | value = value / v_length 83 | KV = torch.einsum("nshd,nshv->nhdv", K, value) 84 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 85 | x = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 86 | x = x.contiguous().view(B, -1, C) 87 | 88 | return x 89 | 90 | class Block(nn.Module): 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,cross = False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, 97 | num_heads=num_heads, qkv_bias=qkv_bias, cross= cross) 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm = norm_layer(dim ) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | 104 | def forward(self, x,H,W): 105 | x = x + self.drop_path(self.attn(self.norm1(x))) 106 | x = x + self.drop_path(self.mlp(self.norm(x),H,W)) 107 | 108 | return x 109 | 110 | class Positional(nn.Module): 111 | def __init__(self, dim): 112 | super().__init__() 113 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 114 | self.sigmoid = nn.Sigmoid() 115 | 116 | def forward(self, x): 117 | return x * self.sigmoid(self.pa_conv(x)) 118 | 119 | class PatchEmbed(nn.Module): 120 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, with_pos=True): 121 | super().__init__() 122 | img_size = to_2tuple(img_size) 123 | patch_size = to_2tuple(patch_size) 124 | self.img_size = img_size 125 | self.patch_size = patch_size 126 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 127 | self.num_patches = self.H * self.W 128 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 129 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 130 | 131 | self.with_pos = with_pos 132 | if self.with_pos: 133 | self.pos = Positional(embed_dim) 134 | 135 | self.norm = nn.LayerNorm(embed_dim) 136 | 137 | def forward(self, x): 138 | x = self.proj(x) 139 | if self.with_pos: 140 | x = self.pos(x) 141 | _, _, H, W = x.shape 142 | x = x.flatten(2).transpose(1, 2) 143 | x = self.norm(x) 144 | 145 | return x, H, W 146 | 147 | class AttentionBlock(nn.Module): 148 | def __init__(self, img_size=224, in_chans=1, embed_dims=128, patch_size=7, num_heads=1, mlp_ratios=4, 149 | qkv_bias=True, drop_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),stride=2, depths=1, cross=[False,False,True]): 150 | super().__init__() 151 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride = stride, in_chans=in_chans, 152 | embed_dim=embed_dims) 153 | self.block = nn.ModuleList([Block( 154 | dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, 155 | drop=drop_rate, drop_path=0, norm_layer=norm_layer, cross=cross[i]) 156 | for i in range(depths)]) 157 | self.norm = norm_layer(embed_dims) 158 | 159 | def forward(self, x): 160 | B = x.shape[0] 161 | x, H, W = self.patch_embed(x) 162 | for i, blk in enumerate(self.block): 163 | x = blk(x,H,W) 164 | x = self.norm(x) 165 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 166 | 167 | return x 168 | 169 | class Matchformer_LA_large(nn.Module): 170 | def __init__(self, img_size=224, in_chans=1, embed_dims=[128, 192, 256, 512], num_heads=[8, 8, 8, 8], 171 | stage1_cross = [False,False,True],stage2_cross = [False,False,True],stage3_cross = [False,True,True],stage4_cross = [False,True,True]): 172 | super().__init__() 173 | #Attention 174 | self.AttentionBlock1 = AttentionBlock(img_size=img_size // 2, patch_size=7, num_heads= num_heads[0], mlp_ratios=4, in_chans=in_chans, 175 | embed_dims=embed_dims[0],stride=2,depths=3, cross=stage1_cross) 176 | self.AttentionBlock2 = AttentionBlock(img_size=img_size // 4, patch_size=3, num_heads= num_heads[1], mlp_ratios=4, in_chans=embed_dims[0], 177 | embed_dims=embed_dims[1],stride=2,depths=3, cross=stage2_cross) 178 | self.AttentionBlock3 = AttentionBlock(img_size=img_size // 16,patch_size=3, num_heads= num_heads[2], mlp_ratios=4, in_chans=embed_dims[1], 179 | embed_dims=embed_dims[2],stride=2,depths=3, cross=stage3_cross) 180 | self.AttentionBlock4 = AttentionBlock(img_size=img_size // 32,patch_size=3, num_heads= num_heads[3], mlp_ratios=4, in_chans=embed_dims[2], 181 | embed_dims=embed_dims[3],stride=2,depths=3, cross=stage4_cross) 182 | 183 | #FPN 184 | self.layer4_outconv = conv1x1(embed_dims[3], embed_dims[3]) 185 | self.layer3_outconv = conv1x1(embed_dims[2], embed_dims[3]) 186 | self.layer3_outconv2 = nn.Sequential( 187 | conv3x3(embed_dims[3], embed_dims[3]), 188 | nn.BatchNorm2d(embed_dims[3]), 189 | nn.LeakyReLU(), 190 | conv3x3(embed_dims[3], embed_dims[2]), 191 | ) 192 | 193 | self.layer2_outconv = conv1x1(embed_dims[1], embed_dims[2]) 194 | self.layer2_outconv2 = nn.Sequential( 195 | conv3x3(embed_dims[2], embed_dims[2]), 196 | nn.BatchNorm2d(embed_dims[2]), 197 | nn.LeakyReLU(), 198 | conv3x3(embed_dims[2], embed_dims[1]), 199 | ) 200 | self.layer1_outconv = conv1x1(embed_dims[0], embed_dims[1]) 201 | self.layer1_outconv2 = nn.Sequential( 202 | conv3x3(embed_dims[1], embed_dims[1]), 203 | nn.BatchNorm2d(embed_dims[1]), 204 | nn.LeakyReLU(), 205 | conv3x3(embed_dims[1], embed_dims[0]), 206 | ) 207 | 208 | self.apply(self._init_weights) 209 | 210 | def _init_weights(self, m): 211 | if isinstance(m, nn.Linear): 212 | trunc_normal_(m.weight, std=.02) 213 | if isinstance(m, nn.Linear) and m.bias is not None: 214 | nn.init.constant_(m.bias, 0) 215 | elif isinstance(m, nn.LayerNorm): 216 | nn.init.constant_(m.bias, 0) 217 | nn.init.constant_(m.weight, 1.0) 218 | elif isinstance(m, nn.Conv2d): 219 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 220 | fan_out //= m.groups 221 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | 225 | def forward(self, x): 226 | # stage 1 # 1/2 227 | x = self.AttentionBlock1(x) 228 | out1 = x 229 | # stage 2 # 1/4 230 | x = self.AttentionBlock2(x) 231 | out2 = x 232 | # stage 3 # 1/8 233 | x = self.AttentionBlock3(x) 234 | out3 = x 235 | # stage 3 # 1/16 236 | x = self.AttentionBlock4(x) 237 | out4 = x 238 | 239 | #FPN 240 | c4_out = self.layer4_outconv(out4) 241 | _,_,H,W = out3.shape 242 | c4_out_2x = F.interpolate(c4_out, size =(H,W), mode='bilinear', align_corners=True) 243 | c3_out = self.layer3_outconv(out3) 244 | _,_,H,W = out2.shape 245 | c3_out = self.layer3_outconv2(c3_out +c4_out_2x) 246 | c3_out_2x = F.interpolate(c3_out, size =(H,W), mode='bilinear', align_corners=True) 247 | c2_out = self.layer2_outconv(out2) 248 | _,_,H,W = out1.shape 249 | c2_out = self.layer2_outconv2(c2_out +c3_out_2x) 250 | c2_out_2x = F.interpolate(c2_out, size =(H,W), mode='bilinear', align_corners=True) 251 | c1_out = self.layer1_outconv(out1) 252 | c1_out = self.layer1_outconv2(c1_out+c2_out_2x) 253 | 254 | return c3_out,c1_out 255 | -------------------------------------------------------------------------------- /model/backbone/match_LA_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import math 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class DWConv(nn.Module): 16 | def __init__(self, dim=768): 17 | super(DWConv, self).__init__() 18 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 19 | 20 | def forward(self, x, H, W): 21 | B, N, C = x.shape 22 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 23 | x = self.dwconv(x) 24 | x = x.flatten(2).transpose(1, 2) 25 | 26 | return x 27 | 28 | class Mlp(nn.Module): 29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 30 | super().__init__() 31 | out_features = out_features or in_features 32 | hidden_features = hidden_features or in_features 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.dwconv = DWConv(hidden_features) 35 | self.act = act_layer() 36 | self.fc2 = nn.Linear(hidden_features, out_features) 37 | self.drop = nn.Dropout(drop) 38 | 39 | def forward(self, x,H,W): 40 | x = self.fc1(x) 41 | x = self.dwconv(x, H, W) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | def elu_feature_map(x): 49 | return torch.nn.functional.elu(x) + 1 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, eps=1e-6 , cross = False): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | 56 | self.cross = cross 57 | self.feature_map = elu_feature_map 58 | self.eps = eps 59 | self.dim = dim 60 | self.num_heads = num_heads 61 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 62 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 63 | 64 | def forward(self, x): 65 | x_q, x_kv = x, x 66 | B, N, C = x_q.shape 67 | MiniB = B // 2 68 | query = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 1, 2, 3) 69 | kv = self.kv(x_kv).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) 70 | 71 | if self.cross == True: 72 | k1, k2 = kv[0].split(MiniB) 73 | v1, v2 = kv[1].split(MiniB) 74 | key = torch.cat([k2, k1], dim=0) 75 | value = torch.cat([v2, v1], dim=0) 76 | else: 77 | key, value = kv[0], kv[1] 78 | 79 | Q = self.feature_map(query) 80 | K = self.feature_map(key) 81 | v_length = value.size(1) 82 | value = value / v_length 83 | KV = torch.einsum("nshd,nshv->nhdv", K, value) 84 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 85 | x = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 86 | x = x.contiguous().view(B, -1, C) 87 | 88 | return x 89 | 90 | class Block(nn.Module): 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,cross = False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, 97 | num_heads=num_heads, qkv_bias=qkv_bias, cross= cross) 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm = norm_layer(dim ) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | 104 | def forward(self, x,H,W): 105 | x = x + self.drop_path(self.attn(self.norm1(x))) 106 | x = x + self.drop_path(self.mlp(self.norm(x),H,W)) 107 | 108 | return x 109 | 110 | class Positional(nn.Module): 111 | def __init__(self, dim): 112 | super().__init__() 113 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 114 | self.sigmoid = nn.Sigmoid() 115 | 116 | def forward(self, x): 117 | return x * self.sigmoid(self.pa_conv(x)) 118 | 119 | class PatchEmbed(nn.Module): 120 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, with_pos=True): 121 | super().__init__() 122 | img_size = to_2tuple(img_size) 123 | patch_size = to_2tuple(patch_size) 124 | self.img_size = img_size 125 | self.patch_size = patch_size 126 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 127 | self.num_patches = self.H * self.W 128 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 129 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 130 | 131 | self.with_pos = with_pos 132 | if self.with_pos: 133 | self.pos = Positional(embed_dim) 134 | 135 | self.norm = nn.LayerNorm(embed_dim) 136 | 137 | def forward(self, x): 138 | x = self.proj(x) 139 | if self.with_pos: 140 | x = self.pos(x) 141 | _, _, H, W = x.shape 142 | x = x.flatten(2).transpose(1, 2) 143 | x = self.norm(x) 144 | 145 | return x, H, W 146 | 147 | class AttentionBlock(nn.Module): 148 | def __init__(self, img_size=224, in_chans=1, embed_dims=128, patch_size=7, num_heads=1, mlp_ratios=4, 149 | qkv_bias=True, drop_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),stride=2, depths=1, cross=[False,False,True]): 150 | super().__init__() 151 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride = stride, in_chans=in_chans, 152 | embed_dim=embed_dims) 153 | self.block = nn.ModuleList([Block( 154 | dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, 155 | drop=drop_rate, drop_path=0, norm_layer=norm_layer, cross=cross[i]) 156 | for i in range(depths)]) 157 | self.norm = norm_layer(embed_dims) 158 | 159 | def forward(self, x): 160 | B = x.shape[0] 161 | x, H, W = self.patch_embed(x) 162 | for i, blk in enumerate(self.block): 163 | x = blk(x,H,W) 164 | x = self.norm(x) 165 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 166 | 167 | return x 168 | 169 | class Matchformer_LA_lite(nn.Module): 170 | def __init__(self, img_size=224, in_chans=1, embed_dims=[128, 192, 256, 512], num_heads=[8, 8, 8, 8], 171 | stage1_cross = [False,False,True],stage2_cross = [False,False,True],stage3_cross = [False,True,True],stage4_cross = [False,True,True]): 172 | super().__init__() 173 | #Attention 174 | self.AttentionBlock1 = AttentionBlock(img_size=img_size // 2, patch_size=7, num_heads= num_heads[0], mlp_ratios=4, in_chans=in_chans, 175 | embed_dims=embed_dims[0],stride=4,depths=3, cross=stage1_cross) 176 | self.AttentionBlock2 = AttentionBlock(img_size=img_size // 4, patch_size=3, num_heads= num_heads[1], mlp_ratios=4, in_chans=embed_dims[0], 177 | embed_dims=embed_dims[1],stride=2,depths=3, cross=stage2_cross) 178 | self.AttentionBlock3 = AttentionBlock(img_size=img_size // 16,patch_size=3, num_heads= num_heads[2], mlp_ratios=4, in_chans=embed_dims[1], 179 | embed_dims=embed_dims[2],stride=2,depths=3, cross=stage3_cross) 180 | self.AttentionBlock4 = AttentionBlock(img_size=img_size // 32,patch_size=3, num_heads= num_heads[3], mlp_ratios=4, in_chans=embed_dims[2], 181 | embed_dims=embed_dims[3],stride=2,depths=3, cross=stage4_cross) 182 | 183 | #FPN 184 | self.layer4_outconv = conv1x1(embed_dims[3], embed_dims[3]) 185 | self.layer3_outconv = conv1x1(embed_dims[2], embed_dims[3]) 186 | self.layer3_outconv2 = nn.Sequential( 187 | conv3x3(embed_dims[3], embed_dims[3]), 188 | nn.BatchNorm2d(embed_dims[3]), 189 | nn.LeakyReLU(), 190 | conv3x3(embed_dims[3], embed_dims[2]), 191 | ) 192 | 193 | self.layer2_outconv = conv1x1(embed_dims[1], embed_dims[2]) 194 | self.layer2_outconv2 = nn.Sequential( 195 | conv3x3(embed_dims[2], embed_dims[2]), 196 | nn.BatchNorm2d(embed_dims[2]), 197 | nn.LeakyReLU(), 198 | conv3x3(embed_dims[2], embed_dims[1]), 199 | ) 200 | self.layer1_outconv = conv1x1(embed_dims[0], embed_dims[1]) 201 | self.layer1_outconv2 = nn.Sequential( 202 | conv3x3(embed_dims[1], embed_dims[1]), 203 | nn.BatchNorm2d(embed_dims[1]), 204 | nn.LeakyReLU(), 205 | conv3x3(embed_dims[1], embed_dims[0]), 206 | ) 207 | 208 | self.apply(self._init_weights) 209 | 210 | def _init_weights(self, m): 211 | if isinstance(m, nn.Linear): 212 | trunc_normal_(m.weight, std=.02) 213 | if isinstance(m, nn.Linear) and m.bias is not None: 214 | nn.init.constant_(m.bias, 0) 215 | elif isinstance(m, nn.LayerNorm): 216 | nn.init.constant_(m.bias, 0) 217 | nn.init.constant_(m.weight, 1.0) 218 | elif isinstance(m, nn.Conv2d): 219 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 220 | fan_out //= m.groups 221 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | 225 | def forward(self, x): 226 | # stage 1 # 1/4 227 | x = self.AttentionBlock1(x) 228 | out1 = x 229 | # stage 2 # 1/8 230 | x = self.AttentionBlock2(x) 231 | out2 = x 232 | # stage 3 # 1/16 233 | x = self.AttentionBlock3(x) 234 | out3 = x 235 | # stage 3 # 1/32 236 | x = self.AttentionBlock4(x) 237 | out4 = x 238 | 239 | #FPN 240 | c4_out = self.layer4_outconv(out4) 241 | _,_,H,W = out3.shape 242 | c4_out_2x = F.interpolate(c4_out, size =(H,W), mode='bilinear', align_corners=True) 243 | c3_out = self.layer3_outconv(out3) 244 | _,_,H,W = out2.shape 245 | c3_out = self.layer3_outconv2(c3_out +c4_out_2x) 246 | c3_out_2x = F.interpolate(c3_out, size =(H,W), mode='bilinear', align_corners=True) 247 | c2_out = self.layer2_outconv(out2) 248 | _,_,H,W = out1.shape 249 | c2_out = self.layer2_outconv2(c2_out +c3_out_2x) 250 | c2_out_2x = F.interpolate(c2_out, size =(H,W), mode='bilinear', align_corners=True) 251 | c1_out = self.layer1_outconv(out1) 252 | c1_out = self.layer1_outconv2(c1_out+c2_out_2x) 253 | 254 | return c2_out,c1_out -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /model/backbone/match_SEA_large.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import math 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class DWConv(nn.Module): 16 | def __init__(self, dim=768): 17 | super(DWConv, self).__init__() 18 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 19 | 20 | def forward(self, x, H, W): 21 | B, N, C = x.shape 22 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 23 | x = self.dwconv(x) 24 | x = x.flatten(2).transpose(1, 2) 25 | 26 | return x 27 | 28 | class Mlp(nn.Module): 29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 30 | super().__init__() 31 | out_features = out_features or in_features 32 | hidden_features = hidden_features or in_features 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.dwconv = DWConv(hidden_features) 35 | self.act = act_layer() 36 | self.fc2 = nn.Linear(hidden_features, out_features) 37 | self.drop = nn.Dropout(drop) 38 | 39 | def forward(self, x,H,W): 40 | x = self.fc1(x) 41 | x = self.dwconv(x, H, W) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | def elu_feature_map(x): 49 | return torch.nn.functional.elu(x) + 1 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, cross= False): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | self.cross = cross 56 | 57 | self.dim = dim 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | def forward(self, x, H, W): 74 | B, N, C = x.shape 75 | if self.cross == True: 76 | MiniB = B // 2 77 | #cross attention 78 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 79 | q1,q2 = q.split(MiniB) 80 | 81 | if self.sr_ratio > 1: 82 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 83 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 84 | x_ = self.norm(x_) 85 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 86 | else: 87 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 88 | 89 | k1, k2 = kv[0].split(MiniB) 90 | v1, v2 = kv[1].split(MiniB) 91 | 92 | attn1 = (q1 @ k2.transpose(-2, -1)) * self.scale 93 | attn1 = attn1.softmax(dim=-1) 94 | attn1 = self.attn_drop(attn1) 95 | 96 | attn2 = (q2 @ k1.transpose(-2, -1)) * self.scale 97 | attn2 = attn2.softmax(dim=-1) 98 | attn2 = self.attn_drop(attn2) 99 | 100 | x1 = (attn1 @ v2).transpose(1, 2).reshape(MiniB, N, C) 101 | x2 = (attn2 @ v1).transpose(1, 2).reshape(MiniB, N, C) 102 | 103 | x = torch.cat([x1, x2], dim=0) 104 | else: 105 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 106 | 107 | 108 | if self.sr_ratio > 1: 109 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 110 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 111 | x_ = self.norm(x_) 112 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 113 | else: 114 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | k, v = kv[0], kv[1] 116 | 117 | attn = (q @ k.transpose(-2, -1)) * self.scale 118 | attn = attn.softmax(dim=-1) 119 | attn = self.attn_drop(attn) 120 | 121 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 122 | 123 | x = self.proj(x) 124 | x = self.proj_drop(x) 125 | 126 | return x 127 | 128 | class Block(nn.Module): 129 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 130 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, cross = False): 131 | super().__init__() 132 | self.norm1 = norm_layer(dim) 133 | self.attn = Attention( 134 | dim, 135 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 136 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, cross= cross) 137 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 138 | self.norm2 = norm_layer(dim) 139 | mlp_hidden_dim = int(dim * mlp_ratio) 140 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 141 | 142 | def forward(self, x, H, W): 143 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 144 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 145 | 146 | return x 147 | 148 | class Positional(nn.Module): 149 | def __init__(self, dim): 150 | super().__init__() 151 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 152 | self.sigmoid = nn.Sigmoid() 153 | 154 | def forward(self, x): 155 | return x * self.sigmoid(self.pa_conv(x)) 156 | 157 | class PatchEmbed(nn.Module): 158 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, with_pos=True): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | self.img_size = img_size 163 | self.patch_size = patch_size 164 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 165 | self.num_patches = self.H * self.W 166 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 167 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 168 | 169 | self.with_pos = with_pos 170 | if self.with_pos: 171 | self.pos = Positional(embed_dim) 172 | 173 | self.norm = nn.LayerNorm(embed_dim) 174 | 175 | def forward(self, x): 176 | x = self.proj(x) 177 | if self.with_pos: 178 | x = self.pos(x) 179 | _, _, H, W = x.shape 180 | x = x.flatten(2).transpose(1, 2) 181 | x = self.norm(x) 182 | 183 | return x, H, W 184 | 185 | class AttentionBlock(nn.Module): 186 | def __init__(self, img_size=224, in_chans=1, embed_dims=128, patch_size=7, num_heads=1, mlp_ratios=4, sr_ratios=8, 187 | qkv_bias=True, drop_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),stride=2, depths=1, cross=[False,False,True]): 188 | super().__init__() 189 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride = stride, in_chans=in_chans, 190 | embed_dim=embed_dims) 191 | self.block = nn.ModuleList([Block( 192 | dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,sr_ratio= sr_ratios, 193 | drop=drop_rate, drop_path=0, norm_layer=norm_layer, cross=cross[i]) 194 | for i in range(depths)]) 195 | self.norm = norm_layer(embed_dims) 196 | 197 | def forward(self, x): 198 | B = x.shape[0] 199 | x, H, W = self.patch_embed(x) 200 | for i, blk in enumerate(self.block): 201 | x = blk(x,H,W) 202 | x = self.norm(x) 203 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 204 | 205 | return x 206 | 207 | class Matchformer_SEA_large(nn.Module): 208 | def __init__(self, img_size=224, in_chans=1, embed_dims=[128, 192, 256, 512], num_heads=[1, 2, 4, 8],sr_ratios=[4,2,2,1], 209 | stage1_cross = [False,False,True],stage2_cross = [False,False,True],stage3_cross = [False,True,True],stage4_cross = [False,True,True]): 210 | super().__init__() 211 | #Attention 212 | self.AttentionBlock1 = AttentionBlock(img_size=img_size // 2, patch_size=7, num_heads= num_heads[0], mlp_ratios=4, in_chans=in_chans, 213 | embed_dims=embed_dims[0],stride=2,sr_ratios=sr_ratios[0],depths=3, cross=stage1_cross) 214 | self.AttentionBlock2 = AttentionBlock(img_size=img_size // 4, patch_size=3, num_heads= num_heads[1], mlp_ratios=4, in_chans=embed_dims[0], 215 | embed_dims=embed_dims[1],stride=2,sr_ratios=sr_ratios[1],depths=3, cross=stage2_cross) 216 | self.AttentionBlock3 = AttentionBlock(img_size=img_size // 16,patch_size=3, num_heads= num_heads[2], mlp_ratios=4, in_chans=embed_dims[1], 217 | embed_dims=embed_dims[2],stride=2,sr_ratios=sr_ratios[2],depths=3, cross=stage3_cross) 218 | self.AttentionBlock4 = AttentionBlock(img_size=img_size // 32,patch_size=3, num_heads= num_heads[3], mlp_ratios=4, in_chans=embed_dims[2], 219 | embed_dims=embed_dims[3],stride=2,sr_ratios=sr_ratios[3],depths=3, cross=stage4_cross) 220 | 221 | #FPN 222 | self.layer4_outconv = conv1x1(embed_dims[3], embed_dims[3]) 223 | self.layer3_outconv = conv1x1(embed_dims[2], embed_dims[3]) 224 | self.layer3_outconv2 = nn.Sequential( 225 | conv3x3(embed_dims[3], embed_dims[3]), 226 | nn.BatchNorm2d(embed_dims[3]), 227 | nn.LeakyReLU(), 228 | conv3x3(embed_dims[3], embed_dims[2]), 229 | ) 230 | self.layer2_outconv = conv1x1(embed_dims[1], embed_dims[2]) 231 | self.layer2_outconv2 = nn.Sequential( 232 | conv3x3(embed_dims[2], embed_dims[2]), 233 | nn.BatchNorm2d(embed_dims[2]), 234 | nn.LeakyReLU(), 235 | conv3x3(embed_dims[2], embed_dims[1]), 236 | ) 237 | self.layer1_outconv = conv1x1(embed_dims[0], embed_dims[1]) 238 | self.layer1_outconv2 = nn.Sequential( 239 | conv3x3(embed_dims[1], embed_dims[1]), 240 | nn.BatchNorm2d(embed_dims[1]), 241 | nn.LeakyReLU(), 242 | conv3x3(embed_dims[1], embed_dims[0]), 243 | ) 244 | 245 | self.apply(self._init_weights) 246 | 247 | def _init_weights(self, m): 248 | if isinstance(m, nn.Linear): 249 | trunc_normal_(m.weight, std=.02) 250 | if isinstance(m, nn.Linear) and m.bias is not None: 251 | nn.init.constant_(m.bias, 0) 252 | elif isinstance(m, nn.LayerNorm): 253 | nn.init.constant_(m.bias, 0) 254 | nn.init.constant_(m.weight, 1.0) 255 | elif isinstance(m, nn.Conv2d): 256 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 257 | fan_out //= m.groups 258 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 259 | if m.bias is not None: 260 | m.bias.data.zero_() 261 | 262 | def forward(self, x): 263 | # stage 1 # 1/4 264 | x = self.AttentionBlock1(x) 265 | out1 = x 266 | # stage 2 # 1/8 267 | x = self.AttentionBlock2(x) 268 | out2 = x 269 | # stage 3 # 1/16 270 | x = self.AttentionBlock3(x) 271 | out3 = x 272 | # stage 3 # 1/32 273 | x = self.AttentionBlock4(x) 274 | out4 = x 275 | 276 | #FPN 277 | c4_out = self.layer4_outconv(out4) 278 | _,_,H,W = out3.shape 279 | c4_out_2x = F.interpolate(c4_out, size =(H,W), mode='bilinear', align_corners=True) 280 | c3_out = self.layer3_outconv(out3) 281 | _,_,H,W = out2.shape 282 | c3_out = self.layer3_outconv2(c3_out +c4_out_2x) 283 | c3_out_2x = F.interpolate(c3_out, size =(H,W), mode='bilinear', align_corners=True) 284 | c2_out = self.layer2_outconv(out2) 285 | _,_,H,W = out1.shape 286 | c2_out = self.layer2_outconv2(c2_out +c3_out_2x) 287 | c2_out_2x = F.interpolate(c2_out, size =(H,W), mode='bilinear', align_corners=True) 288 | c1_out = self.layer1_outconv(out1) 289 | c1_out = self.layer1_outconv2(c1_out+c2_out_2x) 290 | 291 | return c3_out,c1_out -------------------------------------------------------------------------------- /model/backbone/match_SEA_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import math 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class DWConv(nn.Module): 16 | def __init__(self, dim=768): 17 | super(DWConv, self).__init__() 18 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 19 | 20 | def forward(self, x, H, W): 21 | B, N, C = x.shape 22 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 23 | x = self.dwconv(x) 24 | x = x.flatten(2).transpose(1, 2) 25 | 26 | return x 27 | 28 | class Mlp(nn.Module): 29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 30 | super().__init__() 31 | out_features = out_features or in_features 32 | hidden_features = hidden_features or in_features 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.dwconv = DWConv(hidden_features) 35 | self.act = act_layer() 36 | self.fc2 = nn.Linear(hidden_features, out_features) 37 | self.drop = nn.Dropout(drop) 38 | 39 | def forward(self, x,H,W): 40 | x = self.fc1(x) 41 | x = self.dwconv(x, H, W) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | def elu_feature_map(x): 49 | return torch.nn.functional.elu(x) + 1 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, cross= False): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | self.cross = cross 56 | 57 | self.dim = dim 58 | self.num_heads = num_heads 59 | head_dim = dim // num_heads 60 | self.scale = qk_scale or head_dim ** -0.5 61 | 62 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.proj_drop = nn.Dropout(proj_drop) 67 | 68 | self.sr_ratio = sr_ratio 69 | if sr_ratio > 1: 70 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 71 | self.norm = nn.LayerNorm(dim) 72 | 73 | def forward(self, x, H, W): 74 | B, N, C = x.shape 75 | if self.cross == True: 76 | MiniB = B // 2 77 | #cross attention 78 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 79 | q1,q2 = q.split(MiniB) 80 | 81 | if self.sr_ratio > 1: 82 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 83 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 84 | x_ = self.norm(x_) 85 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 86 | else: 87 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 88 | 89 | k1, k2 = kv[0].split(MiniB) 90 | v1, v2 = kv[1].split(MiniB) 91 | 92 | attn1 = (q1 @ k2.transpose(-2, -1)) * self.scale 93 | attn1 = attn1.softmax(dim=-1) 94 | attn1 = self.attn_drop(attn1) 95 | 96 | attn2 = (q2 @ k1.transpose(-2, -1)) * self.scale 97 | attn2 = attn2.softmax(dim=-1) 98 | attn2 = self.attn_drop(attn2) 99 | 100 | x1 = (attn1 @ v2).transpose(1, 2).reshape(MiniB, N, C) 101 | x2 = (attn2 @ v1).transpose(1, 2).reshape(MiniB, N, C) 102 | 103 | x = torch.cat([x1, x2], dim=0) 104 | else: 105 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 106 | 107 | 108 | if self.sr_ratio > 1: 109 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 110 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 111 | x_ = self.norm(x_) 112 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 113 | else: 114 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | k, v = kv[0], kv[1] 116 | 117 | attn = (q @ k.transpose(-2, -1)) * self.scale 118 | attn = attn.softmax(dim=-1) 119 | attn = self.attn_drop(attn) 120 | 121 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 122 | 123 | x = self.proj(x) 124 | x = self.proj_drop(x) 125 | 126 | return x 127 | 128 | class Block(nn.Module): 129 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 130 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, cross = False): 131 | super().__init__() 132 | self.norm1 = norm_layer(dim) 133 | self.attn = Attention( 134 | dim, 135 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 136 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, cross= cross) 137 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 138 | self.norm2 = norm_layer(dim) 139 | mlp_hidden_dim = int(dim * mlp_ratio) 140 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 141 | 142 | def forward(self, x, H, W): 143 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 144 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 145 | 146 | return x 147 | 148 | class Positional(nn.Module): 149 | def __init__(self, dim): 150 | super().__init__() 151 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 152 | self.sigmoid = nn.Sigmoid() 153 | 154 | def forward(self, x): 155 | return x * self.sigmoid(self.pa_conv(x)) 156 | 157 | class PatchEmbed(nn.Module): 158 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, with_pos=True): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | self.img_size = img_size 163 | self.patch_size = patch_size 164 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 165 | self.num_patches = self.H * self.W 166 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 167 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 168 | 169 | self.with_pos = with_pos 170 | if self.with_pos: 171 | self.pos = Positional(embed_dim) 172 | 173 | self.norm = nn.LayerNorm(embed_dim) 174 | 175 | def forward(self, x): 176 | x = self.proj(x) 177 | if self.with_pos: 178 | x = self.pos(x) 179 | _, _, H, W = x.shape 180 | x = x.flatten(2).transpose(1, 2) 181 | x = self.norm(x) 182 | 183 | return x, H, W 184 | 185 | class AttentionBlock(nn.Module): 186 | def __init__(self, img_size=224, in_chans=1, embed_dims=128, patch_size=7, num_heads=1, mlp_ratios=4, sr_ratios=8, 187 | qkv_bias=True, drop_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),stride=2, depths=1, cross=[False,False,True]): 188 | super().__init__() 189 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride = stride, in_chans=in_chans, 190 | embed_dim=embed_dims) 191 | self.block = nn.ModuleList([Block( 192 | dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,sr_ratio= sr_ratios, 193 | drop=drop_rate, drop_path=0, norm_layer=norm_layer, cross=cross[i]) 194 | for i in range(depths)]) 195 | self.norm = norm_layer(embed_dims) 196 | 197 | def forward(self, x): 198 | B = x.shape[0] 199 | x, H, W = self.patch_embed(x) 200 | for i, blk in enumerate(self.block): 201 | x = blk(x,H,W) 202 | x = self.norm(x) 203 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 204 | 205 | return x 206 | 207 | class Matchformer_SEA_lite(nn.Module): 208 | def __init__(self, img_size=224, in_chans=1, embed_dims=[128, 192, 256, 512], num_heads=[1, 2, 4, 8],sr_ratios=[8,4,2,1] 209 | ,stage1_cross = [False,False,True],stage2_cross = [False,False,True],stage3_cross = [False,True,True],stage4_cross = [False,True,True]): 210 | super().__init__() 211 | #Attention 212 | self.AttentionBlock1 = AttentionBlock(img_size=img_size // 2, patch_size=7, num_heads= num_heads[0], mlp_ratios=4, in_chans=in_chans, 213 | embed_dims=embed_dims[0],stride=4,sr_ratios=sr_ratios[0],depths=3, cross=stage1_cross) 214 | self.AttentionBlock2 = AttentionBlock(img_size=img_size // 4, patch_size=3, num_heads= num_heads[1], mlp_ratios=4, in_chans=embed_dims[0], 215 | embed_dims=embed_dims[1],stride=2,sr_ratios=sr_ratios[1],depths=3, cross=stage2_cross) 216 | self.AttentionBlock3 = AttentionBlock(img_size=img_size // 16,patch_size=3, num_heads= num_heads[2], mlp_ratios=4, in_chans=embed_dims[1], 217 | embed_dims=embed_dims[2],stride=2,sr_ratios=sr_ratios[2],depths=3, cross=stage3_cross) 218 | self.AttentionBlock4 = AttentionBlock(img_size=img_size // 32,patch_size=3, num_heads= num_heads[3], mlp_ratios=4, in_chans=embed_dims[2], 219 | embed_dims=embed_dims[3],stride=2,sr_ratios=sr_ratios[3],depths=3, cross=stage4_cross) 220 | 221 | #FPN 222 | self.layer4_outconv = conv1x1(embed_dims[3], embed_dims[3]) 223 | self.layer3_outconv = conv1x1(embed_dims[2], embed_dims[3]) 224 | self.layer3_outconv2 = nn.Sequential( 225 | conv3x3(embed_dims[3], embed_dims[3]), 226 | nn.BatchNorm2d(embed_dims[3]), 227 | nn.LeakyReLU(), 228 | conv3x3(embed_dims[3], embed_dims[2]), 229 | ) 230 | self.layer2_outconv = conv1x1(embed_dims[1], embed_dims[2]) 231 | self.layer2_outconv2 = nn.Sequential( 232 | conv3x3(embed_dims[2], embed_dims[2]), 233 | nn.BatchNorm2d(embed_dims[2]), 234 | nn.LeakyReLU(), 235 | conv3x3(embed_dims[2], embed_dims[1]), 236 | ) 237 | self.layer1_outconv = conv1x1(embed_dims[0], embed_dims[1]) 238 | self.layer1_outconv2 = nn.Sequential( 239 | conv3x3(embed_dims[1], embed_dims[1]), 240 | nn.BatchNorm2d(embed_dims[1]), 241 | nn.LeakyReLU(), 242 | conv3x3(embed_dims[1], embed_dims[0]), 243 | ) 244 | 245 | self.apply(self._init_weights) 246 | 247 | def _init_weights(self, m): 248 | if isinstance(m, nn.Linear): 249 | trunc_normal_(m.weight, std=.02) 250 | if isinstance(m, nn.Linear) and m.bias is not None: 251 | nn.init.constant_(m.bias, 0) 252 | elif isinstance(m, nn.LayerNorm): 253 | nn.init.constant_(m.bias, 0) 254 | nn.init.constant_(m.weight, 1.0) 255 | elif isinstance(m, nn.Conv2d): 256 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 257 | fan_out //= m.groups 258 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 259 | if m.bias is not None: 260 | m.bias.data.zero_() 261 | 262 | def forward(self, x): 263 | # stage 1 # 1/4 264 | x = self.AttentionBlock1(x) 265 | out1 = x 266 | # stage 2 # 1/8 267 | x = self.AttentionBlock2(x) 268 | out2 = x 269 | # stage 3 # 1/16 270 | x = self.AttentionBlock3(x) 271 | out3 = x 272 | # stage 3 # 1/32 273 | x = self.AttentionBlock4(x) 274 | out4 = x 275 | 276 | #FPN 277 | c4_out = self.layer4_outconv(out4) 278 | _,_,H,W = out3.shape 279 | c4_out_2x = F.interpolate(c4_out, size =(H,W), mode='bilinear', align_corners=True) 280 | c3_out = self.layer3_outconv(out3) 281 | _,_,H,W = out2.shape 282 | c3_out = self.layer3_outconv2(c3_out +c4_out_2x) 283 | c3_out_2x = F.interpolate(c3_out, size =(H,W), mode='bilinear', align_corners=True) 284 | c2_out = self.layer2_outconv(out2) 285 | _,_,H,W = out1.shape 286 | c2_out = self.layer2_outconv2(c2_out +c3_out_2x) 287 | c2_out_2x = F.interpolate(c2_out, size =(H,W), mode='bilinear', align_corners=True) 288 | c1_out = self.layer1_outconv(out1) 289 | c1_out = self.layer1_outconv2(c1_out+c2_out_2x) 290 | 291 | return c2_out,c1_out -------------------------------------------------------------------------------- /model/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from collections import abc 4 | from loguru import logger 5 | from torch.utils.data.dataset import Dataset 6 | from tqdm import tqdm 7 | from os import path as osp 8 | from pathlib import Path 9 | from joblib import Parallel, delayed 10 | 11 | import pytorch_lightning as pl 12 | from torch import distributed as dist 13 | from torch.utils.data import ( 14 | Dataset, 15 | DataLoader, 16 | ConcatDataset, 17 | DistributedSampler, 18 | RandomSampler, 19 | dataloader 20 | ) 21 | 22 | from .utils.augment import build_augmentor 23 | from .utils.dataloader import get_local_split 24 | from .utils.misc import tqdm_joblib 25 | from .utils import comm 26 | from .datasets.megadepth import MegaDepthDataset 27 | from .datasets.scannet import ScanNetDataset 28 | from .datasets.sampler import RandomConcatSampler 29 | 30 | 31 | class MultiSceneDataModule(pl.LightningDataModule): 32 | """ 33 | For distributed training, each training process is assgined 34 | only a part of the training scenes to reduce memory overhead. 35 | """ 36 | def __init__(self, args, config): 37 | super().__init__() 38 | 39 | # 1. data config 40 | # Train and Val should from the same data source 41 | self.trainval_data_source = config.DATASET.TRAINVAL_DATA_SOURCE 42 | self.test_data_source = config.DATASET.TEST_DATA_SOURCE 43 | # training and validating 44 | self.train_data_root = config.DATASET.TRAIN_DATA_ROOT 45 | self.train_pose_root = config.DATASET.TRAIN_POSE_ROOT # (optional) 46 | self.train_npz_root = config.DATASET.TRAIN_NPZ_ROOT 47 | self.train_list_path = config.DATASET.TRAIN_LIST_PATH 48 | self.train_intrinsic_path = config.DATASET.TRAIN_INTRINSIC_PATH 49 | self.val_data_root = config.DATASET.VAL_DATA_ROOT 50 | self.val_pose_root = config.DATASET.VAL_POSE_ROOT # (optional) 51 | self.val_npz_root = config.DATASET.VAL_NPZ_ROOT 52 | self.val_list_path = config.DATASET.VAL_LIST_PATH 53 | self.val_intrinsic_path = config.DATASET.VAL_INTRINSIC_PATH 54 | # testing 55 | self.test_data_root = config.DATASET.TEST_DATA_ROOT 56 | self.test_pose_root = config.DATASET.TEST_POSE_ROOT # (optional) 57 | self.test_npz_root = config.DATASET.TEST_NPZ_ROOT 58 | self.test_list_path = config.DATASET.TEST_LIST_PATH 59 | self.test_intrinsic_path = config.DATASET.TEST_INTRINSIC_PATH 60 | 61 | # 2. dataset config 62 | # general options 63 | self.min_overlap_score_test = config.DATASET.MIN_OVERLAP_SCORE_TEST # 0.4, omit data with overlap_score < min_overlap_score 64 | self.min_overlap_score_train = config.DATASET.MIN_OVERLAP_SCORE_TRAIN 65 | self.augment_fn = build_augmentor(config.DATASET.AUGMENTATION_TYPE) # None, options: [None, 'dark', 'mobile'] 66 | 67 | # MegaDepth options 68 | self.mgdpt_img_resize = config.DATASET.MGDPT_IMG_RESIZE # 840 69 | self.mgdpt_img_pad = config.DATASET.MGDPT_IMG_PAD # True 70 | self.mgdpt_depth_pad = config.DATASET.MGDPT_DEPTH_PAD # True 71 | self.mgdpt_df = config.DATASET.MGDPT_DF # 8 72 | self.coarse_scale = 1 / config.MATCHFORMER.RESOLUTION[0] # 0.125. for training loftr. 73 | 74 | # 3.loader parameters 75 | self.train_loader_params = { 76 | 'batch_size': args.batch_size, 77 | 'num_workers': args.num_workers, 78 | 'pin_memory': getattr(args, 'pin_memory', True) 79 | } 80 | self.val_loader_params = { 81 | 'batch_size': 1, 82 | 'shuffle': False, 83 | 'num_workers': args.num_workers, 84 | 'pin_memory': getattr(args, 'pin_memory', True) 85 | } 86 | self.test_loader_params = { 87 | 'batch_size': 1, 88 | 'shuffle': False, 89 | 'num_workers': args.num_workers, 90 | 'pin_memory': True 91 | } 92 | 93 | # 4. sampler 94 | self.data_sampler = config.TRAINER.DATA_SAMPLER 95 | self.n_samples_per_subset = config.TRAINER.N_SAMPLES_PER_SUBSET 96 | self.subset_replacement = config.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT 97 | self.shuffle = config.TRAINER.SB_SUBSET_SHUFFLE 98 | self.repeat = config.TRAINER.SB_REPEAT 99 | 100 | # (optional) RandomSampler for debugging 101 | 102 | # misc configurations 103 | self.parallel_load_data = getattr(args, 'parallel_load_data', False) 104 | self.seed = config.TRAINER.SEED # 66 105 | 106 | def setup(self, stage=None): 107 | """ 108 | Setup train / val / test dataset. This method will be called by PL automatically. 109 | Args: 110 | stage (str): 'fit' in training phase, and 'test' in testing phase. 111 | """ 112 | 113 | assert stage in ['fit', 'test'], "stage must be either fit or test" 114 | 115 | try: 116 | self.world_size = dist.get_world_size() 117 | self.rank = dist.get_rank() 118 | logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") 119 | except AssertionError as ae: 120 | self.world_size = 1 121 | self.rank = 0 122 | logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") 123 | 124 | if stage == 'fit': 125 | self.train_dataset = self._setup_dataset( 126 | self.train_data_root, 127 | self.train_npz_root, 128 | self.train_list_path, 129 | self.train_intrinsic_path, 130 | mode='train', 131 | min_overlap_score=self.min_overlap_score_train, 132 | pose_dir=self.train_pose_root) 133 | # setup multiple (optional) validation subsets 134 | if isinstance(self.val_list_path, (list, tuple)): 135 | self.val_dataset = [] 136 | if not isinstance(self.val_npz_root, (list, tuple)): 137 | self.val_npz_root = [self.val_npz_root for _ in range(len(self.val_list_path))] 138 | for npz_list, npz_root in zip(self.val_list_path, self.val_npz_root): 139 | self.val_dataset.append(self._setup_dataset( 140 | self.val_data_root, 141 | npz_root, 142 | npz_list, 143 | self.val_intrinsic_path, 144 | mode='val', 145 | min_overlap_score=self.min_overlap_score_test, 146 | pose_dir=self.val_pose_root)) 147 | else: 148 | self.val_dataset = self._setup_dataset( 149 | self.val_data_root, 150 | self.val_npz_root, 151 | self.val_list_path, 152 | self.val_intrinsic_path, 153 | mode='val', 154 | min_overlap_score=self.min_overlap_score_test, 155 | pose_dir=self.val_pose_root) 156 | logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') 157 | else: # stage == 'test 158 | self.test_dataset = self._setup_dataset( 159 | self.test_data_root, 160 | self.test_npz_root, 161 | self.test_list_path, 162 | self.test_intrinsic_path, 163 | mode='test', 164 | min_overlap_score=self.min_overlap_score_test, 165 | pose_dir=self.test_pose_root) 166 | logger.info(f'[rank:{self.rank}]: Test Dataset loaded!') 167 | 168 | def _setup_dataset(self, 169 | data_root, 170 | split_npz_root, 171 | scene_list_path, 172 | intri_path, 173 | mode='train', 174 | min_overlap_score=0., 175 | pose_dir=None): 176 | """ Setup train / val / test set""" 177 | with open(scene_list_path, 'r') as f: 178 | npz_names = [name.split()[0] for name in f.readlines()] 179 | 180 | if mode == 'train': 181 | local_npz_names = get_local_split(npz_names, self.world_size, self.rank, self.seed) 182 | else: 183 | local_npz_names = npz_names 184 | logger.info(f'[rank {self.rank}]: {len(local_npz_names)} scene(s) assigned.') 185 | 186 | dataset_builder = self._build_concat_dataset_parallel \ 187 | if self.parallel_load_data \ 188 | else self._build_concat_dataset 189 | return dataset_builder(data_root, local_npz_names, split_npz_root, intri_path, 190 | mode=mode, min_overlap_score=min_overlap_score, pose_dir=pose_dir) 191 | 192 | def _build_concat_dataset( 193 | self, 194 | data_root, 195 | npz_names, 196 | npz_dir, 197 | intrinsic_path, 198 | mode, 199 | min_overlap_score=0., 200 | pose_dir=None 201 | ): 202 | datasets = [] 203 | augment_fn = self.augment_fn if mode == 'train' else None 204 | data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source 205 | if str(data_source).lower() == 'megadepth': 206 | npz_names = [f'{n}.npz' for n in npz_names] 207 | for npz_name in tqdm(npz_names, 208 | desc=f'[rank:{self.rank}] loading {mode} datasets', 209 | disable=int(self.rank) != 0): 210 | # `ScanNetDataset`/`MegaDepthDataset` load all data from npz_path when initialized, which might take time. 211 | npz_path = osp.join(npz_dir, npz_name) 212 | if data_source == 'ScanNet': 213 | datasets.append( 214 | ScanNetDataset(data_root, 215 | npz_path, 216 | intrinsic_path, 217 | mode=mode, 218 | min_overlap_score=min_overlap_score, 219 | augment_fn=augment_fn, 220 | pose_dir=pose_dir)) 221 | elif data_source == 'MegaDepth': 222 | datasets.append( 223 | MegaDepthDataset(data_root, 224 | npz_path, 225 | mode=mode, 226 | min_overlap_score=min_overlap_score, 227 | img_resize=self.mgdpt_img_resize, 228 | df=self.mgdpt_df, 229 | img_padding=self.mgdpt_img_pad, 230 | depth_padding=self.mgdpt_depth_pad, 231 | augment_fn=augment_fn, 232 | coarse_scale=self.coarse_scale)) 233 | else: 234 | raise NotImplementedError() 235 | return ConcatDataset(datasets) 236 | 237 | def _build_concat_dataset_parallel( 238 | self, 239 | data_root, 240 | npz_names, 241 | npz_dir, 242 | intrinsic_path, 243 | mode, 244 | min_overlap_score=0., 245 | pose_dir=None, 246 | ): 247 | augment_fn = self.augment_fn if mode == 'train' else None 248 | data_source = self.trainval_data_source if mode in ['train', 'val'] else self.test_data_source 249 | if str(data_source).lower() == 'megadepth': 250 | npz_names = [f'{n}.npz' for n in npz_names] 251 | with tqdm_joblib(tqdm(desc=f'[rank:{self.rank}] loading {mode} datasets', 252 | total=len(npz_names), disable=int(self.rank) != 0)): 253 | if data_source == 'ScanNet': 254 | datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( 255 | delayed(lambda x: _build_dataset( 256 | ScanNetDataset, 257 | data_root, 258 | osp.join(npz_dir, x), 259 | intrinsic_path, 260 | mode=mode, 261 | min_overlap_score=min_overlap_score, 262 | augment_fn=augment_fn, 263 | pose_dir=pose_dir))(name) 264 | for name in npz_names) 265 | elif data_source == 'MegaDepth': 266 | # TODO: _pickle.PicklingError: Could not pickle the task to send it to the workers. 267 | raise NotImplementedError() 268 | datasets = Parallel(n_jobs=math.floor(len(os.sched_getaffinity(0)) * 0.9 / comm.get_local_size()))( 269 | delayed(lambda x: _build_dataset( 270 | MegaDepthDataset, 271 | data_root, 272 | osp.join(npz_dir, x), 273 | mode=mode, 274 | min_overlap_score=min_overlap_score, 275 | img_resize=self.mgdpt_img_resize, 276 | df=self.mgdpt_df, 277 | img_padding=self.mgdpt_img_pad, 278 | depth_padding=self.mgdpt_depth_pad, 279 | augment_fn=augment_fn, 280 | coarse_scale=self.coarse_scale))(name) 281 | for name in npz_names) 282 | else: 283 | raise ValueError(f'Unknown dataset: {data_source}') 284 | return ConcatDataset(datasets) 285 | 286 | def train_dataloader(self): 287 | """ Build training dataloader for ScanNet / MegaDepth. """ 288 | assert self.data_sampler in ['scene_balance'] 289 | logger.info(f'[rank:{self.rank}/{self.world_size}]: Train Sampler and DataLoader re-init (should not re-init between epochs!).') 290 | if self.data_sampler == 'scene_balance': 291 | sampler = RandomConcatSampler(self.train_dataset, 292 | self.n_samples_per_subset, 293 | self.subset_replacement, 294 | self.shuffle, self.repeat, self.seed) 295 | else: 296 | sampler = None 297 | dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) 298 | return dataloader 299 | 300 | def val_dataloader(self): 301 | """ Build validation dataloader for ScanNet / MegaDepth. """ 302 | logger.info(f'[rank:{self.rank}/{self.world_size}]: Val Sampler and DataLoader re-init.') 303 | if not isinstance(self.val_dataset, abc.Sequence): 304 | sampler = DistributedSampler(self.val_dataset, shuffle=False) 305 | return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) 306 | else: 307 | dataloaders = [] 308 | for dataset in self.val_dataset: 309 | sampler = DistributedSampler(dataset, shuffle=False) 310 | dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) 311 | return dataloaders 312 | 313 | def test_dataloader(self, *args, **kwargs): 314 | logger.info(f'[rank:{self.rank}/{self.world_size}]: Test Sampler and DataLoader re-init.') 315 | sampler = DistributedSampler(self.test_dataset, shuffle=False) 316 | return DataLoader(self.test_dataset, sampler=sampler, **self.test_loader_params) 317 | 318 | 319 | def _build_dataset(dataset: Dataset, *args, **kwargs): 320 | return dataset(*args, **kwargs) 321 | --------------------------------------------------------------------------------