├── train_stage1 ├── teacher │ ├── config │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── megadepth_test_1500.py │ │ │ ├── scannet_test_1500.py │ │ │ └── base.py │ │ └── defaultmf.py │ └── model │ │ ├── backbone │ │ ├── __init__.py │ │ ├── fine_preprocess.py │ │ ├── fine_matching.py │ │ ├── coarse_matching.py │ │ ├── match_LA_large.py │ │ └── match_LA_lite.py │ │ ├── utils │ │ ├── dataloader.py │ │ ├── profiler.py │ │ ├── augment.py │ │ ├── misc.py │ │ ├── metrics.py │ │ └── comm.py │ │ ├── matchformer.py │ │ ├── datasets │ │ ├── sampler.py │ │ ├── scannet.py │ │ ├── megadepth.py │ │ └── dataset.py │ │ └── lightning_loftr.py ├── Extractor.py ├── Loss_stage1.py ├── train_stage1.py ├── utils.py ├── IVS_dataset.py └── reg_dataset.py ├── assets ├── SemLA.png ├── Visualization.png └── data_transform.py ├── .gitignore ├── model ├── SemLA.py ├── fusion.py └── reg.py ├── train_stage3 ├── Fusion.py ├── train_stage3.py ├── Loss_stage3.py ├── dataset.py └── utils.py ├── train_stage2 ├── Loss_stage2.py ├── train_stage2.py ├── Reg.py ├── utils.py └── dataset.py ├── inference_one_pair_images.py ├── test.py └── README.md /train_stage1/teacher/config/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/SemLA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiehousheng/SemLA/HEAD/assets/SemLA.png -------------------------------------------------------------------------------- /assets/Visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiehousheng/SemLA/HEAD/assets/Visualization.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | /**/__pycache__/ 3 | *.pyc 4 | *.DS_Store 5 | *.swp 6 | *.pth 7 | tmp.* 8 | */.ipynb_checkpoints/* 9 | 10 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .match_SEA_lite import Matchformer_SEA_lite 3 | 4 | 5 | 6 | def build_backbone(): 7 | 8 | return Matchformer_SEA_lite() 9 | -------------------------------------------------------------------------------- /train_stage1/teacher/config/data/megadepth_test_1500.py: -------------------------------------------------------------------------------- 1 | from config.data.base import cfg 2 | 3 | TEST_BASE_PATH = "data/megadepth/index" 4 | 5 | cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" 6 | cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" 7 | cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" 8 | cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" 9 | 10 | cfg.DATASET.MGDPT_IMG_RESIZE = 840 11 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 12 | -------------------------------------------------------------------------------- /train_stage1/teacher/config/data/scannet_test_1500.py: -------------------------------------------------------------------------------- 1 | from config.data.base import cfg 2 | 3 | TEST_BASE_PATH = "data/scannet/index" 4 | 5 | cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" 6 | cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" 7 | cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" 8 | cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" 9 | cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" 10 | 11 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 -------------------------------------------------------------------------------- /train_stage1/teacher/model/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # --- PL-DATAMODULE --- 5 | 6 | def get_local_split(items: list, world_size: int, rank: int, seed: int): 7 | """ The local rank only loads a split of the dataset. """ 8 | n_items = len(items) 9 | items_permute = np.random.RandomState(seed).permutation(items) 10 | if n_items % world_size == 0: 11 | padded_items = items_permute 12 | else: 13 | padding = np.random.RandomState(seed).choice( 14 | items, 15 | world_size - (n_items % world_size), 16 | replace=True) 17 | padded_items = np.concatenate([items_permute, padding]) 18 | assert len(padded_items) % world_size == 0, \ 19 | f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' 20 | n_per_rank = len(padded_items) // world_size 21 | local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] 22 | 23 | return local_items 24 | -------------------------------------------------------------------------------- /train_stage1/teacher/config/data/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data config will be the last one merged into the main config. 3 | Setups in data configs will override all existed setups! 4 | """ 5 | 6 | from yacs.config import CfgNode as CN 7 | _CN = CN() 8 | _CN.DATASET = CN() 9 | _CN.TRAINER = CN() 10 | 11 | # training data config 12 | _CN.DATASET.TRAIN_DATA_ROOT = None 13 | _CN.DATASET.TRAIN_POSE_ROOT = None 14 | _CN.DATASET.TRAIN_NPZ_ROOT = None 15 | _CN.DATASET.TRAIN_LIST_PATH = None 16 | _CN.DATASET.TRAIN_INTRINSIC_PATH = None 17 | # validation set config 18 | _CN.DATASET.VAL_DATA_ROOT = None 19 | _CN.DATASET.VAL_POSE_ROOT = None 20 | _CN.DATASET.VAL_NPZ_ROOT = None 21 | _CN.DATASET.VAL_LIST_PATH = None 22 | _CN.DATASET.VAL_INTRINSIC_PATH = None 23 | 24 | # testing data config 25 | _CN.DATASET.TEST_DATA_ROOT = None 26 | _CN.DATASET.TEST_POSE_ROOT = None 27 | _CN.DATASET.TEST_NPZ_ROOT = None 28 | _CN.DATASET.TEST_LIST_PATH = None 29 | _CN.DATASET.TEST_INTRINSIC_PATH = None 30 | 31 | # dataset config 32 | _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 33 | _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val 34 | 35 | cfg = _CN 36 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/utils/profiler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler 3 | from contextlib import contextmanager 4 | from pytorch_lightning.utilities import rank_zero_only 5 | 6 | 7 | class InferenceProfiler(SimpleProfiler): 8 | """ 9 | This profiler records duration of actions with cuda.synchronize() 10 | Use this in test time. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.start = rank_zero_only(self.start) 16 | self.stop = rank_zero_only(self.stop) 17 | self.summary = rank_zero_only(self.summary) 18 | 19 | @contextmanager 20 | def profile(self, action_name: str) -> None: 21 | try: 22 | torch.cuda.synchronize() 23 | self.start(action_name) 24 | yield action_name 25 | finally: 26 | torch.cuda.synchronize() 27 | self.stop(action_name) 28 | 29 | 30 | def build_profiler(name): 31 | if name == 'inference': 32 | return InferenceProfiler() 33 | elif name == 'pytorch': 34 | from pytorch_lightning.profiler import PyTorchProfiler 35 | return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) 36 | elif name is None: 37 | return PassThroughProfiler() 38 | else: 39 | raise ValueError(f'Invalid profiler: {name}') 40 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/utils/augment.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | 3 | 4 | class DarkAug(object): 5 | """ 6 | Extreme dark augmentation aiming at Aachen Day-Night 7 | """ 8 | 9 | def __init__(self) -> None: 10 | self.augmentor = A.Compose([ 11 | A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), 12 | A.Blur(p=0.1, blur_limit=(3, 9)), 13 | A.MotionBlur(p=0.2, blur_limit=(3, 25)), 14 | A.RandomGamma(p=0.1, gamma_limit=(15, 65)), 15 | A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) 16 | ], p=0.75) 17 | 18 | def __call__(self, x): 19 | return self.augmentor(image=x)['image'] 20 | 21 | 22 | class MobileAug(object): 23 | """ 24 | Random augmentations aiming at images of mobile/handhold devices. 25 | """ 26 | 27 | def __init__(self): 28 | self.augmentor = A.Compose([ 29 | A.MotionBlur(p=0.25), 30 | A.ColorJitter(p=0.5), 31 | A.RandomRain(p=0.1), # random occlusion 32 | A.RandomSunFlare(p=0.1), 33 | A.JpegCompression(p=0.25), 34 | A.ISONoise(p=0.25) 35 | ], p=1.0) 36 | 37 | def __call__(self, x): 38 | return self.augmentor(image=x)['image'] 39 | 40 | 41 | def build_augmentor(method=None, **kwargs): 42 | if method is not None: 43 | raise NotImplementedError('Using of augmentation functions are not supported yet!') 44 | if method == 'dark': 45 | return DarkAug() 46 | elif method == 'mobile': 47 | return MobileAug() 48 | elif method is None: 49 | return None 50 | else: 51 | raise ValueError(f'Invalid augmentation method: {method}') 52 | 53 | 54 | if __name__ == '__main__': 55 | augmentor = build_augmentor('FDA') 56 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/matchformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from teacher.model.backbone import build_backbone 4 | from teacher.model.backbone.fine_preprocess import FinePreprocess 5 | from teacher.model.backbone.coarse_matching import CoarseMatching 6 | from teacher.model.backbone.fine_matching import FineMatching 7 | from einops.einops import rearrange 8 | 9 | 10 | class Matchformer(nn.Module): 11 | 12 | def __init__(self, config): 13 | super().__init__() 14 | # Misc 15 | self.config = config 16 | self.backbone = build_backbone() 17 | self.coarse_matching = CoarseMatching(config['matchformer']['match_coarse']) 18 | 19 | 20 | 21 | def forward(self, data): 22 | data.update({ 23 | 'bs': data['image0'].size(0), 24 | 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] 25 | }) 26 | 27 | mask_c0 = mask_c1 = None # mask is useful in training 28 | if 'mask0' in data: 29 | mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) 30 | 31 | if data['hw0_i'] == data['hw1_i']: 32 | feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) 33 | (feat_c0, feat_c1),(feat_f0, feat_f1) = feats_c.split(data['bs']),feats_f.split(data['bs']) 34 | else: 35 | (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) 36 | 37 | data.update({ 38 | 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], 39 | 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] 40 | }) 41 | 42 | feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c') 43 | feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c') 44 | 45 | # match coarse-level 46 | self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) 47 | 48 | def load_state_dict(self, state_dict, *args, **kwargs): 49 | for k in list(state_dict.keys()): 50 | if k.startswith('matcher.'): 51 | state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) 52 | return super().load_state_dict(state_dict, *args, **kwargs) 53 | 54 | -------------------------------------------------------------------------------- /model/SemLA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange 5 | from .reg import SemLA_Reg 6 | from .fusion import SemLA_Fusion 7 | 8 | class SemLA(nn.Module): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.backbone = SemLA_Reg() 13 | self.fusion = SemLA_Fusion() 14 | 15 | def forward(self, img_vi, img_ir, matchmode): 16 | # Select 'scene' mode when no semantic objects exist in the image 17 | if matchmode=='semantic': 18 | thr = 0.5 19 | elif matchmode=='scene': 20 | thr = 0 21 | feat_reg_vi_final, feat_reg_ir_final, feat_sa_vi, feat_sa_ir = self.backbone( 22 | torch.cat((img_vi, img_ir), dim=0)) 23 | 24 | 25 | sa_vi, sa_ir = feat_sa_vi.reshape(-1), feat_sa_ir.reshape(-1) 26 | sa_vi, sa_ir = torch.where(sa_vi > thr)[0], torch.where(sa_ir > thr)[0] 27 | 28 | feat_reg_vi = rearrange(feat_reg_vi_final, 'n c h w -> n (h w) c') 29 | feat_reg_ir = rearrange(feat_reg_ir_final, 'n c h w -> n (h w) c') 30 | 31 | feat_reg_vi, feat_reg_ir = feat_reg_vi[:, sa_vi], feat_reg_ir[:, sa_ir] 32 | feat_reg_vi, feat_reg_ir = map(lambda feat: feat / feat.shape[-1] ** .5, 33 | [feat_reg_vi, feat_reg_ir]) 34 | 35 | conf = torch.einsum("nlc,nsc->nls", feat_reg_vi, 36 | feat_reg_ir) / 0.1 37 | mask = conf > 0. 38 | mask = mask \ 39 | * (conf == conf.max(dim=2, keepdim=True)[0]) \ 40 | * (conf == conf.max(dim=1, keepdim=True)[0]) 41 | 42 | mask_v, all_j_ids = mask.max(dim=2) 43 | b_ids, i_ids = torch.where(mask_v) 44 | j_ids = all_j_ids[b_ids, i_ids] 45 | i_ids = sa_vi[i_ids] 46 | j_ids = sa_ir[j_ids] 47 | 48 | mkpts0 = torch.stack( 49 | [i_ids % feat_sa_vi.shape[3], i_ids // feat_sa_vi.shape[3]], 50 | dim=1) * 8 51 | mkpts1 = torch.stack( 52 | [j_ids % feat_sa_vi.shape[3], j_ids // feat_sa_vi.shape[3]], 53 | dim=1) * 8 54 | 55 | sa_ir= F.interpolate(feat_sa_ir, scale_factor=8, mode='bilinear', align_corners=True) 56 | 57 | return mkpts0, mkpts1, feat_sa_vi, feat_sa_ir, sa_ir 58 | -------------------------------------------------------------------------------- /train_stage3/Fusion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from utils import CBR, DWConv, MLP, MLP2 4 | import torch.nn.functional as F 5 | 6 | class SemLA_Fusion(nn.Module): 7 | """ 8 | The fusion section of SemLA 9 | """ 10 | def __init__(self): 11 | super().__init__() 12 | 13 | self.fuse1 = CR(1, 8) 14 | self.fuse2 = CR(8, 8) 15 | self.fuse3 = CR(8, 16) 16 | self.fuse4 = CR(16, 16) 17 | self.fuse5 = JConv(48, 1) 18 | 19 | self.acitve = nn.Tanh() 20 | 21 | for m in self.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 24 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 25 | nn.init.constant_(m.weight, 1) 26 | nn.init.constant_(m.bias, 0) 27 | 28 | def forward(self, x): 29 | feat1 = self.fuse1(x) 30 | feat2 = self.fuse2(feat1) 31 | feat3 = self.fuse3(feat2) 32 | feat4 = self.fuse4(feat3) 33 | 34 | feat_fuse = torch.cat((feat1, feat2, feat3, feat4), dim=1) 35 | 36 | feat_fuse = self.fuse5(feat_fuse) 37 | feat_fuse = self.acitve(feat_fuse) 38 | result = (feat_fuse + 1) / 2 39 | 40 | return result 41 | 42 | class JConv(nn.Module): 43 | """Joint Convolutional blocks 44 | 45 | Args: 46 | 'x' (torch.Tensor): (N, C, H, W) 47 | """ 48 | def __init__(self, in_channels, out_channels): 49 | super(JConv, self).__init__() 50 | self.feat_trans = CBR(in_channels, out_channels) 51 | self.dwconv = DWConv(out_channels) 52 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 53 | self.mlp = MLP(out_channels, bias=True) 54 | 55 | def forward(self, x): 56 | x = self.feat_trans(x) 57 | x = x + self.dwconv(x) 58 | out = self.norm(x) 59 | x = x + self.mlp(out) 60 | return x 61 | 62 | 63 | 64 | class CR(nn.Module): 65 | """Convolution with Leaky ReLU 66 | 67 | Args: 68 | 'x' (torch.Tensor): (N, C, H, W) 69 | """ 70 | def __init__(self, in_channels, out_channels, stride=1): 71 | super().__init__() 72 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 73 | 74 | def forward(self, x): 75 | return F.leaky_relu(self.conv(x), negative_slope=0.2) -------------------------------------------------------------------------------- /train_stage3/train_stage3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from dataset import Dataset 5 | from Loss_stage3 import Loss_stage3 6 | 7 | def train_fuse(model_stage3, optimizer, data_loader_sa): 8 | device = 'cuda' 9 | for data_iter_step, img in enumerate(data_loader_sa): 10 | loss_ssim, loss_int = model_stage3(img.to(device)) 11 | 12 | # ssim loss and intensity loss 13 | loss_fusion = 0.7 * loss_ssim + 0.3 * loss_int 14 | if data_iter_step % 10 == 0: 15 | print('data_iter:', data_iter_step, 'loss_ssim:', loss_ssim.item(), 'loss_int:', loss_int.item()) 16 | optimizer.zero_grad() 17 | loss_fusion.backward() 18 | optimizer.step() 19 | torch.cuda.synchronize() 20 | 21 | 22 | if __name__ == '__main__': 23 | # Configuring dataset paths 24 | path2COCO = "" 25 | path2COCO_CPSTN = "" 26 | 27 | # Configure the size of the training image 28 | train_size = (320, 240) 29 | 30 | # Device for training: 'cuda' or 'cpu' 31 | device = 'cuda' 32 | 33 | # COCO Dataset 34 | dataset = Dataset(path2COCO, path2COCO_CPSTN, train_size_w = train_size[0], train_size_h = train_size[1]) 35 | 36 | dataset_sampler = torch.utils.data.RandomSampler(dataset) 37 | 38 | data_loader = torch.utils.data.DataLoader( 39 | dataset, 40 | sampler=dataset_sampler, 41 | batch_size=128, 42 | num_workers=8, 43 | pin_memory=True, 44 | drop_last=True 45 | ) 46 | 47 | 48 | model_stage3 = Loss_stage3() 49 | model_stage3.to(device) 50 | 51 | for name, param in model_stage3.named_parameters(): 52 | if param.requires_grad: 53 | print(name) 54 | 55 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_stage3.parameters()), lr=1e-4) 56 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998) 57 | print(optimizer) 58 | 59 | model_stage3.train(True) 60 | 61 | epochs = 20 62 | print('Start training!') 63 | 64 | # Train SAF 65 | print('Train Fusion!') 66 | for epoch in range(0, epochs): 67 | print('current epoch:', epoch) 68 | train_fuse(model_stage3, optimizer, data_loader) 69 | scheduler.step() 70 | torch.save(model_stage3.state_dict(), './weights/fusion_{}epoch.ckpt'.format(epoch+1)) 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /model/fusion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from model.utils import CBR, DWConv, MLP 4 | import torch.nn.functional as F 5 | 6 | class SemLA_Fusion(nn.Module): 7 | """ 8 | The registration section of SemLA 9 | """ 10 | 11 | def __init__(self): 12 | super().__init__() 13 | 14 | self.fuse1 = CR(1, 8) 15 | self.fuse2 = CR(8, 8) 16 | self.fuse3 = CR(8, 16) 17 | self.fuse4 = CR(16, 16) 18 | self.fuse5 = JConv(48, 1) 19 | self.acitve = nn.Tanh() 20 | 21 | for m in self.modules(): 22 | if isinstance(m, nn.Conv2d): 23 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 24 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 25 | nn.init.constant_(m.weight, 1) 26 | nn.init.constant_(m.bias, 0) 27 | 28 | def forward(self, x, mask, matchmode = 'semantic'): 29 | bs = x.shape[0] 30 | 31 | feat1 = self.fuse1(x) 32 | feat2 = self.fuse2(feat1) 33 | feat3 = self.fuse3(feat2) 34 | feat4 = self.fuse4(feat3) 35 | 36 | featfuse=torch.cat((feat1,feat2,feat3,feat4),dim=1) 37 | 38 | (vifeat, irfeat) = featfuse.split(int(bs / 2)) 39 | 40 | if matchmode == 'semantic': 41 | featfuse = irfeat * mask * 0.7 + vifeat * mask * 0.6 + vifeat * (1 - mask) 42 | elif matchmode == 'scene': 43 | featfuse = irfeat * 0.6 + vifeat * 0.6 44 | 45 | featfuse = self.fuse5(featfuse) 46 | featfuse = self.acitve(featfuse) 47 | featfuse = (featfuse + 1) / 2 48 | 49 | return featfuse 50 | 51 | 52 | class JConv(nn.Module): 53 | """Joint Convolutional blocks 54 | 55 | Args: 56 | 'x' (torch.Tensor): (N, C, H, W) 57 | """ 58 | def __init__(self, in_channels, out_channels): 59 | super(JConv, self).__init__() 60 | self.feat_trans = CBR(in_channels, out_channels) 61 | self.dwconv = DWConv(out_channels) 62 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 63 | self.mlp = MLP(out_channels, bias=True) 64 | 65 | def forward(self, x): 66 | x = self.feat_trans(x) 67 | x = x + self.dwconv(x) 68 | out = self.norm(x) 69 | x = x + self.mlp(out) 70 | return x 71 | 72 | 73 | 74 | class CR(nn.Module): 75 | """Convolution with Leaky ReLU 76 | 77 | Args: 78 | 'x' (torch.Tensor): (N, C, H, W) 79 | """ 80 | def __init__(self, in_channels, out_channels, stride=1): 81 | super().__init__() 82 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 83 | 84 | def forward(self, x): 85 | return F.leaky_relu(self.conv(x), negative_slope=0.2) -------------------------------------------------------------------------------- /train_stage1/teacher/model/backbone/fine_preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange, repeat 5 | 6 | 7 | class FinePreprocess(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | 11 | self.config = config 12 | self.cat_c_feat = config['fine_concat_coarse_feat'] 13 | self.W = self.config['fine_window_size'] 14 | 15 | d_model_c = self.config['coarse']['d_model'] 16 | d_model_f = self.config['fine']['d_model'] 17 | self.d_model_f = d_model_f 18 | if self.cat_c_feat: 19 | self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) 20 | self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) 21 | 22 | self._reset_parameters() 23 | 24 | def _reset_parameters(self): 25 | for p in self.parameters(): 26 | if p.dim() > 1: 27 | nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") 28 | 29 | def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): 30 | W = self.W 31 | stride = data['hw0_f'][0] // data['hw0_c'][0] 32 | 33 | data.update({'W': W}) 34 | if data['b_ids'].shape[0] == 0: 35 | feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 36 | feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 37 | return feat0, feat1 38 | 39 | # 1. unfold(crop) all local windows 40 | feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) 41 | feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 42 | feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) 43 | feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 44 | 45 | # 2. select only the predicted matches 46 | feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] 47 | feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] 48 | 49 | # option: use coarse-level loftr feature as context: concat and linear 50 | if self.cat_c_feat: 51 | feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], 52 | feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] 53 | feat_cf_win = self.merge_feat(torch.cat([ 54 | torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] 55 | repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] 56 | ], -1)) 57 | feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) 58 | 59 | return feat_f0_unfold, feat_f1_unfold 60 | -------------------------------------------------------------------------------- /train_stage2/Loss_stage2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange 5 | from Reg import SemLA_Reg 6 | 7 | class Loss_stage2(nn.Module): 8 | """Calculate the loss of the first stage of training, including registration loss and semantic awareness loss 9 | 10 | Args: 11 | 'inputs': (torch.Tensor): (N, C, H, W) 12 | 'targets': (torch.Tensor): (N, C, H, W) 13 | """ 14 | def __init__(self): 15 | super().__init__() 16 | self.backbone = SemLA_Reg() 17 | 18 | def forward(self, img_vi, img_ir, conf_gt, str_conf_gt): 19 | feat_reg_vi_final, feat_reg_ir_final, feat_reg_vi_str, feat_reg_ir_str = self.backbone(torch.cat((img_vi, img_ir), dim=0)) 20 | 21 | feat_reg_vi_final = rearrange(feat_reg_vi_final, 'n c h w -> n (h w) c') 22 | feat_reg_ir_final = rearrange(feat_reg_ir_final, 'n c h w -> n (h w) c') 23 | 24 | feat_reg_vi_final, feat_reg_ir_final = map(lambda feat: feat / feat.shape[-1] ** .5, 25 | [feat_reg_vi_final, feat_reg_ir_final]) 26 | 27 | # The registration loss is implemented based on the code of [LoFTR](https://github.com/zju3dv/LoFTR) 28 | conf_0 = torch.einsum("nlc,nsc->nls", feat_reg_vi_final, 29 | feat_reg_ir_final) / 0.1 30 | 31 | # dual-softmax operator 32 | conf_0 = F.softmax(conf_0, 1) * F.softmax(conf_0, 2) 33 | conf_0 = torch.clamp(conf_0, 1e-6, 1 - 1e-6) 34 | 35 | pos_mask_0, neg_mask_0 = conf_gt == 1, conf_gt == 0 36 | alpha = 0.25 37 | gamma = 2.0 38 | 39 | pos_conf_0 = conf_0[pos_mask_0] 40 | loss_0 = - alpha * torch.pow(1 - pos_conf_0, gamma) * pos_conf_0.log() 41 | 42 | # registration loss 43 | loss_0 = loss_0.mean() 44 | 45 | 46 | 47 | feat_reg_vi_str = rearrange(feat_reg_vi_str, 'n c h w -> n (h w) c') 48 | feat_reg_ir_str = rearrange(feat_reg_ir_str, 'n c h w -> n (h w) c') 49 | 50 | feat_reg_vi_str, feat_reg_ir_str = map(lambda feat: feat / feat.shape[-1] ** .5, 51 | [feat_reg_vi_str, feat_reg_ir_str]) 52 | conf_1 = torch.einsum("nlc,nsc->nls", feat_reg_vi_str, 53 | feat_reg_ir_str) / 0.1 54 | conf_1 = F.softmax(conf_1, 1) * F.softmax(conf_1, 2) 55 | conf_1 = torch.clamp(conf_1, 1e-6, 1 - 1e-6) 56 | 57 | pos_mask_1, neg_mask_1 = str_conf_gt == 1, str_conf_gt == 0 58 | pos_conf_1 = conf_1[pos_mask_1] 59 | 60 | loss_1 = - alpha * torch.pow(1 - pos_conf_1, gamma) * pos_conf_1.log() 61 | loss_1 = loss_1.mean() 62 | 63 | 64 | 65 | return loss_0, loss_1 66 | 67 | -------------------------------------------------------------------------------- /inference_one_pair_images.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from model.SemLA import SemLA 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from einops.einops import rearrange 7 | from model.utils import YCbCr2RGB, RGB2YCrCb, make_matching_figure 8 | 9 | # Test on a pair of images 10 | if __name__ == '__main__': 11 | # config 12 | reg_weight_path = "" 13 | fusion_weight_path = "" 14 | 15 | # img0 is visible image, and img1 is infrared image 16 | img0_pth = "" 17 | img1_pth = "" 18 | 19 | match_mode = 'semantic' # 'semantic' or 'scene' 20 | 21 | matcher = SemLA() 22 | # Loading the weights of the registration model 23 | matcher.load_state_dict(torch.load(reg_weight_path),strict=False) 24 | 25 | # Loading the weights of the fusion model 26 | matcher.load_state_dict(torch.load(fusion_weight_path), strict=False) 27 | 28 | matcher = matcher.eval().cuda() 29 | 30 | img0_raw = cv2.imread(img0_pth) 31 | img0_raw = cv2.cvtColor(img0_raw, cv2.COLOR_BGR2RGB) 32 | img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE) 33 | img0_raw = cv2.resize(img0_raw, (640, 480)) # input size shuold be divisible by 8 34 | img1_raw = cv2.resize(img1_raw, (640, 480)) 35 | 36 | img0 = torch.from_numpy(img0_raw)[None].cuda() / 255. 37 | img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255. 38 | 39 | img0 = rearrange(img0, 'n h w c -> n c h w') 40 | vi_Y, vi_Cb, vi_Cr = RGB2YCrCb(img0) 41 | 42 | mkpts0, mkpts1, feat_sa_vi, feat_sa_ir, sa_ir = matcher(vi_Y, img1, matchmode=match_mode) 43 | mkpts0 = mkpts0.cpu().numpy() 44 | mkpts1 = mkpts1.cpu().numpy() 45 | 46 | _, prediction = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC,5) 47 | prediction = np.array(prediction, dtype=bool).reshape([-1]) 48 | mkpts0_tps = mkpts0[prediction] 49 | mkpts1_tps = mkpts1[prediction] 50 | tps = cv2.createThinPlateSplineShapeTransformer() 51 | mkpts0_tps_ransac = mkpts0_tps.reshape(1, -1, 2) 52 | mkpts1_tps_ransac = mkpts1_tps.reshape(1, -1, 2) 53 | 54 | matches = [] 55 | for j in range(1, mkpts0.shape[0] + 1): 56 | matches.append(cv2.DMatch(j, j, 0)) 57 | 58 | tps.estimateTransformation(mkpts0_tps_ransac, mkpts1_tps_ransac, matches) 59 | img1_raw_trans = tps.warpImage(img1_raw) 60 | sa_ir = tps.warpImage(sa_ir[0][0].detach().cpu().numpy()) 61 | sa_ir = torch.from_numpy(sa_ir)[None][None].cuda() 62 | 63 | img1_trans = torch.from_numpy(img1_raw_trans)[None][None].cuda() / 255. 64 | fuse = matcher.fusion(torch.cat((vi_Y, img1_trans), dim=0), sa_ir, matchmode=match_mode).detach() 65 | 66 | fuse = YCbCr2RGB(fuse, vi_Cb, vi_Cr) 67 | fuse = fuse.detach().cpu()[0] 68 | fuse = rearrange(fuse, ' c h w -> h w c').detach().cpu().numpy() 69 | 70 | fig = make_matching_figure(fuse, img0_raw, img1_raw, mkpts0_tps, mkpts1_tps) 71 | plt.show() 72 | 73 | 74 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/backbone/fine_matching.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from kornia.geometry.subpix import dsnt 6 | from kornia.utils.grid import create_meshgrid 7 | 8 | 9 | class FineMatching(nn.Module): 10 | """FineMatching with s2d paradigm""" 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, feat_f0, feat_f1, data): 16 | """ 17 | Args: 18 | feat0 (torch.Tensor): [M, WW, C] 19 | feat1 (torch.Tensor): [M, WW, C] 20 | data (dict) 21 | Update: 22 | data (dict):{ 23 | 'expec_f' (torch.Tensor): [M, 3], 24 | 'mkpts0_f' (torch.Tensor): [M, 2], 25 | 'mkpts1_f' (torch.Tensor): [M, 2]} 26 | """ 27 | M, WW, C = feat_f0.shape 28 | W = int(math.sqrt(WW)) 29 | scale = data['hw0_i'][0] / data['hw0_f'][0] 30 | self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale 31 | 32 | # corner case: if no coarse matches found 33 | if M == 0: 34 | assert self.training == False, "M is always >0, when training, see coarse_matching.py" 35 | # logger.warning('No matches found in coarse-level.') 36 | data.update({ 37 | 'expec_f': torch.empty(0, 3, device=feat_f0.device), 38 | 'mkpts0_f': data['mkpts0_c'], 39 | 'mkpts1_f': data['mkpts1_c'], 40 | }) 41 | return 42 | 43 | feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] 44 | sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) 45 | softmax_temp = 1. / C**.5 46 | heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) 47 | 48 | # compute coordinates from heatmap 49 | coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] 50 | grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] 51 | 52 | # compute std over 53 | var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] 54 | std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability 55 | 56 | # for fine-level supervision 57 | data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) 58 | 59 | # compute absolute kpt coords 60 | self.get_fine_match(coords_normalized, data) 61 | 62 | @torch.no_grad() 63 | def get_fine_match(self, coords_normed, data): 64 | W, WW, C, scale = self.W, self.WW, self.C, self.scale 65 | 66 | # mkpts0_f and mkpts1_f 67 | mkpts0_f = data['mkpts0_c'] 68 | scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale 69 | mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] 70 | 71 | data.update({ 72 | "mkpts0_f": mkpts0_f, 73 | "mkpts1_f": mkpts1_f 74 | }) 75 | -------------------------------------------------------------------------------- /train_stage1/Extractor.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from utils import CBR, DWConv, MLP, DWT_2D 3 | 4 | 5 | class JConv(nn.Module): 6 | """Joint Convolutional blocks 7 | 8 | Args: 9 | 'x' (torch.Tensor): (N, C, H, W) 10 | """ 11 | def __init__(self, in_channels, out_channels): 12 | super(JConv, self).__init__() 13 | self.feat_trans = CBR(in_channels, out_channels) 14 | self.dwconv = DWConv(out_channels) 15 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 16 | self.mlp = MLP(out_channels, bias=True) 17 | 18 | def forward(self, x): 19 | x = self.feat_trans(x) 20 | x = x + self.dwconv(x) 21 | out = self.norm(x) 22 | x = x + self.mlp(out) 23 | return x 24 | 25 | 26 | class Feature_Extraction(nn.Module): 27 | """ 28 | Feature Extraction Layer in SemLA 29 | Extraction of registration features and semantic awareness maps 30 | 31 | Args: 32 | 'x' (torch.Tensor): (N, C, H, W) 33 | 'train_mode' (String) 34 | """ 35 | def __init__(self): 36 | super().__init__() 37 | 38 | # Discrete Wavelet Transform (For feature map downsampling) 39 | self.dwt = DWT_2D(wave='haar') 40 | 41 | self.reg0 = JConv(1, 8) 42 | self.reg1 = JConv(32, 16) 43 | self.reg2 = JConv(64, 32) 44 | self.reg3 = JConv(128, 256) 45 | self.pred_reg = nn.Sequential(JConv(256, 256), JConv(256, 256), JConv(256, 256)) 46 | 47 | self.sa0 = JConv(256, 256) 48 | self.sa1 = JConv(256, 128) 49 | self.sa2 = JConv(128, 32) 50 | self.sa3 = JConv(32, 1) 51 | self.pred_sa = nn.Sigmoid() 52 | 53 | 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d): 56 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 57 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 58 | nn.init.constant_(m.weight, 1) 59 | nn.init.constant_(m.bias, 0) 60 | 61 | def forward(self, x, mode): 62 | x0 = self.reg0(x) 63 | x1 = self.reg1(self.dwt(x0)) 64 | x2 = self.reg2(self.dwt(x1)) 65 | x3 = self.reg3(self.dwt(x2)) 66 | feat_reg = self.pred_reg(x3) 67 | 68 | # Training the registration of SemLA 69 | if mode == 'train_reg': 70 | return feat_reg 71 | 72 | # Training the semantic awareness of SemLA 73 | elif mode == 'train_sa': 74 | y0 = self.sa0(feat_reg) 75 | y1 = self.sa1(y0) 76 | y2 = self.sa2(y1) 77 | y3 = self.sa3(y2) 78 | feat_sa = self.pred_sa(y3) 79 | return feat_sa 80 | 81 | # Testing the registration and semantic awareness of SemLA (Other modules in SemLA are not included) 82 | elif mode == 'test': 83 | y0 = self.sa0(feat_reg) 84 | y1 = self.sa1(y0) 85 | y2 = self.sa2(y1) 86 | y3 = self.sa3(y2) 87 | feat_sa = self.pred_sa(y3) 88 | return feat_reg, feat_sa 89 | 90 | -------------------------------------------------------------------------------- /train_stage1/Loss_stage1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange 5 | from Extractor import Feature_Extraction 6 | 7 | class BCELoss(nn.Module): 8 | """cross-entropy loss function 9 | 10 | Args: 11 | 'inputs': (torch.Tensor): (N, C, H, W) 12 | 'targets': (torch.Tensor): (N, C, H, W) 13 | """ 14 | 15 | def __init__(self): 16 | super(BCELoss, self).__init__() 17 | self.BCE_loss = nn.BCELoss(reduction='none') 18 | 19 | def forward(self, inputs, targets): 20 | bce_loss = self.BCE_loss(inputs, targets) 21 | 22 | # weighting loss 23 | weight = torch.zeros_like(targets, dtype=torch.float32) 24 | weight = weight.fill_(0.9) 25 | weight[targets > 0] = 1.7 26 | bce_loss = torch.mean(bce_loss * weight) 27 | 28 | return bce_loss 29 | 30 | 31 | class Loss_stage1(nn.Module): 32 | """Calculate the loss of the first stage of training, including registration loss and semantic awareness loss 33 | 34 | Args: 35 | 'inputs': (torch.Tensor): (N, C, H, W) 36 | 'targets': (torch.Tensor): (N, C, H, W) 37 | """ 38 | def __init__(self): 39 | super().__init__() 40 | self.backbone = Feature_Extraction() 41 | self.bceloss = BCELoss() 42 | 43 | def forward(self, reg_vi, reg_ir, conf_gt, sa_vi, sa_ir, sa_vi_gt, sa_ir_gt): 44 | bs_reg = reg_vi.shape[0] 45 | feat_reg = self.backbone(torch.cat((reg_vi, reg_ir), dim=0), mode='train_reg') 46 | (feat_reg_vi, feat_reg_ir) = feat_reg.split(bs_reg) 47 | 48 | feat_reg_vi = rearrange(feat_reg_vi, 'n c h w -> n (h w) c') 49 | feat_reg_ir = rearrange(feat_reg_ir, 'n c h w -> n (h w) c') 50 | 51 | feat_reg_vi, feat_reg_ir = map(lambda feat: feat / feat.shape[-1] ** .5, 52 | [feat_reg_vi, feat_reg_ir]) 53 | 54 | # The registration loss is implemented based on the code of [LoFTR](https://github.com/zju3dv/LoFTR) 55 | conf = torch.einsum("nlc,nsc->nls", feat_reg_vi, 56 | feat_reg_ir) / 0.1 57 | 58 | # dual-softmax operator 59 | conf = F.softmax(conf, 1) * F.softmax(conf, 2) 60 | conf = torch.clamp(conf, 1e-6, 1 - 1e-6) 61 | 62 | pos_mask, neg_mask = conf_gt == 1, conf_gt == 0 63 | 64 | alpha = 0.25 65 | gamma = 2.0 66 | 67 | pos_conf = conf[pos_mask] 68 | loss_reg = - alpha * torch.pow(1 - pos_conf, gamma) * pos_conf.log() 69 | 70 | # registration loss 71 | loss_reg = loss_reg.mean() 72 | bs_sa = sa_vi.shape[0] 73 | 74 | 75 | feat_sa = self.backbone(torch.cat((sa_vi, sa_ir), dim=0), mode='train_sa') 76 | (feat_sa_vi, feat_sa_ir) = feat_sa.split(bs_sa) 77 | 78 | loss_sa_vi = self.bceloss(feat_sa_vi, sa_vi_gt) 79 | loss_sa_ir = self.bceloss(feat_sa_ir, sa_ir_gt) 80 | 81 | # semantic awareness loss 82 | loss_sa = loss_sa_vi + loss_sa_ir 83 | 84 | return loss_reg, loss_sa, conf 85 | 86 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | from model.SemLA import SemLA 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from einops.einops import rearrange 7 | from model.utils import YCbCr2RGB, RGB2YCrCb, make_matching_figure 8 | import os 9 | 10 | # Test on dataset 11 | if __name__ == '__main__': 12 | # config 13 | vi_path = "" 14 | ir_path = "" 15 | result_path = "" 16 | reg_weight_path = "" 17 | fusion_weight_path = "" 18 | match_mode = 'scene' # 'semantic' or 'scene' 19 | 20 | matcher = SemLA() 21 | # Loading the weights of the registration model 22 | matcher.load_state_dict(torch.load(reg_weight_path), strict=False) 23 | 24 | # Loading the weights of the fusion model 25 | matcher.load_state_dict(torch.load(fusion_weight_path), strict=False) 26 | 27 | matcher = matcher.eval().cuda() 28 | 29 | image_path = os.listdir(vi_path) 30 | # vi_path is the visible dataset path, ir_path is the infrared dataset path 31 | for image_name in image_path: 32 | img0_pth = os.path.join(vi_path, image_name) 33 | img1_pth = os.path.join(ir_path, image_name) 34 | 35 | img0_raw = cv2.imread(img0_pth) 36 | img0_raw = cv2.cvtColor(img0_raw, cv2.COLOR_BGR2RGB) 37 | img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE) 38 | img0_raw = cv2.resize(img0_raw, (320, 240)) # input size shuold be divisible by 8 39 | img1_raw = cv2.resize(img1_raw, (320, 240)) 40 | 41 | img0 = torch.from_numpy(img0_raw)[None].cuda() / 255. 42 | img1 = torch.from_numpy(img1_raw)[None][None].cuda() / 255. 43 | 44 | img0 = rearrange(img0, 'n h w c -> n c h w') 45 | vi_Y, vi_Cb, vi_Cr = RGB2YCrCb(img0) 46 | 47 | mkpts0, mkpts1, feat_sa_vi, feat_sa_ir, sa_ir = matcher(vi_Y, img1, matchmode=match_mode) 48 | mkpts0 = mkpts0.cpu().numpy() 49 | mkpts1 = mkpts1.cpu().numpy() 50 | 51 | _, prediction = cv2.findHomography(mkpts0, mkpts1, cv2.RANSAC, 5) 52 | prediction = np.array(prediction, dtype=bool).reshape([-1]) 53 | mkpts0_tps = mkpts0[prediction] 54 | mkpts1_tps = mkpts1[prediction] 55 | tps = cv2.createThinPlateSplineShapeTransformer() 56 | mkpts0_tps = mkpts0_tps.reshape(1, -1, 2) 57 | mkpts1_tps = mkpts1_tps.reshape(1, -1, 2) 58 | 59 | matches = [] 60 | for j in range(1, mkpts0.shape[0] + 1): 61 | matches.append(cv2.DMatch(j, j, 0)) 62 | 63 | tps.estimateTransformation(mkpts0_tps, mkpts1_tps, matches) 64 | img1_raw_trans = tps.warpImage(img1_raw) 65 | sa_ir = tps.warpImage(sa_ir[0][0].detach().cpu().numpy()) 66 | sa_ir = torch.from_numpy(sa_ir)[None][None].cuda() 67 | 68 | img1_trans = torch.from_numpy(img1_raw_trans)[None][None].cuda() / 255. 69 | fuse = matcher.fusion(torch.cat((vi_Y, img1_trans), dim=0), sa_ir, matchmode=match_mode).detach() 70 | 71 | fuse = YCbCr2RGB(fuse, vi_Cb, vi_Cr) 72 | fuse = fuse.detach().cpu()[0] 73 | fuse = rearrange(fuse, ' c h w -> h w c').detach().cpu().numpy() 74 | 75 | fig = make_matching_figure(fuse, img0_raw, img1_raw, mkpts0, mkpts1, path=os.path.join(result_path, image_name)) 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SemLA 2 | **H. Xie**, Y. Zhang, J. Qiu, X. Zhai, X. Liu, Y. Yang, S. Zhao, Y. Luo, and J. Zhong, “**Semantics lead all: Towards unified image registration and fusion from a semantic perspective,” Information Fusion**, p. 101835, 2023. [Paper](https://www.sciencedirect.com/science/article/abs/pii/S1566253523001513) 3 | 4 |

5 | 6 | 7 |

8 | 9 | ## Note 10 | We have updated the existing bugs in the original code. Please download the current project and weights again for testing and training.【07/10】 11 | 12 | ## Data preparation 13 | 1. Download the [COCO](https://cocodataset.org/#download) dataset to ```.\datasets\COCO\``` (path2COCO) 14 | 2. Download the [IVS](https://github.com/xiehousheng/IVS_data) dataset to ```.\datasets\IVS\``` (path2IVS) 15 | 3. Download the label of [IVS](https://github.com/xiehousheng/IVS_data) dataset to ```.\datasets\IVS_Label\``` (path2IVS_Label) 16 | 4. Generate pseudo-infrared images for each image in the COCO dataset using [CPSTN](https://github.com/wdhudiekou/UMF-CMGR/tree/main/CPSTN) and store the results in ```.\datasets\COCO_CPSTN\``` (path2COCO_CPSTN) 17 | 5. Generate pseudo-infrared images for each image in the IVS dataset using [CPSTN](https://github.com/wdhudiekou/UMF-CMGR/tree/main/CPSTN) and store the results in ```.\datasets\IVS_CPSTN\``` ((path2IVS_CPSTN)) 18 | 19 | ## Installation 20 | 21 | The code is implemented in `python=3.6`, as well as `pytorch=1.9` and `opencv-python=4.6.0.66`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch dependencies. Installing PyTorch with CUDA support is strongly recommended. 22 | 23 | ## Training 24 | 1. Train stage1: Registration and semantic feature extraction. ```cd train_stage1``` and configuring dataset paths, then run ```python train_stage1.py``` 25 | 26 | 2. Train stage2: Training CSC and SSR modules. ```cd train_stage2``` and configuring dataset paths, then run ```python train_stage2.py``` 27 | 28 | 3. Train stage3: Training fusion module. ```cd train_stage3``` and configuring dataset paths, then run ```python train_stage3.py``` 29 | 30 | ## Test 31 | Download pre-trained models on [Google Drive](https://drive.google.com/drive/folders/1Lh9UFXWP5bvt_MVwYa9ZPA7g_lOEGrxz?usp=sharing) or [Baidu Yun](https://pan.baidu.com/s/1x3v4BlWgwEH31p5lwL42Ew?pwd=qriy) and configure the path ```reg_weight_path```, ```fusion_weight_path```. We provide two matching modes, one is semantic object-oriented matching, setting ```matchmode = "semantic"```, and the other is global image oriented matching, setting ```matchmode = "scene"```. 32 | ### On a dataset 33 | Configuring dataset paths, then run ```python test.py``` 34 | ### On a pair of images 35 | Configuring images path, then run ```python inference_one_pair_images.py``` 36 | 37 | ## Citation 38 | 39 | If this code is useful for your research, please cite our paper. 40 | ```bibtex 41 | @article{xie2023semantics, 42 | title={Semantics lead all: Towards unified image registration and fusion from a semantic perspective}, 43 | author={Xie, Housheng and Zhang, Yukuan and Qiu, Junhui and Zhai, Xiangshuai and Liu, Xuedong and Yang, Yang and Zhao, Shan and Luo, Yongfang and Zhong, Jianbo}, 44 | journal={Information Fusion}, 45 | pages={101835}, 46 | year={2023}, 47 | publisher={Elsevier} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /train_stage3/Loss_stage3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange 5 | from Fusion import SemLA_Fusion 6 | from math import exp 7 | from torch.autograd import Variable 8 | 9 | class Loss_stage3(nn.Module): 10 | """Calculate the reconstruction loss of image fusion 11 | 12 | Args: 13 | 'img': (torch.Tensor): (N, 1, H, W) 14 | """ 15 | def __init__(self): 16 | super().__init__() 17 | self.fusion = SemLA_Fusion() 18 | self.ssim = SSIMLoss() 19 | 20 | def forward(self, img): 21 | fusion_result = self.fusion(img) 22 | 23 | ssim_loss = self.ssim(img, fusion_result) 24 | intensity_loss = F.l1_loss(fusion_result, img) 25 | 26 | return ssim_loss, intensity_loss 27 | 28 | 29 | 30 | 31 | def gaussian(window_size, sigma): 32 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 33 | return gauss/gauss.sum() 34 | 35 | 36 | def create_window(window_size, channel): 37 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 38 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 39 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 40 | return window 41 | 42 | def _ssim(img1, img2, window, window_size, channel, size_average=True, mask=1): 43 | mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel) 44 | mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel) 45 | 46 | mu1_sq = mu1.pow(2) 47 | mu2_sq = mu2.pow(2) 48 | mu1_mu2 = mu1*mu2 49 | 50 | sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq 51 | sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq 52 | sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2 53 | 54 | C1 = 0.01**2 55 | C2 = 0.03**2 56 | 57 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 58 | ssim_map = ssim_map*mask 59 | 60 | ssim_map = torch.clamp((1.0 - ssim_map) / 2, min=0, max=1) 61 | if size_average: 62 | return ssim_map.mean() 63 | else: 64 | return ssim_map.mean(1).mean(1).mean(1) 65 | 66 | class SSIMLoss(torch.nn.Module): 67 | def __init__(self, window_size=11, size_average=True): 68 | super(SSIMLoss, self).__init__() 69 | self.window_size = window_size 70 | self.size_average = size_average 71 | self.channel = 1 72 | self.window = create_window(window_size, self.channel) 73 | 74 | def forward(self, img1, img2, mask=1): 75 | (_, channel, _, _) = img1.size() 76 | if channel == self.channel and self.window.data.type() == img1.data.type(): 77 | window = self.window 78 | else: 79 | window = create_window(self.window_size, channel) 80 | 81 | if img1.is_cuda: 82 | window = window.cuda(img1.get_device()) 83 | window = window.type_as(img1) 84 | 85 | self.window = window 86 | self.channel = channel 87 | mask = torch.logical_and(img1>0,img2>0).float() 88 | for i in range(self.window_size//2): 89 | mask = (F.conv2d(mask, window, padding=self.window_size//2, groups=channel)>0.8).float() 90 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average, mask=mask) 91 | 92 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import contextlib 3 | import joblib 4 | from typing import Union 5 | from loguru import _Logger, logger 6 | from itertools import chain 7 | 8 | import torch 9 | from yacs.config import CfgNode as CN 10 | from pytorch_lightning.utilities import rank_zero_only 11 | 12 | 13 | def lower_config(yacs_cfg): 14 | if not isinstance(yacs_cfg, CN): 15 | return yacs_cfg 16 | return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} 17 | 18 | 19 | def upper_config(dict_cfg): 20 | if not isinstance(dict_cfg, dict): 21 | return dict_cfg 22 | return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} 23 | 24 | 25 | def log_on(condition, message, level): 26 | if condition: 27 | assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] 28 | logger.log(level, message) 29 | 30 | 31 | def get_rank_zero_only_logger(logger: _Logger): 32 | if rank_zero_only.rank == 0: 33 | return logger 34 | else: 35 | for _level in logger._core.levels.keys(): 36 | level = _level.lower() 37 | setattr(logger, level, 38 | lambda x: None) 39 | logger._log = lambda x: None 40 | return logger 41 | 42 | 43 | def setup_gpus(gpus: Union[str, int]) -> int: 44 | """ A temporary fix for pytorch-lighting 1.3.x """ 45 | gpus = str(gpus) 46 | gpu_ids = [] 47 | 48 | if ',' not in gpus: 49 | n_gpus = int(gpus) 50 | return n_gpus if n_gpus != -1 else torch.cuda.device_count() 51 | else: 52 | gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] 53 | 54 | # setup environment variables 55 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') 56 | if visible_devices is None: 57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 58 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) 59 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') 60 | logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') 61 | else: 62 | logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') 63 | return len(gpu_ids) 64 | 65 | 66 | def flattenList(x): 67 | return list(chain(*x)) 68 | 69 | 70 | @contextlib.contextmanager 71 | def tqdm_joblib(tqdm_object): 72 | """Context manager to patch joblib to report into tqdm progress bar given as argument 73 | 74 | Usage: 75 | with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: 76 | Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) 77 | 78 | When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) 79 | ret_vals = Parallel(n_jobs=args.world_size)( 80 | delayed(lambda x: _compute_cov_score(pid, *x))(param) 81 | for param in tqdm(combinations(image_ids, 2), 82 | desc=f'Computing cov_score of [{pid}]', 83 | total=len(image_ids)*(len(image_ids)-1)/2)) 84 | Src: https://stackoverflow.com/a/58936697 85 | """ 86 | class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): 87 | def __init__(self, *args, **kwargs): 88 | super().__init__(*args, **kwargs) 89 | 90 | def __call__(self, *args, **kwargs): 91 | tqdm_object.update(n=self.batch_size) 92 | return super().__call__(*args, **kwargs) 93 | 94 | old_batch_callback = joblib.parallel.BatchCompletionCallBack 95 | joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback 96 | try: 97 | yield tqdm_object 98 | finally: 99 | joblib.parallel.BatchCompletionCallBack = old_batch_callback 100 | tqdm_object.close() 101 | 102 | -------------------------------------------------------------------------------- /train_stage1/teacher/config/defaultmf.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | def lower_config(yacs_cfg): 3 | if not isinstance(yacs_cfg, CN): 4 | return yacs_cfg 5 | return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} 6 | _CN = CN() 7 | 8 | _CN.MATCHFORMER = CN() 9 | _CN.MATCHFORMER.BACKBONE_TYPE = 'litesea'# litela,largela,litesea,largesea 10 | _CN.MATCHFORMER.SCENS = 'outdoor' # indoor, outdoor 11 | _CN.MATCHFORMER.RESOLUTION = (8,4) #(8,2),(8,4) 12 | _CN.MATCHFORMER.FINE_WINDOW_SIZE = 5 13 | _CN.MATCHFORMER.FINE_CONCAT_COARSE_FEAT = True 14 | 15 | _CN.MATCHFORMER.COARSE = CN() 16 | _CN.MATCHFORMER.COARSE.D_MODEL = 192 17 | _CN.MATCHFORMER.COARSE.D_FFN = 192 18 | 19 | _CN.MATCHFORMER.MATCH_COARSE = CN() 20 | _CN.MATCHFORMER.MATCH_COARSE.THR = 0.2 21 | _CN.MATCHFORMER.MATCH_COARSE.BORDER_RM = 0 22 | _CN.MATCHFORMER.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 23 | _CN.MATCHFORMER.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 24 | _CN.MATCHFORMER.MATCH_COARSE.SKH_ITERS = 3 25 | _CN.MATCHFORMER.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 26 | _CN.MATCHFORMER.MATCH_COARSE.SKH_PREFILTER = False 27 | _CN.MATCHFORMER.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 28 | _CN.MATCHFORMER.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 29 | _CN.MATCHFORMER.MATCH_COARSE.SPARSE_SPVS = True 30 | 31 | _CN.MATCHFORMER.FINE = CN() 32 | _CN.MATCHFORMER.FINE.D_MODEL = 128 33 | _CN.MATCHFORMER.FINE.D_FFN = 128 34 | 35 | ############## Dataset ############## 36 | _CN.DATASET = CN() 37 | # 1. data config 38 | # training and validating 39 | _CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] 40 | _CN.DATASET.TRAIN_DATA_ROOT = None 41 | _CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) 42 | _CN.DATASET.TRAIN_NPZ_ROOT = None 43 | _CN.DATASET.TRAIN_LIST_PATH = None 44 | _CN.DATASET.TRAIN_INTRINSIC_PATH = None 45 | _CN.DATASET.VAL_DATA_ROOT = None 46 | _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) 47 | _CN.DATASET.VAL_NPZ_ROOT = None 48 | _CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file 49 | _CN.DATASET.VAL_INTRINSIC_PATH = None 50 | # testing 51 | _CN.DATASET.TEST_DATA_SOURCE = None 52 | _CN.DATASET.TEST_DATA_ROOT = None 53 | _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) 54 | _CN.DATASET.TEST_NPZ_ROOT = None 55 | _CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file 56 | _CN.DATASET.TEST_INTRINSIC_PATH = None 57 | 58 | # 2. dataset config 59 | # general options 60 | _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score 61 | _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 62 | _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] 63 | 64 | # MegaDepth options 65 | _CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square. 66 | _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE 67 | _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 68 | _CN.DATASET.MGDPT_DF = 8 69 | 70 | # geometric metrics and pose solver 71 | _CN.TRAINER = CN() 72 | _CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) 73 | _CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] 74 | _CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] 75 | _CN.TRAINER.RANSAC_PIXEL_THR = 0.5 76 | _CN.TRAINER.RANSAC_CONF = 0.99999 77 | _CN.TRAINER.RANSAC_MAX_ITERS = 10000 78 | _CN.TRAINER.USE_MAGSACPP = False 79 | 80 | # data sampler 81 | _CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] 82 | _CN.TRAINER.N_SAMPLES_PER_SUBSET = 200 83 | _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not 84 | _CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not 85 | _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data 86 | _CN.TRAINER.SEED = 66 87 | 88 | default_cfg = lower_config(_CN) 89 | def get_cfg_defaults(): 90 | """Get a yacs CfgNode object with default values for my_project.""" 91 | # Return a clone so that the defaults will not be altered 92 | # This is for the "local variable" use pattern 93 | return _CN.clone() 94 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/backbone/coarse_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange 5 | 6 | INF = 1e9 7 | 8 | def mask_border(m, b: int, v): 9 | """ Mask borders with value 10 | Args: 11 | m (torch.Tensor): [N, H0, W0, H1, W1] 12 | b (int) 13 | v (m.dtype) 14 | """ 15 | if b <= 0: 16 | return 17 | 18 | m[:, :b] = v 19 | m[:, :, :b] = v 20 | m[:, :, :, :b] = v 21 | m[:, :, :, :, :b] = v 22 | m[:, -b:] = v 23 | m[:, :, -b:] = v 24 | m[:, :, :, -b:] = v 25 | m[:, :, :, :, -b:] = v 26 | 27 | 28 | def mask_border_with_padding(m, bd, v, p_m0, p_m1): 29 | if bd <= 0: 30 | return 31 | 32 | m[:, :bd] = v 33 | m[:, :, :bd] = v 34 | m[:, :, :, :bd] = v 35 | m[:, :, :, :, :bd] = v 36 | 37 | h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() 38 | h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() 39 | for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): 40 | m[b_idx, h0 - bd:] = v 41 | m[b_idx, :, w0 - bd:] = v 42 | m[b_idx, :, :, h1 - bd:] = v 43 | m[b_idx, :, :, :, w1 - bd:] = v 44 | 45 | 46 | def compute_max_candidates(p_m0, p_m1): 47 | """Compute the max candidates of all pairs within a batch 48 | 49 | Args: 50 | p_m0, p_m1 (torch.Tensor): padded masks 51 | """ 52 | h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] 53 | h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] 54 | max_cand = torch.sum( 55 | torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) 56 | return max_cand 57 | 58 | 59 | 60 | class CoarseMatching(nn.Module): 61 | def __init__(self, config): 62 | super().__init__() 63 | self.config = config 64 | # general config 65 | self.thr = config['thr'] 66 | self.border_rm = config['border_rm'] 67 | # -- # for trainig fine-level LoFTR 68 | self.train_coarse_percent = config['train_coarse_percent'] 69 | self.train_pad_num_gt_min = config['train_pad_num_gt_min'] 70 | 71 | # we provide 2 options for differentiable matching 72 | self.match_type = config['match_type'] 73 | if self.match_type == 'dual_softmax': 74 | self.temperature = config['dsmax_temperature'] 75 | else: 76 | raise NotImplementedError() 77 | 78 | def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): 79 | """ 80 | Args: 81 | feat0 (torch.Tensor): [N, L, C] 82 | feat1 (torch.Tensor): [N, S, C] 83 | data (dict) 84 | mask_c0 (torch.Tensor): [N, L] (optional) 85 | mask_c1 (torch.Tensor): [N, S] (optional) 86 | Update: 87 | data (dict): { 88 | 'b_ids' (torch.Tensor): [M'], 89 | 'i_ids' (torch.Tensor): [M'], 90 | 'j_ids' (torch.Tensor): [M'], 91 | 'gt_mask' (torch.Tensor): [M'], 92 | 'mkpts0_c' (torch.Tensor): [M, 2], 93 | 'mkpts1_c' (torch.Tensor): [M, 2], 94 | 'mconf' (torch.Tensor): [M]} 95 | NOTE: M' != M during training. 96 | """ 97 | N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) 98 | 99 | # normalize 100 | feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, 101 | [feat_c0, feat_c1]) 102 | 103 | if self.match_type == 'dual_softmax': 104 | sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, 105 | feat_c1) / self.temperature 106 | if mask_c0 is not None: 107 | sim_matrix.masked_fill_( 108 | ~(mask_c0[..., None] * mask_c1[:, None]).bool(), 109 | -INF) 110 | conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) 111 | 112 | 113 | data.update({'conf_matrix': conf_matrix}) 114 | data.update({'sim_matrix': sim_matrix}) 115 | # predict coarse matches from conf_matrix 116 | # data.update(**self.get_coarse_match(conf_matrix, data)) 117 | 118 | 119 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler, ConcatDataset 3 | 4 | 5 | class RandomConcatSampler(Sampler): 6 | """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset 7 | in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. 8 | However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. 9 | 10 | For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. 11 | Args: 12 | shuffle (bool): shuffle the random sampled indices across all sub-datsets. 13 | repeat (int): repeatedly use the sampled indices multiple times for training. 14 | [arXiv:1902.05509, arXiv:1901.09335] 15 | NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) 16 | NOTE: This sampler behaves differently with DistributedSampler. 17 | It assume the dataset is splitted across ranks instead of replicated. 18 | TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. 19 | ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 20 | """ 21 | def __init__(self, 22 | data_source: ConcatDataset, 23 | n_samples_per_subset: int, 24 | subset_replacement: bool=True, 25 | shuffle: bool=True, 26 | repeat: int=1, 27 | seed: int=None): 28 | if not isinstance(data_source, ConcatDataset): 29 | raise TypeError("data_source should be torch.utils.data.ConcatDataset") 30 | 31 | self.data_source = data_source 32 | self.n_subset = len(self.data_source.datasets) 33 | self.n_samples_per_subset = n_samples_per_subset 34 | self.n_samples = self.n_subset * self.n_samples_per_subset * repeat 35 | self.subset_replacement = subset_replacement 36 | self.repeat = repeat 37 | self.shuffle = shuffle 38 | self.generator = torch.manual_seed(seed) 39 | assert self.repeat >= 1 40 | 41 | def __len__(self): 42 | return self.n_samples 43 | 44 | def __iter__(self): 45 | indices = [] 46 | # sample from each sub-dataset 47 | for d_idx in range(self.n_subset): 48 | low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] 49 | high = self.data_source.cumulative_sizes[d_idx] 50 | if self.subset_replacement: 51 | rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), 52 | generator=self.generator, dtype=torch.int64) 53 | else: # sample without replacement 54 | len_subset = len(self.data_source.datasets[d_idx]) 55 | rand_tensor = torch.randperm(len_subset, generator=self.generator) + low 56 | if len_subset >= self.n_samples_per_subset: 57 | rand_tensor = rand_tensor[:self.n_samples_per_subset] 58 | else: # padding with replacement 59 | rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), 60 | generator=self.generator, dtype=torch.int64) 61 | rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) 62 | indices.append(rand_tensor) 63 | indices = torch.cat(indices) 64 | if self.shuffle: # shuffle the sampled dataset (from multiple subsets) 65 | rand_tensor = torch.randperm(len(indices), generator=self.generator) 66 | indices = indices[rand_tensor] 67 | 68 | # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) 69 | if self.repeat > 1: 70 | repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] 71 | if self.shuffle: 72 | _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] 73 | repeat_indices = map(_choice, repeat_indices) 74 | indices = torch.cat([indices, *repeat_indices], 0) 75 | 76 | assert indices.shape[0] == self.n_samples 77 | return iter(indices.tolist()) 78 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/lightning_loftr.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from loguru import logger 3 | from pathlib import Path 4 | 5 | import torch 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | 9 | from .matchformer import Matchformer 10 | from .utils.metrics import ( 11 | compute_symmetrical_epipolar_errors, 12 | compute_pose_errors, 13 | aggregate_metrics 14 | ) 15 | from .utils.comm import gather 16 | from .utils.misc import lower_config, flattenList 17 | from .utils.profiler import PassThroughProfiler 18 | 19 | 20 | class PL_LoFTR(pl.LightningModule): 21 | def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None): 22 | super().__init__() 23 | # Misc 24 | self.config = config # full config 25 | _config = lower_config(self.config) 26 | self.profiler = profiler or PassThroughProfiler() 27 | 28 | # Matcher: LoFTR 29 | self.matcher = Matchformer(config=_config['matchformer']) 30 | 31 | # Pretrained weights 32 | if pretrained_ckpt: 33 | self.matcher.load_state_dict({k.replace('matcher.',''):v for k,v in torch.load(pretrained_ckpt, map_location='cpu').items()}) 34 | logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint") 35 | 36 | # Testing 37 | self.dump_dir = dump_dir 38 | 39 | 40 | def _compute_metrics(self, batch): 41 | with self.profiler.profile("Copmute metrics"): 42 | compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match 43 | compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair 44 | 45 | rel_pair_names = list(zip(*batch['pair_names'])) 46 | bs = batch['image0'].size(0) 47 | metrics = { 48 | # to filter duplicate pairs caused by DistributedSampler 49 | 'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)], 50 | 'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)], 51 | 'R_errs': batch['R_errs'], 52 | 't_errs': batch['t_errs'], 53 | 'inliers': batch['inliers']} 54 | ret_dict = {'metrics': metrics} 55 | return ret_dict, rel_pair_names 56 | 57 | 58 | def test_step(self, batch, batch_idx): 59 | with self.profiler.profile("LoFTR"): 60 | self.matcher(batch) 61 | 62 | ret_dict, rel_pair_names = self._compute_metrics(batch) 63 | 64 | with self.profiler.profile("dump_results"): 65 | if self.dump_dir is not None: 66 | # dump results for further analysis 67 | keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf', 'epi_errs'} 68 | pair_names = list(zip(*batch['pair_names'])) 69 | bs = batch['image0'].shape[0] 70 | dumps = [] 71 | for b_id in range(bs): 72 | item = {} 73 | mask = batch['m_bids'] == b_id 74 | item['pair_names'] = pair_names[b_id] 75 | item['identifier'] = '#'.join(rel_pair_names[b_id]) 76 | for key in keys_to_save: 77 | item[key] = batch[key][mask].cpu().numpy() 78 | for key in ['R_errs', 't_errs', 'inliers']: 79 | item[key] = batch[key][b_id] 80 | dumps.append(item) 81 | ret_dict['dumps'] = dumps 82 | 83 | return ret_dict 84 | 85 | def test_epoch_end(self, outputs): 86 | # metrics: dict of list, numpy 87 | _metrics = [o['metrics'] for o in outputs] 88 | metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} 89 | 90 | # [{key: [{...}, *#bs]}, *#batch] 91 | if self.dump_dir is not None: 92 | Path(self.dump_dir).mkdir(parents=True, exist_ok=True) 93 | _dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch] 94 | dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch] 95 | logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}') 96 | 97 | if self.trainer.global_rank == 0: 98 | print(self.profiler.summary()) 99 | val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR) 100 | logger.info('\n' + pprint.pformat(val_metrics_4tb)) 101 | if self.dump_dir is not None: 102 | np.save(Path(self.dump_dir) / 'LoFTR_pred_eval', dumps) -------------------------------------------------------------------------------- /train_stage2/train_stage2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from dataset import Dataset 5 | from Loss_stage2 import Loss_stage2 6 | 7 | def train_csc(model_stage2, optimizer, data_loader_sa, stage1_weight_path, epoch): 8 | device = 'cuda' 9 | optimizer.zero_grad() 10 | # Read the required data set for registration 11 | for data_iter_step, (img_vi, img_ir, conf_gt, str_conf_gt) in enumerate(data_loader_sa): 12 | 13 | 14 | loss_0, loss_1 = model_stage2(img_vi.to(device), img_ir.to(device), conf_gt.to(device), str_conf_gt.to(device)) 15 | loss = loss_0 + loss_1 * (1 + 0.03 * epoch) 16 | 17 | if data_iter_step % 10 == 0: 18 | print('data_iter:', data_iter_step, 'loss_0:', loss_0.item(), 'loss_1:', loss_1.item()) 19 | 20 | loss.backward() 21 | optimizer.step() 22 | 23 | model_stage2.load_state_dict(torch.load(stage1_weight_path), strict=False) 24 | torch.cuda.synchronize() 25 | 26 | 27 | def train_ssr(model_stage2, optimizer, data_loader_sa, csc_weight_wossr, epoch): 28 | device = 'cuda' 29 | # Read the required data set for registration 30 | for data_iter_step, (img_vi, img_ir, conf_gt, str_conf_gt) in enumerate(data_loader_sa): 31 | 32 | loss_0, loss_1 = model_stage2(img_vi.to(device), img_ir.to(device), conf_gt.to(device), str_conf_gt.to(device)) 33 | 34 | loss = loss_0 35 | 36 | if data_iter_step % 10 == 0: 37 | print('data_iter:', data_iter_step, 'loss_0:', loss_0.item(), 'loss_1:', loss_1.item()) 38 | optimizer.zero_grad() 39 | loss.backward() 40 | optimizer.step() 41 | 42 | model_stage2.load_state_dict(csc_weight_wossr, strict=False) 43 | torch.cuda.synchronize() 44 | 45 | 46 | if __name__ == '__main__': 47 | # Configuring dataset paths 48 | path2IVS = "" 49 | path2IVS_CPSTN = "" 50 | path2IVS_Label = "" 51 | 52 | # Configure the size of the training image 53 | train_size = (320, 240) 54 | 55 | # Load the model weights obtained from the first stage of training 56 | stage1_weight_path = "./weights/stage1_14epoch.ckpt" 57 | 58 | # Device for training: 'cuda' or 'cpu' 59 | device = 'cuda' 60 | 61 | dataset = Dataset(path2IVS, path2IVS_CPSTN, path2IVS_Label, train_size_w = train_size[0], train_size_h = train_size[1]) 62 | 63 | dataset_sampler = torch.utils.data.RandomSampler(dataset) 64 | 65 | data_loader = torch.utils.data.DataLoader( 66 | dataset, 67 | sampler=dataset_sampler, 68 | batch_size=64, 69 | num_workers=8, 70 | pin_memory=True, 71 | drop_last=True 72 | ) 73 | 74 | 75 | model_stage2 = Loss_stage2() 76 | model_stage2.to(device) 77 | model_stage2.load_state_dict(torch.load(stage1_weight_path), strict=False) 78 | 79 | for k ,v in model_stage2.named_parameters(): 80 | 81 | if ('csc0' in k) or ('csc1' in k) or ('ssr' in k): 82 | 83 | v.requires_grad = True 84 | else: 85 | v.requires_grad = False 86 | 87 | for name, param in model_stage2.named_parameters(): 88 | if param.requires_grad: 89 | print(name) 90 | 91 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_stage2.parameters()), lr=4e-5) 92 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998) 93 | print(optimizer) 94 | 95 | model_stage2.train(True) 96 | 97 | epochs_train_csc = 2 98 | print('Start training!') 99 | 100 | # Train CSC 101 | print('Train CSC!') 102 | for epoch in range(0, epochs_train_csc): 103 | print('current epoch:', epoch) 104 | train_csc(model_stage2, optimizer, data_loader, stage1_weight_path, epoch) 105 | scheduler.step() 106 | torch.save(model_stage2.state_dict(), './weights/stage2_csc_{}epoch.ckpt'.format(epoch+1)) 107 | 108 | # Train SSR 109 | epochs_train_ssr = 10 110 | print('Train SSR!') 111 | csc_weight_path = "./weights/stage2_csc_2epoch.ckpt" 112 | csc_weight = torch.load(csc_weight_path) 113 | csc_weight_wossr = {key: value for key, value in csc_weight.items() if 'ssr' not in key} 114 | 115 | for epoch in range(0, epochs_train_ssr): 116 | print('current epoch:', epoch) 117 | train_ssr(model_stage2, optimizer, data_loader, csc_weight_wossr, epoch) 118 | scheduler.step() 119 | torch.save(model_stage2.state_dict(), './weights/stage2_ssr_{}epoch.ckpt'.format(epoch+1)) 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /train_stage1/train_stage1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from teacher.model.matchformer import Matchformer 5 | from teacher.config.defaultmf import default_cfg 6 | from reg_dataset import RegDataset 7 | from IVS_dataset import IVSDataset 8 | from Loss_stage1 import Loss_stage1 9 | 10 | 11 | def train_one_epoch(model_stage1, teacher, optimizer, data_loader_reg, data_loader_sa): 12 | device = 'cuda' 13 | dataloader_iterator = iter(data_loader_sa) 14 | 15 | # Read the required data set for registration 16 | for data_iter_step, (reg_img0, reg_img1, reg_conf_gt) in enumerate(data_loader_reg): 17 | 18 | # Read the required data set for semantic awareness 19 | try: 20 | (sa_img0, sa_conf_gt0, sa_img1, sa_conf_gt1) = next(dataloader_iterator) 21 | except StopIteration: 22 | dataloader_iterator = iter(data_loader_sa) 23 | (sa_img0, sa_conf_gt0, sa_img1, sa_conf_gt1) = next(dataloader_iterator) 24 | 25 | 26 | batch = {'image0': reg_img0.to(device), 'image1': reg_img1.to(device)} 27 | with torch.no_grad(): 28 | teacher(batch) 29 | teacher_pred = batch['conf_matrix'] 30 | teacher_pred = torch.clamp(teacher_pred, 1e-6, 1 - 1e-6) 31 | 32 | optimizer.zero_grad() 33 | loss_reg, loss_sa, student_pred = \ 34 | model_stage1(reg_img0.to(device), reg_img1.to(device), reg_conf_gt.to(device), 35 | sa_img0.to(device), sa_img1.to(device), sa_conf_gt0.to(device), sa_conf_gt1.to(device)) 36 | 37 | 38 | stu_loss = nn.KLDivLoss(reduction='batchmean') 39 | student_loss = stu_loss(torch.log(student_pred), teacher_pred) 40 | 41 | if data_iter_step % 10 == 0: 42 | print('data_iter:', data_iter_step, 'reg_loss:', loss_reg.item(), 'sa_loss:', loss_sa.item(), 'student_loss:', 43 | student_loss.item()) 44 | 45 | 46 | loss_stage1 = loss_reg+ loss_sa*0.4+ student_loss * 0.003 47 | loss_stage1.backward() 48 | 49 | optimizer.step() 50 | torch.cuda.synchronize() 51 | 52 | 53 | if __name__ == '__main__': 54 | # Configuring dataset paths 55 | path2COCO = "" 56 | path2COCO_CPSTN = "" 57 | path2IVS = "" 58 | path2IVS_CPSTN = "" 59 | path2IVS_Label = "" 60 | MatchFormer_weight_path = "" 61 | 62 | # Configure the size of the training image 63 | train_size = (320, 240) 64 | 65 | # Device for training: 'cuda' or 'cpu' 66 | device = 'cuda' 67 | 68 | # dataset for training registration 69 | Reg_data = RegDataset(path2COCO, path2COCO_CPSTN, train_size_w = train_size[0], train_size_h = train_size[1]) 70 | 71 | # dataset for training semantic awareness 72 | Sa_data = IVSDataset(path2IVS, 73 | path2IVS_CPSTN, 74 | path2IVS_Label, train_size_w = train_size[0], train_size_h = train_size[1]) 75 | 76 | 77 | Reg_sampler = torch.utils.data.RandomSampler(Reg_data) 78 | Sa_sampler = torch.utils.data.RandomSampler(Sa_data) 79 | 80 | data_loader_reg = torch.utils.data.DataLoader( 81 | Reg_data, 82 | sampler=Reg_sampler, 83 | batch_size=64, 84 | num_workers=8, 85 | pin_memory=True, 86 | drop_last=True 87 | ) 88 | data_loader_sa = torch.utils.data.DataLoader( 89 | Sa_data, 90 | sampler=Sa_sampler, 91 | batch_size=32, 92 | num_workers=8, 93 | pin_memory=True, 94 | drop_last=True 95 | ) 96 | 97 | model_stage1 = Loss_stage1() 98 | model_stage1.to(device) 99 | 100 | for name, param in model_stage1.named_parameters(): 101 | if param.requires_grad: 102 | print(name) 103 | 104 | 105 | # Knowledge distillation using MatchFormer 106 | use_registration_teacher = True 107 | if use_registration_teacher == True: 108 | teacher = Matchformer(config=default_cfg) 109 | 110 | # loading the weights of matchformer 111 | teacher.load_state_dict(torch.load(MatchFormer_weight_path), strict=False) 112 | teacher.eval().to(device) 113 | 114 | optimizer = torch.optim.Adam(model_stage1.parameters(), lr=3e-4) 115 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.908) 116 | print(optimizer) 117 | 118 | model_stage1.train(True) 119 | 120 | epochs = 13 121 | print('Start training!') 122 | for epoch in range(0, epochs): 123 | print('current epoch:', epoch) 124 | train_one_epoch(model_stage1, teacher, optimizer, data_loader_reg, data_loader_sa) 125 | scheduler.step() 126 | torch.save(model_stage1.backbone.state_dict(), './weights/stage1_{}epoch.ckpt'.format(epoch+1)) 127 | 128 | 129 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/datasets/scannet.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from typing import Dict 3 | from unicodedata import name 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils as utils 8 | from numpy.linalg import inv 9 | from .dataset import ( 10 | read_scannet_gray, 11 | read_scannet_depth, 12 | read_scannet_pose 13 | ) 14 | 15 | 16 | class ScanNetDataset(utils.data.Dataset): 17 | def __init__(self, 18 | root_dir, 19 | npz_path, 20 | intrinsic_path, 21 | mode='train', 22 | min_overlap_score=0.4, 23 | augment_fn=None, 24 | pose_dir=None, 25 | **kwargs): 26 | """Manage one scene of ScanNet Dataset. 27 | Args: 28 | root_dir (str): ScanNet root directory that contains scene folders. 29 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 30 | intrinsic_path (str): path to depth-camera intrinsic file. 31 | mode (str): options are ['train', 'val', 'test']. 32 | augment_fn (callable, optional): augments images with pre-defined visual effects. 33 | pose_dir (str): ScanNet root directory that contains all poses. 34 | (we use a separate (optional) pose_dir since we store images and poses separately.) 35 | """ 36 | super().__init__() 37 | self.root_dir = root_dir 38 | self.pose_dir = pose_dir if pose_dir is not None else root_dir 39 | self.mode = mode 40 | 41 | # prepare data_names, intrinsics and extrinsics(T) 42 | with np.load(npz_path) as data: 43 | self.data_names = data['name'] 44 | if 'score' in data.keys() and mode not in ['val' or 'test']: 45 | kept_mask = data['score'] > min_overlap_score 46 | self.data_names = self.data_names[kept_mask] 47 | self.intrinsics = dict(np.load(intrinsic_path)) 48 | 49 | # for training LoFTR 50 | self.augment_fn = augment_fn if mode == 'train' else None 51 | 52 | def __len__(self): 53 | return len(self.data_names) 54 | 55 | def _read_abs_pose(self, scene_name, name): 56 | pth = osp.join(self.pose_dir, 57 | scene_name, 58 | 'pose', f'{name}.txt') 59 | return read_scannet_pose(pth) 60 | 61 | def _compute_rel_pose(self, scene_name, name0, name1): 62 | pose0 = self._read_abs_pose(scene_name, name0) 63 | pose1 = self._read_abs_pose(scene_name, name1) 64 | 65 | return np.matmul(pose1, inv(pose0)) # (4, 4) 66 | 67 | def __getitem__(self, idx): 68 | data_name = self.data_names[idx] 69 | scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name 70 | scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' 71 | 72 | # read the grayscale image which will be resized to (1, 480, 640) 73 | img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') 74 | img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') 75 | 76 | # TODO: Support augmentation & handle seeds for each worker correctly. 77 | image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None) 78 | # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 79 | image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None) 80 | # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 81 | 82 | # read the depthmap which is stored as (480, 640) 83 | if self.mode in ['train', 'val']: 84 | depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) 85 | depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) 86 | else: 87 | depth0 = depth1 = torch.tensor([]) 88 | 89 | # read the intrinsic of depthmap 90 | K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) 91 | 92 | # read and compute relative poses 93 | T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), 94 | dtype=torch.float32) 95 | T_1to0 = T_0to1.inverse() 96 | 97 | data = { 98 | 'image0': image0, # (1, h, w) 99 | 'depth0': depth0, # (h, w) 100 | 'image1': image1, 101 | 'depth1': depth1, 102 | 'T_0to1': T_0to1, # (4, 4) 103 | 'T_1to0': T_1to0, 104 | 'K0': K_0, # (3, 3) 105 | 'K1': K_1, 106 | 'dataset_name': 'ScanNet', 107 | 'scene_id': scene_name, 108 | 'pair_id': idx, 109 | 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), 110 | osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) 111 | } 112 | 113 | return data 114 | -------------------------------------------------------------------------------- /train_stage1/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import pywt 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | def conv1x1(in_channels, out_channels, stride=1): 8 | """1 x 1 convolution""" 9 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) 10 | 11 | 12 | def conv3x3(in_channels, out_channels, stride=1): 13 | """3 x 3 convolution""" 14 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 15 | 16 | 17 | class CBR(nn.Module): 18 | """3 x 3 convolution block 19 | 20 | Args: 21 | 'x': (torch.Tensor): (N, C, H, W) 22 | """ 23 | def __init__(self, in_channels, planes, stride=1): 24 | super().__init__() 25 | self.conv = conv3x3(in_channels, planes, stride) 26 | self.bn = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | return self.relu(self.bn(self.conv(x))) 31 | 32 | 33 | class DWConv(nn.Module): 34 | """DepthWise convolution block 35 | 36 | Args: 37 | 'x': (torch.Tensor): (N, C, H, W) 38 | """ 39 | def __init__(self, out_channels): 40 | super().__init__() 41 | self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, 42 | padding=1, groups=out_channels, bias=False) 43 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 44 | self.act = nn.ReLU(inplace=True) 45 | self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False) 46 | 47 | def forward(self, x): 48 | out = self.group_conv3x3(x) 49 | out = self.norm(out) 50 | out = self.act(out) 51 | out = self.projection(out) 52 | return out 53 | 54 | 55 | class MLP(nn.Module): 56 | """MLP Layer 57 | 58 | Args: 59 | 'x': (torch.Tensor): (N, C, H, W) 60 | """ 61 | def __init__(self, out_channels, bias=True): 62 | super().__init__() 63 | self.conv1 = nn.Conv2d(out_channels, out_channels * 2, kernel_size=1, bias=bias) 64 | self.act = nn.ReLU(inplace=True) 65 | self.conv2 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=bias) 66 | 67 | def forward(self, x): 68 | x = self.conv1(x) 69 | x = self.act(x) 70 | x = self.conv2(x) 71 | return x 72 | 73 | # This class is implemented by [Wave-ViT](https://github.com/YehLi/ImageNetModel/blob/main/classification/torch_wavelets.py). 74 | class DWT_2D(nn.Module): 75 | """Discrete Wavelet Transform for feature maps downsampling 76 | 77 | Args: 78 | 'x': (torch.Tensor): (N, C, H, W) 79 | """ 80 | def __init__(self, wave): 81 | super(DWT_2D, self).__init__() 82 | w = pywt.Wavelet(wave) 83 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 84 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 85 | 86 | w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1) 87 | w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1) 88 | w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1) 89 | w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1) 90 | 91 | self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0)) 92 | self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0)) 93 | self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0)) 94 | self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0)) 95 | 96 | self.w_ll = self.w_ll.to(dtype=torch.float32) 97 | self.w_lh = self.w_lh.to(dtype=torch.float32) 98 | self.w_hl = self.w_hl.to(dtype=torch.float32) 99 | self.w_hh = self.w_hh.to(dtype=torch.float32) 100 | 101 | def forward(self, x): 102 | return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh) 103 | 104 | 105 | # This class is implemented by [Wave-ViT](https://github.com/YehLi/ImageNetModel/blob/main/classification/torch_wavelets.py). 106 | class DWT_Function(Function): 107 | @staticmethod 108 | def forward(ctx, x, w_ll, w_lh, w_hl, w_hh): 109 | x = x.contiguous() 110 | ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh) 111 | ctx.shape = x.shape 112 | 113 | dim = x.shape[1] 114 | x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=2, groups=dim) 115 | x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=2, groups=dim) 116 | x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=2, groups=dim) 117 | x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=2, groups=dim) 118 | x = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1) 119 | return x 120 | 121 | @staticmethod 122 | def backward(ctx, dx): 123 | if ctx.needs_input_grad[0]: 124 | w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors 125 | B, C, H, W = ctx.shape 126 | dx = dx.view(B, 4, -1, H // 2, W // 2) 127 | 128 | dx = dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2) 129 | filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0).repeat(C, 1, 1, 1) 130 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C) 131 | 132 | return dx, None, None, None, None -------------------------------------------------------------------------------- /train_stage3/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | import numpy as np 7 | from PIL import Image, ImageFilter 8 | from torchvision import transforms as tfs 9 | 10 | def get_dst_point(perspective, IMAGE_SHAPE): 11 | a = random.random() 12 | b = random.random() 13 | c = random.random() 14 | d = random.random() 15 | e = random.random() 16 | f = random.random() 17 | 18 | if random.random() > 0.5: 19 | left_top_x = perspective * a 20 | left_top_y = perspective * b 21 | right_top_x = 0.9 + perspective * c 22 | right_top_y = perspective * d 23 | left_bottom_x = perspective * a 24 | left_bottom_y = 0.9 + perspective * e 25 | right_bottom_x = 0.9 + perspective * c 26 | right_bottom_y = 0.9 + perspective * f 27 | else: 28 | left_top_x = perspective * a 29 | left_top_y = perspective * b 30 | right_top_x = 0.9 + perspective * c 31 | right_top_y = perspective * d 32 | left_bottom_x = perspective * e 33 | left_bottom_y = 0.9 + perspective * b 34 | right_bottom_x = 0.9 + perspective * f 35 | right_bottom_y = 0.9 + perspective * d 36 | 37 | dst_point = np.array([(IMAGE_SHAPE[1] * left_top_x, IMAGE_SHAPE[0] * left_top_y, 1), 38 | (IMAGE_SHAPE[1] * right_top_x, IMAGE_SHAPE[0] * right_top_y, 1), 39 | (IMAGE_SHAPE[1] * left_bottom_x, IMAGE_SHAPE[0] * left_bottom_y, 1), 40 | (IMAGE_SHAPE[1] * right_bottom_x, IMAGE_SHAPE[0] * right_bottom_y, 1)], dtype='float32') 41 | return dst_point 42 | 43 | class PilGaussianBlur(ImageFilter.Filter): 44 | name = "GaussianBlur" 45 | 46 | def __init__(self, radius=2): 47 | self.radius = radius 48 | 49 | def filter(self, image): 50 | return image.gaussian_blur(self.radius) 51 | 52 | def enhance(img0, img1, IMAGE_SHAPE): 53 | # The four vertices of the image 54 | src_point = np.array([[0, 0], 55 | [IMAGE_SHAPE[1] - 1, 0], 56 | [0, IMAGE_SHAPE[0] - 1], 57 | [IMAGE_SHAPE[1] - 1, IMAGE_SHAPE[0] - 1]], dtype=np.float32) 58 | 59 | # Perspective Information 60 | dst_point = get_dst_point(0.1, IMAGE_SHAPE) 61 | 62 | # Rotation and scale transformation 63 | rotation = 40 64 | rot = random.randint(-rotation, rotation) 65 | scale = 1.2 + random.randint(-90, 100) * 0.01 66 | 67 | center_offset = 40 68 | center = (IMAGE_SHAPE[1] / 2 + random.randint(-center_offset, center_offset), 69 | IMAGE_SHAPE[0] / 2 + random.randint(-center_offset, center_offset)) 70 | 71 | RS_mat = cv2.getRotationMatrix2D(center, rot, scale) 72 | f_point = np.matmul(dst_point, RS_mat.T).astype('float32') 73 | mat = cv2.getPerspectiveTransform(src_point, f_point) 74 | out_img0, out_img1 = cv2.warpPerspective(img0, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])),\ 75 | cv2.warpPerspective(img1, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])) 76 | 77 | 78 | return out_img0, out_img1, mat 79 | 80 | 81 | class Dataset(Dataset): 82 | def __init__(self, img_path, trans_path, train_size_w, train_size_h): 83 | super(Dataset, self).__init__() 84 | self.img_path = img_path 85 | self.trans_path = trans_path 86 | self.trans = transforms.Compose([ 87 | transforms.ToTensor() 88 | ]) 89 | 90 | data_floder = os.listdir(self.img_path) 91 | self.data_floder = data_floder 92 | 93 | self.train_size_w = int(train_size_w) 94 | self.train_size_h = int(train_size_h) 95 | 96 | self.feat_size_w = int(self.train_size_w / 8) 97 | self.feat_size_h = int(self.train_size_h / 8) 98 | 99 | self.feat_size_wh = int(self.feat_size_w * self.feat_size_h) 100 | 101 | def __getitem__(self, idx): 102 | item = self.data_floder[idx] 103 | 104 | 105 | 106 | seed = random.random() 107 | if seed < 0.7: 108 | img = cv2.imread(os.path.join(self.img_path, item), cv2.IMREAD_GRAYSCALE) 109 | img = cv2.resize(img, (self.train_size_w, self.train_size_h)) 110 | else: 111 | img = cv2.imread(os.path.join(self.trans_path, item[-16:-3] + 'png'), cv2.IMREAD_GRAYSCALE) 112 | img = cv2.resize(img, (self.train_size_w, self.train_size_h)) 113 | 114 | seed = random.random() 115 | if seed < 0.5: 116 | a = random.randint(-1, 1) 117 | img = cv2.flip(img, flipCode=a) 118 | 119 | seed = random.random() 120 | if seed < 0.2: 121 | (h, w) = img.shape[:2] 122 | center = (w // 2, h // 2) 123 | M = cv2.getRotationMatrix2D(center, random.randint(-10, 10), random.randint(3, 8) * 0.1) 124 | img = cv2.warpAffine(img, M, (w, h)) 125 | 126 | img = Image.fromarray(img) 127 | img = tfs.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.4, hue=0.1)(img) 128 | 129 | return self.trans(img) 130 | 131 | def __len__(self): 132 | return len(self.data_floder) 133 | 134 | 135 | def motion_blur(image, degree=15, angle=45): 136 | image = np.array(image) 137 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 138 | motion_blur_kernel = np.diag(np.ones(degree)) 139 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree)) 140 | 141 | motion_blur_kernel = motion_blur_kernel / degree 142 | blurred = cv2.filter2D(image, -1, motion_blur_kernel) 143 | cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) 144 | blurred = np.array(blurred, dtype=np.uint8) 145 | return blurred 146 | 147 | 148 | def gasuss_noise(image, mean=0, var=0.001): 149 | image = np.array(image / 255, dtype=float) 150 | noise = np.random.normal(mean, var ** 0.5, image.shape) 151 | out = image + noise 152 | out = np.clip(out, 0.0, 1.0) 153 | out = np.uint8(out * 255) 154 | return out -------------------------------------------------------------------------------- /train_stage1/teacher/model/datasets/megadepth.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | from loguru import logger 7 | from .dataset import read_megadepth_gray, read_megadepth_depth 8 | 9 | 10 | class MegaDepthDataset(Dataset): 11 | def __init__(self, 12 | root_dir, 13 | npz_path, 14 | mode='train', 15 | min_overlap_score=0.4, 16 | img_resize=None, 17 | df=None, 18 | img_padding=False, 19 | depth_padding=False, 20 | augment_fn=None, 21 | **kwargs): 22 | """ 23 | Manage one scene(npz_path) of MegaDepth dataset. 24 | 25 | Args: 26 | root_dir (str): megadepth root directory that has `phoenix`. 27 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 28 | mode (str): options are ['train', 'val', 'test'] 29 | min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. 30 | img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. 31 | This is useful during training with batches and testing with memory intensive algorithms. 32 | df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. 33 | img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. 34 | depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. 35 | augment_fn (callable, optional): augments images with pre-defined visual effects. 36 | """ 37 | super().__init__() 38 | self.root_dir = root_dir 39 | self.mode = mode 40 | self.scene_id = npz_path.split('.')[0] 41 | 42 | # prepare scene_info and pair_info 43 | if mode == 'test' and min_overlap_score != 0: 44 | logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") 45 | min_overlap_score = 0 46 | self.scene_info = np.load(npz_path, allow_pickle=True) 47 | self.pair_infos = self.scene_info['pair_infos'].copy() 48 | del self.scene_info['pair_infos'] 49 | self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] 50 | 51 | # parameters for image resizing, padding and depthmap padding 52 | if mode == 'train': 53 | assert img_resize is not None and img_padding and depth_padding 54 | self.img_resize = img_resize 55 | self.df = df 56 | self.img_padding = img_padding 57 | self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. 58 | 59 | # for training LoFTR 60 | self.augment_fn = augment_fn if mode == 'train' else None 61 | self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) 62 | 63 | def __len__(self): 64 | return len(self.pair_infos) 65 | 66 | def __getitem__(self, idx): 67 | (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] 68 | 69 | # read grayscale image and mask. (1, h, w) and (h, w) 70 | img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) 71 | img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) 72 | 73 | # TODO: Support augmentation & handle seeds for each worker correctly. 74 | image0, mask0, scale0 = read_megadepth_gray( 75 | img_name0, self.img_resize, self.df, self.img_padding, None) 76 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 77 | image1, mask1, scale1 = read_megadepth_gray( 78 | img_name1, self.img_resize, self.df, self.img_padding, None) 79 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 80 | 81 | # read depth. shape: (h, w) 82 | if self.mode in ['train', 'val']: 83 | depth0 = read_megadepth_depth( 84 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) 85 | depth1 = read_megadepth_depth( 86 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) 87 | else: 88 | depth0 = depth1 = torch.tensor([]) 89 | 90 | # read intrinsics of original size 91 | K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) 92 | K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) 93 | 94 | # read and compute relative poses 95 | T0 = self.scene_info['poses'][idx0] 96 | T1 = self.scene_info['poses'][idx1] 97 | T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) 98 | T_1to0 = T_0to1.inverse() 99 | 100 | data = { 101 | 'image0': image0, # (1, h, w) 102 | 'depth0': depth0, # (h, w) 103 | 'image1': image1, 104 | 'depth1': depth1, 105 | 'T_0to1': T_0to1, # (4, 4) 106 | 'T_1to0': T_1to0, 107 | 'K0': K_0, # (3, 3) 108 | 'K1': K_1, 109 | 'scale0': scale0, # [scale_w, scale_h] 110 | 'scale1': scale1, 111 | 'dataset_name': 'MegaDepth', 112 | 'scene_id': self.scene_id, 113 | 'pair_id': idx, 114 | 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), 115 | } 116 | 117 | # for LoFTR training 118 | if mask0 is not None: # img_padding is True 119 | if self.coarse_scale: 120 | [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), 121 | scale_factor=self.coarse_scale, 122 | mode='nearest', 123 | recompute_scale_factor=False)[0].bool() 124 | data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) 125 | 126 | return data 127 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from collections import OrderedDict 5 | from loguru import logger 6 | from kornia.geometry.epipolar import numeric 7 | from kornia.geometry.conversions import convert_points_to_homogeneous 8 | 9 | 10 | # --- METRICS --- 11 | 12 | def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): 13 | # angle error between 2 vectors 14 | t_gt = T_0to1[:3, 3] 15 | n = np.linalg.norm(t) * np.linalg.norm(t_gt) 16 | t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) 17 | t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity 18 | if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging 19 | t_err = 0 20 | 21 | # angle error between 2 rotation matrices 22 | R_gt = T_0to1[:3, :3] 23 | cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 24 | cos = np.clip(cos, -1., 1.) # handle numercial errors 25 | R_err = np.rad2deg(np.abs(np.arccos(cos))) 26 | 27 | return t_err, R_err 28 | 29 | 30 | def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): 31 | """Squared symmetric epipolar distance. 32 | This can be seen as a biased estimation of the reprojection error. 33 | Args: 34 | pts0 (torch.Tensor): [N, 2] 35 | E (torch.Tensor): [3, 3] 36 | """ 37 | pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 38 | pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 39 | pts0 = convert_points_to_homogeneous(pts0) 40 | pts1 = convert_points_to_homogeneous(pts1) 41 | 42 | Ep0 = pts0 @ E.T # [N, 3] 43 | p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] 44 | Etp1 = pts1 @ E # [N, 3] 45 | 46 | d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N 47 | return d 48 | 49 | 50 | def compute_symmetrical_epipolar_errors(data): 51 | """ 52 | Update: 53 | data (dict):{"epi_errs": [M]} 54 | """ 55 | Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) 56 | E_mat = Tx @ data['T_0to1'][:, :3, :3] 57 | 58 | m_bids = data['m_bids'] 59 | pts0 = data['mkpts0_f'] 60 | pts1 = data['mkpts1_f'] 61 | 62 | epi_errs = [] 63 | for bs in range(Tx.size(0)): 64 | mask = m_bids == bs 65 | epi_errs.append( 66 | symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) 67 | epi_errs = torch.cat(epi_errs, dim=0) 68 | 69 | data.update({'epi_errs': epi_errs}) 70 | 71 | 72 | def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): 73 | if len(kpts0) < 5: 74 | return None 75 | # normalize keypoints 76 | kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 77 | kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 78 | 79 | # normalize ransac threshold 80 | ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) 81 | 82 | # compute pose with cv2 83 | E, mask = cv2.findEssentialMat( 84 | kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) 85 | if E is None: 86 | print("\nE is None while trying to recover pose.\n") 87 | return None 88 | 89 | # recover pose from E 90 | best_num_inliers = 0 91 | ret = None 92 | for _E in np.split(E, len(E) / 3): 93 | n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) 94 | if n > best_num_inliers: 95 | ret = (R, t[:, 0], mask.ravel() > 0) 96 | best_num_inliers = n 97 | 98 | return ret 99 | 100 | 101 | def compute_pose_errors(data, config): 102 | """ 103 | Update: 104 | data (dict):{ 105 | "R_errs" List[float]: [N] 106 | "t_errs" List[float]: [N] 107 | "inliers" List[np.ndarray]: [N] 108 | } 109 | """ 110 | pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 111 | conf = config.TRAINER.RANSAC_CONF # 0.99999 112 | data.update({'R_errs': [], 't_errs': [], 'inliers': []}) 113 | 114 | m_bids = data['m_bids'].cpu().numpy() 115 | pts0 = data['mkpts0_f'].cpu().numpy() 116 | pts1 = data['mkpts1_f'].cpu().numpy() 117 | K0 = data['K0'].cpu().numpy() 118 | K1 = data['K1'].cpu().numpy() 119 | T_0to1 = data['T_0to1'].cpu().numpy() 120 | 121 | for bs in range(K0.shape[0]): 122 | mask = m_bids == bs 123 | ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) 124 | 125 | if ret is None: 126 | data['R_errs'].append(np.inf) 127 | data['t_errs'].append(np.inf) 128 | data['inliers'].append(np.array([]).astype(np.bool)) 129 | else: 130 | R, t, inliers = ret 131 | t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) 132 | data['R_errs'].append(R_err) 133 | data['t_errs'].append(t_err) 134 | data['inliers'].append(inliers) 135 | 136 | 137 | # --- METRIC AGGREGATION --- 138 | 139 | def error_auc(errors, thresholds): 140 | """ 141 | Args: 142 | errors (list): [N,] 143 | thresholds (list) 144 | """ 145 | errors = [0] + sorted(list(errors)) 146 | recall = list(np.linspace(0, 1, len(errors))) 147 | 148 | aucs = [] 149 | thresholds = [5, 10, 20] 150 | for thr in thresholds: 151 | last_index = np.searchsorted(errors, thr) 152 | y = recall[:last_index] + [recall[last_index-1]] 153 | x = errors[:last_index] + [thr] 154 | aucs.append(np.trapz(y, x) / thr) 155 | 156 | return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} 157 | 158 | 159 | def epidist_prec(errors, thresholds, ret_dict=False): 160 | precs = [] 161 | for thr in thresholds: 162 | prec_ = [] 163 | for errs in errors: 164 | correct_mask = errs < thr 165 | prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) 166 | precs.append(np.mean(prec_) if len(prec_) > 0 else 0) 167 | if ret_dict: 168 | return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} 169 | else: 170 | return precs 171 | 172 | 173 | def aggregate_metrics(metrics, epi_err_thr=5e-4): 174 | """ Aggregate metrics for the whole dataset: 175 | (This method should be called once per dataset) 176 | 1. AUC of the pose error (angular) at the threshold [5, 10, 20] 177 | 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) 178 | """ 179 | # filter duplicates 180 | unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) 181 | unq_ids = list(unq_ids.values()) 182 | logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') 183 | 184 | # pose auc 185 | angular_thresholds = [5, 10, 20] 186 | pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] 187 | aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) 188 | 189 | # matching precision 190 | dist_thresholds = [epi_err_thr] 191 | precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) 192 | 193 | return {**aucs, **precs} 194 | -------------------------------------------------------------------------------- /model/reg.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from model.utils import CBR, DWConv, MLP, MLP2, DWT_2D, StructureAttention 3 | from einops.einops import rearrange 4 | import torch 5 | 6 | 7 | class SemLA_Reg(nn.Module): 8 | """ 9 | The registration section of SemLA 10 | """ 11 | 12 | def __init__(self): 13 | super().__init__() 14 | # Discrete Wavelet Transform (For feature map downsampling) 15 | self.dwt = DWT_2D(wave='haar') 16 | 17 | self.reg0 = JConv(1, 8) 18 | self.reg1 = JConv(32, 16) 19 | self.reg2 = JConv(64, 32) 20 | self.reg3 = JConv(128, 256) 21 | self.pred_reg = nn.Sequential(JConv(256, 256), JConv(256, 256), JConv(256, 256)) 22 | 23 | self.sa0 = JConv(256, 256) 24 | self.sa1 = JConv(256, 128) 25 | self.sa2 = JConv(128, 32) 26 | self.sa3 = JConv(32, 1) 27 | self.pred_sa = nn.Sigmoid() 28 | 29 | self.csc0 = CrossModalAttention(256) 30 | self.csc1 = CrossModalAttention(256) 31 | 32 | self.ssr = SemanticStructureRepresentation() 33 | 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 38 | nn.init.constant_(m.weight, 1) 39 | nn.init.constant_(m.bias, 0) 40 | 41 | def forward(self, x): 42 | # Extraction of registration features 43 | x0 = self.reg0(x) 44 | x1 = self.reg1(self.dwt(x0)) 45 | x2 = self.reg2(self.dwt(x1)) 46 | x3 = self.reg3(self.dwt(x2)) 47 | feat_reg = self.pred_reg(x3) 48 | 49 | bs2 = feat_reg.shape[0] 50 | (feat_reg_vi, feat_reg_ir) = feat_reg.split(int(bs2 / 2)) 51 | h = feat_reg.shape[2] 52 | w = feat_reg.shape[3] 53 | 54 | # Predicting semantic awareness maps for infrared images 55 | feat_sa_ir = self.sa0(feat_reg_ir) 56 | feat_sa_ir = self.sa1(feat_sa_ir) 57 | feat_sa_ir = self.sa2(feat_sa_ir) 58 | feat_sa_ir = self.sa3(feat_sa_ir) 59 | feat_sa_ir = self.pred_sa(feat_sa_ir) 60 | 61 | # Flatten 62 | feat_sa_ir_flatten = rearrange(feat_sa_ir, 'n c h w -> n c (h w)') 63 | feat_reg_vi_flatten_ = rearrange(feat_reg_vi, 'n c h w -> n (h w) c') 64 | feat_reg_ir_flatten = rearrange(feat_reg_ir, 'n c h w -> n (h w) c') 65 | 66 | # Feature Similarity Calculation 67 | feat_reg_vi_flatten, feat_reg_ir_flatten = map(lambda feat: feat / feat.shape[-1] ** .5, 68 | [feat_reg_vi_flatten_, feat_reg_ir_flatten]) 69 | attention = torch.einsum("nlc,nsc->nls", feat_reg_vi_flatten, 70 | feat_reg_ir_flatten) / 0.1 71 | attention = attention.softmax(dim=1) 72 | 73 | # Generate cross-modal guidance information 74 | attention = torch.einsum("nls,ncs->nls", attention, feat_sa_ir_flatten) 75 | attention = torch.sum(attention, dim=2) 76 | 77 | # Calibration of semantic features of visible images 78 | feat_reg_vi_ca = self.csc0(feat_reg_vi_flatten_, attention * 1.5) 79 | feat_reg_vi_ca = self.csc1(feat_reg_vi_ca, attention * 1.5) 80 | feat_reg_vi_ca = rearrange(feat_reg_vi_ca, 'n (h w) c -> n c h w', h=h, w=w) 81 | 82 | # Predicting semantic awareness maps for visible images 83 | feat_sa_vi = self.sa0(feat_reg_vi_ca) 84 | feat_sa_vi = self.sa1(feat_sa_vi) 85 | feat_sa_vi = self.sa2(feat_sa_vi) 86 | feat_sa_vi = self.sa3(feat_sa_vi) 87 | feat_sa_vi = self.pred_sa(feat_sa_vi) 88 | 89 | # Semantic structure representation learning 90 | feat_reg_vi_str, feat_reg_ir_str = self.ssr(feat_sa_vi, feat_sa_ir) 91 | feat_reg_vi_final = feat_reg_vi + feat_reg_vi_str 92 | feat_reg_ir_final = feat_reg_ir + feat_reg_ir_str 93 | 94 | return feat_reg_vi_final, feat_reg_ir_final, feat_sa_vi, feat_sa_ir 95 | 96 | 97 | class JConv(nn.Module): 98 | """Joint Convolutional blocks 99 | 100 | Args: 101 | 'x' (torch.Tensor): (N, C, H, W) 102 | """ 103 | 104 | def __init__(self, in_channels, out_channels): 105 | super(JConv, self).__init__() 106 | self.feat_trans = CBR(in_channels, out_channels) 107 | self.dwconv = DWConv(out_channels) 108 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 109 | self.mlp = MLP(out_channels, bias=True) 110 | 111 | def forward(self, x): 112 | x = self.feat_trans(x) 113 | x = x + self.dwconv(x) 114 | out = self.norm(x) 115 | x = x + self.mlp(out) 116 | return x 117 | 118 | 119 | class CrossModalAttention(nn.Module): 120 | """Cross-modal semantic calibration 121 | 122 | Args: 123 | 'feat' (torch.Tensor): (N, L) 124 | 'attention' (torch.Tensor): (N, L, C) 125 | """ 126 | 127 | def __init__(self, dim, mlp_ratio=2, qkv_bias=False): 128 | super(CrossModalAttention, self).__init__() 129 | self.qkv = nn.Linear(dim, dim, bias=qkv_bias) 130 | self.proj_out = nn.Linear(dim, dim) 131 | self.norm1 = nn.LayerNorm(dim) 132 | self.norm2 = nn.LayerNorm(dim) 133 | self.mlp = MLP2(dim, mlp_ratio) 134 | 135 | def forward(self, feat, attention): 136 | shortcut = feat 137 | feat = self.norm1(feat) 138 | feat = self.qkv(feat) 139 | x = torch.einsum('nl, nlc -> nlc', attention, feat) 140 | x = self.proj_out(x) 141 | x = x + shortcut 142 | x = x + self.mlp(self.norm2(x)) 143 | return x 144 | 145 | 146 | class SemanticStructureRepresentation(nn.Module): 147 | """Cross-modal semantic calibration 148 | 149 | Args: 150 | 'feat' (torch.Tensor): (N, L) 151 | 'attention' (torch.Tensor): (N, L, C) 152 | """ 153 | 154 | def __init__(self): 155 | super(SemanticStructureRepresentation, self).__init__() 156 | self.grid_embedding = JConv(2, 256) 157 | self.semantic_embedding = JConv(256, 256) 158 | self.attention = StructureAttention(256, 8) 159 | 160 | def forward(self, feat_sa_vi, feat_sa_ir): 161 | feat_h = feat_sa_vi.shape[2] 162 | feat_w = feat_sa_vi.shape[3] 163 | # Predefined spatial grid 164 | xs = torch.linspace(0, feat_h - 1, feat_h) 165 | ys = torch.linspace(0, feat_w - 1, feat_w) 166 | xs = xs / (feat_h - 1) 167 | ys = ys / (feat_w - 1) 168 | grid = torch.stack(torch.meshgrid([xs, ys]), dim=-1).unsqueeze(0).repeat(int(feat_sa_vi.shape[0]), 1, 1, 169 | 1).cuda() 170 | 171 | h = grid.shape[1] 172 | w = grid.shape[2] 173 | grid = rearrange(grid, 'n h w c -> n c h w') 174 | 175 | # Embedding position information into a high-dimensional space 176 | grid = self.grid_embedding(grid) 177 | 178 | # Embedding semantic information 179 | semantic_grid_vi = grid * feat_sa_vi 180 | semantic_grid_ir = grid * feat_sa_ir 181 | 182 | semantic_grid_vi = self.semantic_embedding(semantic_grid_vi) 183 | semantic_grid_ir = self.semantic_embedding(semantic_grid_ir) 184 | 185 | semantic_grid_vi, semantic_grid_ir = rearrange(semantic_grid_vi, 'n c h w -> n (h w) c'), rearrange( 186 | semantic_grid_ir, 'n c h w -> n (h w) c') 187 | 188 | semantic_grid_vi = self.attention(semantic_grid_vi) 189 | semantic_grid_ir = self.attention(semantic_grid_ir) 190 | semantic_grid_vi, semantic_grid_ir = rearrange(semantic_grid_vi, 'n (h w) c -> n c h w', h=h, w=w), rearrange( 191 | semantic_grid_ir, 'n (h w) c -> n c h w', h=h, w=w) 192 | 193 | return semantic_grid_vi, semantic_grid_ir 194 | -------------------------------------------------------------------------------- /train_stage2/Reg.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from utils import CBR, DWConv, MLP, MLP2, DWT_2D, StructureAttention 3 | from einops.einops import rearrange 4 | import torch 5 | 6 | class SemLA_Reg(nn.Module): 7 | """ 8 | The registration section of SemLA 9 | """ 10 | def __init__(self): 11 | super().__init__() 12 | # Discrete Wavelet Transform (For feature map downsampling) 13 | self.dwt = DWT_2D(wave='haar') 14 | 15 | self.reg0 = JConv(1, 8) 16 | self.reg1 = JConv(32, 16) 17 | self.reg2 = JConv(64, 32) 18 | self.reg3 = JConv(128, 256) 19 | self.pred_reg = nn.Sequential(JConv(256, 256), JConv(256, 256), JConv(256, 256)) 20 | 21 | self.sa0 = JConv(256, 256) 22 | self.sa1 = JConv(256, 128) 23 | self.sa2 = JConv(128, 32) 24 | self.sa3 = JConv(32, 1) 25 | self.pred_sa = nn.Sigmoid() 26 | 27 | self.csc0 = CrossModalAttention(256) 28 | self.csc1 = CrossModalAttention(256) 29 | 30 | self.ssr = SemanticStructureRepresentation() 31 | 32 | 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 36 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 37 | nn.init.constant_(m.weight, 1) 38 | nn.init.constant_(m.bias, 0) 39 | 40 | def forward(self, x): 41 | # Extraction of registration features 42 | x0 = self.reg0(x) 43 | x1 = self.reg1(self.dwt(x0)) 44 | x2 = self.reg2(self.dwt(x1)) 45 | x3 = self.reg3(self.dwt(x2)) 46 | feat_reg = self.pred_reg(x3) 47 | 48 | bs2 = feat_reg.shape[0] 49 | (feat_reg_vi, feat_reg_ir) = feat_reg.split(int(bs2 / 2)) 50 | h = feat_reg.shape[2] 51 | w = feat_reg.shape[3] 52 | 53 | # Predicting semantic awareness maps for infrared images 54 | feat_sa_ir = self.sa0(feat_reg_ir) 55 | feat_sa_ir = self.sa1(feat_sa_ir) 56 | feat_sa_ir = self.sa2(feat_sa_ir) 57 | feat_sa_ir = self.sa3(feat_sa_ir) 58 | feat_sa_ir = self.pred_sa(feat_sa_ir) 59 | 60 | # Flatten 61 | feat_sa_ir_flatten = rearrange(feat_sa_ir, 'n c h w -> n c (h w)') 62 | feat_reg_vi_flatten_ = rearrange(feat_reg_vi, 'n c h w -> n (h w) c') 63 | feat_reg_ir_flatten = rearrange(feat_reg_ir, 'n c h w -> n (h w) c') 64 | 65 | # Feature Similarity Calculation 66 | feat_reg_vi_flatten, feat_reg_ir_flatten = map(lambda feat: feat / feat.shape[-1] ** .5, 67 | [feat_reg_vi_flatten_, feat_reg_ir_flatten]) 68 | attention = torch.einsum("nlc,nsc->nls", feat_reg_vi_flatten, 69 | feat_reg_ir_flatten) / 0.1 70 | attention = attention.softmax(dim=1) 71 | 72 | # Generate cross-modal guidance information 73 | attention = torch.einsum("nls,ncs->nls", attention, feat_sa_ir_flatten) 74 | attention = torch.sum(attention, dim=2) 75 | 76 | # Calibration of semantic features of visible images 77 | feat_reg_vi_ca = self.csc0(feat_reg_vi_flatten_, attention * 1.5) 78 | feat_reg_vi_ca = self.csc1(feat_reg_vi_ca, attention * 1.5) 79 | feat_reg_vi_ca = rearrange(feat_reg_vi_ca, 'n (h w) c -> n c h w', h=h, w=w) 80 | 81 | # Predicting semantic awareness maps for visible images 82 | feat_sa_vi = self.sa0(feat_reg_vi_ca) 83 | feat_sa_vi = self.sa1(feat_sa_vi) 84 | feat_sa_vi = self.sa2(feat_sa_vi) 85 | feat_sa_vi = self.sa3(feat_sa_vi) 86 | feat_sa_vi = self.pred_sa(feat_sa_vi) 87 | 88 | # Semantic structure representation learning 89 | feat_reg_vi_str, feat_reg_ir_str = self.ssr(feat_sa_vi, feat_sa_ir) 90 | feat_reg_vi_final = feat_reg_vi + feat_reg_vi_str 91 | feat_reg_ir_final = feat_reg_ir + feat_reg_ir_str 92 | 93 | return feat_reg_vi_final, feat_reg_ir_final, feat_reg_vi_str, feat_reg_ir_str 94 | 95 | class JConv(nn.Module): 96 | """Joint Convolutional blocks 97 | 98 | Args: 99 | 'x' (torch.Tensor): (N, C, H, W) 100 | """ 101 | def __init__(self, in_channels, out_channels): 102 | super(JConv, self).__init__() 103 | self.feat_trans = CBR(in_channels, out_channels) 104 | self.dwconv = DWConv(out_channels) 105 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 106 | self.mlp = MLP(out_channels, bias=True) 107 | 108 | def forward(self, x): 109 | x = self.feat_trans(x) 110 | x = x + self.dwconv(x) 111 | out = self.norm(x) 112 | x = x + self.mlp(out) 113 | return x 114 | 115 | class CrossModalAttention(nn.Module): 116 | """Cross-modal semantic calibration 117 | 118 | Args: 119 | 'feat' (torch.Tensor): (N, L) 120 | 'attention' (torch.Tensor): (N, L, C) 121 | """ 122 | def __init__(self, dim, mlp_ratio=2, qkv_bias=False): 123 | super(CrossModalAttention, self).__init__() 124 | self.qkv = nn.Linear(dim, dim, bias=qkv_bias) 125 | self.proj_out = nn.Linear(dim, dim) 126 | self.norm1 = nn.LayerNorm(dim) 127 | self.norm2 = nn.LayerNorm(dim) 128 | self.mlp = MLP2(dim, mlp_ratio) 129 | 130 | def forward(self, feat, attention): 131 | shortcut = feat 132 | feat = self.norm1(feat) 133 | feat = self.qkv(feat) 134 | x = torch.einsum('nl, nlc -> nlc', attention, feat) 135 | x = self.proj_out(x) 136 | x = x + shortcut 137 | x = x + self.mlp(self.norm2(x)) 138 | return x 139 | 140 | 141 | class SemanticStructureRepresentation(nn.Module): 142 | """Cross-modal semantic calibration 143 | 144 | Args: 145 | 'feat' (torch.Tensor): (N, L) 146 | 'attention' (torch.Tensor): (N, L, C) 147 | """ 148 | def __init__(self): 149 | super(SemanticStructureRepresentation, self).__init__() 150 | self.grid_embedding = JConv(2, 256) 151 | self.semantic_embedding = JConv(256, 256) 152 | self.attention = StructureAttention(256, 8) 153 | 154 | def forward(self, feat_sa_vi, feat_sa_ir): 155 | h = feat_sa_ir.shape[2] 156 | w = feat_sa_ir.shape[3] 157 | 158 | # Predefined spatial grid 159 | xs = torch.linspace(0, h - 1, h) 160 | ys = torch.linspace(0, w - 1, w) 161 | xs = xs / (h - 1) 162 | ys = ys / (w - 1) 163 | grid = torch.stack(torch.meshgrid([xs, ys]), dim=-1).unsqueeze(0).repeat(int(feat_sa_vi.shape[0]), 1, 1, 164 | 1).cuda() 165 | 166 | h = grid.shape[1] 167 | w = grid.shape[2] 168 | grid = rearrange(grid, 'n h w c -> n c h w') 169 | 170 | # Embedding position information into a high-dimensional space 171 | grid = self.grid_embedding(grid) 172 | 173 | # Embedding semantic information 174 | semantic_grid_vi = grid * feat_sa_vi 175 | semantic_grid_ir = grid * feat_sa_ir 176 | 177 | semantic_grid_vi = self.semantic_embedding(semantic_grid_vi) 178 | semantic_grid_ir = self.semantic_embedding(semantic_grid_ir) 179 | 180 | semantic_grid_vi, semantic_grid_ir = rearrange(semantic_grid_vi, 'n c h w -> n (h w) c'), rearrange( 181 | semantic_grid_ir, 'n c h w -> n (h w) c') 182 | 183 | semantic_grid_vi = self.attention(semantic_grid_vi) 184 | semantic_grid_ir = self.attention(semantic_grid_ir) 185 | semantic_grid_vi, semantic_grid_ir = rearrange(semantic_grid_vi, 'n (h w) c -> n c h w', h=h, w=w), rearrange( 186 | semantic_grid_ir, 'n (h w) c -> n c h w', h=h, w=w) 187 | 188 | return semantic_grid_vi, semantic_grid_ir 189 | -------------------------------------------------------------------------------- /train_stage1/IVS_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | import numpy as np 7 | 8 | 9 | def get_dst_point(perspective, IMAGE_SHAPE): 10 | a = random.random() 11 | b = random.random() 12 | c = random.random() 13 | d = random.random() 14 | e = random.random() 15 | f = random.random() 16 | 17 | if random.random() > 0.5: 18 | left_top_x = perspective * a 19 | left_top_y = perspective * b 20 | right_top_x = 0.9 + perspective * c 21 | right_top_y = perspective * d 22 | left_bottom_x = perspective * a 23 | left_bottom_y = 0.9 + perspective * e 24 | right_bottom_x = 0.9 + perspective * c 25 | right_bottom_y = 0.9 + perspective * f 26 | else: 27 | left_top_x = perspective * a 28 | left_top_y = perspective * b 29 | right_top_x = 0.9 + perspective * c 30 | right_top_y = perspective * d 31 | left_bottom_x = perspective * e 32 | left_bottom_y = 0.9 + perspective * b 33 | right_bottom_x = 0.9 + perspective * f 34 | right_bottom_y = 0.9 + perspective * d 35 | 36 | dst_point = np.array([(IMAGE_SHAPE[1] * left_top_x, IMAGE_SHAPE[0] * left_top_y, 1), 37 | (IMAGE_SHAPE[1] * right_top_x, IMAGE_SHAPE[0] * right_top_y, 1), 38 | (IMAGE_SHAPE[1] * left_bottom_x, IMAGE_SHAPE[0] * left_bottom_y, 1), 39 | (IMAGE_SHAPE[1] * right_bottom_x, IMAGE_SHAPE[0] * right_bottom_y, 1)], dtype='float32') 40 | return dst_point 41 | 42 | 43 | def enhance(img, label, img2, label2, IMAGE_SHAPE): 44 | # The four vertices of the image 45 | src_point = np.array([[0, 0], 46 | [IMAGE_SHAPE[1] - 1, 0], 47 | [0, IMAGE_SHAPE[0] - 1], 48 | [IMAGE_SHAPE[1] - 1, IMAGE_SHAPE[0] - 1]], dtype=np.float32) 49 | 50 | # Perspective Information 51 | dst_point = get_dst_point(0.2, IMAGE_SHAPE) 52 | 53 | # Rotation and scale transformation 54 | rotation = 40 55 | rot = random.randint(-rotation, rotation) 56 | scale = 1.2 + random.randint(-90, 100) * 0.01 57 | 58 | center_offset = 40 59 | center = (IMAGE_SHAPE[1] / 2 + random.randint(-center_offset, center_offset), 60 | IMAGE_SHAPE[0] / 2 + random.randint(-center_offset, center_offset)) 61 | 62 | RS_mat = cv2.getRotationMatrix2D(center, rot, scale) 63 | f_point = np.matmul(dst_point, RS_mat.T).astype('float32') 64 | mat = cv2.getPerspectiveTransform(src_point, f_point) 65 | out_img, out_img2 = cv2.warpPerspective(img, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])), cv2.warpPerspective(label, mat, 66 | ( 67 | IMAGE_SHAPE[ 68 | 1], 69 | IMAGE_SHAPE[ 70 | 0])) 71 | 72 | out_img3, out_img4 = cv2.warpPerspective(img2, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])), cv2.warpPerspective(label2, 73 | mat, 74 | ( 75 | IMAGE_SHAPE[ 76 | 1], 77 | IMAGE_SHAPE[ 78 | 0])) 79 | 80 | return out_img, out_img2, out_img3, out_img4 81 | 82 | 83 | class IVSDataset(Dataset): 84 | def __init__(self, img_vi_path, img_ir_path, label_path, train_size_w, train_size_h): 85 | super(IVSDataset, self).__init__() 86 | self.img_vi_path = img_vi_path 87 | self.img_ir_path = img_ir_path 88 | self.label_path = label_path 89 | 90 | self.trans1 = transforms.Compose([ 91 | transforms.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.4, hue=0.2), 92 | transforms.Grayscale(), 93 | transforms.ToTensor() 94 | ]) 95 | self.trans2 = transforms.Compose([ 96 | transforms.ColorJitter(brightness=0., contrast=0.1, saturation=0., hue=0.), 97 | transforms.ToTensor() 98 | ]) 99 | 100 | data_floder = os.listdir(self.img_vi_path) 101 | self.data_floder = data_floder 102 | 103 | self.train_size_w = int(train_size_w) 104 | self.train_size_h = int(train_size_h) 105 | 106 | self.feat_size_w = int(self.train_size_w / 8) 107 | self.feat_size_h = int(self.train_size_h / 8) 108 | 109 | def __getitem__(self, idx): 110 | item = self.data_floder[idx] 111 | 112 | imagevi, labelvi = cv2.imread(os.path.join(self.img_vi_path, item)), cv2.imread( 113 | os.path.join(self.label_path, item), cv2.IMREAD_GRAYSCALE) 114 | 115 | imageir, labelir = cv2.imread(os.path.join(self.img_ir_path, item)), cv2.imread( 116 | os.path.join(self.label_path, item), cv2.IMREAD_GRAYSCALE) 117 | 118 | imagevi = cv2.resize(imagevi, (self.train_size_w, self.train_size_h)) 119 | labelvi = cv2.resize(labelvi, (self.train_size_w, self.train_size_h)) 120 | imageir = cv2.resize(imageir, (self.train_size_w, self.train_size_h)) 121 | labelir = cv2.resize(labelir, (self.train_size_w, self.train_size_h)) 122 | 123 | seed = random.random() 124 | if seed < 0.25: 125 | imagevi = motion_blur(imagevi, degree=random.randint(6, 15), angle=random.randint(-45, 45)) 126 | seed = random.random() 127 | if seed < 0.25: 128 | imageir = motion_blur(imageir, degree=random.randint(6, 15), angle=random.randint(-45, 45)) 129 | seed = random.random() 130 | if seed < 0.25: 131 | imagevi = gasuss_noise(imagevi, mean=0, var=0.001) 132 | seed = random.random() 133 | if seed < 0.25: 134 | imageir = gasuss_noise(imageir, mean=0, var=0.001) 135 | 136 | imagevi, labelvi, imageir, labelir = enhance(imagevi, labelvi, imageir, labelir, 137 | [self.train_size_h, self.train_size_w]) 138 | 139 | imagevi, labelvi = transforms.ToPILImage()(imagevi), transforms.ToPILImage()( 140 | cv2.resize(labelvi, (self.feat_size_w, self.feat_size_h))) 141 | imageir, labelir = transforms.ToPILImage()(imageir), transforms.ToPILImage()( 142 | cv2.resize(labelir, (self.feat_size_w, self.feat_size_h))) 143 | 144 | return self.trans1(imagevi), self.trans2(labelvi), self.trans1(imageir), self.trans2(labelir) 145 | 146 | def __len__(self): 147 | return len(self.data_floder) 148 | 149 | 150 | def motion_blur(image, degree=15, angle=45): 151 | image = np.array(image) 152 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 153 | motion_blur_kernel = np.diag(np.ones(degree)) 154 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree)) 155 | 156 | motion_blur_kernel = motion_blur_kernel / degree 157 | blurred = cv2.filter2D(image, -1, motion_blur_kernel) 158 | cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) 159 | blurred = np.array(blurred, dtype=np.uint8) 160 | return blurred 161 | 162 | 163 | def gasuss_noise(image, mean=0, var=0.001): 164 | image = np.array(image / 255, dtype=float) 165 | noise = np.random.normal(mean, var ** 0.5, image.shape) 166 | out = image + noise 167 | out = np.clip(out, 0.0, 1.0) 168 | out = np.uint8(out * 255) 169 | return out -------------------------------------------------------------------------------- /train_stage1/teacher/model/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | from loguru import logger 3 | 4 | import cv2 5 | import numpy as np 6 | import h5py 7 | import torch 8 | from numpy.linalg import inv 9 | 10 | 11 | try: 12 | # for internel use only 13 | from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT 14 | except Exception: 15 | MEGADEPTH_CLIENT = SCANNET_CLIENT = None 16 | 17 | # --- DATA IO --- 18 | 19 | def load_array_from_s3( 20 | path, client, cv_type, 21 | use_h5py=False, 22 | ): 23 | byte_str = client.Get(path) 24 | try: 25 | if not use_h5py: 26 | raw_array = np.fromstring(byte_str, np.uint8) 27 | data = cv2.imdecode(raw_array, cv_type) 28 | else: 29 | f = io.BytesIO(byte_str) 30 | data = np.array(h5py.File(f, 'r')['/depth']) 31 | except Exception as ex: 32 | print(f"==> Data loading failure: {path}") 33 | raise ex 34 | 35 | assert data is not None 36 | return data 37 | 38 | 39 | def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): 40 | cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ 41 | else cv2.IMREAD_COLOR 42 | if str(path).startswith('s3://'): 43 | image = load_array_from_s3(str(path), client, cv_type) 44 | else: 45 | image = cv2.imread(str(path), cv_type) 46 | 47 | if augment_fn is not None: 48 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 50 | image = augment_fn(image) 51 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 52 | return image # (h, w) 53 | 54 | def imread_rgb(path, augment_fn=None, client=SCANNET_CLIENT): 55 | cv_type = cv2.IMREAD_COLOR 56 | if str(path).startswith('s3://'): 57 | image = load_array_from_s3(str(path), client, cv_type) 58 | else: 59 | image = cv2.imread(str(path), cv_type) 60 | 61 | if augment_fn is not None: 62 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 63 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 64 | image = augment_fn(image) 65 | 66 | return image # (3, h, w) 67 | 68 | def get_resized_wh(w, h, resize=None): 69 | if resize is not None: # resize the longer edge 70 | scale = resize / max(h, w) 71 | w_new, h_new = int(round(w*scale)), int(round(h*scale)) 72 | else: 73 | w_new, h_new = w, h 74 | return w_new, h_new 75 | 76 | 77 | def get_divisible_wh(w, h, df=None): 78 | if df is not None: 79 | w_new, h_new = map(lambda x: int(x // df * df), [w, h]) 80 | else: 81 | w_new, h_new = w, h 82 | return w_new, h_new 83 | 84 | 85 | def pad_bottom_right(inp, pad_size, ret_mask=False): 86 | assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" 87 | mask = None 88 | if inp.ndim == 2: 89 | padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) 90 | padded[:inp.shape[0], :inp.shape[1]] = inp 91 | if ret_mask: 92 | mask = np.zeros((pad_size, pad_size), dtype=bool) 93 | mask[:inp.shape[0], :inp.shape[1]] = True 94 | elif inp.ndim == 3: 95 | padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) 96 | padded[:, :inp.shape[1], :inp.shape[2]] = inp 97 | if ret_mask: 98 | mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) 99 | mask[:, :inp.shape[1], :inp.shape[2]] = True 100 | else: 101 | raise NotImplementedError() 102 | return padded, mask 103 | 104 | # --- MEGADEPTH --- 105 | def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): 106 | """ 107 | Args: 108 | resize (int, optional): the longer edge of resized images. None for no resize. 109 | padding (bool): If set to 'True', zero-pad resized images to squared size. 110 | augment_fn (callable, optional): augments images with pre-defined visual effects 111 | Returns: 112 | image (torch.tensor): (1, h, w) 113 | mask (torch.tensor): (h, w) 114 | scale (torch.tensor): [w/w_new, h/h_new] 115 | """ 116 | # read image 117 | image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) 118 | 119 | # resize image 120 | w, h = image.shape[1], image.shape[0] 121 | w_new, h_new = get_resized_wh(w, h, resize) 122 | w_new, h_new = get_divisible_wh(w_new, h_new, df) 123 | 124 | image = cv2.resize(image, (w_new, h_new)) 125 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) 126 | 127 | if padding: # padding 128 | pad_to = max(h_new, w_new) 129 | image, mask = pad_bottom_right(image, pad_to, ret_mask=True) 130 | else: 131 | mask = None 132 | image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized 133 | mask = torch.from_numpy(mask) 134 | 135 | return image, mask, scale 136 | 137 | 138 | def read_megadepth_rgb(path, resize=None, df=None, padding=False, augment_fn=None): 139 | """ 140 | Args: 141 | resize (int, optional): the longer edge of resized images. None for no resize. 142 | padding (bool): If set to 'True', zero-pad resized images to squared size. 143 | augment_fn (callable, optional): augments images with pre-defined visual effects 144 | Returns: 145 | image (torch.tensor): (3, h, w) 146 | mask (torch.tensor): (h, w) 147 | scale (torch.tensor): [w/w_new, h/h_new] 148 | """ 149 | # read image 150 | image_gray = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) 151 | image = imread_rgb(path, augment_fn, client=MEGADEPTH_CLIENT) 152 | 153 | # resize image 154 | w, h = image.shape[1], image.shape[0] 155 | w_new, h_new = get_resized_wh(w, h, resize) 156 | w_new, h_new = get_divisible_wh(w_new, h_new, df) 157 | 158 | image = cv2.resize(image, (w_new, h_new)) 159 | image = image.transpose(2,0,1) 160 | image_gray = cv2.resize(image_gray, (w_new, h_new)) 161 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) 162 | 163 | if padding: # padding 164 | pad_to = max(h_new, w_new) 165 | image, _ = pad_bottom_right(image, pad_to, ret_mask=True) 166 | _, mask = pad_bottom_right(image_gray, pad_to, ret_mask=True) 167 | else: 168 | mask = None 169 | 170 | image = torch.from_numpy(image).float() / 255 # (3, h, w) -> (3, h, w) normalized 171 | mask = torch.from_numpy(mask) 172 | return image, mask, scale 173 | 174 | def read_megadepth_depth(path, pad_to=None): 175 | if str(path).startswith('s3://'): 176 | depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) 177 | else: 178 | depth = np.array(h5py.File(path, 'r')['depth']) 179 | if pad_to is not None: 180 | depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) 181 | depth = torch.from_numpy(depth).float() # (h, w) 182 | return depth 183 | 184 | 185 | # --- ScanNet --- 186 | 187 | def read_scannet_gray(path, resize=(640, 480), augment_fn=None): 188 | """ 189 | Args: 190 | resize (tuple): align image to depthmap, in (w, h). 191 | augment_fn (callable, optional): augments images with pre-defined visual effects 192 | Returns: 193 | image (torch.tensor): (1, h, w) 194 | mask (torch.tensor): (h, w) 195 | scale (torch.tensor): [w/w_new, h/h_new] 196 | """ 197 | # read and resize image 198 | image = imread_gray(path, augment_fn) 199 | image = cv2.resize(image, resize) 200 | 201 | # (h, w) -> (1, h, w) and normalized 202 | image = torch.from_numpy(image).float()[None] / 255 203 | return image 204 | 205 | 206 | def read_scannet_depth(path): 207 | if str(path).startswith('s3://'): 208 | depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) 209 | else: 210 | depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) 211 | depth = depth / 1000 212 | depth = torch.from_numpy(depth).float() # (h, w) 213 | return depth 214 | 215 | 216 | def read_scannet_pose(path): 217 | """ Read ScanNet's Camera2World pose and transform it to World2Camera. 218 | 219 | Returns: 220 | pose_w2c (np.ndarray): (4, 4) 221 | """ 222 | cam2world = np.loadtxt(path, delimiter=' ') 223 | world2cam = inv(cam2world) 224 | return world2cam 225 | 226 | 227 | def read_scannet_intrinsic(path): 228 | """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. 229 | """ 230 | intrinsic = np.loadtxt(path, delimiter=' ') 231 | return intrinsic[:-1, :-1] 232 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/utils/comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | [Copied from detectron2] 4 | This file contains primitives for multi-gpu communication. 5 | This is useful when doing distributed training. 6 | """ 7 | 8 | import functools 9 | import logging 10 | import numpy as np 11 | import pickle 12 | import torch 13 | import torch.distributed as dist 14 | 15 | _LOCAL_PROCESS_GROUP = None 16 | """ 17 | A torch process group which only includes processes that on the same machine as the current process. 18 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 19 | """ 20 | 21 | 22 | def get_world_size() -> int: 23 | if not dist.is_available(): 24 | return 1 25 | if not dist.is_initialized(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def get_rank() -> int: 31 | if not dist.is_available(): 32 | return 0 33 | if not dist.is_initialized(): 34 | return 0 35 | return dist.get_rank() 36 | 37 | 38 | def get_local_rank() -> int: 39 | """ 40 | Returns: 41 | The rank of the current process within the local (per-machine) process group. 42 | """ 43 | if not dist.is_available(): 44 | return 0 45 | if not dist.is_initialized(): 46 | return 0 47 | assert _LOCAL_PROCESS_GROUP is not None 48 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 49 | 50 | 51 | def get_local_size() -> int: 52 | """ 53 | Returns: 54 | The size of the per-machine process group, 55 | i.e. the number of processes per machine. 56 | """ 57 | if not dist.is_available(): 58 | return 1 59 | if not dist.is_initialized(): 60 | return 1 61 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 62 | 63 | 64 | def is_main_process() -> bool: 65 | return get_rank() == 0 66 | 67 | 68 | def synchronize(): 69 | """ 70 | Helper function to synchronize (barrier) among all processes when 71 | using distributed training 72 | """ 73 | if not dist.is_available(): 74 | return 75 | if not dist.is_initialized(): 76 | return 77 | world_size = dist.get_world_size() 78 | if world_size == 1: 79 | return 80 | dist.barrier() 81 | 82 | 83 | @functools.lru_cache() 84 | def _get_global_gloo_group(): 85 | """ 86 | Return a process group based on gloo backend, containing all the ranks 87 | The result is cached. 88 | """ 89 | if dist.get_backend() == "nccl": 90 | return dist.new_group(backend="gloo") 91 | else: 92 | return dist.group.WORLD 93 | 94 | 95 | def _serialize_to_tensor(data, group): 96 | backend = dist.get_backend(group) 97 | assert backend in ["gloo", "nccl"] 98 | device = torch.device("cpu" if backend == "gloo" else "cuda") 99 | 100 | buffer = pickle.dumps(data) 101 | if len(buffer) > 1024 ** 3: 102 | logger = logging.getLogger(__name__) 103 | logger.warning( 104 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 105 | get_rank(), len(buffer) / (1024 ** 3), device 106 | ) 107 | ) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to(device=device) 110 | return tensor 111 | 112 | 113 | def _pad_to_largest_tensor(tensor, group): 114 | """ 115 | Returns: 116 | list[int]: size of the tensor, on each rank 117 | Tensor: padded tensor that has the max size 118 | """ 119 | world_size = dist.get_world_size(group=group) 120 | assert ( 121 | world_size >= 1 122 | ), "comm.gather/all_gather must be called from ranks within the given group!" 123 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 124 | size_list = [ 125 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 126 | ] 127 | dist.all_gather(size_list, local_size, group=group) 128 | 129 | size_list = [int(size.item()) for size in size_list] 130 | 131 | max_size = max(size_list) 132 | 133 | # we pad the tensor because torch all_gather does not support 134 | # gathering tensors of different shapes 135 | if local_size != max_size: 136 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 137 | tensor = torch.cat((tensor, padding), dim=0) 138 | return size_list, tensor 139 | 140 | 141 | def all_gather(data, group=None): 142 | """ 143 | Run all_gather on arbitrary picklable data (not necessarily tensors). 144 | 145 | Args: 146 | data: any picklable object 147 | group: a torch process group. By default, will use a group which 148 | contains all ranks on gloo backend. 149 | 150 | Returns: 151 | list[data]: list of data gathered from each rank 152 | """ 153 | if get_world_size() == 1: 154 | return [data] 155 | if group is None: 156 | group = _get_global_gloo_group() 157 | if dist.get_world_size(group) == 1: 158 | return [data] 159 | 160 | tensor = _serialize_to_tensor(data, group) 161 | 162 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 163 | max_size = max(size_list) 164 | 165 | # receiving Tensor from all ranks 166 | tensor_list = [ 167 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 168 | ] 169 | dist.all_gather(tensor_list, tensor, group=group) 170 | 171 | data_list = [] 172 | for size, tensor in zip(size_list, tensor_list): 173 | buffer = tensor.cpu().numpy().tobytes()[:size] 174 | data_list.append(pickle.loads(buffer)) 175 | 176 | return data_list 177 | 178 | 179 | def gather(data, dst=0, group=None): 180 | """ 181 | Run gather on arbitrary picklable data (not necessarily tensors). 182 | 183 | Args: 184 | data: any picklable object 185 | dst (int): destination rank 186 | group: a torch process group. By default, will use a group which 187 | contains all ranks on gloo backend. 188 | 189 | Returns: 190 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 191 | an empty list. 192 | """ 193 | if get_world_size() == 1: 194 | return [data] 195 | if group is None: 196 | group = _get_global_gloo_group() 197 | if dist.get_world_size(group=group) == 1: 198 | return [data] 199 | rank = dist.get_rank(group=group) 200 | 201 | tensor = _serialize_to_tensor(data, group) 202 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 203 | 204 | # receiving Tensor from all ranks 205 | if rank == dst: 206 | max_size = max(size_list) 207 | tensor_list = [ 208 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 209 | ] 210 | dist.gather(tensor, tensor_list, dst=dst, group=group) 211 | 212 | data_list = [] 213 | for size, tensor in zip(size_list, tensor_list): 214 | buffer = tensor.cpu().numpy().tobytes()[:size] 215 | data_list.append(pickle.loads(buffer)) 216 | return data_list 217 | else: 218 | dist.gather(tensor, [], dst=dst, group=group) 219 | return [] 220 | 221 | 222 | def shared_random_seed(): 223 | """ 224 | Returns: 225 | int: a random number that is the same across all workers. 226 | If workers need a shared RNG, they can use this shared seed to 227 | create one. 228 | 229 | All workers must call this function, otherwise it will deadlock. 230 | """ 231 | ints = np.random.randint(2 ** 31) 232 | all_ints = all_gather(ints) 233 | return all_ints[0] 234 | 235 | 236 | def reduce_dict(input_dict, average=True): 237 | """ 238 | Reduce the values in the dictionary from all processes so that process with rank 239 | 0 has the reduced results. 240 | 241 | Args: 242 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 243 | average (bool): whether to do average or sum 244 | 245 | Returns: 246 | a dict with the same keys as input_dict, after reduction. 247 | """ 248 | world_size = get_world_size() 249 | if world_size < 2: 250 | return input_dict 251 | with torch.no_grad(): 252 | names = [] 253 | values = [] 254 | # sort the keys so that they are consistent across processes 255 | for k in sorted(input_dict.keys()): 256 | names.append(k) 257 | values.append(input_dict[k]) 258 | values = torch.stack(values, dim=0) 259 | dist.reduce(values, dst=0) 260 | if dist.get_rank() == 0 and average: 261 | # only main process gets accumulated, so only divide by 262 | # world_size in this case 263 | values /= world_size 264 | reduced_dict = {k: v for k, v in zip(names, values)} 265 | return reduced_dict 266 | -------------------------------------------------------------------------------- /train_stage3/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import pywt 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | def conv1x1(in_channels, out_channels, stride=1): 8 | """1 x 1 convolution""" 9 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) 10 | 11 | 12 | def conv3x3(in_channels, out_channels, stride=1): 13 | """3 x 3 convolution""" 14 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 15 | 16 | 17 | class CBR(nn.Module): 18 | """3 x 3 convolution block 19 | 20 | Args: 21 | 'x': (torch.Tensor): (N, C, H, W) 22 | """ 23 | def __init__(self, in_channels, planes, stride=1): 24 | super().__init__() 25 | self.conv = conv3x3(in_channels, planes, stride) 26 | self.bn = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | return self.relu(self.bn(self.conv(x))) 31 | 32 | 33 | class DWConv(nn.Module): 34 | """DepthWise convolution block 35 | 36 | Args: 37 | 'x': (torch.Tensor): (N, C, H, W) 38 | """ 39 | def __init__(self, out_channels): 40 | super().__init__() 41 | self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, 42 | padding=1, groups=out_channels, bias=False) 43 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 44 | self.act = nn.ReLU(inplace=True) 45 | self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False) 46 | 47 | def forward(self, x): 48 | out = self.group_conv3x3(x) 49 | out = self.norm(out) 50 | out = self.act(out) 51 | out = self.projection(out) 52 | return out 53 | 54 | 55 | class MLP(nn.Module): 56 | """MLP Layer 57 | 58 | Args: 59 | 'x': (torch.Tensor): (N, C, H, W) 60 | """ 61 | def __init__(self, out_channels, bias=True): 62 | super().__init__() 63 | self.conv1 = nn.Conv2d(out_channels, out_channels * 2, kernel_size=1, bias=bias) 64 | self.act = nn.ReLU(inplace=True) 65 | self.conv2 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=bias) 66 | 67 | def forward(self, x): 68 | x = self.conv1(x) 69 | x = self.act(x) 70 | x = self.conv2(x) 71 | return x 72 | 73 | # This class is implemented by [Wave-ViT](https://github.com/YehLi/ImageNetModel/blob/main/classification/torch_wavelets.py). 74 | class DWT_2D(nn.Module): 75 | """Discrete Wavelet Transform for feature maps downsampling 76 | 77 | Args: 78 | 'x': (torch.Tensor): (N, C, H, W) 79 | """ 80 | def __init__(self, wave): 81 | super(DWT_2D, self).__init__() 82 | w = pywt.Wavelet(wave) 83 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 84 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 85 | 86 | w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1) 87 | w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1) 88 | w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1) 89 | w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1) 90 | 91 | self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0)) 92 | self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0)) 93 | self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0)) 94 | self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0)) 95 | 96 | self.w_ll = self.w_ll.to(dtype=torch.float32) 97 | self.w_lh = self.w_lh.to(dtype=torch.float32) 98 | self.w_hl = self.w_hl.to(dtype=torch.float32) 99 | self.w_hh = self.w_hh.to(dtype=torch.float32) 100 | 101 | def forward(self, x): 102 | return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh) 103 | 104 | 105 | # This class is implemented by [Wave-ViT](https://github.com/YehLi/ImageNetModel/blob/main/classification/torch_wavelets.py). 106 | class DWT_Function(Function): 107 | @staticmethod 108 | def forward(ctx, x, w_ll, w_lh, w_hl, w_hh): 109 | x = x.contiguous() 110 | ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh) 111 | ctx.shape = x.shape 112 | 113 | dim = x.shape[1] 114 | x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=2, groups=dim) 115 | x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=2, groups=dim) 116 | x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=2, groups=dim) 117 | x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=2, groups=dim) 118 | x = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1) 119 | return x 120 | 121 | @staticmethod 122 | def backward(ctx, dx): 123 | if ctx.needs_input_grad[0]: 124 | w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors 125 | B, C, H, W = ctx.shape 126 | dx = dx.view(B, 4, -1, H // 2, W // 2) 127 | 128 | dx = dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2) 129 | filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0).repeat(C, 1, 1, 1) 130 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C) 131 | 132 | return dx, None, None, None, None 133 | 134 | class MLP2(nn.Module): 135 | def __init__(self, in_features, mlp_ratio=4): 136 | super(MLP2, self).__init__() 137 | hidden_features = in_features * mlp_ratio 138 | 139 | self.fc = nn.Sequential( 140 | nn.Linear(in_features, hidden_features), 141 | nn.GELU(), 142 | nn.Linear(hidden_features, in_features) 143 | ) 144 | 145 | def forward(self, x): 146 | return self.fc(x) 147 | 148 | class StructureAttention(nn.Module): 149 | # This class is implemented by [LoFTR](https://github.com/zju3dv/LoFTR). 150 | def __init__(self, d_model, nhead): 151 | super(StructureAttention, self).__init__() 152 | self.dim = d_model // nhead 153 | self.nhead = nhead 154 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 155 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 156 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 157 | self.attention = LinearAttention() 158 | self.merge = nn.Linear(d_model, d_model, bias=False) 159 | 160 | # feed-forward network 161 | self.mlp = nn.Sequential( 162 | nn.Linear(d_model * 2, d_model * 1, bias=False), 163 | nn.ReLU(True), 164 | nn.Linear(d_model * 1, d_model, bias=False), 165 | ) 166 | 167 | # norm and dropout 168 | self.norm1 = nn.LayerNorm(d_model) 169 | self.norm2 = nn.LayerNorm(d_model) 170 | 171 | def forward(self, x): 172 | """ 173 | Args: 174 | x (torch.Tensor): [N, L, C] 175 | source (torch.Tensor): [N, S, C] 176 | """ 177 | bs = x.size(0) 178 | query, key, value = x, x, x 179 | 180 | # multi-head attention 181 | query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] 182 | key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] 183 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) 184 | message = self.attention(query, key, value) # [N, L, (H, D)] 185 | message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C] 186 | message = self.norm1(message) 187 | 188 | # feed-forward network 189 | message = self.mlp(torch.cat([x, message], dim=2)) 190 | message = self.norm2(message) 191 | 192 | return x + message 193 | 194 | def elu_feature_map(x): 195 | return torch.nn.functional.elu(x) + 1 196 | 197 | class LinearAttention(nn.Module): 198 | def __init__(self, eps=1e-6): 199 | super().__init__() 200 | self.feature_map = elu_feature_map 201 | self.eps = eps 202 | 203 | def forward(self, queries, keys, values): 204 | """ Multi-Head linear attention proposed in "Transformers are RNNs" 205 | Args: 206 | queries: [N, L, H, D] 207 | keys: [N, S, H, D] 208 | values: [N, S, H, D] 209 | q_mask: [N, L] 210 | kv_mask: [N, S] 211 | Returns: 212 | queried_values: (N, L, H, D) 213 | """ 214 | Q = self.feature_map(queries) 215 | K = self.feature_map(keys) 216 | v_length = values.size(1) 217 | values = values / v_length # prevent fp16 overflow 218 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V 219 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 220 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 221 | 222 | return queried_values.contiguous() 223 | 224 | class MLP2(nn.Module): 225 | def __init__(self, in_features, mlp_ratio=4): 226 | super(MLP2, self).__init__() 227 | hidden_features = in_features * mlp_ratio 228 | 229 | self.fc = nn.Sequential( 230 | nn.Linear(in_features, hidden_features), 231 | nn.GELU(), 232 | nn.Linear(hidden_features, in_features) 233 | ) 234 | 235 | def forward(self, x): 236 | return self.fc(x) -------------------------------------------------------------------------------- /train_stage2/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import pywt 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | def conv1x1(in_channels, out_channels, stride=1): 8 | """1 x 1 convolution""" 9 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False) 10 | 11 | 12 | def conv3x3(in_channels, out_channels, stride=1): 13 | """3 x 3 convolution""" 14 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 15 | 16 | 17 | class CBR(nn.Module): 18 | """3 x 3 convolution block 19 | 20 | Args: 21 | 'x': (torch.Tensor): (N, C, H, W) 22 | """ 23 | def __init__(self, in_channels, planes, stride=1): 24 | super().__init__() 25 | self.conv = conv3x3(in_channels, planes, stride) 26 | self.bn = nn.BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | 29 | def forward(self, x): 30 | return self.relu(self.bn(self.conv(x))) 31 | 32 | 33 | class DWConv(nn.Module): 34 | """DepthWise convolution block 35 | 36 | Args: 37 | 'x': (torch.Tensor): (N, C, H, W) 38 | """ 39 | def __init__(self, out_channels): 40 | super().__init__() 41 | self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, 42 | padding=1, groups=out_channels, bias=False) 43 | self.norm = nn.BatchNorm2d(out_channels, eps=1e-5) 44 | self.act = nn.ReLU(inplace=True) 45 | self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False) 46 | 47 | def forward(self, x): 48 | out = self.group_conv3x3(x) 49 | out = self.norm(out) 50 | out = self.act(out) 51 | out = self.projection(out) 52 | return out 53 | 54 | 55 | class MLP(nn.Module): 56 | """MLP Layer 57 | 58 | Args: 59 | 'x': (torch.Tensor): (N, C, H, W) 60 | """ 61 | def __init__(self, out_channels, bias=True): 62 | super().__init__() 63 | self.conv1 = nn.Conv2d(out_channels, out_channels * 2, kernel_size=1, bias=bias) 64 | self.act = nn.ReLU(inplace=True) 65 | self.conv2 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=bias) 66 | 67 | def forward(self, x): 68 | x = self.conv1(x) 69 | x = self.act(x) 70 | x = self.conv2(x) 71 | return x 72 | 73 | # This class is implemented by [Wave-ViT](https://github.com/YehLi/ImageNetModel/blob/main/classification/torch_wavelets.py). 74 | class DWT_2D(nn.Module): 75 | """Discrete Wavelet Transform for feature maps downsampling 76 | 77 | Args: 78 | 'x': (torch.Tensor): (N, C, H, W) 79 | """ 80 | def __init__(self, wave): 81 | super(DWT_2D, self).__init__() 82 | w = pywt.Wavelet(wave) 83 | dec_hi = torch.Tensor(w.dec_hi[::-1]) 84 | dec_lo = torch.Tensor(w.dec_lo[::-1]) 85 | 86 | w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1) 87 | w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1) 88 | w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1) 89 | w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1) 90 | 91 | self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0)) 92 | self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0)) 93 | self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0)) 94 | self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0)) 95 | 96 | self.w_ll = self.w_ll.to(dtype=torch.float32) 97 | self.w_lh = self.w_lh.to(dtype=torch.float32) 98 | self.w_hl = self.w_hl.to(dtype=torch.float32) 99 | self.w_hh = self.w_hh.to(dtype=torch.float32) 100 | 101 | def forward(self, x): 102 | return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh) 103 | 104 | 105 | # This class is implemented by [Wave-ViT](https://github.com/YehLi/ImageNetModel/blob/main/classification/torch_wavelets.py). 106 | class DWT_Function(Function): 107 | @staticmethod 108 | def forward(ctx, x, w_ll, w_lh, w_hl, w_hh): 109 | x = x.contiguous() 110 | ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh) 111 | ctx.shape = x.shape 112 | 113 | dim = x.shape[1] 114 | x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=2, groups=dim) 115 | x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=2, groups=dim) 116 | x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=2, groups=dim) 117 | x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=2, groups=dim) 118 | x = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1) 119 | return x 120 | 121 | @staticmethod 122 | def backward(ctx, dx): 123 | if ctx.needs_input_grad[0]: 124 | w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors 125 | B, C, H, W = ctx.shape 126 | dx = dx.view(B, 4, -1, H // 2, W // 2) 127 | 128 | dx = dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2) 129 | filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0).repeat(C, 1, 1, 1) 130 | dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C) 131 | 132 | return dx, None, None, None, None 133 | 134 | class MLP2(nn.Module): 135 | def __init__(self, in_features, mlp_ratio=4): 136 | super(MLP2, self).__init__() 137 | hidden_features = in_features * mlp_ratio 138 | 139 | self.fc = nn.Sequential( 140 | nn.Linear(in_features, hidden_features), 141 | nn.GELU(), 142 | nn.Linear(hidden_features, in_features) 143 | ) 144 | 145 | def forward(self, x): 146 | return self.fc(x) 147 | 148 | class StructureAttention(nn.Module): 149 | # This class is implemented by [LoFTR](https://github.com/zju3dv/LoFTR). 150 | def __init__(self, d_model, nhead): 151 | super(StructureAttention, self).__init__() 152 | self.dim = d_model // nhead 153 | self.nhead = nhead 154 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 155 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 156 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 157 | self.attention = LinearAttention() 158 | self.merge = nn.Linear(d_model, d_model, bias=False) 159 | 160 | # feed-forward network 161 | self.mlp = nn.Sequential( 162 | nn.Linear(d_model * 2, d_model * 1, bias=False), 163 | nn.ReLU(True), 164 | nn.Linear(d_model * 1, d_model, bias=False), 165 | ) 166 | 167 | # norm and dropout 168 | self.norm1 = nn.LayerNorm(d_model) 169 | self.norm2 = nn.LayerNorm(d_model) 170 | 171 | def forward(self, x): 172 | """ 173 | Args: 174 | x (torch.Tensor): [N, L, C] 175 | source (torch.Tensor): [N, S, C] 176 | """ 177 | bs = x.size(0) 178 | query, key, value = x, x, x 179 | 180 | # multi-head attention 181 | query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] 182 | key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] 183 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) 184 | message = self.attention(query, key, value) # [N, L, (H, D)] 185 | message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C] 186 | message = self.norm1(message) 187 | 188 | # feed-forward network 189 | message = self.mlp(torch.cat([x, message], dim=2)) 190 | message = self.norm2(message) 191 | 192 | return x + message 193 | 194 | def elu_feature_map(x): 195 | return torch.nn.functional.elu(x) + 1 196 | 197 | class LinearAttention(nn.Module): 198 | def __init__(self, eps=1e-6): 199 | super().__init__() 200 | self.feature_map = elu_feature_map 201 | self.eps = eps 202 | 203 | def forward(self, queries, keys, values): 204 | """ Multi-Head linear attention proposed in "Transformers are RNNs" 205 | Args: 206 | queries: [N, L, H, D] 207 | keys: [N, S, H, D] 208 | values: [N, S, H, D] 209 | q_mask: [N, L] 210 | kv_mask: [N, S] 211 | Returns: 212 | queried_values: (N, L, H, D) 213 | """ 214 | Q = self.feature_map(queries) 215 | K = self.feature_map(keys) 216 | v_length = values.size(1) 217 | values = values / v_length # prevent fp16 overflow 218 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V 219 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 220 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 221 | 222 | return queried_values.contiguous() 223 | 224 | class MLP2(nn.Module): 225 | def __init__(self, in_features, mlp_ratio=4): 226 | super(MLP2, self).__init__() 227 | hidden_features = in_features * mlp_ratio 228 | 229 | self.fc = nn.Sequential( 230 | nn.Linear(in_features, hidden_features), 231 | nn.GELU(), 232 | nn.Linear(hidden_features, in_features) 233 | ) 234 | 235 | def forward(self, x): 236 | return self.fc(x) -------------------------------------------------------------------------------- /train_stage2/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import random 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | import numpy as np 7 | from PIL import Image, ImageFilter 8 | from torchvision import transforms as tfs 9 | 10 | def get_dst_point(perspective, IMAGE_SHAPE): 11 | a = random.random() 12 | b = random.random() 13 | c = random.random() 14 | d = random.random() 15 | e = random.random() 16 | f = random.random() 17 | 18 | if random.random() > 0.5: 19 | left_top_x = perspective * a 20 | left_top_y = perspective * b 21 | right_top_x = 0.9 + perspective * c 22 | right_top_y = perspective * d 23 | left_bottom_x = perspective * a 24 | left_bottom_y = 0.9 + perspective * e 25 | right_bottom_x = 0.9 + perspective * c 26 | right_bottom_y = 0.9 + perspective * f 27 | else: 28 | left_top_x = perspective * a 29 | left_top_y = perspective * b 30 | right_top_x = 0.9 + perspective * c 31 | right_top_y = perspective * d 32 | left_bottom_x = perspective * e 33 | left_bottom_y = 0.9 + perspective * b 34 | right_bottom_x = 0.9 + perspective * f 35 | right_bottom_y = 0.9 + perspective * d 36 | 37 | dst_point = np.array([(IMAGE_SHAPE[1] * left_top_x, IMAGE_SHAPE[0] * left_top_y, 1), 38 | (IMAGE_SHAPE[1] * right_top_x, IMAGE_SHAPE[0] * right_top_y, 1), 39 | (IMAGE_SHAPE[1] * left_bottom_x, IMAGE_SHAPE[0] * left_bottom_y, 1), 40 | (IMAGE_SHAPE[1] * right_bottom_x, IMAGE_SHAPE[0] * right_bottom_y, 1)], dtype='float32') 41 | return dst_point 42 | 43 | class PilGaussianBlur(ImageFilter.Filter): 44 | name = "GaussianBlur" 45 | 46 | def __init__(self, radius=2): 47 | self.radius = radius 48 | 49 | def filter(self, image): 50 | return image.gaussian_blur(self.radius) 51 | 52 | def enhance(img0, img1, IMAGE_SHAPE): 53 | # The four vertices of the image 54 | src_point = np.array([[0, 0], 55 | [IMAGE_SHAPE[1] - 1, 0], 56 | [0, IMAGE_SHAPE[0] - 1], 57 | [IMAGE_SHAPE[1] - 1, IMAGE_SHAPE[0] - 1]], dtype=np.float32) 58 | 59 | # Perspective Information 60 | dst_point = get_dst_point(0.1, IMAGE_SHAPE) 61 | 62 | # Rotation and scale transformation 63 | rotation = 40 64 | rot = random.randint(-rotation, rotation) 65 | scale = 1.2 + random.randint(-90, 100) * 0.01 66 | 67 | center_offset = 40 68 | center = (IMAGE_SHAPE[1] / 2 + random.randint(-center_offset, center_offset), 69 | IMAGE_SHAPE[0] / 2 + random.randint(-center_offset, center_offset)) 70 | 71 | RS_mat = cv2.getRotationMatrix2D(center, rot, scale) 72 | f_point = np.matmul(dst_point, RS_mat.T).astype('float32') 73 | mat = cv2.getPerspectiveTransform(src_point, f_point) 74 | out_img0, out_img1 = cv2.warpPerspective(img0, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])),\ 75 | cv2.warpPerspective(img1, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])) 76 | 77 | 78 | return out_img0, out_img1, mat 79 | 80 | 81 | class Dataset(Dataset): 82 | def __init__(self, img_vi_path, img_ir_path, label_path, train_size_w, train_size_h): 83 | super(Dataset, self).__init__() 84 | self.img_vi_path = img_vi_path 85 | self.img_ir_path = img_ir_path 86 | self.label_path = label_path 87 | 88 | self.trans = transforms.Compose([ 89 | transforms.Grayscale(), 90 | transforms.ToTensor() 91 | ]) 92 | 93 | data_floder = os.listdir(self.img_vi_path) 94 | self.data_floder = data_floder 95 | 96 | self.train_size_w = int(train_size_w) 97 | self.train_size_h = int(train_size_h) 98 | 99 | self.feat_size_w = int(self.train_size_w / 8) 100 | self.feat_size_h = int(self.train_size_h / 8) 101 | 102 | self.feat_size_wh = int(self.feat_size_w * self.feat_size_h) 103 | 104 | def __getitem__(self, idx): 105 | item = self.data_floder[idx] 106 | 107 | imagevi, imageir = cv2.imread(os.path.join(self.img_vi_path, item), cv2.IMREAD_GRAYSCALE), cv2.imread( 108 | os.path.join(self.img_ir_path, item), cv2.IMREAD_GRAYSCALE) 109 | 110 | label = cv2.imread(os.path.join(self.label_path, item), cv2.IMREAD_GRAYSCALE) 111 | 112 | imagevi = cv2.resize(imagevi, (self.train_size_w, self.train_size_h)) 113 | imageir = cv2.resize(imageir, (self.train_size_w, self.train_size_h)) 114 | label = cv2.resize(label, (self.train_size_w, self.train_size_h)) 115 | 116 | seed = random.random() 117 | if seed < 0.2: 118 | (h, w) = imagevi.shape[:2] 119 | center = (w // 2, h // 2) 120 | M = cv2.getRotationMatrix2D(center, random.randint(-10, 10), random.randint(3, 8) * 0.1) 121 | imagevi = cv2.warpAffine(imagevi, M, (w, h)) 122 | imageir = cv2.warpAffine(imageir, M, (w, h)) 123 | label = cv2.warpAffine(label, M, (w, h)) 124 | 125 | seed = random.random() 126 | if seed < 0.25: 127 | imagevi = motion_blur(imagevi, degree=random.randint(5, 14), angle=random.randint(-45, 45)) 128 | seed = random.random() 129 | if seed < 0.25: 130 | imageir = motion_blur(imageir, degree=random.randint(5, 14), angle=random.randint(-45, 45)) 131 | seed = random.random() 132 | if seed < 0.25: 133 | imagevi = gasuss_noise(imagevi, mean=0, var=0.001) 134 | seed = random.random() 135 | if seed < 0.25: 136 | imageir = gasuss_noise(imageir, mean=0, var=0.001) 137 | 138 | label_trans, imageir, mat = enhance(label, imageir, [self.train_size_h, self.train_size_w]) 139 | 140 | mask_idx = cv2.resize(label, (self.feat_size_w, self.feat_size_h)) 141 | mask_idx = np.asarray(mask_idx).reshape(-1) 142 | mask_idx = np.where(mask_idx < 20)[0] 143 | 144 | mask_trans_idx = cv2.resize(label_trans, (self.feat_size_w, self.feat_size_h)) 145 | mask_trans_idx = np.asarray(mask_trans_idx).reshape(-1) 146 | mask_trans_idx = np.where(mask_trans_idx < 200)[0] 147 | 148 | imagevi = Image.fromarray(imagevi) 149 | imageir = Image.fromarray(imageir) 150 | 151 | seed = random.random() 152 | if seed < 0.25: 153 | imagevi = imagevi.filter(PilGaussianBlur(radius=random.randint(1, 2))) 154 | 155 | imagevi = tfs.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.4, hue=0.1)(imagevi) 156 | imageir = tfs.ColorJitter(brightness=0.4, contrast=0.5, saturation=0.4, hue=0.1)(imageir) 157 | 158 | imagevi = np.asarray(imagevi) 159 | imageir = np.asarray(imageir) 160 | 161 | 162 | point1 = np.array(list(range(0, self.feat_size_wh))) 163 | point1 = 8 * np.stack([point1 % self.feat_size_w, point1 // self.feat_size_w], axis=1).reshape(1, -1, 2).astype(np.float32) 164 | 165 | point2 = cv2.perspectiveTransform(point1, mat).reshape(-1, 2) 166 | point1 = point1.reshape(-1, 2) 167 | 168 | mask0 = np.where(point1[:, 0] < self.train_size_w - 8, True, False) * np.where(point1[:, 0] > 8, True, False) \ 169 | * np.where(point1[:, 1] > 8, True, False) * np.where(point1[:, 1] < self.train_size_h - 8, True, False) 170 | 171 | mask1 = np.where(point2[:, 0] < self.train_size_w - 8, True, False) * np.where(point2[:, 0] > 8, True, False) \ 172 | * np.where(point2[:, 1] > 8, True, False) * np.where(point2[:, 1] < self.train_size_h - 8, True, False) 173 | 174 | mask1 = mask0 * mask1 175 | point1 = point1.astype(np.int32) 176 | point2 = point2.astype(np.int32) 177 | point1 = point1[mask1] 178 | point2 = point2[mask1] 179 | 180 | point1 = point1 // 8 181 | point2 = point2 // 8 182 | 183 | mask = np.where( 184 | np.all((point2[:, :2] >= (0, 0)) & (point2[:, :2] < (self.feat_size_w, self.feat_size_h)), 185 | axis=1)) 186 | point1 = point1[mask][:, :2] 187 | point2 = point2[mask][:, :2] 188 | 189 | mkpts0 = point1[:, 1] * self.feat_size_w + point1[:, 0] 190 | mkpts1 = point2[:, 1] * self.feat_size_w + point2[:, 0] 191 | 192 | gt_conf_matrix = np.zeros([self.feat_size_wh, self.feat_size_wh], dtype=float) 193 | gt_conf_matrix[mkpts0, mkpts1] = 1.0 194 | gt = gt_conf_matrix 195 | 196 | gt[mask_idx, :] = 0 197 | gt[:, mask_trans_idx] = 0 198 | 199 | imagevi, imageir = transforms.ToPILImage()(imagevi), transforms.ToPILImage()(imageir) 200 | 201 | 202 | 203 | return self.trans(imagevi), self.trans(imageir), gt_conf_matrix, gt 204 | 205 | def __len__(self): 206 | return len(self.data_floder) 207 | 208 | 209 | def motion_blur(image, degree=15, angle=45): 210 | image = np.array(image) 211 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 212 | motion_blur_kernel = np.diag(np.ones(degree)) 213 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree)) 214 | 215 | motion_blur_kernel = motion_blur_kernel / degree 216 | blurred = cv2.filter2D(image, -1, motion_blur_kernel) 217 | cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) 218 | blurred = np.array(blurred, dtype=np.uint8) 219 | return blurred 220 | 221 | 222 | def gasuss_noise(image, mean=0, var=0.001): 223 | image = np.array(image / 255, dtype=float) 224 | noise = np.random.normal(mean, var ** 0.5, image.shape) 225 | out = image + noise 226 | out = np.clip(out, 0.0, 1.0) 227 | out = np.uint8(out * 255) 228 | return out -------------------------------------------------------------------------------- /assets/data_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.data import Dataset 4 | from torchvision import transforms 5 | from matplotlib import pyplot as plt 6 | import torch 7 | import torch.nn 8 | import cv2 9 | import numpy as np 10 | from pathlib import Path 11 | from PIL import Image, ImageFilter 12 | import random 13 | from torchvision import transforms as tfs 14 | import matplotlib 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | 19 | 20 | def enhance(img): 21 | IMAGE_SHAPE = [240,320] 22 | 23 | src_point = np.array([[ 0, 0], 24 | [IMAGE_SHAPE[1]-1, 0], 25 | [ 0, IMAGE_SHAPE[0]-1], 26 | [IMAGE_SHAPE[1]-1, IMAGE_SHAPE[0]-1]], dtype=np.float32) 27 | 28 | dst_point = get_dst_point(0.1, IMAGE_SHAPE) 29 | 30 | # rot = random.randint(-2, 2) * config['homographic']['rotation'] + random.randint(0, 15) 31 | rotation = 30 32 | rot = random.randint(-rotation, rotation) 33 | 34 | rot = random.randint(-30, 30) 35 | 36 | while rot >= -10 and rot <= 10: 37 | rot = random.randint(-30, 30) 38 | 39 | 40 | 41 | 42 | # sc = random.randint(-50, 50) 43 | # 44 | # if sc >= -20 and sc <= 20: 45 | # sc = random.randint(-30, 30) 46 | 47 | 48 | # scale = 1.2 - config['homographic']['scale'] * random.random() 49 | scale = 1.0 + random.randint(-50, 50) * 0.01 50 | 51 | center_offset = 40 52 | center = (IMAGE_SHAPE[1] / 2 + random.randint(-center_offset, center_offset), 53 | IMAGE_SHAPE[0] / 2 + random.randint(-center_offset, center_offset)) 54 | 55 | RS_mat = cv2.getRotationMatrix2D(center, rot, scale) 56 | f_point = np.matmul(dst_point, RS_mat.T).astype('float32') 57 | mat = cv2.getPerspectiveTransform(src_point, f_point) 58 | out_img = cv2.warpPerspective(img, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])) 59 | 60 | return out_img, mat 61 | 62 | 63 | def get_dst_point(perspective, IMAGE_SHAPE): 64 | a = random.random() 65 | b = random.random() 66 | c = random.random() 67 | d = random.random() 68 | e = random.random() 69 | f = random.random() 70 | 71 | if random.random() > 0.5: 72 | left_top_x = perspective*a 73 | left_top_y = perspective*b 74 | right_top_x = 0.9+perspective*c 75 | right_top_y = perspective*d 76 | left_bottom_x = perspective*a 77 | left_bottom_y = 0.9 + perspective*e 78 | right_bottom_x = 0.9 + perspective*c 79 | right_bottom_y = 0.9 + perspective*f 80 | else: 81 | left_top_x = perspective*a 82 | left_top_y = perspective*b 83 | right_top_x = 0.9+perspective*c 84 | right_top_y = perspective*d 85 | left_bottom_x = perspective*e 86 | left_bottom_y = 0.9 + perspective*b 87 | right_bottom_x = 0.9 + perspective*f 88 | right_bottom_y = 0.9 + perspective*d 89 | 90 | dst_point = np.array([(IMAGE_SHAPE[1]*left_top_x,IMAGE_SHAPE[0]*left_top_y,1), 91 | (IMAGE_SHAPE[1]*right_top_x, IMAGE_SHAPE[0]*right_top_y,1), 92 | (IMAGE_SHAPE[1]*left_bottom_x,IMAGE_SHAPE[0]*left_bottom_y,1), 93 | (IMAGE_SHAPE[1]*right_bottom_x,IMAGE_SHAPE[0]*right_bottom_y,1)],dtype = 'float32') 94 | return dst_point 95 | 96 | def make_matching_figure( 97 | img0, img1, mkpts0, mkpts1, color=None, 98 | kpts0=None, kpts1=None, text=[], dpi=100, path=None): 99 | # draw image pair 100 | assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' 101 | fig, axes = plt.subplots(1,2, figsize=(10, 6), dpi=dpi) 102 | axes[0].imshow(img0, cmap='gray') 103 | axes[1].imshow(img1, cmap='gray') 104 | 105 | for i in range(2): # clear all frames 106 | axes[i].get_yaxis().set_ticks([]) 107 | axes[i].get_xaxis().set_ticks([]) 108 | for spine in axes[i].spines.values(): 109 | spine.set_visible(False) 110 | plt.tight_layout(pad=1) 111 | 112 | if kpts0 is not None: 113 | assert kpts1 is not None 114 | axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) 115 | axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) 116 | 117 | # draw matches 118 | if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: 119 | fig.canvas.draw() 120 | transFigure = fig.transFigure.inverted() 121 | fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) 122 | fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) 123 | fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), 124 | (fkpts0[i, 1], fkpts1[i, 1]), 125 | transform=fig.transFigure, c=(124/255,252/255,0), linewidth=1) 126 | for i in range(len(mkpts0))] 127 | 128 | axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=(124/255,252/255,0), s=4) 129 | axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=(124/255,252/255,0), s=4) 130 | 131 | # put txts 132 | # txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' 133 | # fig.text( 134 | # 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, 135 | # fontsize=15, va='top', ha='left', color=txt_color) 136 | 137 | # save or return figure 138 | if path: 139 | plt.savefig(str(path), bbox_inches='tight', pad_inches=0) 140 | plt.close() 141 | else: 142 | return fig 143 | 144 | class PilGaussianBlur(ImageFilter.Filter): 145 | name = "GaussianBlur" 146 | 147 | def __init__(self, radius=2): 148 | self.radius = radius 149 | 150 | def filter(self, image): 151 | return image.gaussian_blur(self.radius) 152 | 153 | def gasuss_noise(image, mean=0, var=0.001): 154 | 155 | image = np.array(image/255, dtype=float) 156 | noise = np.random.normal(mean, var ** 0.5, image.shape) 157 | out = image + noise 158 | out = np.clip(out, 0.0, 1.0) 159 | out = np.uint8(out * 255) 160 | return out 161 | 162 | def motion_blur(image, degree=15, angle=45): 163 | image = np.array(image) 164 | 165 | 166 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 167 | motion_blur_kernel = np.diag(np.ones(degree)) 168 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree)) 169 | 170 | motion_blur_kernel = motion_blur_kernel / degree 171 | blurred = cv2.filter2D(image, -1, motion_blur_kernel) 172 | 173 | # convert to uint8 174 | cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) 175 | blurred = np.array(blurred, dtype=np.uint8) 176 | return blurred 177 | # if __name__ == '__main__': 178 | # image_path = os.listdir(r'D:\SLANet\dataset\MSRS\ir') 179 | # for image_name in image_path: 180 | # image=cv2.imread(r'D:\SLANet\dataset\MSRS\ir'+'\\'+image_name) 181 | # image=cv2.resize(image,(320,240)) 182 | # cv2.imwrite(r'D:\SLANet\dataset\MSRS\irr'+'\\'+image_name,image) 183 | 184 | 185 | if __name__ == '__main__': 186 | image_path = os.listdir(r'D:\SLANet\dataset\m3fd\vii') 187 | for image_name in image_path: 188 | image=cv2.imread(r'D:\SLANet\dataset\m3fd\transir'+'\\'+image_name) 189 | image=cv2.resize(image,(320,240)) 190 | cv2.imwrite(r'D:\SLANet\dataset\m3fd\tranir'+'\\'+image_name,image) 191 | 192 | # image1 = cv2.imread(r'D:\SLANet\dataset\Ir' + '\\' + image_name) 193 | # image1 = cv2.resize(image1, (320, 240)) 194 | # cv2.imwrite(r'D:\SLANet\dataset\m3fd\irr' + '\\' + image_name, image1) 195 | 196 | 197 | 198 | 199 | if __name__ == '__main__': 200 | image_path = os.listdir(r'D:\SLANet\dataset\m3fd\vii') 201 | for image_name in image_path: 202 | # image = cv2.imread(r'D:\SLANet\dataset\MSRS\vii' + '\\' + image_name) 203 | # image = gasuss_noise(image, mean=0, var=0.001) 204 | # cv2.imwrite(r'D:\SLANet\dataset\MSRS\vii' + '\\' + image_name, image) 205 | 206 | image=cv2.imread(r'D:\SLANet\dataset\m3fd\irr'+'\\'+image_name) 207 | # image=cv2.resize(image,(320,240)) 208 | out_img, mat=enhance(image) 209 | # out_img = gasuss_noise(out_img, mean=0, var=0.001) 210 | cv2.imwrite(r'D:\SLANet\dataset\m3fd\transir'+'\\'+image_name,out_img) 211 | np.save(r'D:\SLANet\dataset\m3fd\mat'+'\\' + image_name + '.npy', mat) 212 | 213 | # # 214 | # 215 | # 216 | # 217 | # 218 | # 219 | # 220 | # 221 | # 222 | # 223 | 224 | # point1 = np.array(([[290, 100], [223, 320], [4, 12], [9, 16]])).reshape(1, -1, 2).astype(np.float32) 225 | # point2 = cv2.perspectiveTransform(point1, mat).reshape(-1, 2) 226 | # 227 | # if __name__ == '__main__': 228 | # image_path = os.listdir(r'D:\SLANet\dataset\MSRS\irr') 229 | # for image_name in image_path: 230 | # ir = cv2.imread(r'D:\SLANet\dataset\MSRS\transir' + '\\' + image_name, cv2.IMREAD_GRAYSCALE) 231 | # vi = cv2.imread(r'D:\SLANet\dataset\MSRS\vii' + '\\' + image_name, cv2.IMREAD_GRAYSCALE) 232 | # 233 | # mat = np.load(r'D:\SLANet\dataset\MSRS\mat'+'\\' + image_name + '.npy') 234 | # 235 | # point1 = np.array(([[210, 100], [223, 150], [4, 12], [9, 16]])).reshape(1, -1, 2).astype(np.float32) 236 | # point2 = cv2.perspectiveTransform(point1, mat).reshape(-1, 2) 237 | # point1=point1.reshape(-1, 2) 238 | # 239 | # fig = make_matching_figure(vi, ir, point1, point2) 240 | # 241 | # # fuse=ir*0.5+vi*0.5 242 | # # plt.imshow(fuse,cmap='gray') 243 | # 244 | # plt.show() 245 | 246 | # if __name__ == '__main__': 247 | # image_path = os.listdir(r'D:\SLANet\dataset\MSRS\vi') 248 | # for image_name in image_path: 249 | # image = cv2.imread(r'D:\SLANet\dataset\MSRS\MSRS-master\crop_LR_visible' + '\\' + image_name) 250 | # image = cv2.resize(image, (320, 240)) 251 | # cv2.imwrite(r'D:\SLANet\dataset\MSRS\vii' + '\\' + image_name, image) 252 | 253 | 254 | 255 | 256 | 257 | 258 | -------------------------------------------------------------------------------- /train_stage1/reg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | from torchvision import transforms 4 | import cv2 5 | import numpy as np 6 | from PIL import Image, ImageFilter 7 | import random 8 | from torchvision import transforms as tfs 9 | 10 | """ 11 | Implementation based on the modification of 12 | https://github.com/791136190/UnsuperPoint_PyTorch/blob/main/Unsuper/dataset/coco.py 13 | """ 14 | 15 | def get_dst_point(perspective, IMAGE_SHAPE): 16 | a = random.random() 17 | b = random.random() 18 | c = random.random() 19 | d = random.random() 20 | e = random.random() 21 | f = random.random() 22 | 23 | if random.random() > 0.5: 24 | left_top_x = perspective * a 25 | left_top_y = perspective * b 26 | right_top_x = 0.9 + perspective * c 27 | right_top_y = perspective * d 28 | left_bottom_x = perspective * a 29 | left_bottom_y = 0.9 + perspective * e 30 | right_bottom_x = 0.9 + perspective * c 31 | right_bottom_y = 0.9 + perspective * f 32 | else: 33 | left_top_x = perspective * a 34 | left_top_y = perspective * b 35 | right_top_x = 0.9 + perspective * c 36 | right_top_y = perspective * d 37 | left_bottom_x = perspective * e 38 | left_bottom_y = 0.9 + perspective * b 39 | right_bottom_x = 0.9 + perspective * f 40 | right_bottom_y = 0.9 + perspective * d 41 | 42 | dst_point = np.array([(IMAGE_SHAPE[1] * left_top_x, IMAGE_SHAPE[0] * left_top_y, 1), 43 | (IMAGE_SHAPE[1] * right_top_x, IMAGE_SHAPE[0] * right_top_y, 1), 44 | (IMAGE_SHAPE[1] * left_bottom_x, IMAGE_SHAPE[0] * left_bottom_y, 1), 45 | (IMAGE_SHAPE[1] * right_bottom_x, IMAGE_SHAPE[0] * right_bottom_y, 1)], dtype='float32') 46 | return dst_point 47 | 48 | 49 | def enhance(img, IMAGE_SHAPE): 50 | # The four vertices of the image 51 | src_point = np.array([[0, 0], 52 | [IMAGE_SHAPE[1] - 1, 0], 53 | [0, IMAGE_SHAPE[0] - 1], 54 | [IMAGE_SHAPE[1] - 1, IMAGE_SHAPE[0] - 1]], dtype=np.float32) 55 | 56 | # Perspective Information 57 | dst_point = get_dst_point(0.2, IMAGE_SHAPE) 58 | 59 | # Rotation and scale transformation 60 | rotation = 25 61 | rot = random.randint(-rotation, rotation) 62 | scale = 1.2 + random.randint(-90, 100) * 0.01 63 | 64 | center_offset = 40 65 | center = (IMAGE_SHAPE[1] / 2 + random.randint(-center_offset, center_offset), 66 | IMAGE_SHAPE[0] / 2 + random.randint(-center_offset, center_offset)) 67 | 68 | RS_mat = cv2.getRotationMatrix2D(center, rot, scale) 69 | f_point = np.matmul(dst_point, RS_mat.T).astype('float32') 70 | mat = cv2.getPerspectiveTransform(src_point, f_point) 71 | out_img = cv2.warpPerspective(img, mat, (IMAGE_SHAPE[1], IMAGE_SHAPE[0])) 72 | 73 | return out_img, mat, f_point 74 | 75 | 76 | class PilGaussianBlur(ImageFilter.Filter): 77 | name = "GaussianBlur" 78 | 79 | def __init__(self, radius=2): 80 | self.radius = radius 81 | 82 | def filter(self, image): 83 | return image.gaussian_blur(self.radius) 84 | 85 | 86 | def motion_blur(image, degree=15, angle=45): 87 | image = np.array(image) 88 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 89 | motion_blur_kernel = np.diag(np.ones(degree)) 90 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree)) 91 | motion_blur_kernel = motion_blur_kernel / degree 92 | blurred = cv2.filter2D(image, -1, motion_blur_kernel) 93 | cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) 94 | blurred = np.array(blurred, dtype=np.uint8) 95 | return blurred 96 | 97 | 98 | def resize_img(img, IMAGE_SHAPE): 99 | h, w = img.shape[:2] 100 | if h < IMAGE_SHAPE[0] or w < IMAGE_SHAPE[1]: 101 | new_h = IMAGE_SHAPE[0] 102 | new_w = IMAGE_SHAPE[1] 103 | h = new_h 104 | w = new_w 105 | img = cv2.resize(img, (new_w, new_h)) 106 | new_h, new_w = IMAGE_SHAPE 107 | try: 108 | top = np.random.randint(0, h - new_h + 1) 109 | left = np.random.randint(0, w - new_w + 1) 110 | except: 111 | print(h, new_h, w, new_w) 112 | raise 113 | if len(img.shape) == 2: 114 | img = img[top: top + new_h, left: left + new_w] # crop image 115 | else: 116 | img = img[top: top + new_h, left: left + new_w, :] 117 | return img 118 | 119 | 120 | def gasuss_noise(image, mean=0, var=0.001): 121 | image = np.array(image / 255, dtype=float) 122 | noise = np.random.normal(mean, var ** 0.5, image.shape) 123 | out = image + noise 124 | out = np.clip(out, 0.0, 1.0) 125 | out = np.uint8(out * 255) 126 | return out 127 | 128 | 129 | class RegDataset(Dataset): 130 | def __init__(self, data_path_vi, data_path_ir, train_size_w, train_size_h): 131 | super(RegDataset, self).__init__() 132 | 133 | self.trans = transforms.Compose([ 134 | transforms.ToTensor() 135 | ]) 136 | 137 | self.data_floder = os.listdir(data_path_vi) 138 | # data = [] 139 | # for item in data_floder: 140 | # data.append((os.path.join(data_path_vi, item))) 141 | self.data_path_vi = data_path_vi 142 | self.data_path_ir = data_path_ir 143 | 144 | self.train_size_w = int(train_size_w) 145 | self.train_size_h = int(train_size_h) 146 | self.num_of_points = int((self.train_size_w / 8) * (self.train_size_h / 8)) 147 | 148 | self.feat_size_w = self.train_size_w/8 149 | self.feat_size_h = self.train_size_h/8 150 | 151 | 152 | 153 | 154 | def __getitem__(self, idx): 155 | image_name = self.data_floder[idx] 156 | 157 | 158 | # visible images, obtained from COCO dataset 159 | image_vi = cv2.imread(os.path.join(self.data_path_vi, image_name), cv2.IMREAD_GRAYSCALE) 160 | image_vi = cv2.resize(image_vi, (self.train_size_w, self.train_size_h)) 161 | 162 | # pseudo-infrared image, obtained by CPSTN 163 | # refer to the Details of Implementation section of the SemLA paper 164 | image_ir = cv2.imread(os.path.join(self.data_path_ir, image_name[:-3]+'png'), cv2.IMREAD_GRAYSCALE) 165 | image_ir = cv2.resize(image_ir, (self.train_size_w, self.train_size_h)) 166 | 167 | 168 | # data enhancement 169 | seed = random.random() 170 | if seed < 0.25: 171 | (h, w) = image_vi.shape[:2] 172 | center = (w // 2, h // 2) 173 | M = cv2.getRotationMatrix2D(center, random.randint(-20, 20), random.randint(10,40) * 0.1) 174 | image_vi = cv2.warpAffine(image_vi, M, (w, h)) 175 | image_ir = cv2.warpAffine(image_ir, M, (w, h)) 176 | 177 | seed = random.random() 178 | if seed < 0.25: 179 | image_vi = motion_blur(image_vi, degree=random.randint(7, 13), angle=random.randint(-45, 45)) 180 | seed = random.random() 181 | if seed < 0.25: 182 | image_ir = motion_blur(image_ir, degree=random.randint(7, 13), angle=random.randint(-45, 45)) 183 | 184 | seed = random.random() 185 | if seed < 0.25: 186 | image_vi = gasuss_noise(image_vi, mean=0, var=0.001) 187 | seed = random.random() 188 | if seed < 0.25: 189 | image_ir = gasuss_noise(image_ir, mean=0, var=0.001) 190 | 191 | image_vi = resize_img(image_vi, [self.train_size_h, self.train_size_w]) 192 | image_ir = resize_img(image_ir, [self.train_size_h, self.train_size_w]) # reshape the image 193 | image_ir, mat, f_point = enhance(image_ir, [self.train_size_h, self.train_size_w]) 194 | 195 | 196 | # cv2 -> PIL 197 | image_vi = Image.fromarray(image_vi) # rgb 198 | image_ir = Image.fromarray(image_ir) # rgb 199 | 200 | seed = random.random() 201 | if seed < 0.25: 202 | image_vi = image_vi.filter(PilGaussianBlur(radius=random.randint(1, 2))) 203 | 204 | image_vi = tfs.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.4, hue=0.1)(image_vi) 205 | 206 | seed = random.random() 207 | if seed < 0.25: 208 | image_ir = image_ir.filter(PilGaussianBlur(radius=random.randint(1, 2))) 209 | 210 | image_ir = tfs.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)(image_ir) 211 | 212 | # PIL -> cv2 213 | image_vi = np.asarray(image_vi) 214 | image_ir = np.asarray(image_ir) 215 | 216 | 217 | # Generate ground truth confidence matrix based on the synthesized transformation matrix 218 | point1 = np.array(list(range(0, self.num_of_points))) 219 | point1 = 8 * np.stack([point1 % self.feat_size_w, point1 // self.feat_size_w], axis=1).reshape(1, -1, 2).astype(np.float32) 220 | 221 | point2 = cv2.perspectiveTransform(point1, mat).reshape(-1, 2) 222 | point1 = point1.reshape(-1, 2) 223 | 224 | mask0 = np.where(point1[:, 0] < (self.train_size_w-8), True, False) * np.where(point1[:, 0] > 8, True, False) \ 225 | * np.where(point1[:, 1] > 8, True, False) * np.where(point1[:, 1] < (self.train_size_h-8), True, False) 226 | 227 | mask1 = np.where(point2[:, 0] < (self.train_size_w-8), True, False) * np.where(point2[:, 0] > 8, True, False) \ 228 | * np.where(point2[:, 1] > 8, True, False) * np.where(point2[:, 1] < (self.train_size_h-8), True, False) 229 | 230 | mask1 = mask0 * mask1 231 | point1 = point1.astype(np.int32) 232 | point2 = point2.astype(np.int32) 233 | point1 = point1[mask1] 234 | point2 = point2[mask1] 235 | 236 | point1 = point1 // 8 237 | point2 = point2 // 8 238 | 239 | mask = np.where( 240 | np.all((point2[:, :2] >= (0, 0)) & (point2[:, :2] < (self.feat_size_w, self.feat_size_h)), 241 | axis=1)) 242 | point1 = point1[mask][:, :2] 243 | point2 = point2[mask][:, :2] 244 | 245 | mkpts0 = point1[:, 1] * self.feat_size_w + point1[:, 0] 246 | mkpts1 = point2[:, 1] * self.feat_size_w + point2[:, 0] 247 | 248 | gt_conf_matrix = np.zeros([self.num_of_points, self.num_of_points], dtype=float) 249 | gt_conf_matrix[mkpts0.astype(np.int32), mkpts1.astype(np.int32)] = 1.0 250 | 251 | image_vi = transforms.ToPILImage()(image_vi) 252 | image_ir = transforms.ToPILImage()(image_ir) 253 | 254 | return (self.trans(image_vi), self.trans(image_ir), gt_conf_matrix) 255 | 256 | def __len__(self): 257 | return len(self.data_floder) 258 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/backbone/match_LA_large.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import math 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class DWConv(nn.Module): 16 | def __init__(self, dim=768): 17 | super(DWConv, self).__init__() 18 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 19 | 20 | def forward(self, x, H, W): 21 | B, N, C = x.shape 22 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 23 | x = self.dwconv(x) 24 | x = x.flatten(2).transpose(1, 2) 25 | 26 | return x 27 | 28 | class Mlp(nn.Module): 29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 30 | super().__init__() 31 | out_features = out_features or in_features 32 | hidden_features = hidden_features or in_features 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.dwconv = DWConv(hidden_features) 35 | self.act = act_layer() 36 | self.fc2 = nn.Linear(hidden_features, out_features) 37 | self.drop = nn.Dropout(drop) 38 | 39 | def forward(self, x,H,W): 40 | x = self.fc1(x) 41 | x = self.dwconv(x, H, W) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | def elu_feature_map(x): 49 | return torch.nn.functional.elu(x) + 1 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, eps=1e-6 , cross = False): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | 56 | self.cross = cross 57 | self.feature_map = elu_feature_map 58 | self.eps = eps 59 | self.dim = dim 60 | self.num_heads = num_heads 61 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 62 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 63 | 64 | def forward(self, x): 65 | x_q, x_kv = x, x 66 | B, N, C = x_q.shape 67 | MiniB = B // 2 68 | query = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 1, 2, 3) 69 | kv = self.kv(x_kv).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) 70 | 71 | if self.cross == True: 72 | k1, k2 = kv[0].split(MiniB) 73 | v1, v2 = kv[1].split(MiniB) 74 | key = torch.cat([k2, k1], dim=0) 75 | value = torch.cat([v2, v1], dim=0) 76 | else: 77 | key, value = kv[0], kv[1] 78 | 79 | Q = self.feature_map(query) 80 | K = self.feature_map(key) 81 | v_length = value.size(1) 82 | value = value / v_length 83 | KV = torch.einsum("nshd,nshv->nhdv", K, value) 84 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 85 | x = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 86 | x = x.contiguous().view(B, -1, C) 87 | 88 | return x 89 | 90 | class Block(nn.Module): 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,cross = False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, 97 | num_heads=num_heads, qkv_bias=qkv_bias, cross= cross) 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm = norm_layer(dim ) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | 104 | def forward(self, x,H,W): 105 | x = x + self.drop_path(self.attn(self.norm1(x))) 106 | x = x + self.drop_path(self.mlp(self.norm(x),H,W)) 107 | 108 | return x 109 | 110 | class Positional(nn.Module): 111 | def __init__(self, dim): 112 | super().__init__() 113 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 114 | self.sigmoid = nn.Sigmoid() 115 | 116 | def forward(self, x): 117 | return x * self.sigmoid(self.pa_conv(x)) 118 | 119 | class PatchEmbed(nn.Module): 120 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, with_pos=True): 121 | super().__init__() 122 | img_size = to_2tuple(img_size) 123 | patch_size = to_2tuple(patch_size) 124 | self.img_size = img_size 125 | self.patch_size = patch_size 126 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 127 | self.num_patches = self.H * self.W 128 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 129 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 130 | 131 | self.with_pos = with_pos 132 | if self.with_pos: 133 | self.pos = Positional(embed_dim) 134 | 135 | self.norm = nn.LayerNorm(embed_dim) 136 | 137 | def forward(self, x): 138 | x = self.proj(x) 139 | if self.with_pos: 140 | x = self.pos(x) 141 | _, _, H, W = x.shape 142 | x = x.flatten(2).transpose(1, 2) 143 | x = self.norm(x) 144 | 145 | return x, H, W 146 | 147 | class AttentionBlock(nn.Module): 148 | def __init__(self, img_size=224, in_chans=1, embed_dims=128, patch_size=7, num_heads=1, mlp_ratios=4, 149 | qkv_bias=True, drop_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),stride=2, depths=1, cross=[False,False,True]): 150 | super().__init__() 151 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride = stride, in_chans=in_chans, 152 | embed_dim=embed_dims) 153 | self.block = nn.ModuleList([Block( 154 | dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, 155 | drop=drop_rate, drop_path=0, norm_layer=norm_layer, cross=cross[i]) 156 | for i in range(depths)]) 157 | self.norm = norm_layer(embed_dims) 158 | 159 | def forward(self, x): 160 | B = x.shape[0] 161 | x, H, W = self.patch_embed(x) 162 | for i, blk in enumerate(self.block): 163 | x = blk(x,H,W) 164 | x = self.norm(x) 165 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 166 | 167 | return x 168 | 169 | class Matchformer_LA_large(nn.Module): 170 | def __init__(self, img_size=224, in_chans=1, embed_dims=[128, 192, 256, 512], num_heads=[8, 8, 8, 8], 171 | stage1_cross = [False,False,True],stage2_cross = [False,False,True],stage3_cross = [False,True,True],stage4_cross = [False,True,True]): 172 | super().__init__() 173 | #Attention 174 | self.AttentionBlock1 = AttentionBlock(img_size=img_size // 2, patch_size=7, num_heads= num_heads[0], mlp_ratios=4, in_chans=in_chans, 175 | embed_dims=embed_dims[0],stride=2,depths=3, cross=stage1_cross) 176 | self.AttentionBlock2 = AttentionBlock(img_size=img_size // 4, patch_size=3, num_heads= num_heads[1], mlp_ratios=4, in_chans=embed_dims[0], 177 | embed_dims=embed_dims[1],stride=2,depths=3, cross=stage2_cross) 178 | self.AttentionBlock3 = AttentionBlock(img_size=img_size // 16,patch_size=3, num_heads= num_heads[2], mlp_ratios=4, in_chans=embed_dims[1], 179 | embed_dims=embed_dims[2],stride=2,depths=3, cross=stage3_cross) 180 | self.AttentionBlock4 = AttentionBlock(img_size=img_size // 32,patch_size=3, num_heads= num_heads[3], mlp_ratios=4, in_chans=embed_dims[2], 181 | embed_dims=embed_dims[3],stride=2,depths=3, cross=stage4_cross) 182 | 183 | #FPN 184 | self.layer4_outconv = conv1x1(embed_dims[3], embed_dims[3]) 185 | self.layer3_outconv = conv1x1(embed_dims[2], embed_dims[3]) 186 | self.layer3_outconv2 = nn.Sequential( 187 | conv3x3(embed_dims[3], embed_dims[3]), 188 | nn.BatchNorm2d(embed_dims[3]), 189 | nn.LeakyReLU(), 190 | conv3x3(embed_dims[3], embed_dims[2]), 191 | ) 192 | 193 | self.layer2_outconv = conv1x1(embed_dims[1], embed_dims[2]) 194 | self.layer2_outconv2 = nn.Sequential( 195 | conv3x3(embed_dims[2], embed_dims[2]), 196 | nn.BatchNorm2d(embed_dims[2]), 197 | nn.LeakyReLU(), 198 | conv3x3(embed_dims[2], embed_dims[1]), 199 | ) 200 | self.layer1_outconv = conv1x1(embed_dims[0], embed_dims[1]) 201 | self.layer1_outconv2 = nn.Sequential( 202 | conv3x3(embed_dims[1], embed_dims[1]), 203 | nn.BatchNorm2d(embed_dims[1]), 204 | nn.LeakyReLU(), 205 | conv3x3(embed_dims[1], embed_dims[0]), 206 | ) 207 | 208 | self.apply(self._init_weights) 209 | 210 | def _init_weights(self, m): 211 | if isinstance(m, nn.Linear): 212 | trunc_normal_(m.weight, std=.02) 213 | if isinstance(m, nn.Linear) and m.bias is not None: 214 | nn.init.constant_(m.bias, 0) 215 | elif isinstance(m, nn.LayerNorm): 216 | nn.init.constant_(m.bias, 0) 217 | nn.init.constant_(m.weight, 1.0) 218 | elif isinstance(m, nn.Conv2d): 219 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 220 | fan_out //= m.groups 221 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | 225 | def forward(self, x): 226 | # stage 1 # 1/2 227 | x = self.AttentionBlock1(x) 228 | out1 = x 229 | # stage 2 # 1/4 230 | x = self.AttentionBlock2(x) 231 | out2 = x 232 | # stage 3 # 1/8 233 | x = self.AttentionBlock3(x) 234 | out3 = x 235 | # stage 3 # 1/16 236 | x = self.AttentionBlock4(x) 237 | out4 = x 238 | 239 | #FPN 240 | c4_out = self.layer4_outconv(out4) 241 | _,_,H,W = out3.shape 242 | c4_out_2x = F.interpolate(c4_out, size =(H,W), mode='bilinear', align_corners=True) 243 | c3_out = self.layer3_outconv(out3) 244 | _,_,H,W = out2.shape 245 | c3_out = self.layer3_outconv2(c3_out +c4_out_2x) 246 | c3_out_2x = F.interpolate(c3_out, size =(H,W), mode='bilinear', align_corners=True) 247 | c2_out = self.layer2_outconv(out2) 248 | _,_,H,W = out1.shape 249 | c2_out = self.layer2_outconv2(c2_out +c3_out_2x) 250 | c2_out_2x = F.interpolate(c2_out, size =(H,W), mode='bilinear', align_corners=True) 251 | c1_out = self.layer1_outconv(out1) 252 | c1_out = self.layer1_outconv2(c1_out+c2_out_2x) 253 | 254 | return c3_out,c1_out 255 | -------------------------------------------------------------------------------- /train_stage1/teacher/model/backbone/match_LA_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | import math 7 | 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class DWConv(nn.Module): 16 | def __init__(self, dim=768): 17 | super(DWConv, self).__init__() 18 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 19 | 20 | def forward(self, x, H, W): 21 | B, N, C = x.shape 22 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 23 | x = self.dwconv(x) 24 | x = x.flatten(2).transpose(1, 2) 25 | 26 | return x 27 | 28 | class Mlp(nn.Module): 29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 30 | super().__init__() 31 | out_features = out_features or in_features 32 | hidden_features = hidden_features or in_features 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.dwconv = DWConv(hidden_features) 35 | self.act = act_layer() 36 | self.fc2 = nn.Linear(hidden_features, out_features) 37 | self.drop = nn.Dropout(drop) 38 | 39 | def forward(self, x,H,W): 40 | x = self.fc1(x) 41 | x = self.dwconv(x, H, W) 42 | x = self.act(x) 43 | x = self.drop(x) 44 | x = self.fc2(x) 45 | x = self.drop(x) 46 | return x 47 | 48 | def elu_feature_map(x): 49 | return torch.nn.functional.elu(x) + 1 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, dim, num_heads=8, qkv_bias=False, eps=1e-6 , cross = False): 53 | super().__init__() 54 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 55 | 56 | self.cross = cross 57 | self.feature_map = elu_feature_map 58 | self.eps = eps 59 | self.dim = dim 60 | self.num_heads = num_heads 61 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 62 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 63 | 64 | def forward(self, x): 65 | x_q, x_kv = x, x 66 | B, N, C = x_q.shape 67 | MiniB = B // 2 68 | query = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 1, 2, 3) 69 | kv = self.kv(x_kv).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) 70 | 71 | if self.cross == True: 72 | k1, k2 = kv[0].split(MiniB) 73 | v1, v2 = kv[1].split(MiniB) 74 | key = torch.cat([k2, k1], dim=0) 75 | value = torch.cat([v2, v1], dim=0) 76 | else: 77 | key, value = kv[0], kv[1] 78 | 79 | Q = self.feature_map(query) 80 | K = self.feature_map(key) 81 | v_length = value.size(1) 82 | value = value / v_length 83 | KV = torch.einsum("nshd,nshv->nhdv", K, value) 84 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 85 | x = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 86 | x = x.contiguous().view(B, -1, C) 87 | 88 | return x 89 | 90 | class Block(nn.Module): 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,cross = False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, 97 | num_heads=num_heads, qkv_bias=qkv_bias, cross= cross) 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm = norm_layer(dim ) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | 104 | def forward(self, x,H,W): 105 | x = x + self.drop_path(self.attn(self.norm1(x))) 106 | x = x + self.drop_path(self.mlp(self.norm(x),H,W)) 107 | 108 | return x 109 | 110 | class Positional(nn.Module): 111 | def __init__(self, dim): 112 | super().__init__() 113 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 114 | self.sigmoid = nn.Sigmoid() 115 | 116 | def forward(self, x): 117 | return x * self.sigmoid(self.pa_conv(x)) 118 | 119 | class PatchEmbed(nn.Module): 120 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, with_pos=True): 121 | super().__init__() 122 | img_size = to_2tuple(img_size) 123 | patch_size = to_2tuple(patch_size) 124 | self.img_size = img_size 125 | self.patch_size = patch_size 126 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 127 | self.num_patches = self.H * self.W 128 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 129 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 130 | 131 | self.with_pos = with_pos 132 | if self.with_pos: 133 | self.pos = Positional(embed_dim) 134 | 135 | self.norm = nn.LayerNorm(embed_dim) 136 | 137 | def forward(self, x): 138 | x = self.proj(x) 139 | if self.with_pos: 140 | x = self.pos(x) 141 | _, _, H, W = x.shape 142 | x = x.flatten(2).transpose(1, 2) 143 | x = self.norm(x) 144 | 145 | return x, H, W 146 | 147 | class AttentionBlock(nn.Module): 148 | def __init__(self, img_size=224, in_chans=1, embed_dims=128, patch_size=7, num_heads=1, mlp_ratios=4, 149 | qkv_bias=True, drop_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),stride=2, depths=1, cross=[False,False,True]): 150 | super().__init__() 151 | self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, stride = stride, in_chans=in_chans, 152 | embed_dim=embed_dims) 153 | self.block = nn.ModuleList([Block( 154 | dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, 155 | drop=drop_rate, drop_path=0, norm_layer=norm_layer, cross=cross[i]) 156 | for i in range(depths)]) 157 | self.norm = norm_layer(embed_dims) 158 | 159 | def forward(self, x): 160 | B = x.shape[0] 161 | x, H, W = self.patch_embed(x) 162 | for i, blk in enumerate(self.block): 163 | x = blk(x,H,W) 164 | x = self.norm(x) 165 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 166 | 167 | return x 168 | 169 | class Matchformer_LA_lite(nn.Module): 170 | def __init__(self, img_size=224, in_chans=1, embed_dims=[128, 192, 256, 512], num_heads=[8, 8, 8, 8], 171 | stage1_cross = [False,False,True],stage2_cross = [False,False,True],stage3_cross = [False,True,True],stage4_cross = [False,True,True]): 172 | super().__init__() 173 | #Attention 174 | self.AttentionBlock1 = AttentionBlock(img_size=img_size // 2, patch_size=7, num_heads= num_heads[0], mlp_ratios=4, in_chans=in_chans, 175 | embed_dims=embed_dims[0],stride=4,depths=3, cross=stage1_cross) 176 | self.AttentionBlock2 = AttentionBlock(img_size=img_size // 4, patch_size=3, num_heads= num_heads[1], mlp_ratios=4, in_chans=embed_dims[0], 177 | embed_dims=embed_dims[1],stride=2,depths=3, cross=stage2_cross) 178 | self.AttentionBlock3 = AttentionBlock(img_size=img_size // 16,patch_size=3, num_heads= num_heads[2], mlp_ratios=4, in_chans=embed_dims[1], 179 | embed_dims=embed_dims[2],stride=2,depths=3, cross=stage3_cross) 180 | self.AttentionBlock4 = AttentionBlock(img_size=img_size // 32,patch_size=3, num_heads= num_heads[3], mlp_ratios=4, in_chans=embed_dims[2], 181 | embed_dims=embed_dims[3],stride=2,depths=3, cross=stage4_cross) 182 | 183 | #FPN 184 | self.layer4_outconv = conv1x1(embed_dims[3], embed_dims[3]) 185 | self.layer3_outconv = conv1x1(embed_dims[2], embed_dims[3]) 186 | self.layer3_outconv2 = nn.Sequential( 187 | conv3x3(embed_dims[3], embed_dims[3]), 188 | nn.BatchNorm2d(embed_dims[3]), 189 | nn.LeakyReLU(), 190 | conv3x3(embed_dims[3], embed_dims[2]), 191 | ) 192 | 193 | self.layer2_outconv = conv1x1(embed_dims[1], embed_dims[2]) 194 | self.layer2_outconv2 = nn.Sequential( 195 | conv3x3(embed_dims[2], embed_dims[2]), 196 | nn.BatchNorm2d(embed_dims[2]), 197 | nn.LeakyReLU(), 198 | conv3x3(embed_dims[2], embed_dims[1]), 199 | ) 200 | self.layer1_outconv = conv1x1(embed_dims[0], embed_dims[1]) 201 | self.layer1_outconv2 = nn.Sequential( 202 | conv3x3(embed_dims[1], embed_dims[1]), 203 | nn.BatchNorm2d(embed_dims[1]), 204 | nn.LeakyReLU(), 205 | conv3x3(embed_dims[1], embed_dims[0]), 206 | ) 207 | 208 | self.apply(self._init_weights) 209 | 210 | def _init_weights(self, m): 211 | if isinstance(m, nn.Linear): 212 | trunc_normal_(m.weight, std=.02) 213 | if isinstance(m, nn.Linear) and m.bias is not None: 214 | nn.init.constant_(m.bias, 0) 215 | elif isinstance(m, nn.LayerNorm): 216 | nn.init.constant_(m.bias, 0) 217 | nn.init.constant_(m.weight, 1.0) 218 | elif isinstance(m, nn.Conv2d): 219 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 220 | fan_out //= m.groups 221 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 222 | if m.bias is not None: 223 | m.bias.data.zero_() 224 | 225 | def forward(self, x): 226 | # stage 1 # 1/4 227 | x = self.AttentionBlock1(x) 228 | out1 = x 229 | # stage 2 # 1/8 230 | x = self.AttentionBlock2(x) 231 | out2 = x 232 | # stage 3 # 1/16 233 | x = self.AttentionBlock3(x) 234 | out3 = x 235 | # stage 3 # 1/32 236 | x = self.AttentionBlock4(x) 237 | out4 = x 238 | 239 | #FPN 240 | c4_out = self.layer4_outconv(out4) 241 | _,_,H,W = out3.shape 242 | c4_out_2x = F.interpolate(c4_out, size =(H,W), mode='bilinear', align_corners=True) 243 | c3_out = self.layer3_outconv(out3) 244 | _,_,H,W = out2.shape 245 | c3_out = self.layer3_outconv2(c3_out +c4_out_2x) 246 | c3_out_2x = F.interpolate(c3_out, size =(H,W), mode='bilinear', align_corners=True) 247 | c2_out = self.layer2_outconv(out2) 248 | _,_,H,W = out1.shape 249 | c2_out = self.layer2_outconv2(c2_out +c3_out_2x) 250 | c2_out_2x = F.interpolate(c2_out, size =(H,W), mode='bilinear', align_corners=True) 251 | c1_out = self.layer1_outconv(out1) 252 | c1_out = self.layer1_outconv2(c1_out+c2_out_2x) 253 | 254 | return c2_out,c1_out --------------------------------------------------------------------------------