├── fig └── fig.PNG ├── warmup_scheduler ├── __init__.py ├── run.py └── scheduler.py ├── requirements.txt ├── utils ├── __init__.py ├── dir_utils.py ├── dataset_utils.py ├── image_utils.py └── model_utils.py ├── losses.py ├── CR.py ├── options.py ├── dataloader.py ├── README.md ├── train.py └── model.py /fig/fig.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bzq-Hit/HoloFormer/HEAD/fig/fig.PNG -------------------------------------------------------------------------------- /warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from warmup_scheduler.scheduler import GradualWarmupScheduler 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision==0.15.2 3 | numpy 4 | einops 5 | tqdm 6 | timm 7 | natsort -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .dataset_utils import * 3 | from .model_utils import * 4 | from .image_utils import * -------------------------------------------------------------------------------- /utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class CharbonnierLoss(nn.Module): 7 | """Charbonnier Loss (L1)""" 8 | 9 | def __init__(self, eps=1e-3): 10 | super(CharbonnierLoss, self).__init__() 11 | self.eps = eps 12 | 13 | def forward(self, x, y): 14 | diff = x - y 15 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 16 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 17 | return loss 18 | 19 | class MSELoss(nn.Module): # MSE Loss 20 | """MSE Loss""" 21 | 22 | def __init__(self, eps=1e-3): 23 | super(MSELoss, self).__init__() 24 | self.eps = eps 25 | 26 | def forward(self, x, y): 27 | diff = x - y 28 | loss = torch.mean(diff * diff + self.eps*self.eps) 29 | return loss 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ### rotate and flip 4 | class Augment_RGB_torch: 5 | def __init__(self): 6 | pass 7 | def transform0(self, torch_tensor): 8 | return torch_tensor 9 | def transform1(self, torch_tensor): 10 | torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2]) 11 | return torch_tensor 12 | def transform2(self, torch_tensor): 13 | torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2]) 14 | return torch_tensor 15 | def transform3(self, torch_tensor): 16 | torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2]) 17 | return torch_tensor 18 | def transform4(self, torch_tensor): 19 | torch_tensor = torch_tensor.flip(-2) 20 | return torch_tensor 21 | def transform5(self, torch_tensor): 22 | torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2) 23 | return torch_tensor 24 | def transform6(self, torch_tensor): 25 | torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2) 26 | return torch_tensor 27 | def transform7(self, torch_tensor): 28 | torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2) 29 | return torch_tensor 30 | 31 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def myPSNR(tar_img, prd_img): 5 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) + 1e-8 6 | rmse = (imdff**2).mean().sqrt() 7 | ps = 20*torch.log10(1/rmse) 8 | return ps 9 | 10 | def batch_PSNR(img1, img2, average=True): 11 | PSNR = [] 12 | for im1, im2 in zip(img1, img2): 13 | psnr = myPSNR(im1, im2) 14 | PSNR.append(psnr) 15 | return sum(PSNR)/len(PSNR) if average else sum(PSNR) 16 | 17 | def mySSIM(tar_img, prd_img): 18 | tar_mean = F.avg_pool2d(tar_img, kernel_size=3, stride=1, padding=0) 19 | prd_mean = F.avg_pool2d(prd_img, kernel_size=3, stride=1, padding=0) 20 | 21 | tar_var = F.avg_pool2d(tar_img**2, kernel_size=3, stride=1, padding=0) - tar_mean**2 22 | prd_var = F.avg_pool2d(prd_img**2, kernel_size=3, stride=1, padding=0) - prd_mean**2 23 | 24 | covar = F.avg_pool2d(tar_img * prd_img, kernel_size=3, stride=1, padding=0) - tar_mean * prd_mean 25 | 26 | C1 = 0.01**2 27 | C2 = 0.03**2 28 | 29 | ssim = (2 * tar_mean * prd_mean + C1) * (2 * covar + C2) / ((tar_mean**2 + prd_mean**2 + C1) * (tar_var + prd_var + C2)) 30 | 31 | return ssim.mean() 32 | 33 | def batch_SSIM(img1, img2, average=True): 34 | SSIM = [] 35 | for im1, im2 in zip(img1, img2): 36 | ssim = mySSIM(im1, im2) 37 | SSIM.append(ssim) 38 | return sum(SSIM)/len(SSIM) if average else sum(SSIM) 39 | 40 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from collections import OrderedDict 5 | 6 | def save_checkpoint(model_dir, state, session): 7 | epoch = state['epoch'] 8 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 9 | torch.save(state, model_out_path) 10 | 11 | def load_checkpoint(model, weights): 12 | checkpoint = torch.load(weights) 13 | try: 14 | model.load_state_dict(checkpoint["state_dict"]) 15 | except: 16 | state_dict = checkpoint["state_dict"] 17 | new_state_dict = OrderedDict() 18 | for k, v in state_dict.items(): 19 | name = k[7:] if 'module.' in k else k 20 | new_state_dict[name] = v 21 | model.load_state_dict(new_state_dict) 22 | 23 | 24 | def load_checkpoint_multigpu(model, weights): 25 | checkpoint = torch.load(weights) 26 | state_dict = checkpoint["state_dict"] 27 | new_state_dict = OrderedDict() 28 | for k, v in state_dict.items(): 29 | name = k[7:] 30 | new_state_dict[name] = v 31 | model.load_state_dict(new_state_dict) 32 | 33 | def load_start_epoch(weights): 34 | checkpoint = torch.load(weights) 35 | epoch = checkpoint["epoch"] 36 | return epoch 37 | 38 | def load_optim(optimizer, weights): 39 | checkpoint = torch.load(weights) 40 | optimizer.load_state_dict(checkpoint['optimizer']) 41 | for p in optimizer.param_groups: lr = p['lr'] 42 | return lr 43 | 44 | 45 | 46 | def get_arch(opt): 47 | from model import HoloFormer 48 | 49 | arch = opt.arch 50 | 51 | if arch == 'HoloFormer': 52 | model_restoration = HoloFormer(embed_dim=32,win_size=8,token_projection='linear',token_mlp=opt.token_mlp, 53 | depths=[1, 2, 8, 8, 2, 8, 8, 2, 1] 54 | ) 55 | elif arch == 'HoloFormer_S': 56 | model_restoration = HoloFormer(embed_dim=32,win_size=8,token_projection='linear',token_mlp=opt.token_mlp, 57 | depths=[2, 2, 2, 2, 2, 2, 2, 2, 2] 58 | ) 59 | elif arch == 'HoloFormer_T': 60 | model_restoration = HoloFormer(embed_dim=16,win_size=8,token_projection='linear',token_mlp=opt.token_mlp, 61 | depths=[2, 2, 2, 2, 2, 2, 2, 2, 2] 62 | ) 63 | else: 64 | raise Exception("Arch error!") 65 | 66 | return model_restoration 67 | 68 | -------------------------------------------------------------------------------- /CR.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.nn.functional as fnn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from torchvision import models 8 | 9 | class Vgg19(torch.nn.Module): 10 | def __init__(self, requires_grad=False): 11 | super(Vgg19, self).__init__() 12 | vgg_pretrained_features = models.vgg19(pretrained=True).features 13 | 14 | self.slice1 = torch.nn.Sequential() 15 | self.slice2 = torch.nn.Sequential() 16 | self.slice3 = torch.nn.Sequential() 17 | self.slice4 = torch.nn.Sequential() 18 | self.slice5 = torch.nn.Sequential() 19 | for x in range(2): 20 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(2, 7): 22 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(7, 12): 24 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(12, 21): 26 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(21, 30): 28 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 29 | if not requires_grad: 30 | for param in self.parameters(): 31 | param.requires_grad = False 32 | 33 | def forward(self, X): 34 | h_relu1 = self.slice1(X) 35 | h_relu2 = self.slice2(h_relu1) 36 | h_relu3 = self.slice3(h_relu2) 37 | h_relu4 = self.slice4(h_relu3) 38 | h_relu5 = self.slice5(h_relu4) 39 | return [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 40 | 41 | class ContrastLoss(nn.Module): 42 | def __init__(self, ablation=False): 43 | 44 | super(ContrastLoss, self).__init__() 45 | self.vgg = Vgg19().cuda() 46 | self.l1 = nn.L1Loss() 47 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 48 | self.ab = ablation 49 | self.conv1x1 = nn.Conv2d(2, 3, kernel_size=1) 50 | self.conv1x1 = self.conv1x1.cuda() 51 | 52 | def forward(self, a, p, n): 53 | # channel to 3 54 | a = self.conv1x1(a) 55 | p = self.conv1x1(p) 56 | n = self.conv1x1(n) 57 | 58 | a_vgg, p_vgg, n_vgg = self.vgg(a), self.vgg(p), self.vgg(n) 59 | loss = 0 60 | 61 | d_ap, d_an = 0, 0 62 | for i in range(len(a_vgg)): 63 | d_ap = self.l1(a_vgg[i], p_vgg[i].detach()) 64 | if not self.ab: 65 | d_an = self.l1(a_vgg[i], n_vgg[i].detach()) 66 | contrastive = d_ap / (d_an + 1e-7) 67 | else: 68 | contrastive = d_ap 69 | 70 | loss += self.weights[i] * contrastive 71 | return loss 72 | -------------------------------------------------------------------------------- /warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | class Options(): 2 | """docstring for Options""" 3 | def __init__(self): 4 | pass 5 | 6 | def init(self, parser): 7 | 8 | # global settings 9 | parser.add_argument('--batch_size', type=int, default=8, help='batch size') 10 | parser.add_argument('--nepoch', type=int, default=250, help='training epochs') 11 | parser.add_argument('--train_workers', type=int, default=4, help='train_dataloader workers') 12 | parser.add_argument('--eval_workers', type=int, default=4, help='eval_dataloader workers') 13 | parser.add_argument('--dataset', type=str, default ='Microscope') 14 | 15 | parser.add_argument('--optimizer', type=str, default ='adamw', help='optimizer for training') 16 | parser.add_argument('--lr_initial', type=float, default=0.0002, help='initial learning rate') 17 | parser.add_argument('--step_lr', type=int, default=50, help='step size for lr decay') 18 | parser.add_argument('--weight_decay', type=float, default=0.02, help='weight decay') 19 | parser.add_argument('--gpu', type=str, default='0,1', help='GPUs') 20 | parser.add_argument('--arch', type=str, default ='HoloFormer', help='archtechture') 21 | parser.add_argument('--dd_in', type=int, default=2, help='dd_in') 22 | 23 | # args for saving 24 | parser.add_argument('--save_dir', type=str, default ='./logs/', help='save dir') 25 | parser.add_argument('--checkpoint', type=int, default=50, help='checkpoint') 26 | parser.add_argument('--env', type=str, default ='_', help='env') 27 | 28 | # args for Uformer 29 | parser.add_argument('--norm_layer', type=str, default ='nn.LayerNorm', help='normalize layer in transformer') 30 | parser.add_argument('--embed_dim', type=int, default=32, help='dim of emdeding features') 31 | parser.add_argument('--win_size', type=int, default=8, help='window size of self-attention') 32 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection') 33 | parser.add_argument('--token_mlp', type=str,default='mlp', help='ffn') 34 | 35 | # args for vit 36 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim') 37 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth') 38 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim') 39 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim') 40 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size') 41 | 42 | # args for training 43 | parser.add_argument('--resume', action='store_true',default=False) 44 | parser.add_argument('--data_dir', type=str, default ='/home/zx/Desktop/bzq_exp/lensless/dataset_patch', help='dir of train and val data') 45 | parser.add_argument('--warmup', action='store_true', default=False, help='warmup') 46 | parser.add_argument('--warmup_epochs', type=int,default=3, help='epochs for warmup') 47 | 48 | # contrast loss 49 | parser.add_argument('--is_ab', type=bool, default=False, help='whether ablation study of contrast loss') 50 | parser.add_argument('--w_loss_1st', type=float, default=1, help='weight of 1st loss') 51 | parser.add_argument('--w_loss_contrast', type=float, default=0.01, help='weight of contrast loss') 52 | 53 | return parser 54 | 55 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.utils.data as data 4 | import os 5 | import random 6 | import numpy as np 7 | from utils import Augment_RGB_torch 8 | augment = Augment_RGB_torch() 9 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] 10 | 11 | def get_training_data(data_dir): 12 | return DatasetFromFolder_realnimag_inverse_withoutnorm(data_dir, Train=True) 13 | 14 | 15 | def get_validation_data(data_dir): 16 | return DatasetFromFolder_realnimag_inverse_withoutnorm(data_dir, Train=False) 17 | 18 | 19 | class DatasetFromFolder_realnimag_inverse_withoutnorm(data.Dataset): 20 | def __init__(self, data_dir, Train = True): 21 | super(DatasetFromFolder_realnimag_inverse_withoutnorm, self).__init__() 22 | self.data_dir = data_dir 23 | self.Train = Train 24 | 25 | self.data_all_dir = os.listdir(data_dir) 26 | self.numpy_list = [] 27 | for i in self.data_all_dir: 28 | if i.endswith('.npy'): 29 | self.numpy_list.append(i) 30 | 31 | self.ob_real_list = [] 32 | for i in self.numpy_list: 33 | if 'ob_real' in i: 34 | self.ob_real_list.append(i) 35 | 36 | random.seed(2021) 37 | random.shuffle(self.ob_real_list) 38 | 39 | self.train_ob_real_list = self.ob_real_list[0:int(len(self.ob_real_list)*0.9)] 40 | self.val_ob_real_list = self.ob_real_list[int(len(self.ob_real_list)*0.9):] 41 | 42 | 43 | self.gt_dic = {} 44 | for i in self.ob_real_list: 45 | self.gt_dic[i] = (i.replace('ob_real', 'GT_real'), i.replace('ob_real', 'GT_imag')) 46 | 47 | self.ob_imag_dict = {} 48 | for i in self.ob_real_list: 49 | self.ob_imag_dict[i] = i.replace('ob_real', 'ob_imag') 50 | 51 | def __getitem__(self, index): 52 | if self.Train: 53 | ob_real_path = self.train_ob_real_list[index] 54 | ob_imag_path = self.ob_imag_dict[ob_real_path] 55 | gt_real_path = self.gt_dic[ob_real_path][0] 56 | gt_imag_path = self.gt_dic[ob_real_path][1] 57 | else: 58 | ob_real_path = self.val_ob_real_list[index] 59 | ob_imag_path = self.ob_imag_dict[ob_real_path] 60 | gt_real_path = self.gt_dic[ob_real_path][0] 61 | gt_imag_path = self.gt_dic[ob_real_path][1] 62 | 63 | ob_real = torch.from_numpy(np.load(os.path.join(self.data_dir, ob_real_path))) 64 | ob_imag = torch.from_numpy(np.load(os.path.join(self.data_dir, ob_imag_path))) 65 | gt_real = torch.from_numpy(np.load(os.path.join(self.data_dir, gt_real_path))) 66 | gt_imag = torch.from_numpy(np.load(os.path.join(self.data_dir, gt_imag_path))) 67 | 68 | if self.Train: 69 | apply_trans = transforms_aug[random.getrandbits(3)] 70 | ob_real = getattr(augment, apply_trans)(ob_real) 71 | ob_imag = getattr(augment, apply_trans)(ob_imag) 72 | gt_real = getattr(augment, apply_trans)(gt_real) 73 | gt_imag = getattr(augment, apply_trans)(gt_imag) 74 | return ob_real, ob_imag, gt_real, gt_imag 75 | else: 76 | return ob_real, ob_imag, gt_real, gt_imag 77 | 78 | def __len__(self): 79 | if self.Train: 80 | return len(self.train_ob_real_list) 81 | else: 82 | return len(self.val_ob_real_list) 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HoloFormer: Contrastive Regularization based Transformer for Holographic Image Reconstruction (IEEE TCI 2024) 2 | 3 | Ziqi Bai, [Xianming Liu](https://homepage.hit.edu.cn/xmliu), [Cheng Guo](https://scholar.google.com.hk/citations?hl=zh-CN&user=D_jtz9sAAAAJ&view_op=list_works), [Kui Jiang](https://homepage.hit.edu.cn/jiangkui?lang=zh), [Junjun Jiang](https://homepage.hit.edu.cn/jiangjunjun?lang=zh), [Xiangyang Ji](https://www.au.tsinghua.edu.cn/info/1111/1524.htm) 4 | 5 | --- 6 | 7 | Paper link: https://ieeexplore.ieee.org/document/10490149 8 | 9 | accepted by IEEE Transactions on Computational Imaging (IEEE TCI) 10 | 11 | --- 12 | 13 | *In this paper, we introduce HoloFormer, a novel hierarchical framework based on the self-attention mechanism for digital holographic reconstruction. To address the limitation of uniform representation in CNNs, we employ a window-based transformer block as its backbone. Compared to canonical transformer, our structure significantly reduces computational costs, making it particularly suitable for reconstructing high-resolution holograms. To enhance global feature learning capabilities of HoloFormer, a pyramid-like hierarchical structure within the encoder facilitates the learning of feature map representations across various scales. In the decoder, we adopt a dual-branch design to simultaneously reconstruct the real and imaginary parts of the complex amplitude without cross-talk between them. Additionally, supervised deep learning-based hologram reconstruction algorithms typically employ clear ground truth without twin-image for network training, limiting the exploration of contour features in degraded holograms with twin-image artifacts. To address this limitation, we integrate contrastive regularization during the training phase to maximize the utilization of mutual information. Extensive experimental results showcase the outstanding performance of HoloFormer compared to state-of-the-art techniques.* 14 | 15 | ![Image text](https://github.com/Bzq-Hit/HoloFormer/blob/main/fig/fig.PNG) 16 | 17 | --- 18 | 19 | ## Dependencies 20 | 21 | For dependencies, you can install them by 22 | 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | --- 28 | 29 | ## Data 30 | 31 | To train HoloFormer, you need to prepare your own hologram dataset, as we do not open-source the dataset used in this research paper. For detailed information about the dataset, please refer to Section Ⅳ.A of the research paper. 32 | 33 | Once you have your dataset ready, you can choose to use our provided dataloader by placing the path of your dataset in the appropriate location in `train.py`. This allows you to start training HoloFormer from scratch. Of course, you are also free to develop your own dataloader. 34 | 35 | --- 36 | 37 | ## Training 38 | 39 | To train HoloFormer, you can begin the training by: 40 | 41 | ``` 42 | python3 train.py --nepoch 500 --warmup --w_loss_contrast 5 --arch HoloFormer 43 | ``` 44 | 45 | To train HoloFormer_S, you can begin the training by: 46 | 47 | ``` 48 | python3 train.py --nepoch 500 --warmup --w_loss_contrast 5 --arch HoloFormer_S 49 | ``` 50 | 51 | To train HoloFormer_T, you can begin the training by: 52 | 53 | ``` 54 | python3 train.py --nepoch 500 --warmup --w_loss_contrast 5 --arch HoloFormer_T 55 | ``` 56 | 57 | --- 58 | 59 | ## Citation 60 | 61 | If you find HoloFormer useful in your research, please consider citing: 62 | 63 | ``` 64 | @ARTICLE{10490149, 65 | author={Bai, Ziqi and Liu, Xianming and Guo, Cheng and Jiang, Kui and Jiang, Junjun and Ji, Xiangyang}, 66 | journal={IEEE Transactions on Computational Imaging}, 67 | title={HoloFormer: Contrastive Regularization Based Transformer for Holographic Image Reconstruction}, 68 | year={2024}, 69 | volume={10}, 70 | number={}, 71 | pages={560-573}, 72 | keywords={Image reconstruction;Transformers;Imaging;Holography;Task analysis;Image restoration;Self-supervised learning;Deep learning;holography;transformer;contrastive learning;computational imaging;lensless microscopy}, 73 | doi={10.1109/TCI.2024.3384809}} 74 | 75 | ``` 76 | 77 | --- 78 | 79 | ## Contact 80 | 81 | If you have any questions or suggestions regarding this project or the research paper, please feel free to contact the author, Ziqi Bai, at 21B951029@stu.hit.edu.cn. 82 | 83 | --- 84 | 85 | 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """python3 train.py --nepoch 500 --warmup --w_loss_contrast 5 --arch HoloFormer""" 2 | """python3 train.py --nepoch 500 --warmup --w_loss_contrast 5 --arch HoloFormer_S""" 3 | """python3 train.py --nepoch 500 --warmup --w_loss_contrast 5 --arch HoloFormer_T""" 4 | 5 | import os 6 | import sys 7 | 8 | from torch.utils.tensorboard.writer import SummaryWriter 9 | from torchvision.utils import make_grid 10 | 11 | # add dir 12 | dir_name = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(os.path.join(dir_name,'..')) 14 | 15 | import argparse 16 | import options 17 | ######### parser ########### 18 | opt = options.Options().init(argparse.ArgumentParser(description='Lensless restore')).parse_args() 19 | print(opt) 20 | 21 | ######### Set GPUs ########### 22 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 23 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu 24 | import torch 25 | torch.backends.cudnn.benchmark = True 26 | 27 | import torch.nn as nn 28 | import torch.optim as optim 29 | from torch.utils.data import DataLoader 30 | 31 | import random 32 | import time 33 | import numpy as np 34 | from einops import rearrange, repeat 35 | import datetime 36 | from pdb import set_trace as stx 37 | 38 | import utils 39 | from losses import CharbonnierLoss 40 | from losses import MSELoss 41 | from CR import * 42 | 43 | from tqdm import tqdm 44 | from warmup_scheduler import GradualWarmupScheduler 45 | from torch.optim.lr_scheduler import StepLR 46 | from timm.utils import NativeScaler 47 | 48 | import dataloader 49 | 50 | 51 | ######### Logs dir ########### 52 | log_dir = os.path.join(opt.save_dir, 'lensless', opt.dataset, opt.arch+opt.env) 53 | if not os.path.exists(log_dir): 54 | os.makedirs(log_dir) 55 | print("Now time is : ",datetime.datetime.now().isoformat()) 56 | now = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M") 57 | logname = os.path.join(log_dir, now, now +'.txt') 58 | model_dir = os.path.join(log_dir, now, 'models') 59 | board_dir = os.path.join(log_dir, now, 'board') 60 | utils.mkdir(model_dir) 61 | utils.mkdir(board_dir) 62 | 63 | ########## Set Seeds ########### 64 | random.seed(1234) 65 | np.random.seed(1234) 66 | torch.manual_seed(1234) 67 | torch.cuda.manual_seed_all(1234) 68 | 69 | ######### Model ########### 70 | model_restoration = utils.get_arch(opt) 71 | 72 | with open(logname,'a') as f: 73 | f.write(str(opt)+'\n') 74 | f.write(str(model_restoration)+'\n') 75 | 76 | ######### Optimizer ########### 77 | start_epoch = 1 78 | if opt.optimizer.lower() == 'adam': 79 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 80 | elif opt.optimizer.lower() == 'adamw': 81 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 82 | else: 83 | raise Exception("Error optimizer...") 84 | 85 | ######### DataParallel ########### 86 | model_restoration = torch.nn.DataParallel (model_restoration) 87 | model_restoration.cuda() 88 | 89 | ######### Scheduler ########### 90 | if opt.warmup: 91 | print("Using warmup and cosine strategy!") 92 | warmup_epochs = opt.warmup_epochs 93 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6) 94 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 95 | #scheduler.step() 96 | else: 97 | step = opt.step_lr 98 | print("Using StepLR,step={}!".format(step)) 99 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5) 100 | #scheduler.step() 101 | 102 | ######### Resume ########### 103 | if opt.resume: 104 | path_chk_rest = opt.pretrain_weights 105 | print("Resume from "+path_chk_rest) 106 | utils.load_checkpoint(model_restoration,path_chk_rest) 107 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 108 | lr = utils.load_optim(optimizer, path_chk_rest) 109 | 110 | # for p in optimizer.param_groups: p['lr'] = lr 111 | # warmup = False 112 | # new_lr = lr 113 | # print('------------------------------------------------------------------------------') 114 | # print("==> Resuming Training with learning rate:",new_lr) 115 | # print('------------------------------------------------------------------------------') 116 | for i in range(1, start_epoch): 117 | scheduler.step() 118 | new_lr = scheduler.get_lr()[0] 119 | print('------------------------------------------------------------------------------') 120 | print("==> Resuming Training with learning rate:", new_lr) 121 | print('------------------------------------------------------------------------------') 122 | 123 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-start_epoch+1, eta_min=1e-6) 124 | 125 | ######### Loss ########### 126 | c_loss = CharbonnierLoss().cuda() 127 | 128 | # add Contrast regularization 129 | contrast_loss = ContrastLoss(ablation=opt.is_ab) 130 | 131 | 132 | ######### DataLoader ########### 133 | print('===> Loading datasets') 134 | 135 | # load real exp data real and img part with augmentation 136 | opt.data_dir = '' # add your data dir 137 | 138 | train_dataset = dataloader.get_training_data(opt.data_dir) 139 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, 140 | num_workers=opt.train_workers, pin_memory=False, drop_last=True) 141 | val_dataset = dataloader.get_validation_data(opt.data_dir) 142 | val_loader = DataLoader(dataset=val_dataset, batch_size=opt.batch_size, shuffle=False, 143 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False) 144 | 145 | len_trainset = train_dataset.__len__() 146 | len_valset = val_dataset.__len__() 147 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset) 148 | with open(logname, 'a') as f: 149 | f.write("Sizeof training set: " + str(len_trainset) + ", sizeof validation set: " + str(len_valset)) 150 | 151 | ######### Tensorboard ########### 152 | writer = SummaryWriter(log_dir=board_dir) 153 | 154 | ######### train ########### 155 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch)) 156 | best_psnr = 0 157 | best_epoch = 0 158 | best_iter = 0 159 | eval_now = len(train_loader)//4 160 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now)) 161 | 162 | loss_scaler = NativeScaler() 163 | 164 | save_index = 0 # used in tensorboard saving -- train_batch_loss 165 | save_index_1 = 0 # used in tensorboard saving -- val_batch_loss 166 | 167 | torch.cuda.empty_cache() 168 | for epoch in range(start_epoch, opt.nepoch + 1): 169 | epoch_start_time = time.time() 170 | epoch_loss = 0 171 | eval_count = 0 172 | psnr_train_rgb = [] 173 | ssim_train_rgb = [] 174 | 175 | for i, (ob_real, ob_imag, gt_real, gt_imag) in enumerate(tqdm(train_loader)): 176 | # zero_grad 177 | optimizer.zero_grad() 178 | 179 | gt_real = torch.unsqueeze(gt_real, dim=1) 180 | gt_imag = torch.unsqueeze(gt_imag, dim=1) 181 | target = torch.cat((gt_real, gt_imag), dim=1) 182 | target = target.cuda() 183 | 184 | ob_real = torch.unsqueeze(ob_real, dim=1) 185 | ob_imag = torch.unsqueeze(ob_imag, dim=1) 186 | ft = torch.cat((ob_real, ob_imag), dim=1) 187 | input_ = ft.cuda() 188 | 189 | with torch.cuda.amp.autocast(): 190 | restored = model_restoration(input_.to(torch.float32)) 191 | 192 | # loss 193 | c_loss_ = c_loss(restored, target) 194 | contrast_loss_ = contrast_loss(restored, target.to(torch.float32), input_.to(torch.float32)) # anchor, positive, negative 195 | 196 | loss = opt.w_loss_1st * c_loss_ + opt.w_loss_contrast * contrast_loss_ 197 | loss_scaler( 198 | loss, optimizer,parameters=model_restoration.parameters()) 199 | 200 | writer.add_scalar('loss/train_batch_total_loss', loss.item(), save_index+1) 201 | writer.add_scalar('loss/train_batch_c_loss', c_loss_.item(), save_index+1) 202 | writer.add_scalar('loss/train_batch_contrast_loss', contrast_loss_.item(), save_index+1) 203 | 204 | save_index += 1 205 | epoch_loss +=loss.item() 206 | 207 | # psnr&ssim cal in trainset 208 | with torch.no_grad(): 209 | 210 | restored_norm = (restored - restored.min()) / (restored.max() - restored.min()) 211 | target_norm = (target - target.min()) / (target.max() - target.min()) 212 | 213 | restored_norm = torch.clamp(restored_norm, 0, 1) 214 | target_norm = torch.clamp(target_norm,0,1) 215 | 216 | psnr_train_rgb.append(utils.batch_PSNR(restored_norm, target_norm, False).item()) 217 | ssim_train_rgb.append(utils.batch_SSIM(restored_norm, target_norm, False).item()) 218 | 219 | #### Evaluation #### 220 | if (i+1)%eval_now==0 and i>0: 221 | eval_count += 1 222 | with torch.no_grad(): 223 | model_restoration.eval() 224 | 225 | # psnr&ssim cal in valset 226 | psnr_val_rgb = [] 227 | ssim_val_rgb = [] 228 | for ii, (ob_real_val, ob_imag_val, gt_real_val, gt_imag_val) in enumerate((val_loader)): 229 | 230 | gt_real_val = torch.unsqueeze(gt_real_val, dim=1) 231 | gt_imag_val = torch.unsqueeze(gt_imag_val, dim=1) 232 | target = torch.cat((gt_real_val, gt_imag_val), dim=1) 233 | target = target.cuda() 234 | 235 | ob_real_val = torch.unsqueeze(ob_real_val, dim=1) 236 | ob_imag_val = torch.unsqueeze(ob_imag_val, dim=1) 237 | ft_val = torch.cat((ob_real_val, ob_imag_val), dim=1) 238 | input_ = ft_val.cuda() 239 | 240 | with torch.cuda.amp.autocast(): 241 | restored = model_restoration(input_.to(torch.float32)) 242 | 243 | # loss 244 | loss_val_batch_contrast_loss = contrast_loss(restored, target.to(torch.float32), input_.to(torch.float32)) # anchor, positive, negative 245 | loss_val_batch_c = c_loss(restored, target) 246 | loss_val = opt.w_loss_1st * loss_val_batch_c + opt.w_loss_contrast * loss_val_batch_contrast_loss 247 | 248 | writer.add_scalar('loss/val_batch_c_loss', loss_val_batch_c.item(), save_index_1+1) 249 | writer.add_scalar('loss/val_batch_contrast_loss', loss_val_batch_contrast_loss.item(), save_index_1+1) 250 | writer.add_scalar('loss/val_batch_total_loss', loss_val.item(), save_index_1+1) 251 | save_index_1 += 1 252 | 253 | restored_norm = (restored - restored.min()) / (restored.max() - restored.min()) 254 | target_norm = (target - target.min()) / (target.max() - target.min()) 255 | 256 | restored_norm = torch.clamp(restored_norm,0,1) 257 | target_norm = torch.clamp(target_norm,0,1) 258 | 259 | psnr_val_rgb.append(utils.batch_PSNR(restored_norm, target_norm, False).item()) 260 | ssim_val_rgb.append(utils.batch_SSIM(restored_norm, target_norm, False).item()) 261 | 262 | psnr_val_rgb = sum(psnr_val_rgb)/len_valset 263 | ssim_val_rgb = sum(ssim_val_rgb)/len_valset 264 | 265 | if psnr_val_rgb > best_psnr: 266 | best_psnr = psnr_val_rgb 267 | best_epoch = epoch 268 | best_iter = i 269 | torch.save({'epoch': epoch, 270 | 'state_dict': model_restoration.state_dict(), 271 | 'optimizer' : optimizer.state_dict() 272 | }, os.path.join(model_dir,"model_best.pth")) 273 | 274 | print("[Ep %d it %d\t PSNR: %.4f\t] ---- [best_Ep %d best_it %d Best_PSNR %.4f] " % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)) 275 | print("SSIM_val: %.4f\t" % (ssim_val_rgb)) 276 | with open(logname,'a') as f: 277 | f.write("[Ep %d it %d\t PSNR: %.4f\t] ---- [best_Ep %d best_it %d Best_PSNR %.4f] " \ 278 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n') 279 | f.write("SSIM_val: %.4f\t" % (ssim_val_rgb)+'\n') 280 | 281 | writer.add_scalar('metrics/val_psnr', psnr_val_rgb, epoch) 282 | writer.add_scalar('metrics/val_ssim', ssim_val_rgb, epoch) 283 | 284 | if eval_count == 4 and epoch%10==0: 285 | input_ft_grid = make_grid(input_[:, 0, :, :].unsqueeze(1), nrow=input_.shape[0]) 286 | target_real_grid = make_grid(target[:, 0, :, :].unsqueeze(1), nrow=target.shape[0]) 287 | restored_real_grid = make_grid(restored[:, 0, :, :].unsqueeze(1), nrow=restored.shape[0]) 288 | target_imag_grid = make_grid(target[:, 1, :, :].unsqueeze(1), nrow=target.shape[0]) 289 | restored_imag_grid = make_grid(restored[:, 1, :, :].unsqueeze(1), nrow=restored.shape[0]) 290 | 291 | input_ft_grid_norm = (input_ft_grid - input_ft_grid.min()) / (input_ft_grid.max() - input_ft_grid.min()) 292 | target_real_grid_norm = (target_real_grid - target_real_grid.min()) / (target_real_grid.max() - target_real_grid.min()) 293 | restored_real_grid_norm = (restored_real_grid - restored_real_grid.min()) / (restored_real_grid.max() - restored_real_grid.min()) 294 | target_imag_grid_norm = (target_imag_grid - target_imag_grid.min()) / (target_imag_grid.max() - target_imag_grid.min()) 295 | restored_imag_grid_norm = (restored_imag_grid - restored_imag_grid.min()) / (restored_imag_grid.max() - restored_imag_grid.min()) 296 | 297 | writer.add_image('image/input_ft_val', input_ft_grid_norm, epoch) 298 | writer.add_image('image/target_real_val', target_real_grid_norm, epoch) 299 | writer.add_image('image/restored_real_val', restored_real_grid_norm, epoch) 300 | writer.add_image('image/target_imag_val', target_imag_grid_norm, epoch) 301 | writer.add_image('image/restored_imag_val', restored_imag_grid_norm, epoch) 302 | 303 | model_restoration.train() 304 | torch.cuda.empty_cache() 305 | 306 | psnr_train_rgb = sum(psnr_train_rgb)/len_trainset 307 | ssim_train_rgb = sum(ssim_train_rgb)/len_trainset 308 | print("[Ep %d\t trainset_PSNR: %.4f\t] " % (epoch, psnr_train_rgb)) 309 | print("[Ep %d\t trainset_SSIM: %.4f\t] " % (epoch, ssim_train_rgb)) 310 | with open(logname,'a') as f: 311 | f.write("[Ep %d\t trainset_PSNR: %.4f\t] " % (epoch, psnr_train_rgb)+'\n') 312 | f.write("[Ep %d\t trainset_SSIM: %.4f\t] " % (epoch, ssim_train_rgb)+'\n') 313 | writer.add_scalar('metrics/train_psnr', psnr_train_rgb, epoch) 314 | writer.add_scalar('metrics/train_ssim', ssim_train_rgb, epoch) 315 | 316 | scheduler.step() 317 | 318 | writer.add_scalar('loss/train_epoch_loss', epoch_loss, epoch) 319 | 320 | if epoch % 10 == 0: 321 | input_ft_grid = make_grid(input_[:, 0, :, :].unsqueeze(1), nrow=input_.shape[0]) 322 | target_real_grid = make_grid(target[:, 0, :, :].unsqueeze(1), nrow=target.shape[0]) 323 | restored_real_grid = make_grid(restored[:, 0, :, :].unsqueeze(1), nrow=restored.shape[0]) 324 | target_imag_grid = make_grid(target[:, 1, :, :].unsqueeze(1), nrow=target.shape[0]) 325 | restored_imag_grid = make_grid(restored[:, 1, :, :].unsqueeze(1), nrow=restored.shape[0]) 326 | 327 | input_ft_grid_norm = (input_ft_grid - input_ft_grid.min()) / (input_ft_grid.max() - input_ft_grid.min()) 328 | target_real_grid_norm = (target_real_grid - target_real_grid.min()) / (target_real_grid.max() - target_real_grid.min()) 329 | restored_real_grid_norm = (restored_real_grid - restored_real_grid.min()) / (restored_real_grid.max() - restored_real_grid.min()) 330 | target_imag_grid_norm = (target_imag_grid - target_imag_grid.min()) / (target_imag_grid.max() - target_imag_grid.min()) 331 | restored_imag_grid_norm = (restored_imag_grid - restored_imag_grid.min()) / (restored_imag_grid.max() - restored_imag_grid.min()) 332 | 333 | writer.add_image('image/input_ft_train', input_ft_grid_norm, epoch) 334 | writer.add_image('image/target_real_train', target_real_grid_norm, epoch) 335 | writer.add_image('image/restored_real_train', restored_real_grid_norm, epoch) 336 | writer.add_image('image/target_imag_train', target_imag_grid_norm, epoch) 337 | writer.add_image('image/restored_imag_train', restored_imag_grid_norm, epoch) 338 | 339 | print("------------------------------------------------------------------") 340 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])) 341 | print("------------------------------------------------------------------") 342 | with open(logname,'a') as f: 343 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n') 344 | 345 | torch.save({'epoch': epoch, 346 | 'state_dict': model_restoration.state_dict(), 347 | 'optimizer' : optimizer.state_dict() 348 | }, os.path.join(model_dir,"model_latest.pth")) 349 | 350 | if epoch%opt.checkpoint == 0: 351 | torch.save({'epoch': epoch, 352 | 'state_dict': model_restoration.state_dict(), 353 | 'optimizer' : optimizer.state_dict() 354 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch))) 355 | print("Now time is : ",datetime.datetime.now().isoformat()) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from timm.models.layers import trunc_normal_, to_2tuple, DropPath 5 | import torch.utils.checkpoint as checkpoint 6 | import torch.nn.functional as F 7 | from einops import repeat, rearrange 8 | 9 | ######################################### 10 | ########### window operation############# 11 | def window_reverse(windows, win_size, H, W, dilation_rate=1): 12 | # B' ,Wh ,Ww ,C 13 | B = int(windows.shape[0] / (H * W / win_size / win_size)) 14 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) 15 | if dilation_rate !=1: 16 | x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww 17 | x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size) 18 | else: 19 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 20 | return x 21 | 22 | def window_partition(x, win_size, dilation_rate=1): 23 | B, H, W, C = x.shape 24 | if dilation_rate !=1: 25 | x = x.permute(0,3,1,2) # B, C, H, W 26 | assert type(dilation_rate) is int, 'dilation_rate should be a int' 27 | x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww 28 | windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww 29 | windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C 30 | else: 31 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) 32 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C 33 | return windows 34 | 35 | ######## Embedding for q,k,v ######## 36 | class LinearProjection(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | self.heads = heads 41 | self.to_q = nn.Linear(dim, inner_dim, bias = bias) 42 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias) 43 | self.dim = dim 44 | self.inner_dim = inner_dim 45 | 46 | def forward(self, x, attn_kv=None): 47 | B_, N, C = x.shape 48 | if attn_kv is not None: 49 | attn_kv = attn_kv.unsqueeze(0).repeat(B_,1,1) 50 | else: 51 | attn_kv = x 52 | N_kv = attn_kv.size(1) 53 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 54 | kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 55 | q = q[0] 56 | k, v = kv[0], kv[1] 57 | return q,k,v 58 | 59 | class SepConv2d(torch.nn.Module): 60 | def __init__(self, 61 | in_channels, 62 | out_channels, 63 | kernel_size, 64 | stride=1, 65 | padding=0, 66 | dilation=1,act_layer=nn.LeakyReLU): 67 | super(SepConv2d, self).__init__() 68 | self.depthwise = torch.nn.Conv2d(in_channels, 69 | in_channels, 70 | kernel_size=kernel_size, 71 | stride=stride, 72 | padding=padding, 73 | dilation=dilation, 74 | groups=in_channels) 75 | self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | self.act_layer = act_layer() if act_layer is not None else nn.Identity() 77 | self.in_channels = in_channels 78 | self.out_channels = out_channels 79 | self.kernel_size = kernel_size 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | x = self.depthwise(x) 84 | x = self.act_layer(x) 85 | x = self.pointwise(x) 86 | return x 87 | 88 | class ConvProjection(nn.Module): 89 | def __init__(self, dim, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0., 90 | last_stage=False,bias=True): 91 | 92 | super().__init__() 93 | 94 | inner_dim = dim_head * heads 95 | self.heads = heads 96 | pad = (kernel_size - q_stride)//2 97 | self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias) 98 | self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias) 99 | self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias) 100 | 101 | def forward(self, x, attn_kv=None): 102 | b, n, c, h = *x.shape, self.heads 103 | l = int(math.sqrt(n)) 104 | w = int(math.sqrt(n)) 105 | 106 | attn_kv = x if attn_kv is None else attn_kv 107 | x = rearrange(x, 'b (l w) c -> b c l w', l=l, w=w) 108 | attn_kv = rearrange(attn_kv, 'b (l w) c -> b c l w', l=l, w=w) 109 | # print(attn_kv) 110 | q = self.to_q(x) 111 | q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) 112 | 113 | k = self.to_k(attn_kv) 114 | v = self.to_v(attn_kv) 115 | k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h) 116 | v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h) 117 | return q,k,v 118 | 119 | ######################################### 120 | ########### window-based self-attention ############# 121 | class WindowAttention(nn.Module): 122 | def __init__(self, dim, win_size,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 123 | 124 | super().__init__() 125 | self.dim = dim 126 | self.win_size = win_size # Wh, Ww 127 | self.num_heads = num_heads 128 | head_dim = dim // num_heads 129 | self.scale = qk_scale or head_dim ** -0.5 130 | 131 | # define a parameter table of relative position bias 132 | self.relative_position_bias_table = nn.Parameter( 133 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 134 | 135 | # get pair-wise relative position index for each token inside the window 136 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 137 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 138 | # coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 139 | 140 | coords = torch.meshgrid(coords_h, coords_w, indexing='xy') 141 | coords = torch.stack(coords) # changed 23.5.16 142 | 143 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 144 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 145 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 146 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 147 | relative_coords[:, :, 1] += self.win_size[1] - 1 148 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 149 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 150 | self.register_buffer("relative_position_index", relative_position_index) 151 | trunc_normal_(self.relative_position_bias_table, std=.02) 152 | 153 | if token_projection =='conv': 154 | self.qkv = ConvProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 155 | elif token_projection =='linear': 156 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 157 | else: 158 | raise Exception("Projection error!") 159 | 160 | self.token_projection = token_projection 161 | self.attn_drop = nn.Dropout(attn_drop) 162 | self.proj = nn.Linear(dim, dim) 163 | self.proj_drop = nn.Dropout(proj_drop) 164 | 165 | self.softmax = nn.Softmax(dim=-1) 166 | 167 | def forward(self, x, attn_kv=None, mask=None): 168 | B_, N, C = x.shape 169 | q, k, v = self.qkv(x,attn_kv) 170 | q = q * self.scale 171 | attn = (q @ k.transpose(-2, -1)) 172 | 173 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 174 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 175 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 176 | ratio = attn.size(-1)//relative_position_bias.size(-1) 177 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d = ratio) 178 | 179 | attn = attn + relative_position_bias.unsqueeze(0) 180 | 181 | if mask is not None: 182 | nW = mask.shape[0] 183 | mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio) 184 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N*ratio) + mask.unsqueeze(1).unsqueeze(0) 185 | attn = attn.view(-1, self.num_heads, N, N*ratio) 186 | attn = self.softmax(attn) 187 | else: 188 | attn = self.softmax(attn) 189 | 190 | attn = self.attn_drop(attn) 191 | 192 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 193 | x = self.proj(x) 194 | x = self.proj_drop(x) 195 | return x 196 | 197 | def extra_repr(self) -> str: 198 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 199 | 200 | # Input Projection 201 | class InputProj(nn.Module): 202 | def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None,act_layer=nn.LeakyReLU): 203 | super().__init__() 204 | self.proj = nn.Sequential( 205 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 206 | act_layer(inplace=True) 207 | ) 208 | if norm_layer is not None: 209 | self.norm = norm_layer(out_channel) 210 | else: 211 | self.norm = None 212 | self.in_channel = in_channel 213 | self.out_channel = out_channel 214 | 215 | def forward(self, x): 216 | B, C, H, W = x.shape 217 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C 218 | if self.norm is not None: 219 | x = self.norm(x) 220 | return x 221 | 222 | # Output Projection 223 | class OutputProj(nn.Module): 224 | def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None,act_layer=None): 225 | super().__init__() 226 | self.proj = nn.Sequential( 227 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 228 | ) 229 | if act_layer is not None: 230 | self.proj.add_module(act_layer(inplace=True)) 231 | if norm_layer is not None: 232 | self.norm = norm_layer(out_channel) 233 | else: 234 | self.norm = None 235 | self.in_channel = in_channel 236 | self.out_channel = out_channel 237 | 238 | def forward(self, x): 239 | B, L, C = x.shape 240 | H = int(math.sqrt(L)) 241 | W = int(math.sqrt(L)) 242 | x = x.transpose(1, 2).view(B, C, H, W) 243 | x = self.proj(x) 244 | if self.norm is not None: 245 | x = self.norm(x) 246 | return x 247 | 248 | # Downsample Block 249 | class Downsample(nn.Module): 250 | def __init__(self, in_channel, out_channel): 251 | super(Downsample, self).__init__() 252 | self.conv = nn.Sequential( 253 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1), 254 | ) 255 | self.in_channel = in_channel 256 | self.out_channel = out_channel 257 | 258 | def forward(self, x): 259 | B, L, C = x.shape 260 | # import pdb;pdb.set_trace() 261 | H = int(math.sqrt(L)) 262 | W = int(math.sqrt(L)) 263 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 264 | out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 265 | return out 266 | 267 | # Upsample Block 268 | class Upsample(nn.Module): 269 | def __init__(self, in_channel, out_channel): 270 | super(Upsample, self).__init__() 271 | self.deconv = nn.Sequential( 272 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2), 273 | ) 274 | self.in_channel = in_channel 275 | self.out_channel = out_channel 276 | 277 | def forward(self, x): 278 | B, L, C = x.shape 279 | H = int(math.sqrt(L)) 280 | W = int(math.sqrt(L)) 281 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 282 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 283 | return out 284 | 285 | ######################################### 286 | ########### feed-forward network ############# 287 | class Mlp(nn.Module): 288 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.LeakyReLU, drop=0.): 289 | super().__init__() 290 | out_features = out_features or in_features 291 | hidden_features = hidden_features or in_features 292 | self.fc1 = nn.Linear(in_features, hidden_features) 293 | self.act = act_layer() 294 | self.fc2 = nn.Linear(hidden_features, out_features) 295 | self.drop = nn.Dropout(drop) 296 | self.in_features = in_features 297 | self.hidden_features = hidden_features 298 | self.out_features = out_features 299 | 300 | def forward(self, x): 301 | x = self.fc1(x) 302 | x = self.act(x) 303 | x = self.drop(x) 304 | x = self.fc2(x) 305 | x = self.drop(x) 306 | return x 307 | 308 | ######################################### 309 | ########### WinTransformer ############ 310 | class WinTransformerBlock(nn.Module): 311 | def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0, 312 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 313 | act_layer=nn.LeakyReLU, norm_layer=nn.LayerNorm,token_projection='linear',token_mlp='mlp' 314 | ): 315 | super().__init__() 316 | self.dim = dim 317 | self.input_resolution = input_resolution 318 | self.num_heads = num_heads 319 | self.win_size = win_size 320 | self.shift_size = shift_size 321 | self.mlp_ratio = mlp_ratio 322 | self.token_mlp = token_mlp 323 | if min(self.input_resolution) <= self.win_size: 324 | self.shift_size = 0 325 | self.win_size = min(self.input_resolution) 326 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" 327 | 328 | self.norm1 = norm_layer(dim) 329 | self.attn = WindowAttention( 330 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, 331 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 332 | token_projection=token_projection) 333 | 334 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 335 | self.norm2 = norm_layer(dim) 336 | mlp_hidden_dim = int(dim * mlp_ratio) 337 | if token_mlp in ['ffn','mlp']: 338 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) 339 | else: 340 | raise Exception("FFN error!") 341 | 342 | def with_pos_embed(self, tensor, pos): 343 | return tensor if pos is None else tensor + pos 344 | 345 | def extra_repr(self) -> str: 346 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 347 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 348 | 349 | def forward(self, x, mask=None): 350 | B, L, C = x.shape 351 | H = int(math.sqrt(L)) 352 | W = int(math.sqrt(L)) 353 | 354 | ## input mask 355 | if mask != None: 356 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1) 357 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 358 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 359 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size 360 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 361 | else: 362 | attn_mask = None 363 | 364 | ## shift mask 365 | if self.shift_size > 0: 366 | # calculate attention mask for SW-MSA 367 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x) 368 | h_slices = (slice(0, -self.win_size), 369 | slice(-self.win_size, -self.shift_size), 370 | slice(-self.shift_size, None)) 371 | w_slices = (slice(0, -self.win_size), 372 | slice(-self.win_size, -self.shift_size), 373 | slice(-self.shift_size, None)) 374 | cnt = 0 375 | for h in h_slices: 376 | for w in w_slices: 377 | shift_mask[:, h, w, :] = cnt 378 | cnt += 1 379 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 380 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 381 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size 382 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0)) 383 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask 384 | 385 | shortcut = x 386 | x = self.norm1(x) 387 | x = x.view(B, H, W, C) 388 | 389 | # cyclic shift 390 | if self.shift_size > 0: 391 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 392 | else: 393 | shifted_x = x 394 | 395 | # partition windows 396 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C 397 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C 398 | 399 | wmsa_in = x_windows 400 | 401 | # W-MSA/SW-MSA 402 | attn_windows = self.attn(wmsa_in, mask=attn_mask) # nW*B, win_size*win_size, C 403 | 404 | # merge windows 405 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) 406 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C 407 | 408 | # reverse cyclic shift 409 | if self.shift_size > 0: 410 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 411 | else: 412 | x = shifted_x 413 | x = x.view(B, H * W, C) 414 | 415 | # FFN 416 | x = shortcut + self.drop_path(x) 417 | x = x + self.drop_path(self.mlp(self.norm2(x))) 418 | del attn_mask 419 | return x 420 | 421 | ######################################### 422 | ########### Basic layer of HoloFormer ################ 423 | class BasicHoloformerLayer(nn.Module): 424 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size, 425 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 426 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 427 | token_projection='linear',token_mlp='mlp', shift_flag=True 428 | ): 429 | 430 | super().__init__() 431 | self.dim = dim 432 | self.input_resolution = input_resolution 433 | self.depth = depth 434 | self.use_checkpoint = use_checkpoint 435 | # build blocks 436 | if shift_flag: 437 | self.blocks = nn.ModuleList([ 438 | WinTransformerBlock(dim=dim, input_resolution=input_resolution, 439 | num_heads=num_heads, win_size=win_size, 440 | shift_size=0 if (i % 2 == 0) else win_size // 2, 441 | mlp_ratio=mlp_ratio, 442 | qkv_bias=qkv_bias, qk_scale=qk_scale, 443 | drop=drop, attn_drop=attn_drop, 444 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 445 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp 446 | ) 447 | for i in range(depth)]) 448 | else: 449 | self.blocks = nn.ModuleList([ 450 | WinTransformerBlock(dim=dim, input_resolution=input_resolution, 451 | num_heads=num_heads, win_size=win_size, 452 | shift_size=0, 453 | mlp_ratio=mlp_ratio, 454 | qkv_bias=qkv_bias, qk_scale=qk_scale, 455 | drop=drop, attn_drop=attn_drop, 456 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 457 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp 458 | ) 459 | for i in range(depth)]) 460 | 461 | def extra_repr(self) -> str: 462 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 463 | 464 | def forward(self, x, mask=None): 465 | for blk in self.blocks: 466 | if self.use_checkpoint: 467 | x = checkpoint.checkpoint(blk, x) 468 | else: 469 | x = blk(x,mask) 470 | return x 471 | 472 | class HoloFormer(nn.Module): 473 | def __init__(self, img_size=256, in_chans=1, dd_in=2, 474 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], 475 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, 476 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 477 | norm_layer=nn.LayerNorm, patch_norm=True, 478 | use_checkpoint=False, token_projection='linear', token_mlp='mlp', 479 | dowsample=Downsample, upsample=Upsample, shift_flag=True, 480 | **kwargs): 481 | super().__init__() 482 | 483 | self.num_enc_layers = len(depths)//2 484 | self.num_dec_layers = len(depths)//2 485 | self.embed_dim = embed_dim 486 | self.patch_norm = patch_norm 487 | self.mlp_ratio = mlp_ratio 488 | self.token_projection = token_projection 489 | self.mlp = token_mlp 490 | self.win_size =win_size 491 | self.reso = img_size 492 | self.pos_drop_1 = nn.Dropout(p=drop_rate) 493 | self.dd_in = dd_in 494 | self.in_chans = in_chans 495 | 496 | # stochastic depth 497 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))] 498 | conv_dpr = [drop_path_rate]*depths[4] 499 | dec_dpr = enc_dpr[::-1] 500 | 501 | # Input/Output 502 | self.input_proj_1 = InputProj(in_channel=dd_in, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU) 503 | self.output_proj_1 = OutputProj(in_channel=2*embed_dim, out_channel=in_chans, kernel_size=3, stride=1) 504 | self.output_proj_2 = OutputProj(in_channel=2*embed_dim, out_channel=in_chans, kernel_size=3, stride=1) 505 | 506 | # Encoder 507 | self.encoderlayer_0 = BasicHoloformerLayer(dim=embed_dim, 508 | output_dim=embed_dim, 509 | input_resolution=(img_size, 510 | img_size), 511 | depth=depths[0], 512 | num_heads=num_heads[0], 513 | win_size=win_size, 514 | mlp_ratio=self.mlp_ratio, 515 | qkv_bias=qkv_bias, qk_scale=qk_scale, 516 | drop=drop_rate, attn_drop=attn_drop_rate, 517 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])], 518 | norm_layer=norm_layer, 519 | use_checkpoint=use_checkpoint, 520 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 521 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2) 522 | self.encoderlayer_1 = BasicHoloformerLayer(dim=embed_dim*2, 523 | output_dim=embed_dim*2, 524 | input_resolution=(img_size // 2, 525 | img_size // 2), 526 | depth=depths[1], 527 | num_heads=num_heads[1], 528 | win_size=win_size, 529 | mlp_ratio=self.mlp_ratio, 530 | qkv_bias=qkv_bias, qk_scale=qk_scale, 531 | drop=drop_rate, attn_drop=attn_drop_rate, 532 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 533 | norm_layer=norm_layer, 534 | use_checkpoint=use_checkpoint, 535 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 536 | self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4) 537 | self.encoderlayer_2 = BasicHoloformerLayer(dim=embed_dim*4, 538 | output_dim=embed_dim*4, 539 | input_resolution=(img_size // (2 ** 2), 540 | img_size // (2 ** 2)), 541 | depth=depths[2], 542 | num_heads=num_heads[2], 543 | win_size=win_size, 544 | mlp_ratio=self.mlp_ratio, 545 | qkv_bias=qkv_bias, qk_scale=qk_scale, 546 | drop=drop_rate, attn_drop=attn_drop_rate, 547 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 548 | norm_layer=norm_layer, 549 | use_checkpoint=use_checkpoint, 550 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 551 | self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8) 552 | self.encoderlayer_3 = BasicHoloformerLayer(dim=embed_dim*8, 553 | output_dim=embed_dim*8, 554 | input_resolution=(img_size // (2 ** 3), 555 | img_size // (2 ** 3)), 556 | depth=depths[3], 557 | num_heads=num_heads[3], 558 | win_size=win_size, 559 | mlp_ratio=self.mlp_ratio, 560 | qkv_bias=qkv_bias, qk_scale=qk_scale, 561 | drop=drop_rate, attn_drop=attn_drop_rate, 562 | drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])], 563 | norm_layer=norm_layer, 564 | use_checkpoint=use_checkpoint, 565 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 566 | self.dowsample_3 = dowsample(embed_dim*8, embed_dim*16) 567 | 568 | # Bottleneck 569 | self.conv = BasicHoloformerLayer(dim=embed_dim*16, 570 | output_dim=embed_dim*16, 571 | input_resolution=(img_size // (2 ** 4), 572 | img_size // (2 ** 4)), 573 | depth=depths[4], 574 | num_heads=num_heads[4], 575 | win_size=win_size, 576 | mlp_ratio=self.mlp_ratio, 577 | qkv_bias=qkv_bias, qk_scale=qk_scale, 578 | drop=drop_rate, attn_drop=attn_drop_rate, 579 | drop_path=conv_dpr, 580 | norm_layer=norm_layer, 581 | use_checkpoint=use_checkpoint, 582 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag) 583 | 584 | # Decoder_real 585 | self.upsample_0 = upsample(embed_dim*16, embed_dim*8) 586 | self.decoderlayer_0 = BasicHoloformerLayer(dim=embed_dim*16, 587 | output_dim=embed_dim*16, 588 | input_resolution=(img_size // (2 ** 3), 589 | img_size // (2 ** 3)), 590 | depth=depths[5], 591 | num_heads=num_heads[5], 592 | win_size=win_size, 593 | mlp_ratio=self.mlp_ratio, 594 | qkv_bias=qkv_bias, qk_scale=qk_scale, 595 | drop=drop_rate, attn_drop=attn_drop_rate, 596 | drop_path=dec_dpr[:depths[5]], 597 | norm_layer=norm_layer, 598 | use_checkpoint=use_checkpoint, 599 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 600 | ) 601 | self.upsample_1 = upsample(embed_dim*16, embed_dim*4) 602 | self.decoderlayer_1 = BasicHoloformerLayer(dim=embed_dim*8, 603 | output_dim=embed_dim*8, 604 | input_resolution=(img_size // (2 ** 2), 605 | img_size // (2 ** 2)), 606 | depth=depths[6], 607 | num_heads=num_heads[6], 608 | win_size=win_size, 609 | mlp_ratio=self.mlp_ratio, 610 | qkv_bias=qkv_bias, qk_scale=qk_scale, 611 | drop=drop_rate, attn_drop=attn_drop_rate, 612 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], 613 | norm_layer=norm_layer, 614 | use_checkpoint=use_checkpoint, 615 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 616 | ) 617 | self.upsample_2 = upsample(embed_dim*8, embed_dim*2) 618 | self.decoderlayer_2 = BasicHoloformerLayer(dim=embed_dim*4, 619 | output_dim=embed_dim*4, 620 | input_resolution=(img_size // 2, 621 | img_size // 2), 622 | depth=depths[7], 623 | num_heads=num_heads[7], 624 | win_size=win_size, 625 | mlp_ratio=self.mlp_ratio, 626 | qkv_bias=qkv_bias, qk_scale=qk_scale, 627 | drop=drop_rate, attn_drop=attn_drop_rate, 628 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], 629 | norm_layer=norm_layer, 630 | use_checkpoint=use_checkpoint, 631 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 632 | ) 633 | self.upsample_3 = upsample(embed_dim*4, embed_dim) 634 | self.decoderlayer_3 = BasicHoloformerLayer(dim=embed_dim*2, 635 | output_dim=embed_dim*2, 636 | input_resolution=(img_size, 637 | img_size), 638 | depth=depths[8], 639 | num_heads=num_heads[8], 640 | win_size=win_size, 641 | mlp_ratio=self.mlp_ratio, 642 | qkv_bias=qkv_bias, qk_scale=qk_scale, 643 | drop=drop_rate, attn_drop=attn_drop_rate, 644 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], 645 | norm_layer=norm_layer, 646 | use_checkpoint=use_checkpoint, 647 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 648 | ) 649 | 650 | # Decoder_imag 651 | self.upsample_0_imag = upsample(embed_dim*16, embed_dim*8) 652 | self.decoderlayer_0_imag = BasicHoloformerLayer(dim=embed_dim*16, 653 | output_dim=embed_dim*16, 654 | input_resolution=(img_size // (2 ** 3), 655 | img_size // (2 ** 3)), 656 | depth=depths[5], 657 | num_heads=num_heads[5], 658 | win_size=win_size, 659 | mlp_ratio=self.mlp_ratio, 660 | qkv_bias=qkv_bias, qk_scale=qk_scale, 661 | drop=drop_rate, attn_drop=attn_drop_rate, 662 | drop_path=dec_dpr[:depths[5]], 663 | norm_layer=norm_layer, 664 | use_checkpoint=use_checkpoint, 665 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 666 | ) 667 | self.upsample_1_imag = upsample(embed_dim*16, embed_dim*4) 668 | self.decoderlayer_1_imag = BasicHoloformerLayer(dim=embed_dim*8, 669 | output_dim=embed_dim*8, 670 | input_resolution=(img_size // (2 ** 2), 671 | img_size // (2 ** 2)), 672 | depth=depths[6], 673 | num_heads=num_heads[6], 674 | win_size=win_size, 675 | mlp_ratio=self.mlp_ratio, 676 | qkv_bias=qkv_bias, qk_scale=qk_scale, 677 | drop=drop_rate, attn_drop=attn_drop_rate, 678 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], 679 | norm_layer=norm_layer, 680 | use_checkpoint=use_checkpoint, 681 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 682 | ) 683 | self.upsample_2_imag = upsample(embed_dim*8, embed_dim*2) 684 | self.decoderlayer_2_imag = BasicHoloformerLayer(dim=embed_dim*4, 685 | output_dim=embed_dim*4, 686 | input_resolution=(img_size // 2, 687 | img_size // 2), 688 | depth=depths[7], 689 | num_heads=num_heads[7], 690 | win_size=win_size, 691 | mlp_ratio=self.mlp_ratio, 692 | qkv_bias=qkv_bias, qk_scale=qk_scale, 693 | drop=drop_rate, attn_drop=attn_drop_rate, 694 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], 695 | norm_layer=norm_layer, 696 | use_checkpoint=use_checkpoint, 697 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 698 | ) 699 | self.upsample_3_imag = upsample(embed_dim*4, embed_dim) 700 | self.decoderlayer_3_imag = BasicHoloformerLayer(dim=embed_dim*2, 701 | output_dim=embed_dim*2, 702 | input_resolution=(img_size, 703 | img_size), 704 | depth=depths[8], 705 | num_heads=num_heads[8], 706 | win_size=win_size, 707 | mlp_ratio=self.mlp_ratio, 708 | qkv_bias=qkv_bias, qk_scale=qk_scale, 709 | drop=drop_rate, attn_drop=attn_drop_rate, 710 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], 711 | norm_layer=norm_layer, 712 | use_checkpoint=use_checkpoint, 713 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag 714 | ) 715 | 716 | self.apply(self._init_weights) 717 | 718 | def _init_weights(self, m): 719 | if isinstance(m, nn.Linear): 720 | trunc_normal_(m.weight, std=.02) 721 | if isinstance(m, nn.Linear) and m.bias is not None: 722 | nn.init.constant_(m.bias, 0) 723 | elif isinstance(m, nn.LayerNorm): 724 | nn.init.constant_(m.bias, 0) 725 | nn.init.constant_(m.weight, 1.0) 726 | 727 | @torch.jit.ignore 728 | def no_weight_decay(self): 729 | return {'absolute_pos_embed'} 730 | 731 | @torch.jit.ignore 732 | def no_weight_decay_keywords(self): 733 | return {'relative_position_bias_table'} 734 | 735 | def extra_repr(self) -> str: 736 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}" 737 | 738 | def forward(self, x, mask=None): 739 | # divide real and imag 740 | x_real = x[:,0,:,:].unsqueeze(1) # B 1 H W 741 | x_imag = x[:,1,:,:].unsqueeze(1) # B 1 H W 742 | 743 | # Input Projection of real and imag 744 | y = self.input_proj_1(x) # B H*W C ; complex B H*W C 2 745 | y = self.pos_drop_1(y) 746 | 747 | #Encoder 748 | conv0 = self.encoderlayer_0(y,mask=mask) 749 | pool0 = self.dowsample_0(conv0) 750 | conv1 = self.encoderlayer_1(pool0,mask=mask) 751 | pool1 = self.dowsample_1(conv1) 752 | conv2 = self.encoderlayer_2(pool1,mask=mask) 753 | pool2 = self.dowsample_2(conv2) 754 | conv3 = self.encoderlayer_3(pool2,mask=mask) 755 | pool3 = self.dowsample_3(conv3) 756 | 757 | # Bottleneck 758 | conv4 = self.conv(pool3, mask=mask) 759 | 760 | # Decoder_real 761 | up0 = self.upsample_0(conv4) 762 | deconv0 = torch.cat([up0,conv3],-1) 763 | deconv0 = self.decoderlayer_0(deconv0,mask=mask) 764 | 765 | up1 = self.upsample_1(deconv0) 766 | deconv1 = torch.cat([up1,conv2],-1) 767 | deconv1 = self.decoderlayer_1(deconv1,mask=mask) 768 | 769 | up2 = self.upsample_2(deconv1) 770 | deconv2 = torch.cat([up2,conv1],-1) 771 | deconv2 = self.decoderlayer_2(deconv2,mask=mask) 772 | 773 | up3 = self.upsample_3(deconv2) 774 | deconv3 = torch.cat([up3,conv0],-1) 775 | deconv3 = self.decoderlayer_3(deconv3,mask=mask) 776 | 777 | # Decoder_imag 778 | up0_imag = self.upsample_0_imag(conv4) 779 | deconv0_imag = torch.cat([up0_imag,conv3],-1) 780 | deconv0_imag = self.decoderlayer_0_imag(deconv0_imag,mask=mask) 781 | 782 | up1_imag = self.upsample_1_imag(deconv0_imag) 783 | deconv1_imag = torch.cat([up1_imag,conv2],-1) 784 | deconv1_imag = self.decoderlayer_1_imag(deconv1_imag,mask=mask) 785 | 786 | up2_imag = self.upsample_2_imag(deconv1_imag) 787 | deconv2_imag = torch.cat([up2_imag,conv1],-1) 788 | deconv2_imag = self.decoderlayer_2_imag(deconv2_imag,mask=mask) 789 | 790 | up3_imag = self.upsample_3_imag(deconv2_imag) 791 | deconv3_imag = torch.cat([up3_imag,conv0],-1) 792 | deconv3_imag = self.decoderlayer_3_imag(deconv3_imag,mask=mask) 793 | 794 | # Output Projection 795 | y_real_out = self.output_proj_1(deconv3) 796 | y_real_out = y_real_out + x_real 797 | 798 | y_imag_out = self.output_proj_2(deconv3_imag) 799 | y_imag_out = y_imag_out + x_imag 800 | 801 | return torch.cat([y_real_out,y_imag_out],1) 802 | --------------------------------------------------------------------------------