├── doc ├── res.jpg ├── details.jpg └── pipeline.jpg ├── warmup_scheduler ├── __init__.py ├── run.py └── scheduler.py ├── evaluation └── measure_shadow.m ├── utils ├── __init__.py ├── loader.py ├── dir_utils.py ├── dataset_utils.py ├── model_utils.py ├── bundle_submissions.py ├── antialias.py └── image_utils.py ├── requirements.txt ├── LICENSE ├── losses.py ├── options.py ├── dataset.py ├── README.md ├── test.py ├── train.py └── model.py /doc/res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoLanqing/ShadowFormer/HEAD/doc/res.jpg -------------------------------------------------------------------------------- /doc/details.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoLanqing/ShadowFormer/HEAD/doc/details.jpg -------------------------------------------------------------------------------- /doc/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoLanqing/ShadowFormer/HEAD/doc/pipeline.jpg -------------------------------------------------------------------------------- /warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /evaluation/measure_shadow.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuoLanqing/ShadowFormer/HEAD/evaluation/measure_shadow.m -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .dataset_utils import * 3 | from .image_utils import * 4 | from .model_utils import * 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | matplotlib 4 | scikit-image 5 | opencv-python 6 | yacs 7 | joblib 8 | natsort 9 | h5py 10 | tqdm 11 | einops 12 | linformer 13 | timm 14 | ptflops 15 | dataclasses -------------------------------------------------------------------------------- /utils/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dataset import DataLoaderTrain, DataLoaderVal 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options, None) 7 | 8 | def get_validation_data(rgb_dir): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, None) 11 | -------------------------------------------------------------------------------- /utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 GuoLanqing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def tv_loss(x, beta = 0.5, reg_coeff = 5): 7 | '''Calculates TV loss for an image `x`. 8 | 9 | Args: 10 | x: image, torch.Variable of torch.Tensor 11 | beta: See https://arxiv.org/abs/1412.0035 (fig. 2) to see effect of `beta` 12 | ''' 13 | dh = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2) 14 | dw = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2) 15 | a,b,c,d=x.shape 16 | return reg_coeff*(torch.sum(torch.pow(dh[:, :, :-1] + dw[:, :, :, :-1], beta))/(a*b*c*d)) 17 | 18 | 19 | class TVLoss(nn.Module): 20 | def __init__(self, tv_loss_weight=1): 21 | super(TVLoss, self).__init__() 22 | self.tv_loss_weight = tv_loss_weight 23 | 24 | def forward(self, x): 25 | batch_size = x.size()[0] 26 | h_x = x.size()[2] 27 | w_x = x.size()[3] 28 | count_h = self.tensor_size(x[:, :, 1:, :]) 29 | count_w = self.tensor_size(x[:, :, :, 1:]) 30 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 31 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 32 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 33 | 34 | @staticmethod 35 | def tensor_size(t): 36 | return t.size()[1] * t.size()[2] * t.size()[3] 37 | 38 | 39 | 40 | class CharbonnierLoss(nn.Module): 41 | """Charbonnier Loss (L1)""" 42 | 43 | def __init__(self, eps=1e-3): 44 | super(CharbonnierLoss, self).__init__() 45 | self.eps = eps 46 | 47 | def forward(self, x, y): 48 | diff = x - y 49 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 50 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 51 | return loss 52 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | ### rotate and flip 5 | class Augment_RGB_torch: 6 | def __init__(self): 7 | pass 8 | def transform0(self, torch_tensor): 9 | return torch_tensor 10 | def transform1(self, torch_tensor): 11 | torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2]) 12 | return torch_tensor 13 | def transform2(self, torch_tensor): 14 | torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2]) 15 | return torch_tensor 16 | def transform3(self, torch_tensor): 17 | torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2]) 18 | return torch_tensor 19 | def transform4(self, torch_tensor): 20 | torch_tensor = torch_tensor.flip(-2) 21 | return torch_tensor 22 | def transform5(self, torch_tensor): 23 | torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2) 24 | return torch_tensor 25 | def transform6(self, torch_tensor): 26 | torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2) 27 | return torch_tensor 28 | def transform7(self, torch_tensor): 29 | torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2) 30 | return torch_tensor 31 | 32 | 33 | ### mix two images 34 | class MixUp_AUG: 35 | def __init__(self): 36 | self.dist = torch.distributions.beta.Beta(torch.tensor([1.2]), torch.tensor([1.2])) 37 | 38 | def aug(self, rgb_gt, rgb_noisy, gray_mask): 39 | bs = rgb_gt.size(0) 40 | indices = torch.randperm(bs) 41 | rgb_gt2 = rgb_gt[indices] 42 | rgb_noisy2 = rgb_noisy[indices] 43 | gray_mask2 = gray_mask[indices] 44 | # gray_contour2 = gray_mask[indices] 45 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 46 | 47 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 48 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 49 | gray_mask = lam * gray_mask + (1-lam) * gray_mask2 50 | # gray_mask = torch.where(gray_mask>0.01, torch.ones_like(gray_mask), torch.zeros_like(gray_mask)) 51 | # gray_contour = lam * gray_contour + (1-lam) * gray_contour2 52 | return rgb_gt, rgb_noisy, gray_mask 53 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from collections import OrderedDict 5 | 6 | def freeze(model): 7 | for p in model.parameters(): 8 | p.requires_grad=False 9 | 10 | def unfreeze(model): 11 | for p in model.parameters(): 12 | p.requires_grad=True 13 | 14 | def is_frozen(model): 15 | x = [p.requires_grad for p in model.parameters()] 16 | return not all(x) 17 | 18 | def save_checkpoint(model_dir, state, session): 19 | epoch = state['epoch'] 20 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 21 | torch.save(state, model_out_path) 22 | 23 | def load_checkpoint(model, weights): 24 | checkpoint = torch.load(weights) 25 | try: 26 | model.load_state_dict(checkpoint["state_dict"]) 27 | except: 28 | state_dict = checkpoint["state_dict"] 29 | new_state_dict = OrderedDict() 30 | for k, v in state_dict.items(): 31 | name = k[7:] if 'module.' in k else k 32 | new_state_dict[name] = v 33 | model.load_state_dict(new_state_dict) 34 | 35 | 36 | def load_checkpoint_multigpu(model, weights): 37 | checkpoint = torch.load(weights) 38 | state_dict = checkpoint["state_dict"] 39 | new_state_dict = OrderedDict() 40 | for k, v in state_dict.items(): 41 | name = k[7:] 42 | new_state_dict[name] = v 43 | model.load_state_dict(new_state_dict) 44 | 45 | def load_start_epoch(weights): 46 | checkpoint = torch.load(weights) 47 | epoch = checkpoint["epoch"] 48 | return epoch 49 | 50 | def load_optim(optimizer, weights): 51 | checkpoint = torch.load(weights) 52 | optimizer.load_state_dict(checkpoint['optimizer']) 53 | for p in optimizer.param_groups: lr = p['lr'] 54 | return lr 55 | 56 | def get_arch(opt): 57 | from model import UNet,ShadowFormer 58 | arch = opt.arch 59 | 60 | print('You choose '+arch+'...') 61 | if arch == 'UNet': 62 | model_restoration = UNet(dim=opt.embed_dim) 63 | elif arch == 'ShadowFormer': 64 | model_restoration = ShadowFormer(img_size=opt.train_ps,embed_dim=opt.embed_dim,win_size=opt.win_size,token_projection=opt.token_projection,token_mlp=opt.token_mlp) 65 | else: 66 | raise Exception("Arch error!") 67 | 68 | return model_restoration -------------------------------------------------------------------------------- /warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class Options(): 6 | """docstring for Options""" 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def init(self, parser): 12 | # global settings 13 | parser.add_argument('--batch_size', type=int, default=7, help='batch size') 14 | parser.add_argument('--nepoch', type=int, default=500, help='training epochs') 15 | parser.add_argument('--train_workers', type=int, default=0, help='train_dataloader workers') 16 | parser.add_argument('--eval_workers', type=int, default=8, help='eval_dataloader workers') 17 | parser.add_argument('--dataset', type=str, default='ISTD') 18 | parser.add_argument('--pretrain_weights', type=str, default='./log/model_best.pth', 19 | help='path of pretrained_weights') 20 | parser.add_argument('--optimizer', type=str, default='adamw', help='optimizer for training') 21 | parser.add_argument('--lr_initial', type=float, default=0.0002, help='initial learning rate') 22 | parser.add_argument('--weight_decay', type=float, default=0.02, help='weight decay') 23 | parser.add_argument('--gpu', type=str, default='2', help='GPUs') 24 | parser.add_argument('--arch', type=str, default='ShadowFormer', help='archtechture') 25 | parser.add_argument('--mode', type=str, default='shadow', help='image restoration mode') 26 | 27 | # args for saving 28 | parser.add_argument('--save_dir', type=str, default='./log', help='save dir') 29 | parser.add_argument('--save_images', action='store_true', default=False) 30 | parser.add_argument('--env', type=str, default='_istd', help='env') 31 | parser.add_argument('--checkpoint', type=int, default=50, help='checkpoint') 32 | 33 | # args for Uformer 34 | parser.add_argument('--norm_layer', type=str, default='nn.LayerNorm', help='normalize layer in transformer') 35 | parser.add_argument('--embed_dim', type=int, default=32, help='dim of emdeding features') 36 | parser.add_argument('--win_size', type=int, default=10, help='window size of self-attention') 37 | parser.add_argument('--token_projection', type=str, default='linear', help='linear/conv token projection') 38 | parser.add_argument('--token_mlp', type=str, default='leff', help='ffn/leff token mlp') 39 | parser.add_argument('--att_se', action='store_true', default=False, help='se after sa') 40 | 41 | # args for vit 42 | parser.add_argument('--vit_dim', type=int, default=320, help='vit hidden_dim') 43 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth') 44 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim') 45 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim') 46 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size') 47 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection') 48 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection') 49 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module') 50 | 51 | # args for training 52 | parser.add_argument('--train_ps', type=int, default=320, help='patch size of training sample') 53 | parser.add_argument('--resume', action='store_true', default=False) 54 | parser.add_argument('--train_dir', type=str, default='../ISTD_Dataset/train', help='dir of train data') 55 | parser.add_argument('--val_dir', type=str, default='../ISTD_Dataset/test', help='dir of train data') 56 | parser.add_argument('--warmup', action='store_true', default=True, help='warmup') 57 | parser.add_argument('--warmup_epochs', type=int, default=3, help='epochs for warmup') 58 | 59 | return parser 60 | -------------------------------------------------------------------------------- /utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | submission_folder Folder where denoised images reside 17 | Output is written to /bundled/. Please submit 18 | the content of this folder. 19 | ''' 20 | 21 | out_folder = os.path.join(submission_folder, session) 22 | # out_folder = os.path.join(submission_folder, "bundled/") 23 | try: 24 | os.mkdir(out_folder) 25 | except:pass 26 | 27 | israw = True 28 | eval_version="1.0" 29 | 30 | for i in range(50): 31 | Idenoised = np.zeros((20,), dtype=np.object) 32 | for bb in range(20): 33 | filename = '%04d_%02d.mat'%(i+1,bb+1) 34 | s = sio.loadmat(os.path.join(submission_folder,filename)) 35 | Idenoised_crop = s["Idenoised_crop"] 36 | Idenoised[bb] = Idenoised_crop 37 | filename = '%04d.mat'%(i+1) 38 | sio.savemat(os.path.join(out_folder, filename), 39 | {"Idenoised": Idenoised, 40 | "israw": israw, 41 | "eval_version": eval_version}, 42 | ) 43 | 44 | def bundle_submissions_srgb(submission_folder,session): 45 | ''' 46 | Bundles submission data for sRGB denoising 47 | 48 | submission_folder Folder where denoised images reside 49 | Output is written to /bundled/. Please submit 50 | the content of this folder. 51 | ''' 52 | out_folder = os.path.join(submission_folder, session) 53 | # out_folder = os.path.join(submission_folder, "bundled/") 54 | try: 55 | os.mkdir(out_folder) 56 | except:pass 57 | israw = False 58 | eval_version="1.0" 59 | 60 | for i in range(50): 61 | Idenoised = np.zeros((20,), dtype=np.object) 62 | for bb in range(20): 63 | filename = '%04d_%02d.mat'%(i+1,bb+1) 64 | s = sio.loadmat(os.path.join(submission_folder,filename)) 65 | Idenoised_crop = s["Idenoised_crop"] 66 | Idenoised[bb] = Idenoised_crop 67 | filename = '%04d.mat'%(i+1) 68 | sio.savemat(os.path.join(out_folder, filename), 69 | {"Idenoised": Idenoised, 70 | "israw": israw, 71 | "eval_version": eval_version}, 72 | ) 73 | 74 | 75 | 76 | def bundle_submissions_srgb_v1(submission_folder,session): 77 | ''' 78 | Bundles submission data for sRGB denoising 79 | 80 | submission_folder Folder where denoised images reside 81 | Output is written to /bundled/. Please submit 82 | the content of this folder. 83 | ''' 84 | out_folder = os.path.join(submission_folder, session) 85 | # out_folder = os.path.join(submission_folder, "bundled/") 86 | try: 87 | os.mkdir(out_folder) 88 | except:pass 89 | israw = False 90 | eval_version="1.0" 91 | 92 | for i in range(50): 93 | Idenoised = np.zeros((20,), dtype=np.object) 94 | for bb in range(20): 95 | filename = '%04d_%d.mat'%(i+1,bb+1) 96 | s = sio.loadmat(os.path.join(submission_folder,filename)) 97 | Idenoised_crop = s["Idenoised_crop"] 98 | Idenoised[bb] = Idenoised_crop 99 | filename = '%04d.mat'%(i+1) 100 | sio.savemat(os.path.join(out_folder, filename), 101 | {"Idenoised": Idenoised, 102 | "israw": israw, 103 | "eval_version": eval_version}, 104 | ) -------------------------------------------------------------------------------- /utils/antialias.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, Adobe Inc. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4 | # 4.0 International Public License. To view a copy of this license, visit 5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 6 | 7 | 8 | 9 | ######## https://github.com/adobe/antialiased-cnns/blob/master/models_lpf/__init__.py 10 | 11 | 12 | 13 | import torch 14 | import torch.nn.parallel 15 | import numpy as np 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | class Downsample(nn.Module): 20 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 21 | super(Downsample, self).__init__() 22 | self.filt_size = filt_size 23 | self.pad_off = pad_off 24 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] 25 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] 26 | self.stride = stride 27 | self.off = int((self.stride-1)/2.) 28 | self.channels = channels 29 | 30 | # print('Filter size [%i]'%filt_size) 31 | if(self.filt_size==1): 32 | a = np.array([1.,]) 33 | elif(self.filt_size==2): 34 | a = np.array([1., 1.]) 35 | elif(self.filt_size==3): 36 | a = np.array([1., 2., 1.]) 37 | elif(self.filt_size==4): 38 | a = np.array([1., 3., 3., 1.]) 39 | elif(self.filt_size==5): 40 | a = np.array([1., 4., 6., 4., 1.]) 41 | elif(self.filt_size==6): 42 | a = np.array([1., 5., 10., 10., 5., 1.]) 43 | elif(self.filt_size==7): 44 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 45 | 46 | filt = torch.Tensor(a[:,None]*a[None,:]) 47 | filt = filt/torch.sum(filt) 48 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) 49 | 50 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 51 | 52 | def forward(self, inp): 53 | if(self.filt_size==1): 54 | if(self.pad_off==0): 55 | return inp[:,:,::self.stride,::self.stride] 56 | else: 57 | return self.pad(inp)[:,:,::self.stride,::self.stride] 58 | else: 59 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 60 | 61 | def get_pad_layer(pad_type): 62 | if(pad_type in ['refl','reflect']): 63 | PadLayer = nn.ReflectionPad2d 64 | elif(pad_type in ['repl','replicate']): 65 | PadLayer = nn.ReplicationPad2d 66 | elif(pad_type=='zero'): 67 | PadLayer = nn.ZeroPad2d 68 | else: 69 | print('Pad type [%s] not recognized'%pad_type) 70 | return PadLayer 71 | 72 | 73 | class Downsample1D(nn.Module): 74 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 75 | super(Downsample1D, self).__init__() 76 | self.filt_size = filt_size 77 | self.pad_off = pad_off 78 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] 79 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 80 | self.stride = stride 81 | self.off = int((self.stride - 1) / 2.) 82 | self.channels = channels 83 | 84 | # print('Filter size [%i]' % filt_size) 85 | if(self.filt_size == 1): 86 | a = np.array([1., ]) 87 | elif(self.filt_size == 2): 88 | a = np.array([1., 1.]) 89 | elif(self.filt_size == 3): 90 | a = np.array([1., 2., 1.]) 91 | elif(self.filt_size == 4): 92 | a = np.array([1., 3., 3., 1.]) 93 | elif(self.filt_size == 5): 94 | a = np.array([1., 4., 6., 4., 1.]) 95 | elif(self.filt_size == 6): 96 | a = np.array([1., 5., 10., 10., 5., 1.]) 97 | elif(self.filt_size == 7): 98 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 99 | 100 | filt = torch.Tensor(a) 101 | filt = filt / torch.sum(filt) 102 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) 103 | 104 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) 105 | 106 | def forward(self, inp): 107 | if(self.filt_size == 1): 108 | if(self.pad_off == 0): 109 | return inp[:, :, ::self.stride] 110 | else: 111 | return self.pad(inp)[:, :, ::self.stride] 112 | else: 113 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 114 | 115 | 116 | def get_pad_layer_1d(pad_type): 117 | if(pad_type in ['refl', 'reflect']): 118 | PadLayer = nn.ReflectionPad1d 119 | elif(pad_type in ['repl', 'replicate']): 120 | PadLayer = nn.ReplicationPad1d 121 | elif(pad_type == 'zero'): 122 | PadLayer = nn.ZeroPad1d 123 | else: 124 | print('Pad type [%s] not recognized' % pad_type) 125 | return PadLayer -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import cv2 5 | from skimage.color import rgb2lab 6 | 7 | def is_numpy_file(filename): 8 | return any(filename.endswith(extension) for extension in [".npy"]) 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in [".jpg"]) 12 | 13 | def is_png_file(filename): 14 | return any(filename.endswith(extension) for extension in [".png"]) 15 | 16 | def is_pkl_file(filename): 17 | return any(filename.endswith(extension) for extension in [".pkl"]) 18 | 19 | def load_pkl(filename_): 20 | with open(filename_, 'rb') as f: 21 | ret_dict = pickle.load(f) 22 | return ret_dict 23 | 24 | def save_dict(dict_, filename_): 25 | with open(filename_, 'wb') as f: 26 | pickle.dump(dict_, f) 27 | 28 | def load_npy(filepath): 29 | img = np.load(filepath) 30 | return img 31 | 32 | def load_img(filepath): 33 | img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 34 | # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA) 35 | img = img.astype(np.float32) 36 | img = img/255. 37 | return img 38 | 39 | def load_val_img(filepath): 40 | img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 41 | # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA) 42 | resized_img = img.astype(np.float32) 43 | resized_img = resized_img/255. 44 | return resized_img 45 | 46 | def load_mask(filepath): 47 | img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE) 48 | # img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA) 49 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2YUV) 50 | # kernel = np.ones((8,8), np.uint8) 51 | # erosion = cv2.erode(img, kernel, iterations=1) 52 | # dilation = cv2.dilate(img, kernel, iterations=1) 53 | # contour = dilation - erosion 54 | img = img.astype(np.float32) 55 | # contour = contour.astype(np.float32) 56 | # contour = contour/255. 57 | img = img/255. 58 | return img 59 | 60 | def load_val_mask(filepath): 61 | img = cv2.imread(filepath, 0) 62 | resized_img = img 63 | # resized_img = cv2.resize(img, [256, 256], interpolation=cv2.INTER_AREA) 64 | resized_img = resized_img.astype(np.float32) 65 | resized_img = resized_img/255. 66 | return resized_img 67 | 68 | def save_img(img, filepath): 69 | cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 70 | 71 | def myPSNR(tar_img, prd_img): 72 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 73 | rmse = (imdff**2).mean().sqrt() 74 | ps = 20*torch.log10(1/rmse) 75 | return ps 76 | 77 | def batch_PSNR(img1, img2, average=True): 78 | PSNR = [] 79 | for im1, im2 in zip(img1, img2): 80 | psnr = myPSNR(im1, im2) 81 | PSNR.append(psnr) 82 | return sum(PSNR)/len(PSNR) if average else sum(PSNR) 83 | 84 | def tensor2im(input_image, imtype=np.uint8): 85 | """"Converts a Tensor array into a numpy image array. 86 | Parameters: 87 | input_image (tensor) -- the input image tensor array 88 | imtype (type) -- the desired type of the converted numpy array 89 | """ 90 | if not isinstance(input_image, np.ndarray): 91 | 92 | if isinstance(input_image, torch.Tensor): # get the data from a variable 93 | image_tensor = input_image.data 94 | else: 95 | return input_image 96 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 97 | if image_numpy.shape[0] == 1: # grayscale to RGB 98 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 99 | # image_numpy = image_numpy.convert('L') 100 | 101 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 102 | else: # if it is a numpy array, do nothing 103 | image_numpy = input_image 104 | # image_numpy = 105 | return np.clip(image_numpy, 0, 255).astype(imtype) 106 | 107 | def calc_RMSE(real_img, fake_img): 108 | # convert to LAB color space 109 | real_lab = rgb2lab(real_img) 110 | fake_lab = rgb2lab(fake_img) 111 | return real_lab - fake_lab 112 | 113 | def tensor2uint(img): 114 | img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() 115 | if img.ndim == 3: 116 | img = np.transpose(img, (1, 2, 0)) 117 | return np.uint8((img*255.0).round()) 118 | 119 | def imsave(img, img_path): 120 | img = np.squeeze(img) 121 | if img.ndim == 3: 122 | img = img[:, :, [2, 1, 0]] 123 | cv2.imwrite(img_path, img) 124 | 125 | # def yCbCr2rgb(input_im): 126 | # im_flat = input_im.contiguous().view(-1, 3).float() 127 | # mat = torch.tensor([[1.164, 1.164, 1.164], 128 | # [0, -0.392, 2.017], 129 | # [1.596, -0.813, 0]]) 130 | # bias = torch.tensor([-16.0/255.0, -128.0/255.0, -128.0/255.0]) 131 | # temp = (im_flat + bias).mm(mat) 132 | # out = temp.view(3, list(input_im.size())[1], list(input_im.size())[2]) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torch.utils.data import Dataset 4 | import torch 5 | from utils import is_png_file, load_img, load_val_img, load_mask, load_val_mask, Augment_RGB_torch 6 | import torch.nn.functional as F 7 | import random 8 | 9 | augment = Augment_RGB_torch() 10 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] 11 | 12 | ################################################################################################## 13 | class DataLoaderTrain(Dataset): 14 | def __init__(self, rgb_dir, img_options=None, target_transform=None): 15 | super(DataLoaderTrain, self).__init__() 16 | 17 | self.target_transform = target_transform 18 | 19 | gt_dir = 'train_C' 20 | input_dir = 'train_A' 21 | mask_dir = 'train_B' 22 | 23 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) 24 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) 25 | mask_files = sorted(os.listdir(os.path.join(rgb_dir, mask_dir))) 26 | 27 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)] 28 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_png_file(x)] 29 | self.mask_filenames = [os.path.join(rgb_dir, mask_dir, x) for x in mask_files if is_png_file(x)] 30 | 31 | self.img_options = img_options 32 | 33 | self.tar_size = len(self.clean_filenames) # get the size of target 34 | 35 | def __len__(self): 36 | return self.tar_size 37 | 38 | def __getitem__(self, index): 39 | tar_index = index % self.tar_size 40 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 41 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 42 | mask = load_mask(self.mask_filenames[tar_index]) 43 | mask = torch.from_numpy(np.float32(mask)) 44 | 45 | clean = clean.permute(2,0,1) 46 | noisy = noisy.permute(2,0,1) 47 | 48 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 49 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 50 | mask_filename = os.path.split(self.mask_filenames[tar_index])[-1] 51 | 52 | #Crop Input and Target 53 | ps = self.img_options['patch_size'] 54 | H = clean.shape[1] 55 | W = clean.shape[2] 56 | # r = np.random.randint(0, H - ps) if not H-ps else 0 57 | # c = np.random.randint(0, W - ps) if not H-ps else 0 58 | if H-ps==0: 59 | r=0 60 | c=0 61 | else: 62 | r = np.random.randint(0, H - ps) 63 | c = np.random.randint(0, W - ps) 64 | clean = clean[:, r:r + ps, c:c + ps] 65 | noisy = noisy[:, r:r + ps, c:c + ps] 66 | mask = mask[r:r + ps, c:c + ps] 67 | 68 | apply_trans = transforms_aug[random.getrandbits(3)] 69 | 70 | clean = getattr(augment, apply_trans)(clean) 71 | noisy = getattr(augment, apply_trans)(noisy) 72 | mask = getattr(augment, apply_trans)(mask) 73 | mask = torch.unsqueeze(mask, dim=0) 74 | return clean, noisy, mask, clean_filename, noisy_filename 75 | 76 | ################################################################################################## 77 | class DataLoaderVal(Dataset): 78 | def __init__(self, rgb_dir, target_transform=None): 79 | super(DataLoaderVal, self).__init__() 80 | 81 | self.target_transform = target_transform 82 | 83 | gt_dir = 'test_C' 84 | input_dir = 'test_A' 85 | mask_dir = 'test_B' 86 | 87 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) 88 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) 89 | mask_files = sorted(os.listdir(os.path.join(rgb_dir, mask_dir))) 90 | 91 | 92 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)] 93 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_png_file(x)] 94 | self.mask_filenames = [os.path.join(rgb_dir, mask_dir, x) for x in mask_files if is_png_file(x)] 95 | 96 | 97 | self.tar_size = len(self.clean_filenames) 98 | 99 | def __len__(self): 100 | return self.tar_size 101 | 102 | def __getitem__(self, index): 103 | tar_index = index % self.tar_size 104 | 105 | 106 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 107 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 108 | mask = load_mask(self.mask_filenames[tar_index]) 109 | mask = torch.from_numpy(np.float32(mask)) 110 | 111 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 112 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 113 | mask_filename = os.path.split(self.mask_filenames[tar_index])[-1] 114 | 115 | clean = clean.permute(2,0,1) 116 | noisy = noisy.permute(2,0,1) 117 | mask = torch.unsqueeze(mask, dim=0) 118 | 119 | return clean, noisy, mask, clean_filename, noisy_filename 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ShadowFormer (AAAI'23) 2 | This is the official implementation of the AAAI 2023 paper [ShadowFormer: Global Context Helps Image Shadow Removal](https://arxiv.org/pdf/2302.01650.pdf). 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/shadowformer-global-context-helps-image/shadow-removal-on-istd)](https://paperswithcode.com/sota/shadow-removal-on-istd?p=shadowformer-global-context-helps-image) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/shadowformer-global-context-helps-image/shadow-removal-on-adjusted-istd)](https://paperswithcode.com/sota/shadow-removal-on-adjusted-istd?p=shadowformer-global-context-helps-image) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/shadowformer-global-context-helps-image/shadow-removal-on-srd)](https://paperswithcode.com/sota/shadow-removal-on-srd?p=shadowformer-global-context-helps-image) 7 | 8 | #### News 9 | * **Feb 24, 2023**: Release the pretrained models for ISTD and ISTD+. 10 | * **Feb 18, 2023**: Release the training and testing codes. 11 | * **Feb 17, 2023**: Add the testing results and the description of our work. 12 | 13 | ## Introduction 14 | To tackle image shadow removal problem, we propose a novel transformer-based method, dubbed ShadowFormer, for exploiting non-shadow 15 | regions to help shadow region restoration. A multi-scale channel attention framework is employed to hierarchically 16 | capture the global information. Based on that, we propose a Shadow-Interaction Module (SIM) with Shadow-Interaction Attention (SIA) in the bottleneck stage to effectively model the context correlation between shadow and non-shadow regions. 17 | For more details, please refer to our [original paper](https://arxiv.org/pdf/2302.01650.pdf) 18 | 19 |

20 | 21 |

22 | 23 | ## Requirement 24 | * Python 3.7 25 | * Pytorch 1.7 26 | * CUDA 11.1 27 | ```bash 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ## Datasets 32 | * ISTD [[link]](https://github.com/DeepInsight-PCALab/ST-CGAN) 33 | * ISTD+ [[link]](https://github.com/cvlab-stonybrook/SID) 34 | * SRD [[Training]](https://drive.google.com/file/d/1W8vBRJYDG9imMgr9I2XaA13tlFIEHOjS/view)[[Testing]](https://drive.google.com/file/d/1GTi4BmQ0SJ7diDMmf-b7x2VismmXtfTo/view) 35 | 36 | ## Pretrained models 37 | [ISTD](https://drive.google.com/file/d/1bHbkHxY5D5905BMw2jzvkzgXsFPKzSq4/view?usp=share_link) | [ISTD+](https://drive.google.com/file/d/10pBsJenoWGriZ9kjWOcE4l4Kzg-F1TFd/view?usp=share_link) | [SRD]() 38 | 39 | Please download the corresponding pretrained model and modify the `weights` in `test.py`. 40 | 41 | ## Test 42 | You can directly test the performance of the pre-trained model as follows 43 | 1. Modify the paths to dataset and pre-trained model. You need to modify the following path in the `test.py` 44 | ```python 45 | input_dir # shadow image input path -- Line 27 46 | weights # pretrained model path -- Line 31 47 | ``` 48 | 2. Test the model 49 | ```python 50 | python test.py --save_images 51 | ``` 52 | You can check the output in `./results`. 53 | 54 | ## Train 55 | 1. Download datasets and set the following structure 56 | ``` 57 | |-- ISTD_Dataset 58 | |-- train 59 | |-- train_A # shadow image 60 | |-- train_B # shadow mask 61 | |-- train_C # shadow-free GT 62 | |-- test 63 | |-- test_A # shadow image 64 | |-- test_B # shadow mask 65 | |-- test_C # shadow-free GT 66 | ``` 67 | 2. You need to modify the following terms in `option.py` 68 | ```python 69 | train_dir # training set path 70 | val_dir # testing set path 71 | gpu: 0 # Our model can be trained using a single RTX A5000 GPU. You can also train the model using multiple GPUs by adding more GPU ids in it. 72 | ``` 73 | 3. Train the network 74 | If you want to train the network on 256X256 images: 75 | ```python 76 | python train.py --warmup --win_size 8 --train_ps 256 77 | ``` 78 | or you want to train on original resolution, e.g., 480X640 for ISTD: 79 | ```python 80 | python train.py --warmup --win_size 10 --train_ps 320 81 | ``` 82 | 83 | ## Evaluation 84 | The results reported in the paper are calculated by the `matlab` script used in [previous method](https://github.com/zhuyr97/AAAI2022_Unfolding_Network_Shadow_Removal/tree/master/codes). Details refer to `evaluation/measure_shadow.m`. 85 | We also provide the `python` code for calculating the metrics in `test.py`, using `python test.py --cal_metrics` to print. 86 | 87 | ## Results 88 | #### Evaluation on ISTD 89 | The evaluation results on ISTD are as follows 90 | | Method | PSNR | SSIM | RMSE | 91 | | :-- | :--: | :--: | :--: | 92 | | ST-CGAN | 27.44 | 0.929 | 6.65 | 93 | | DSC | 29.00 | 0.944 | 5.59 | 94 | | DHAN | 29.11 | 0.954 | 5.66 | 95 | | Fu et al. | 27.19 | 0.945 | 5.88 | 96 | | Zhu et al. | 29.85 | 0.960 | 4.27 | 97 | | **ShadowFormer (Ours)** | **32.21** | **0.968** | **4.09** | 98 | 99 | #### Visual Results 100 |

101 | 102 | #### Testing results 103 | The testing results on dataset ISTD, ISTD+, SRD are: [results](https://drive.google.com/file/d/1zcv7KBCIKgk-CGQJCWnM2YAKcSAj8Sc4/view?usp=share_link) 104 | 105 | ## References 106 | Our implementation is based on [Uformer](https://github.com/ZhendongWang6/Uformer) and [Restormer](https://github.com/swz30/Restormer). We would like to thank them. 107 | 108 | Citation 109 | ----- 110 | Preprint available [here](https://arxiv.org/pdf/2302.01650.pdf). 111 | 112 | In case of use, please cite our publication: 113 | 114 | L. Guo, S. Huang, D. Liu, H. Cheng and B. Wen, "ShadowFormer: Global Context Helps Image Shadow Removal," AAAI 2023. 115 | 116 | Bibtex: 117 | ``` 118 | @article{guo2023shadowformer, 119 | title={ShadowFormer: Global Context Helps Image Shadow Removal}, 120 | author={Guo, Lanqing and Huang, Siyu and Liu, Ding and Cheng, Hao and Wen, Bihan}, 121 | journal={arXiv preprint arXiv:2302.01650}, 122 | year={2023} 123 | } 124 | ``` 125 | 126 | ## Contact 127 | If you have any questions, please contact lanqing001@e.ntu.edu.sg 128 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os,sys 4 | import argparse 5 | from tqdm import tqdm 6 | from einops import rearrange, repeat 7 | 8 | import torch.nn as nn 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | # from ptflops import get_model_complexity_info 13 | 14 | import scipy.io as sio 15 | from utils.loader import get_validation_data 16 | import utils 17 | import cv2 18 | from model import UNet 19 | 20 | from skimage import img_as_float32, img_as_ubyte 21 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 22 | from skimage.metrics import structural_similarity as ssim_loss 23 | from sklearn.metrics import mean_squared_error as mse_loss 24 | 25 | parser = argparse.ArgumentParser(description='RGB denoising evaluation on the validation set of SIDD') 26 | parser.add_argument('--input_dir', default='../ISTD_Dataset/test/', 27 | type=str, help='Directory of validation images') 28 | parser.add_argument('--result_dir', default='./results/', 29 | type=str, help='Directory for results') 30 | parser.add_argument('--weights', default='./log/ShadowFormer_istd/models/model_best.pth', 31 | type=str, help='Path to weights') 32 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 33 | parser.add_argument('--arch', default='ShadowFormer', type=str, help='arch') 34 | parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader') 35 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 36 | parser.add_argument('--cal_metrics', action='store_true', help='Measure denoised images with GT') 37 | parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers') 38 | parser.add_argument('--win_size', type=int, default=10, help='number of data loading workers') 39 | parser.add_argument('--token_projection', type=str, default='linear', help='linear/conv token projection') 40 | parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp') 41 | # args for vit 42 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim') 43 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth') 44 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim') 45 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim') 46 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size') 47 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection') 48 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection') 49 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module') 50 | parser.add_argument('--train_ps', type=int, default=320, help='patch size of training sample') 51 | parser.add_argument('--tile', type=int, default=None, help='Tile size (e.g 720). None means testing on the original resolution image') 52 | parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles') 53 | args = parser.parse_args() 54 | 55 | 56 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 57 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 58 | 59 | utils.mkdir(args.result_dir) 60 | 61 | test_dataset = get_validation_data(args.input_dir) 62 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False) 63 | 64 | model_restoration = utils.get_arch(args) 65 | model_restoration = torch.nn.DataParallel(model_restoration) 66 | 67 | utils.load_checkpoint(model_restoration, args.weights) 68 | print("===>Testing using weights: ", args.weights) 69 | 70 | model_restoration.cuda() 71 | model_restoration.eval() 72 | 73 | img_multiple_of = 8 * args.win_size 74 | 75 | with torch.no_grad(): 76 | psnr_val_rgb = [] 77 | ssim_val_rgb = [] 78 | rmse_val_rgb = [] 79 | psnr_val_s = [] 80 | ssim_val_s = [] 81 | psnr_val_ns = [] 82 | ssim_val_ns = [] 83 | rmse_val_s = [] 84 | rmse_val_ns = [] 85 | for ii, data_test in enumerate(tqdm(test_loader), 0): 86 | rgb_gt = data_test[0].numpy().squeeze().transpose((1, 2, 0)) 87 | rgb_noisy = data_test[1].cuda() 88 | mask = data_test[2].cuda() 89 | filenames = data_test[3] 90 | 91 | # Pad the input if not_multiple_of win_size * 8 92 | height, width = rgb_noisy.shape[2], rgb_noisy.shape[3] 93 | H, W = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of, ( 94 | (width + img_multiple_of) // img_multiple_of) * img_multiple_of 95 | padh = H - height if height % img_multiple_of != 0 else 0 96 | padw = W - width if width % img_multiple_of != 0 else 0 97 | rgb_noisy = F.pad(rgb_noisy, (0, padw, 0, padh), 'reflect') 98 | mask = F.pad(mask, (0, padw, 0, padh), 'reflect') 99 | 100 | if args.tile is None: 101 | rgb_restored = model_restoration(rgb_noisy, mask) 102 | else: 103 | # test the image tile by tile 104 | b, c, h, w = rgb_noisy.shape 105 | tile = min(args.tile, h, w) 106 | assert tile % 8 == 0, "tile size should be multiple of 8" 107 | tile_overlap = args.tile_overlap 108 | 109 | stride = tile - tile_overlap 110 | h_idx_list = list(range(0, h - tile, stride)) + [h - tile] 111 | w_idx_list = list(range(0, w - tile, stride)) + [w - tile] 112 | E = torch.zeros(b, c, h, w).type_as(rgb_noisy) 113 | W = torch.zeros_like(E) 114 | 115 | for h_idx in h_idx_list: 116 | for w_idx in w_idx_list: 117 | in_patch = rgb_noisy[..., h_idx:h_idx + tile, w_idx:w_idx + tile] 118 | mask_patch = mask[..., h_idx:h_idx + tile, w_idx:w_idx + tile] 119 | out_patch = model_restoration(in_patch, mask_patch) 120 | out_patch_mask = torch.ones_like(out_patch) 121 | 122 | E[..., h_idx:(h_idx + tile), w_idx:(w_idx + tile)].add_(out_patch) 123 | W[..., h_idx:(h_idx + tile), w_idx:(w_idx + tile)].add_(out_patch_mask) 124 | restored = E.div_(W) 125 | 126 | rgb_restored = torch.clamp(rgb_restored, 0, 1).cpu().numpy().squeeze().transpose((1, 2, 0)) 127 | 128 | # Unpad the output 129 | rgb_restored = rgb_restored[:height, :width, :] 130 | 131 | if args.cal_metrics: 132 | bm = torch.where(mask == 0, torch.zeros_like(mask), torch.ones_like(mask)) #binarize mask 133 | bm = np.expand_dims(bm.cpu().numpy().squeeze(), axis=2) 134 | 135 | # calculate SSIM in gray space 136 | gray_restored = cv2.cvtColor(rgb_restored, cv2.COLOR_RGB2GRAY) 137 | gray_gt = cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2GRAY) 138 | ssim_val_rgb.append(ssim_loss(gray_restored, gray_gt, channel_axis=None)) 139 | ssim_val_ns.append(ssim_loss(gray_restored * (1 - bm.squeeze()), gray_gt * (1 - bm.squeeze()), channel_axis=None)) 140 | ssim_val_s.append(ssim_loss(gray_restored * bm.squeeze(), gray_gt * bm.squeeze(), channel_axis=None)) 141 | 142 | psnr_val_rgb.append(psnr_loss(rgb_restored, rgb_gt)) 143 | psnr_val_ns.append(psnr_loss(rgb_restored * (1 - bm), rgb_gt * (1 - bm))) 144 | psnr_val_s.append(psnr_loss(rgb_restored * bm, rgb_gt * bm)) 145 | 146 | # calculate the RMSE in LAB space 147 | rmse_temp = np.abs(cv2.cvtColor(rgb_restored, cv2.COLOR_RGB2LAB) - cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2LAB)).mean() * 3 148 | rmse_val_rgb.append(rmse_temp) 149 | rmse_temp_s = np.abs(cv2.cvtColor(rgb_restored * bm, cv2.COLOR_RGB2LAB) - cv2.cvtColor(rgb_gt * bm, cv2.COLOR_RGB2LAB)).sum() / bm.sum() 150 | rmse_temp_ns = np.abs(cv2.cvtColor(rgb_restored * (1-bm), cv2.COLOR_RGB2LAB) - cv2.cvtColor(rgb_gt * (1-bm), 151 | cv2.COLOR_RGB2LAB)).sum() / (1-bm).sum() 152 | rmse_val_s.append(rmse_temp_s) 153 | rmse_val_ns.append(rmse_temp_ns) 154 | 155 | 156 | if args.save_images: 157 | utils.save_img(rgb_restored*255.0, os.path.join(args.result_dir, filenames[0])) 158 | 159 | if args.cal_metrics: 160 | psnr_val_rgb = sum(psnr_val_rgb)/len(test_dataset) 161 | ssim_val_rgb = sum(ssim_val_rgb)/len(test_dataset) 162 | psnr_val_s = sum(psnr_val_s)/len(test_dataset) 163 | ssim_val_s = sum(ssim_val_s)/len(test_dataset) 164 | psnr_val_ns = sum(psnr_val_ns)/len(test_dataset) 165 | ssim_val_ns = sum(ssim_val_ns)/len(test_dataset) 166 | rmse_val_rgb = sum(rmse_val_rgb) / len(test_dataset) 167 | rmse_val_s = sum(rmse_val_s) / len(test_dataset) 168 | rmse_val_ns = sum(rmse_val_ns) / len(test_dataset) 169 | print("PSNR: %f, SSIM: %f, RMSE: %f " %(psnr_val_rgb, ssim_val_rgb, rmse_val_rgb)) 170 | print("SPSNR: %f, SSSIM: %f, SRMSE: %f " %(psnr_val_s, ssim_val_s, rmse_val_s)) 171 | print("NSPSNR: %f, NSSSIM: %f, NSRMSE: %f " %(psnr_val_ns, ssim_val_ns, rmse_val_ns)) 172 | 173 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # add dir 5 | dir_name = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(os.path.join(dir_name,'./auxiliary/')) 7 | print(dir_name) 8 | 9 | import argparse 10 | import options 11 | ######### parser ########### 12 | opt = options.Options().init(argparse.ArgumentParser(description='image denoising')).parse_args() 13 | print(opt) 14 | 15 | import utils 16 | ######### Set GPUs ########### 17 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 18 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu 19 | import torch 20 | torch.backends.cudnn.benchmark = True 21 | # from piqa import SSIM 22 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | # print(device) 24 | import torch.nn as nn 25 | import torch.optim as optim 26 | from torch.utils.data import DataLoader 27 | from natsort import natsorted 28 | import glob 29 | import random 30 | import time 31 | import numpy as np 32 | from einops import rearrange, repeat 33 | import datetime 34 | from pdb import set_trace as stx 35 | from utils import save_img 36 | from losses import CharbonnierLoss 37 | 38 | from tqdm import tqdm 39 | from warmup_scheduler import GradualWarmupScheduler 40 | from torch.optim.lr_scheduler import StepLR 41 | from timm.utils import NativeScaler 42 | 43 | from utils.loader import get_training_data, get_validation_data 44 | 45 | 46 | ######### Logs dir ########### 47 | log_dir = os.path.join(dir_name, 'log', opt.arch+opt.env) 48 | if not os.path.exists(log_dir): 49 | os.makedirs(log_dir) 50 | logname = os.path.join(log_dir, datetime.datetime.now().isoformat()+'.txt') 51 | print("Now time is : ", datetime.datetime.now().isoformat()) 52 | result_dir = os.path.join(log_dir, 'results') 53 | model_dir = os.path.join(log_dir, 'models') 54 | utils.mkdir(result_dir) 55 | utils.mkdir(model_dir) 56 | 57 | # ######### Set Seeds ########### 58 | random.seed(1234) 59 | np.random.seed(1234) 60 | torch.manual_seed(1234) 61 | torch.cuda.manual_seed_all(1234) 62 | 63 | 64 | 65 | ######### Model ########### 66 | model_restoration = utils.get_arch(opt) 67 | 68 | with open(logname,'a') as f: 69 | f.write(str(opt)+'\n') 70 | f.write(str(model_restoration)+'\n') 71 | 72 | ######### Optimizer ########### 73 | start_epoch = 1 74 | if opt.optimizer.lower() == 'adam': 75 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 76 | elif opt.optimizer.lower() == 'adamw': 77 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 78 | else: 79 | raise Exception("Error optimizer...") 80 | 81 | 82 | ######### DataParallel ########### 83 | model_restoration = torch.nn.DataParallel (model_restoration) 84 | model_restoration.cuda() 85 | 86 | ######### Resume ########### 87 | if opt.resume: 88 | path_chk_rest = opt.pretrain_weights 89 | utils.load_checkpoint(model_restoration,path_chk_rest) 90 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 91 | # lr = utils.load_optim(optimizer, path_chk_rest) 92 | # 93 | # for p in optimizer.param_groups: p['lr'] = lr 94 | # warmup = False 95 | # new_lr = lr 96 | # print('------------------------------------------------------------------------------') 97 | # print("==> Resuming Training with learning rate:",new_lr) 98 | # print('------------------------------------------------------------------------------') 99 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-start_epoch+1, eta_min=1e-6) 100 | 101 | # ######### Scheduler ########### 102 | if opt.warmup: 103 | print("Using warmup and cosine strategy!") 104 | warmup_epochs = opt.warmup_epochs 105 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6) 106 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 107 | scheduler.step() 108 | else: 109 | step = 50 110 | print("Using StepLR,step={}!".format(step)) 111 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5) 112 | scheduler.step() 113 | 114 | 115 | ######### Loss ########### 116 | criterion = CharbonnierLoss().cuda() 117 | 118 | ######### DataLoader ########### 119 | print('===> Loading datasets') 120 | img_options_train = {'patch_size':opt.train_ps} 121 | train_dataset = get_training_data(opt.train_dir, img_options_train) 122 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, 123 | num_workers=opt.train_workers, pin_memory=True, drop_last=False) 124 | 125 | val_dataset = get_validation_data(opt.val_dir) 126 | val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, 127 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False) 128 | 129 | len_trainset = train_dataset.__len__() 130 | len_valset = val_dataset.__len__() 131 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset) 132 | 133 | ######### train ########### 134 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch)) 135 | best_psnr = 0 136 | best_epoch = 0 137 | best_iter = 0 138 | eval_now = 1000 139 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now)) 140 | 141 | loss_scaler = NativeScaler() 142 | torch.cuda.empty_cache() 143 | ii=0 144 | index = 0 145 | for epoch in range(start_epoch, opt.nepoch + 1): 146 | epoch_start_time = time.time() 147 | epoch_loss = 0 148 | train_id = 1 149 | epoch_ssim_loss = 0 150 | for i, data in enumerate(train_loader, 0): 151 | # zero_grad 152 | index += 1 153 | optimizer.zero_grad() 154 | target = data[0].cuda() 155 | input_ = data[1].cuda() 156 | mask = data[2].cuda() 157 | if epoch > 5: 158 | target, input_, mask = utils.MixUp_AUG().aug(target, input_, mask) 159 | with torch.cuda.amp.autocast(): 160 | restored = model_restoration(input_, mask) 161 | restored = torch.clamp(restored,0,1) 162 | loss = criterion(restored, target) 163 | loss_scaler( 164 | loss, optimizer,parameters=model_restoration.parameters()) 165 | epoch_loss +=loss.item() 166 | #### Evaluation #### 167 | if (index+1)%eval_now==0 and i>0: 168 | eval_shadow_rmse = 0 169 | eval_nonshadow_rmse = 0 170 | eval_rmse = 0 171 | with torch.no_grad(): 172 | model_restoration.eval() 173 | psnr_val_rgb = [] 174 | for ii, data_val in enumerate((val_loader), 0): 175 | target = data_val[0].cuda() 176 | input_ = data_val[1].cuda() 177 | mask = data_val[2].cuda() 178 | filenames = data_val[3] 179 | with torch.cuda.amp.autocast(): 180 | restored = model_restoration(input_, mask) 181 | restored = torch.clamp(restored,0,1) 182 | psnr_val_rgb.append(utils.batch_PSNR(restored, target, False).item()) 183 | 184 | psnr_val_rgb = sum(psnr_val_rgb)/len(val_loader) 185 | if psnr_val_rgb > best_psnr: 186 | best_psnr = psnr_val_rgb 187 | best_epoch = epoch 188 | best_iter = i 189 | torch.save({'epoch': epoch, 190 | 'state_dict': model_restoration.state_dict(), 191 | 'optimizer' : optimizer.state_dict() 192 | }, os.path.join(model_dir,"model_best.pth")) 193 | print("[Ep %d it %d\t PSNR : %.4f] " % (epoch, i, psnr_val_rgb)) 194 | with open(logname,'a') as f: 195 | f.write("[Ep %d it %d\t PSNR SIDD: %.4f\t] ---- [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] " \ 196 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n') 197 | model_restoration.train() 198 | torch.cuda.empty_cache() 199 | scheduler.step() 200 | 201 | print("------------------------------------------------------------------") 202 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss,scheduler.get_lr()[0])) 203 | print("------------------------------------------------------------------") 204 | with open(logname,'a') as f: 205 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n') 206 | 207 | torch.save({'epoch': epoch, 208 | 'state_dict': model_restoration.state_dict(), 209 | 'optimizer' : optimizer.state_dict() 210 | }, os.path.join(model_dir,"model_latest.pth")) 211 | 212 | if epoch%opt.checkpoint == 0: 213 | torch.save({'epoch': epoch, 214 | 'state_dict': model_restoration.state_dict(), 215 | 'optimizer' : optimizer.state_dict() 216 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch))) 217 | print("Now time is : ",datetime.datetime.now().isoformat()) 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | import math 9 | import numpy as np 10 | import time 11 | from torch import einsum 12 | import cv2 13 | import scipy.misc 14 | import utils 15 | ######################################### 16 | class ConvBlock(nn.Module): 17 | def __init__(self, in_channel, out_channel, strides=1): 18 | super(ConvBlock, self).__init__() 19 | self.strides = strides 20 | self.in_channel=in_channel 21 | self.out_channel=out_channel 22 | self.block = nn.Sequential( 23 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1), 24 | nn.LeakyReLU(inplace=True), 25 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1), 26 | nn.LeakyReLU(inplace=True), 27 | ) 28 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0) 29 | 30 | def forward(self, x): 31 | out1 = self.block(x) 32 | out2 = self.conv11(x) 33 | out = out1 + out2 34 | return out 35 | 36 | def flops(self, H, W): 37 | flops = H*W*self.in_channel*self.out_channel*(3*3+1)+H*W*self.out_channel*self.out_channel*3*3 38 | return flops 39 | 40 | class UNet(nn.Module): 41 | def __init__(self, block=ConvBlock,dim=32): 42 | super(UNet, self).__init__() 43 | 44 | self.dim = dim 45 | self.ConvBlock1 = ConvBlock(4, dim, strides=1) 46 | self.pool1 = nn.Conv2d(dim,dim,kernel_size=4, stride=2, padding=1) 47 | 48 | self.ConvBlock2 = block(dim, dim*2, strides=1) 49 | self.pool2 = nn.Conv2d(dim*2,dim*2,kernel_size=4, stride=2, padding=1) 50 | 51 | self.ConvBlock3 = block(dim*2, dim*4, strides=1) 52 | self.pool3 = nn.Conv2d(dim*4,dim*4,kernel_size=4, stride=2, padding=1) 53 | 54 | self.ConvBlock4 = block(dim*4, dim*8, strides=1) 55 | self.pool4 = nn.Conv2d(dim*8, dim*8,kernel_size=4, stride=2, padding=1) 56 | 57 | self.ConvBlock5 = block(dim*8, dim*16, strides=1) 58 | 59 | self.upv6 = nn.ConvTranspose2d(dim*16, dim*8, 2, stride=2) 60 | self.ConvBlock6 = block(dim*16, dim*8, strides=1) 61 | 62 | self.upv7 = nn.ConvTranspose2d(dim*8, dim*4, 2, stride=2) 63 | self.ConvBlock7 = block(dim*8, dim*4, strides=1) 64 | 65 | self.upv8 = nn.ConvTranspose2d(dim*4, dim*2, 2, stride=2) 66 | self.ConvBlock8 = block(dim*4, dim*2, strides=1) 67 | 68 | self.upv9 = nn.ConvTranspose2d(dim*2, dim, 2, stride=2) 69 | self.ConvBlock9 = block(dim*2, dim, strides=1) 70 | 71 | self.conv10 = nn.Conv2d(dim, 3, kernel_size=3, stride=1, padding=1) 72 | 73 | def forward(self, x): 74 | conv1 = self.ConvBlock1(x) 75 | pool1 = self.pool1(conv1) 76 | 77 | conv2 = self.ConvBlock2(pool1) 78 | pool2 = self.pool2(conv2) 79 | 80 | conv3 = self.ConvBlock3(pool2) 81 | pool3 = self.pool3(conv3) 82 | 83 | conv4 = self.ConvBlock4(pool3) 84 | pool4 = self.pool4(conv4) 85 | 86 | conv5 = self.ConvBlock5(pool4) 87 | 88 | up6 = self.upv6(conv5) 89 | up6 = torch.cat([up6, conv4], 1) 90 | conv6 = self.ConvBlock6(up6) 91 | 92 | up7 = self.upv7(conv6) 93 | up7 = torch.cat([up7, conv3], 1) 94 | conv7 = self.ConvBlock7(up7) 95 | 96 | up8 = self.upv8(conv7) 97 | up8 = torch.cat([up8, conv2], 1) 98 | conv8 = self.ConvBlock8(up8) 99 | 100 | up9 = self.upv9(conv8) 101 | up9 = torch.cat([up9, conv1], 1) 102 | conv9 = self.ConvBlock9(up9) 103 | 104 | conv10 = self.conv10(conv9) 105 | out = x + conv10 106 | 107 | return out 108 | 109 | def flops(self, H, W): 110 | flops = 0 111 | flops += self.ConvBlock1.flops(H, W) 112 | flops += H/2*W/2*self.dim*self.dim*4*4 113 | flops += self.ConvBlock2.flops(H/2, W/2) 114 | flops += H/4*W/4*self.dim*2*self.dim*2*4*4 115 | flops += self.ConvBlock3.flops(H/4, W/4) 116 | flops += H/8*W/8*self.dim*4*self.dim*4*4*4 117 | flops += self.ConvBlock4.flops(H/8, W/8) 118 | flops += H/16*W/16*self.dim*8*self.dim*8*4*4 119 | 120 | flops += self.ConvBlock5.flops(H/16, W/16) 121 | 122 | flops += H/8*W/8*self.dim*16*self.dim*8*2*2 123 | flops += self.ConvBlock6.flops(H/8, W/8) 124 | flops += H/4*W/4*self.dim*8*self.dim*4*2*2 125 | flops += self.ConvBlock7.flops(H/4, W/4) 126 | flops += H/2*W/2*self.dim*4*self.dim*2*2*2 127 | flops += self.ConvBlock8.flops(H/2, W/2) 128 | flops += H*W*self.dim*2*self.dim*2*2 129 | flops += self.ConvBlock9.flops(H, W) 130 | 131 | flops += H*W*self.dim*3*3*3 132 | return flops 133 | 134 | ######################################### 135 | class PosCNN(nn.Module): 136 | def __init__(self, in_chans, embed_dim=768, s=1): 137 | super(PosCNN, self).__init__() 138 | self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim)) 139 | self.s = s 140 | 141 | def forward(self, x, H=None, W=None): 142 | B, N, C = x.shape 143 | H = H or int(math.sqrt(N)) 144 | W = W or int(math.sqrt(N)) 145 | feat_token = x 146 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) 147 | if self.s == 1: 148 | x = self.proj(cnn_feat) + cnn_feat 149 | else: 150 | x = self.proj(cnn_feat) 151 | x = x.flatten(2).transpose(1, 2) 152 | return x 153 | 154 | def no_weight_decay(self): 155 | return ['proj.%d.weight' % i for i in range(4)] 156 | 157 | class SELayer(nn.Module): 158 | def __init__(self, channel, reduction=16): 159 | super(SELayer, self).__init__() 160 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 161 | self.fc = nn.Sequential( 162 | nn.Linear(channel, channel // reduction, bias=False), 163 | nn.ReLU(inplace=True), 164 | nn.Linear(channel // reduction, channel, bias=False), 165 | nn.Sigmoid() 166 | ) 167 | 168 | def forward(self, x): # x: [B, N, C] 169 | x = torch.transpose(x, 1, 2) # [B, C, N] 170 | b, c, _ = x.size() 171 | y = self.avg_pool(x).view(b, c) 172 | y = self.fc(y).view(b, c, 1) 173 | x = x * y.expand_as(x) 174 | x = torch.transpose(x, 1, 2) # [B, N, C] 175 | return x 176 | 177 | class SepConv2d(torch.nn.Module): 178 | def __init__(self, 179 | in_channels, 180 | out_channels, 181 | kernel_size, 182 | stride=1, 183 | padding=0, 184 | dilation=1,act_layer=nn.ReLU): 185 | super(SepConv2d, self).__init__() 186 | self.depthwise = torch.nn.Conv2d(in_channels, 187 | in_channels, 188 | kernel_size=kernel_size, 189 | stride=stride, 190 | padding=padding, 191 | dilation=dilation, 192 | groups=in_channels) 193 | self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) 194 | self.act_layer = act_layer() if act_layer is not None else nn.Identity() 195 | self.in_channels = in_channels 196 | self.out_channels = out_channels 197 | self.kernel_size = kernel_size 198 | self.stride = stride 199 | 200 | def forward(self, x): 201 | x = self.depthwise(x) 202 | x = self.act_layer(x) 203 | x = self.pointwise(x) 204 | return x 205 | 206 | def flops(self, H, W): 207 | flops = 0 208 | flops += H*W*self.in_channels*self.kernel_size**2/self.stride**2 209 | flops += H*W*self.in_channels*self.out_channels 210 | return flops 211 | 212 | ########################################################################## 213 | ## Channel Attention Layer 214 | class CALayer(nn.Module): 215 | def __init__(self, channel, reduction=16, bias=False): 216 | super(CALayer, self).__init__() 217 | # global average pooling: feature --> point 218 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 219 | # feature channel downscale and upscale --> channel weight 220 | self.conv_du = nn.Sequential( 221 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 222 | nn.ReLU(inplace=True), 223 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 224 | nn.Sigmoid() 225 | ) 226 | 227 | def forward(self, x): 228 | y = self.avg_pool(x) 229 | y = self.conv_du(y) 230 | return x * y 231 | 232 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 233 | return nn.Conv2d( 234 | in_channels, out_channels, kernel_size, 235 | padding=(kernel_size//2), bias=bias, stride = stride) 236 | 237 | ########################################################################## 238 | ## Channel Attention Block (CAB) 239 | class CAB(nn.Module): 240 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 241 | super(CAB, self).__init__() 242 | modules_body = [] 243 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 244 | modules_body.append(act) 245 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 246 | 247 | self.CA = CALayer(n_feat, reduction, bias=bias) 248 | self.body = nn.Sequential(*modules_body) 249 | 250 | def forward(self, x): 251 | res = self.body(x) 252 | res = self.CA(res) 253 | res += x 254 | return res 255 | 256 | ######################################### 257 | ######## Embedding for q,k,v ######## 258 | class ConvProjection(nn.Module): 259 | def __init__(self, dim, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0., 260 | last_stage=False,bias=True): 261 | 262 | super().__init__() 263 | 264 | inner_dim = dim_head * heads 265 | self.heads = heads 266 | pad = (kernel_size - q_stride)//2 267 | self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad, bias) 268 | self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad, bias) 269 | self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad, bias) 270 | 271 | def forward(self, x, attn_kv=None): 272 | b, n, c, h = *x.shape, self.heads 273 | l = int(math.sqrt(n)) 274 | w = int(math.sqrt(n)) 275 | 276 | attn_kv = x if attn_kv is None else attn_kv 277 | x = rearrange(x, 'b (l w) c -> b c l w', l=l, w=w) 278 | attn_kv = rearrange(attn_kv, 'b (l w) c -> b c l w', l=l, w=w) 279 | # print(attn_kv) 280 | q = self.to_q(x) 281 | q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h) 282 | 283 | k = self.to_k(attn_kv) 284 | v = self.to_v(attn_kv) 285 | k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h) 286 | v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h) 287 | return q,k,v 288 | 289 | def flops(self, H, W): 290 | flops = 0 291 | flops += self.to_q.flops(H, W) 292 | flops += self.to_k.flops(H, W) 293 | flops += self.to_v.flops(H, W) 294 | return flops 295 | 296 | class LinearProjection(nn.Module): 297 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True): 298 | super().__init__() 299 | inner_dim = dim_head * heads 300 | self.heads = heads 301 | self.to_q = nn.Linear(dim, inner_dim, bias = bias) 302 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias) 303 | self.dim = dim 304 | self.inner_dim = inner_dim 305 | 306 | def forward(self, x, attn_kv=None): 307 | B_, N, C = x.shape 308 | attn_kv = x if attn_kv is None else attn_kv 309 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 310 | kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 311 | q = q[0] 312 | k, v = kv[0], kv[1] 313 | return q,k,v 314 | 315 | def flops(self, H, W): 316 | flops = H*W*self.dim*self.inner_dim*3 317 | return flops 318 | 319 | class LinearProjection_Concat_kv(nn.Module): 320 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True): 321 | super().__init__() 322 | inner_dim = dim_head * heads 323 | self.heads = heads 324 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = bias) 325 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias) 326 | self.dim = dim 327 | self.inner_dim = inner_dim 328 | 329 | def forward(self, x, attn_kv=None): 330 | B_, N, C = x.shape 331 | attn_kv = x if attn_kv is None else attn_kv 332 | qkv_dec = self.to_qkv(x).reshape(B_, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 333 | kv_enc = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 334 | q, k_d, v_d = qkv_dec[0], qkv_dec[1], qkv_dec[2] # make torchscript happy (cannot use tensor as tuple) 335 | k_e, v_e = kv_enc[0], kv_enc[1] 336 | k = torch.cat((k_d,k_e),dim=2) 337 | v = torch.cat((v_d,v_e),dim=2) 338 | return q,k,v 339 | 340 | def flops(self, H, W): 341 | flops = H*W*self.dim*self.inner_dim*5 342 | return flops 343 | 344 | ######################################### 345 | 346 | ########### SIA ############# 347 | class WindowAttention(nn.Module): 348 | def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., 349 | proj_drop=0., se_layer=False): 350 | 351 | super().__init__() 352 | self.dim = dim 353 | self.win_size = win_size # Wh, Ww 354 | self.num_heads = num_heads 355 | head_dim = dim // num_heads 356 | self.scale = qk_scale or head_dim ** -0.5 357 | 358 | # define a parameter table of relative position bias 359 | self.relative_position_bias_table = nn.Parameter( 360 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 361 | 362 | # get pair-wise relative position index for each token inside the window 363 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 364 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 365 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 366 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 367 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 368 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 369 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 370 | relative_coords[:, :, 1] += self.win_size[1] - 1 371 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 372 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 373 | self.register_buffer("relative_position_index", relative_position_index) 374 | 375 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 376 | if token_projection == 'conv': 377 | self.qkv = ConvProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) 378 | elif token_projection == 'linear_concat': 379 | self.qkv = LinearProjection_Concat_kv(dim, num_heads, dim // num_heads, bias=qkv_bias) 380 | else: 381 | self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) 382 | 383 | self.token_projection = token_projection 384 | self.attn_drop = nn.Dropout(attn_drop) 385 | self.proj = nn.Linear(dim, dim) 386 | # self.se_layer = SELayer(dim) 387 | self.ll = nn.Identity() 388 | self.proj_drop = nn.Dropout(proj_drop) 389 | 390 | trunc_normal_(self.relative_position_bias_table, std=.02) 391 | self.softmax = nn.Softmax(dim=-1) 392 | 393 | def forward(self, x, xm, attn_kv=None, mask=None): 394 | B_, N, C = x.shape 395 | # x = self.se_layer(x) 396 | one = torch.ones_like(xm) 397 | zero = torch.zeros_like(xm) 398 | xm = torch.where(xm < 0.1, one, one*2) 399 | mm = xm @ xm.transpose(-2, -1) 400 | one = torch.ones_like(mm) 401 | mm = torch.where(mm==2, one, one*0.2) 402 | mm = torch.unsqueeze(mm, dim=1) 403 | q, k, v = self.qkv(x, attn_kv) 404 | q = q * self.scale 405 | attn = (q @ k.transpose(-2, -1)) * mm 406 | 407 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 408 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 409 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 410 | ratio = attn.size(-1) // relative_position_bias.size(-1) 411 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio) 412 | 413 | attn = attn + relative_position_bias.unsqueeze(0) 414 | 415 | if mask is not None: 416 | nW = mask.shape[0] 417 | mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio) 418 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) 419 | attn = attn.view(-1, self.num_heads, N, N * ratio) 420 | attn = self.softmax(attn) 421 | else: 422 | attn = self.softmax(attn) 423 | 424 | attn = self.attn_drop(attn) 425 | 426 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 427 | x = self.proj(x) 428 | x = self.ll(x) 429 | x = self.proj_drop(x) 430 | return x 431 | 432 | def extra_repr(self) -> str: 433 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 434 | 435 | def flops(self, H, W): 436 | # calculate flops for 1 window with token length of N 437 | # print(N, self.dim) 438 | flops = 0 439 | N = self.win_size[0] * self.win_size[1] 440 | nW = H * W / N 441 | # qkv = self.qkv(x) 442 | # flops += N * self.dim * 3 * self.dim 443 | flops += self.qkv.flops(H, W) 444 | # attn = (q @ k.transpose(-2, -1)) 445 | if self.token_projection != 'linear_concat': 446 | flops += nW * self.num_heads * N * (self.dim // self.num_heads) * N 447 | # x = (attn @ v) 448 | flops += nW * self.num_heads * N * N * (self.dim // self.num_heads) 449 | else: 450 | flops += nW * self.num_heads * N * (self.dim // self.num_heads) * N * 2 451 | # x = (attn @ v) 452 | flops += nW * self.num_heads * N * N * 2 * (self.dim // self.num_heads) 453 | # x = self.proj(x) 454 | flops += nW * N * self.dim * self.dim 455 | print("W-MSA:{%.2f}" % (flops / 1e9)) 456 | return flops 457 | 458 | 459 | ######################################### 460 | ########### feed-forward network ############# 461 | class Mlp(nn.Module): 462 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 463 | super().__init__() 464 | out_features = out_features or in_features 465 | hidden_features = hidden_features or in_features 466 | self.fc1 = nn.Linear(in_features, hidden_features) 467 | self.act = act_layer() 468 | self.fc2 = nn.Linear(hidden_features, out_features) 469 | self.drop = nn.Dropout(drop) 470 | self.in_features = in_features 471 | self.hidden_features = hidden_features 472 | self.out_features = out_features 473 | 474 | def forward(self, x): 475 | x = self.fc1(x) 476 | x = self.act(x) 477 | x = self.drop(x) 478 | x = self.fc2(x) 479 | x = self.drop(x) 480 | return x 481 | 482 | def flops(self, H, W): 483 | flops = 0 484 | # fc1 485 | flops += H*W*self.in_features*self.hidden_features 486 | # fc2 487 | flops += H*W*self.hidden_features*self.out_features 488 | print("MLP:{%.2f}"%(flops/1e9)) 489 | return flops 490 | 491 | class LeFF(nn.Module): 492 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0.): 493 | super().__init__() 494 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), 495 | act_layer()) 496 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1), 497 | act_layer()) 498 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) 499 | self.dim = dim 500 | self.hidden_dim = hidden_dim 501 | 502 | def forward(self, x, img_size=(128,128)): 503 | # bs x hw x c 504 | bs, hw, c = x.size() 505 | # hh = int(math.sqrt(hw)) 506 | hh = img_size[0] 507 | ww = img_size[1] 508 | 509 | x = self.linear1(x) 510 | 511 | # spatial restore 512 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = ww) 513 | # bs,hidden_dim,32x32 514 | 515 | x = self.dwconv(x) 516 | 517 | # flaten 518 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = ww) 519 | 520 | x = self.linear2(x) 521 | 522 | return x 523 | 524 | def flops(self, H, W): 525 | flops = 0 526 | # fc1 527 | flops += H*W*self.dim*self.hidden_dim 528 | # dwconv 529 | flops += H*W*self.hidden_dim*3*3 530 | # fc2 531 | flops += H*W*self.hidden_dim*self.dim 532 | print("LeFF:{%.2f}"%(flops/1e9)) 533 | return flops 534 | 535 | ######################################### 536 | ########### window operation############# 537 | def window_partition(x, win_size, dilation_rate=1): 538 | B, H, W, C = x.shape 539 | if dilation_rate !=1: 540 | x = x.permute(0,3,1,2) # B, C, H, W 541 | assert type(dilation_rate) is int, 'dilation_rate should be a int' 542 | x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww 543 | windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww 544 | windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C 545 | else: 546 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) 547 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C 548 | return windows 549 | 550 | def window_reverse(windows, win_size, H, W, dilation_rate=1): 551 | # B' ,Wh ,Ww ,C 552 | B = int(windows.shape[0] / (H * W / win_size / win_size)) 553 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) 554 | if dilation_rate !=1: 555 | x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww 556 | x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size) 557 | else: 558 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 559 | return x 560 | 561 | ######################################### 562 | # Downsample Block 563 | class Downsample(nn.Module): 564 | def __init__(self, in_channel, out_channel): 565 | super(Downsample, self).__init__() 566 | self.conv = nn.Sequential( 567 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1), 568 | ) 569 | self.in_channel = in_channel 570 | self.out_channel = out_channel 571 | 572 | def forward(self, x, img_size=(128,128)): 573 | B, L, C = x.shape 574 | H = img_size[0] 575 | W = img_size[1] 576 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 577 | out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 578 | return out 579 | 580 | def flops(self, H, W): 581 | flops = 0 582 | # conv 583 | flops += H/2*W/2*self.in_channel*self.out_channel*4*4 584 | print("Downsample:{%.2f}"%(flops/1e9)) 585 | return flops 586 | 587 | # Upsample Block 588 | class Upsample(nn.Module): 589 | def __init__(self, in_channel, out_channel): 590 | super(Upsample, self).__init__() 591 | self.deconv = nn.Sequential( 592 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2), 593 | ) 594 | self.in_channel = in_channel 595 | self.out_channel = out_channel 596 | 597 | def forward(self, x, img_size=(128,128)): 598 | B, L, C = x.shape 599 | H = img_size[0] 600 | W = img_size[1] 601 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 602 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 603 | return out 604 | 605 | def flops(self, H, W): 606 | flops = 0 607 | # conv 608 | flops += H*2*W*2*self.in_channel*self.out_channel*2*2 609 | print("Upsample:{%.2f}"%(flops/1e9)) 610 | return flops 611 | 612 | # Input Projection 613 | class InputProj(nn.Module): 614 | def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None,act_layer=nn.LeakyReLU): 615 | super().__init__() 616 | self.proj = nn.Sequential( 617 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 618 | act_layer(inplace=True) 619 | ) 620 | if norm_layer is not None: 621 | self.norm = norm_layer(out_channel) 622 | else: 623 | self.norm = None 624 | self.in_channel = in_channel 625 | self.out_channel = out_channel 626 | 627 | def forward(self, x): 628 | B, C, H, W = x.shape 629 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C 630 | if self.norm is not None: 631 | x = self.norm(x) 632 | return x 633 | 634 | def flops(self, H, W): 635 | flops = 0 636 | # conv 637 | flops += H*W*self.in_channel*self.out_channel*3*3 638 | 639 | if self.norm is not None: 640 | flops += H*W*self.out_channel 641 | print("Input_proj:{%.2f}"%(flops/1e9)) 642 | return flops 643 | 644 | # Output Projection 645 | class OutputProj(nn.Module): 646 | def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None,act_layer=None): 647 | super().__init__() 648 | self.proj = nn.Sequential( 649 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 650 | ) 651 | if act_layer is not None: 652 | self.proj.add_module(act_layer(inplace=True)) 653 | if norm_layer is not None: 654 | self.norm = norm_layer(out_channel) 655 | else: 656 | self.norm = None 657 | self.in_channel = in_channel 658 | self.out_channel = out_channel 659 | 660 | def forward(self, x, img_size=(128,128)): 661 | B, L, C = x.shape 662 | H = img_size[0] 663 | W = img_size[1] 664 | # H = int(math.sqrt(L)) 665 | # W = int(math.sqrt(L)) 666 | x = x.transpose(1, 2).view(B, C, H, W) 667 | x = self.proj(x) 668 | if self.norm is not None: 669 | x = self.norm(x) 670 | return x 671 | 672 | def flops(self, H, W): 673 | flops = 0 674 | # conv 675 | flops += H*W*self.in_channel*self.out_channel*3*3 676 | 677 | if self.norm is not None: 678 | flops += H*W*self.out_channel 679 | print("Output_proj:{%.2f}"%(flops/1e9)) 680 | return flops 681 | 682 | 683 | ######################################### 684 | ########### CA Transformer ############# 685 | class CATransformerBlock(nn.Module): 686 | def __init__(self, dim, input_resolution, num_heads, win_size=10, shift_size=0, 687 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 688 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, token_projection='linear', token_mlp='leff', 689 | se_layer=False): 690 | super().__init__() 691 | self.dim = dim 692 | self.input_resolution = input_resolution 693 | self.num_heads = num_heads 694 | self.win_size = win_size 695 | self.shift_size = shift_size 696 | self.mlp_ratio = mlp_ratio 697 | self.token_mlp = token_mlp 698 | if min(self.input_resolution) <= self.win_size: 699 | self.shift_size = 0 700 | self.win_size = min(self.input_resolution) 701 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" 702 | 703 | self.norm1 = norm_layer(dim) 704 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 705 | self.norm2 = norm_layer(dim) 706 | mlp_hidden_dim = int(dim * mlp_ratio) 707 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, 708 | drop=drop) if token_mlp == 'ffn' else LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop) 709 | self.CAB = CAB(dim, kernel_size=3, reduction=4, bias=False, act=nn.PReLU()) 710 | 711 | 712 | def extra_repr(self) -> str: 713 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 714 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 715 | 716 | def forward(self, x, xm, mask=None, img_size=(128, 128)): 717 | B, L, C = x.shape 718 | H = img_size[0] 719 | W = img_size[1] 720 | assert L == W * H, \ 721 | f"Input image size ({H}*{W} doesn't match model ({L})." 722 | 723 | shortcut = x 724 | x = self.norm1(x) 725 | 726 | # spatial restore 727 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h=H, w=W) 728 | # bs,hidden_dim,32x32 729 | 730 | x = self.CAB(x) 731 | 732 | # flaten 733 | x = rearrange(x, ' b c h w -> b (h w) c', h=H, w=W) 734 | x = x.view(B, H * W, C) 735 | 736 | # FFN 737 | x = shortcut + self.drop_path(x) 738 | x = x + self.drop_path(self.mlp(self.norm2(x), img_size=img_size)) 739 | 740 | return x 741 | 742 | def flops(self): 743 | flops = 0 744 | H, W = self.input_resolution 745 | # norm1 746 | flops += self.dim * H * W 747 | # W-MSA/SW-MSA 748 | flops += self.attn.flops(H, W) 749 | # norm2 750 | flops += self.dim * H * W 751 | # mlp 752 | flops += self.mlp.flops(H, W) 753 | print("LeWin:{%.2f}" % (flops / 1e9)) 754 | return flops 755 | 756 | ######################################### 757 | ########### SIM Transformer ############# 758 | class SIMTransformerBlock(nn.Module): 759 | def __init__(self, dim, input_resolution, num_heads, win_size=10, shift_size=0, 760 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 761 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='linear',token_mlp='leff',se_layer=False): 762 | super().__init__() 763 | self.dim = dim 764 | self.input_resolution = input_resolution 765 | self.num_heads = num_heads 766 | self.win_size = win_size 767 | self.shift_size = shift_size 768 | self.mlp_ratio = mlp_ratio 769 | self.token_mlp = token_mlp 770 | if min(self.input_resolution) <= self.win_size: 771 | self.shift_size = 0 772 | self.win_size = min(self.input_resolution) 773 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" 774 | 775 | self.norm1 = norm_layer(dim) 776 | self.attn = WindowAttention( 777 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, 778 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 779 | token_projection=token_projection,se_layer=se_layer) 780 | 781 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 782 | self.norm2 = norm_layer(dim) 783 | mlp_hidden_dim = int(dim * mlp_ratio) 784 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) if token_mlp=='ffn' else LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop) 785 | self.CAB = CAB(dim, kernel_size=3, reduction=4, bias=False, act=nn.PReLU()) 786 | 787 | 788 | def extra_repr(self) -> str: 789 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 790 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 791 | 792 | def forward(self, x, xm, mask=None, img_size = (128, 128)): 793 | B, L, C = x.shape 794 | H = img_size[0] 795 | W = img_size[1] 796 | assert L == W * H, \ 797 | f"Input image size ({H}*{W} doesn't match model ({L})." 798 | 799 | ## input mask 800 | if mask != None: 801 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1) 802 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 803 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 804 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size 805 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 806 | else: 807 | attn_mask = None 808 | 809 | ## shift mask 810 | if self.shift_size > 0: 811 | # calculate attention mask for SW-MSA 812 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x) 813 | h_slices = (slice(0, -self.win_size), 814 | slice(-self.win_size, -self.shift_size), 815 | slice(-self.shift_size, None)) 816 | w_slices = (slice(0, -self.win_size), 817 | slice(-self.win_size, -self.shift_size), 818 | slice(-self.shift_size, None)) 819 | cnt = 0 820 | for h in h_slices: 821 | for w in w_slices: 822 | shift_mask[:, h, w, :] = cnt 823 | cnt += 1 824 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 825 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 826 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size 827 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0)) 828 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask 829 | 830 | shortcut = x 831 | x = self.norm1(x) 832 | 833 | 834 | 835 | x = x.view(B, H, W, C) 836 | xm = xm.permute(0, 2, 3, 1) 837 | # cyclic shift 838 | if self.shift_size > 0: 839 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 840 | shifted_m = torch.roll(xm, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 841 | else: 842 | shifted_x = x 843 | shifted_m = xm 844 | 845 | # partition windows 846 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C 847 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C 848 | m_windows = window_partition(shifted_m, self.win_size) # nW*B, win_size, win_size, C N*C->C 849 | m_windows = m_windows.view(-1, self.win_size * self.win_size, 1) # nW*B, win_size*win_size, C 850 | 851 | # W-MSA/SW-MSA 852 | attn_windows = self.attn(x_windows, m_windows, mask=attn_mask) # nW*B, win_size*win_size, C 853 | 854 | # merge windows 855 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) 856 | 857 | 858 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C 859 | 860 | # reverse cyclic shift 861 | if self.shift_size > 0: 862 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 863 | else: 864 | x = shifted_x 865 | x = x.view(B, H * W, C) 866 | 867 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h=H, w=W) 868 | # bs,hidden_dim,32x32 869 | 870 | x = self.CAB(x) 871 | 872 | # flaten 873 | x = rearrange(x, ' b c h w -> b (h w) c', h=H, w=W) 874 | 875 | # FFN 876 | x = shortcut + self.drop_path(x) 877 | x = x + self.drop_path(self.mlp(self.norm2(x), img_size=img_size)) 878 | 879 | del attn_mask 880 | return x 881 | 882 | def flops(self): 883 | flops = 0 884 | H, W = self.input_resolution 885 | # norm1 886 | flops += self.dim * H * W 887 | # W-MSA/SW-MSA 888 | flops += self.attn.flops(H, W) 889 | # norm2 890 | flops += self.dim * H * W 891 | # mlp 892 | flops += self.mlp.flops(H,W) 893 | print("LeWin:{%.2f}"%(flops/1e9)) 894 | return flops 895 | 896 | 897 | ######################################### 898 | ########### Basic layer of ShadowFormer ################ 899 | class BasicShadowFormer(nn.Module): 900 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size, 901 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 902 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 903 | token_projection='linear',token_mlp='ffn',se_layer=False,cab=False): 904 | 905 | super().__init__() 906 | self.dim = dim 907 | self.input_resolution = input_resolution 908 | self.depth = depth 909 | self.use_checkpoint = use_checkpoint 910 | # build blocks 911 | if cab: 912 | self.blocks = nn.ModuleList([ 913 | CATransformerBlock(dim=dim, input_resolution=input_resolution, 914 | num_heads=num_heads, win_size=win_size, 915 | shift_size=0 if (i % 2 == 0) else win_size // 2, 916 | mlp_ratio=mlp_ratio, 917 | qkv_bias=qkv_bias, qk_scale=qk_scale, 918 | drop=drop, attn_drop=attn_drop, 919 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 920 | norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp, 921 | se_layer=se_layer) 922 | for i in range(depth)]) 923 | else: 924 | self.blocks = nn.ModuleList([ 925 | SIMTransformerBlock(dim=dim, input_resolution=input_resolution, 926 | num_heads=num_heads, win_size=win_size, 927 | shift_size=0 if (i % 2 == 0) else win_size // 2, 928 | mlp_ratio=mlp_ratio, 929 | qkv_bias=qkv_bias, qk_scale=qk_scale, 930 | drop=drop, attn_drop=attn_drop, 931 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 932 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 933 | for i in range(depth)]) 934 | 935 | def extra_repr(self) -> str: 936 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 937 | 938 | def forward(self, x, xm, mask=None, img_size=(128,128)): 939 | for blk in self.blocks: 940 | if self.use_checkpoint: 941 | x = checkpoint.checkpoint(blk, x) 942 | else: 943 | x = blk(x, xm, mask, img_size) 944 | return x 945 | 946 | def flops(self): 947 | flops = 0 948 | for blk in self.blocks: 949 | flops += blk.flops() 950 | return flops 951 | 952 | 953 | class ShadowFormer(nn.Module): 954 | def __init__(self, img_size=256, in_chans=3, 955 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], 956 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, 957 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 958 | norm_layer=nn.LayerNorm, patch_norm=True, 959 | use_checkpoint=False, token_projection='linear', token_mlp='leff', se_layer=True, 960 | dowsample=Downsample, upsample=Upsample, **kwargs): 961 | super().__init__() 962 | 963 | self.num_enc_layers = len(depths)//2 964 | self.num_dec_layers = len(depths)//2 965 | self.embed_dim = embed_dim 966 | self.patch_norm = patch_norm 967 | self.mlp_ratio = mlp_ratio 968 | self.token_projection = token_projection 969 | self.mlp = token_mlp 970 | self.win_size =win_size 971 | self.reso = img_size 972 | self.pos_drop = nn.Dropout(p=drop_rate) 973 | 974 | # stochastic depth 975 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))] 976 | conv_dpr = [drop_path_rate]*depths[4] 977 | dec_dpr = enc_dpr[::-1] 978 | 979 | # build layers 980 | 981 | # Input/Output 982 | self.input_proj = InputProj(in_channel=4, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU) 983 | self.output_proj = OutputProj(in_channel=2*embed_dim, out_channel=in_chans, kernel_size=3, stride=1) 984 | # self.CAB = CAB(embed_dim, kernel_size=3, reduction=4, bias=False, act=nn.PReLU()) 985 | 986 | # Encoder 987 | self.encoderlayer_0 = BasicShadowFormer(dim=embed_dim, 988 | output_dim=embed_dim, 989 | input_resolution=(img_size, 990 | img_size), 991 | depth=depths[0], 992 | num_heads=num_heads[0], 993 | win_size=win_size, 994 | mlp_ratio=self.mlp_ratio, 995 | qkv_bias=qkv_bias, qk_scale=qk_scale, 996 | drop=drop_rate, attn_drop=attn_drop_rate, 997 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])], 998 | norm_layer=norm_layer, 999 | use_checkpoint=use_checkpoint, 1000 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer,cab=True) 1001 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2) 1002 | self.encoderlayer_1 = BasicShadowFormer(dim=embed_dim*2, 1003 | output_dim=embed_dim*2, 1004 | input_resolution=(img_size // 2, 1005 | img_size // 2), 1006 | depth=depths[1], 1007 | num_heads=num_heads[1], 1008 | win_size=win_size, 1009 | mlp_ratio=self.mlp_ratio, 1010 | qkv_bias=qkv_bias, qk_scale=qk_scale, 1011 | drop=drop_rate, attn_drop=attn_drop_rate, 1012 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 1013 | norm_layer=norm_layer, 1014 | use_checkpoint=use_checkpoint, 1015 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer, cab=True) 1016 | self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4) 1017 | self.encoderlayer_2 = BasicShadowFormer(dim=embed_dim*4, 1018 | output_dim=embed_dim*4, 1019 | input_resolution=(img_size // (2 ** 2), 1020 | img_size // (2 ** 2)), 1021 | depth=depths[2], 1022 | num_heads=num_heads[2], 1023 | win_size=win_size, 1024 | mlp_ratio=self.mlp_ratio, 1025 | qkv_bias=qkv_bias, qk_scale=qk_scale, 1026 | drop=drop_rate, attn_drop=attn_drop_rate, 1027 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 1028 | norm_layer=norm_layer, 1029 | use_checkpoint=use_checkpoint, 1030 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 1031 | self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8) 1032 | 1033 | # Bottleneck 1034 | self.conv = BasicShadowFormer(dim=embed_dim*8, 1035 | output_dim=embed_dim*8, 1036 | input_resolution=(img_size // (2 ** 3), 1037 | img_size // (2 ** 3)), 1038 | depth=depths[4], 1039 | num_heads=num_heads[4], 1040 | win_size=win_size, 1041 | mlp_ratio=self.mlp_ratio, 1042 | qkv_bias=qkv_bias, qk_scale=qk_scale, 1043 | drop=drop_rate, attn_drop=attn_drop_rate, 1044 | drop_path=conv_dpr, 1045 | norm_layer=norm_layer, 1046 | use_checkpoint=use_checkpoint, 1047 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 1048 | 1049 | # # Decoder 1050 | self.upsample_0 = upsample(embed_dim*8, embed_dim*4) 1051 | self.decoderlayer_0 = BasicShadowFormer(dim=embed_dim*8, 1052 | output_dim=embed_dim*8, 1053 | input_resolution=(img_size // (2 ** 2), 1054 | img_size // (2 ** 2)), 1055 | depth=depths[6], 1056 | num_heads=num_heads[6], 1057 | win_size=win_size, 1058 | mlp_ratio=self.mlp_ratio, 1059 | qkv_bias=qkv_bias, qk_scale=qk_scale, 1060 | drop=drop_rate, attn_drop=attn_drop_rate, 1061 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], 1062 | norm_layer=norm_layer, 1063 | use_checkpoint=use_checkpoint, 1064 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer) 1065 | self.upsample_1 = upsample(embed_dim*8, embed_dim*2) 1066 | self.decoderlayer_1 = BasicShadowFormer(dim=embed_dim*4, 1067 | output_dim=embed_dim*4, 1068 | input_resolution=(img_size // 2, 1069 | img_size // 2), 1070 | depth=depths[7], 1071 | num_heads=num_heads[7], 1072 | win_size=win_size, 1073 | mlp_ratio=self.mlp_ratio, 1074 | qkv_bias=qkv_bias, qk_scale=qk_scale, 1075 | drop=drop_rate, attn_drop=attn_drop_rate, 1076 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], 1077 | norm_layer=norm_layer, 1078 | use_checkpoint=use_checkpoint, 1079 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer, cab=True) 1080 | self.upsample_2 = upsample(embed_dim*4, embed_dim) 1081 | self.decoderlayer_2 = BasicShadowFormer(dim=embed_dim*2, 1082 | output_dim=embed_dim*2, 1083 | input_resolution=(img_size, 1084 | img_size), 1085 | depth=depths[8], 1086 | num_heads=num_heads[8], 1087 | win_size=win_size, 1088 | mlp_ratio=self.mlp_ratio, 1089 | qkv_bias=qkv_bias, qk_scale=qk_scale, 1090 | drop=drop_rate, attn_drop=attn_drop_rate, 1091 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], 1092 | norm_layer=norm_layer, 1093 | use_checkpoint=use_checkpoint, 1094 | token_projection=token_projection,token_mlp=token_mlp,se_layer=se_layer,cab=True) 1095 | self.apply(self._init_weights) 1096 | 1097 | def _init_weights(self, m): 1098 | if isinstance(m, nn.Linear): 1099 | trunc_normal_(m.weight, std=.02) 1100 | if isinstance(m, nn.Linear) and m.bias is not None: 1101 | nn.init.constant_(m.bias, 0) 1102 | elif isinstance(m, nn.LayerNorm): 1103 | nn.init.constant_(m.bias, 0) 1104 | nn.init.constant_(m.weight, 1.0) 1105 | 1106 | @torch.jit.ignore 1107 | def no_weight_decay(self): 1108 | return {'absolute_pos_embed'} 1109 | 1110 | @torch.jit.ignore 1111 | def no_weight_decay_keywords(self): 1112 | return {'relative_position_bias_table'} 1113 | 1114 | def extra_repr(self) -> str: 1115 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}" 1116 | 1117 | def forward(self, x, xm, mask=None): 1118 | # Input Projection 1119 | xi = torch.cat((x, xm), dim=1) 1120 | self.img_size = (x.shape[2], x.shape[3]) 1121 | y = self.input_proj(xi) 1122 | y = self.pos_drop(y) 1123 | 1124 | #Encoder 1125 | conv0 = self.encoderlayer_0(y, xm, mask=mask, img_size = self.img_size) 1126 | pool0 = self.dowsample_0(conv0, img_size = self.img_size) 1127 | m = nn.MaxPool2d(2) 1128 | xm1 = m(xm) 1129 | self.img_size = (int(self.img_size[0]/2), int(self.img_size[1]/2)) 1130 | conv1 = self.encoderlayer_1(pool0, xm1, mask=mask, img_size = self.img_size) 1131 | pool1 = self.dowsample_1(conv1, img_size = self.img_size) 1132 | m = nn.MaxPool2d(2) 1133 | xm2 = m(xm1) 1134 | self.img_size = (int(self.img_size[0] / 2), int(self.img_size[1] / 2)) 1135 | conv2 = self.encoderlayer_2(pool1, xm2, mask=mask, img_size = self.img_size) 1136 | pool2 = self.dowsample_2(conv2, img_size = self.img_size) 1137 | self.img_size = (int(self.img_size[0] / 2), int(self.img_size[1] / 2)) 1138 | m = nn.MaxPool2d(2) 1139 | xm3 = m(xm2) 1140 | 1141 | # Bottleneck 1142 | conv3 = self.conv(pool2, xm3, mask=mask, img_size = self.img_size) 1143 | 1144 | #Decoder 1145 | up0 = self.upsample_0(conv3, img_size = self.img_size) 1146 | self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2)) 1147 | deconv0 = torch.cat([up0,conv2],-1) 1148 | deconv0 = self.decoderlayer_0(deconv0, xm2, mask=mask, img_size = self.img_size) 1149 | 1150 | up1 = self.upsample_1(deconv0, img_size = self.img_size) 1151 | self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2)) 1152 | deconv1 = torch.cat([up1,conv1],-1) 1153 | deconv1 = self.decoderlayer_1(deconv1, xm1, mask=mask, img_size = self.img_size) 1154 | 1155 | up2 = self.upsample_2(deconv1, img_size = self.img_size) 1156 | self.img_size = (int(self.img_size[0] * 2), int(self.img_size[1] * 2)) 1157 | deconv2 = torch.cat([up2,conv0],-1) 1158 | deconv2 = self.decoderlayer_2(deconv2, xm, mask=mask, img_size = self.img_size) 1159 | 1160 | # Output Projection 1161 | y = self.output_proj(deconv2, img_size = self.img_size) + x 1162 | return y 1163 | --------------------------------------------------------------------------------