├── image ├── 1.png ├── 2.png └── table.png ├── loss ├── CL1.py └── perceptual.py ├── metrics.py ├── README.md ├── test.py ├── condconv.py ├── dataloader.py ├── base_net_snow.py └── SnowFormer.py /image/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ephemeral182/SnowFormer/HEAD/image/1.png -------------------------------------------------------------------------------- /image/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ephemeral182/SnowFormer/HEAD/image/2.png -------------------------------------------------------------------------------- /image/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ephemeral182/SnowFormer/HEAD/image/table.png -------------------------------------------------------------------------------- /loss/CL1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | class L1_Charbonnier_loss(torch.nn.Module): 4 | """L1 Charbonnierloss.""" 5 | def __init__(self): 6 | super(L1_Charbonnier_loss, self).__init__() 7 | self.eps = 1e-6 8 | 9 | def forward(self, X, Y): 10 | diff = torch.add(X, -Y) 11 | error = torch.sqrt(diff * diff + self.eps) 12 | loss = torch.mean(error) 13 | return loss 14 | 15 | class PSNRLoss(torch.nn.Module): 16 | 17 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 18 | super(PSNRLoss, self).__init__() 19 | assert reduction == 'mean' 20 | self.loss_weight = loss_weight 21 | self.scale = 10 / np.log(10) 22 | self.toY = toY 23 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 24 | self.first = True 25 | 26 | def forward(self, pred, target): 27 | assert len(pred.size()) == 4 28 | if self.toY: 29 | if self.first: 30 | self.coef = self.coef.to(pred.device) 31 | self.first = False 32 | 33 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 34 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 35 | 36 | pred, target = pred / 255., target / 255. 37 | pass 38 | assert len(pred.size()) == 4 39 | 40 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() -------------------------------------------------------------------------------- /loss/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | from torchvision.models.vgg import vgg19,vgg16 7 | import torch.nn as nn 8 | 9 | 10 | 11 | 12 | class PerceptualLoss(nn.Module): 13 | def __init__(self): 14 | super(PerceptualLoss,self).__init__() 15 | self.L1 = nn.L1Loss() 16 | self.mse = nn.MSELoss() 17 | vgg = vgg19(pretrained=True).eval() 18 | self.loss_net1 = nn.Sequential(*list(vgg.features)[:1]).eval() 19 | self.loss_net3 = nn.Sequential(*list(vgg.features)[:3]).eval() 20 | self.loss_net5 = nn.Sequential(*list(vgg.features)[:5]).eval() 21 | self.loss_net9 = nn.Sequential(*list(vgg.features)[:9]).eval() 22 | self.loss_net13 = nn.Sequential(*list(vgg.features)[:13]).eval() 23 | def forward(self,x,y): 24 | loss1 = self.L1(self.loss_net1(x),self.loss_net1(y)) 25 | loss3 = self.L1(self.loss_net3(x),self.loss_net3(y)) 26 | loss5 = self.L1(self.loss_net5(x),self.loss_net5(y)) 27 | loss9 = self.L1(self.loss_net9(x),self.loss_net9(y)) 28 | loss13 = self.L1(self.loss_net13(x),self.loss_net13(y)) 29 | #print(self.loss_net13(x).shape) 30 | loss = 0.2*loss1 + 0.2*loss3 + 0.2*loss5 + 0.2*loss9 + 0.2*loss13 31 | return loss 32 | 33 | 34 | 35 | class PerceptualLoss2(nn.Module): 36 | def __init__(self): 37 | super(PerceptualLoss2,self).__init__() 38 | self.L1 = nn.L1Loss() 39 | self.mse = nn.MSELoss() 40 | vgg = vgg19(pretrained=True).eval() 41 | self.loss_net1 = nn.Sequential(*list(vgg.features)[:1]).eval() 42 | self.loss_net3 = nn.Sequential(*list(vgg.features)[:3]).eval() 43 | def forward(self,x,y): 44 | loss1 = self.L1(self.loss_net1(x),self.loss_net1(y)) 45 | loss3 = self.L1(self.loss_net3(x),self.loss_net3(y)) 46 | loss = 0.5*loss1+0.5*loss3 47 | return loss -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | from math import exp 10 | import math 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | from torchvision.transforms import ToPILImage 17 | 18 | def gaussian(window_size, sigma): 19 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 20 | return gauss / gauss.sum() 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 26 | return window 27 | 28 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 29 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 30 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 31 | mu1_sq = mu1.pow(2) 32 | mu2_sq = mu2.pow(2) 33 | mu1_mu2 = mu1 * mu2 34 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 35 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 36 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 37 | C1 = 0.01 ** 2 38 | C2 = 0.03 ** 2 39 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 40 | 41 | if size_average: 42 | return ssim_map.mean() 43 | else: 44 | return ssim_map.mean(1).mean(1).mean(1) 45 | 46 | 47 | def SSIM(img1, img2, window_size=11, size_average=True): 48 | img1=torch.clamp(img1,min=0,max=1) 49 | img2=torch.clamp(img2,min=0,max=1) 50 | (_, channel, _, _) = img1.size() 51 | window = create_window(window_size, channel) 52 | if img1.is_cuda: 53 | window = window.cuda(img1.get_device()) 54 | window = window.type_as(img1) 55 | return _ssim(img1, img2, window, window_size, channel, size_average) 56 | def PSNR(pred, gt): 57 | pred=pred.clamp(0,1).cpu().numpy() 58 | gt=gt.clamp(0,1).cpu().numpy() 59 | imdff = pred - gt 60 | rmse = math.sqrt(np.mean(imdff ** 2)) 61 | if rmse == 0: 62 | return 100 63 | return 20 * math.log10( 1.0 / rmse) 64 | 65 | if __name__ == "__main__": 66 | pass 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SnowFormer: Context Interaction Transformer with Scale-awareness for Single Image Desnowing 2 | **Authors:** [Sixiang Chen](https://scholar.google.com/citations?hl=zh-CN&user=EtljKSgAAAAJ), [Tian Ye](https://scholar.google.com/citations?user=1sGXZ-wAAAAJ&hl=zh-CN), [Yun Liu](http://ai.swu.edu.cn/info/1071/1804.htm), [Erkang Chen](https://scholar.google.com/citations?user=hWo1RTsAAAAJ&hl=zh-CN) 3 | 4 | [SnowFormer: Context Interaction Transformer with Scale-awareness for Single Image Desnowing](https://arxiv.org/abs/2208.09703) 5 | 6 | > **Abstract:** *Due to various and complicated snow degradations, single image desnowing is a challenging image restoration task. As prior arts can not handle it ideally, we propose a novel transformer, SnowFormer, which explores efficient cross-attentions to build local-global context interaction across patches and surpasses existing works that employ local operators or vanilla transformers. Compared to prior desnowing methods and universal image restoration methods, SnowFormer has several benefits. Firstly, unlike the multi-head self-attention in recent image restoration Vision Transformers, SnowFormer incorporates the multi-head cross-attention mechanism to perform local-global context interaction between scale-aware snow queries and local-patch embeddings. Second, the snow queries in SnowFormer are generated by the query generator from aggregated scale-aware features, which are rich in potential clean cues, leading to superior restoration results. Third, SnowFormer outshines advanced state-of-the-art desnowing networks and the prevalent universal image restoration transformers on six synthetic and real-world datasets.* 7 | 8 | #### News 9 | - **November 22, 2022:** Checkpoint of CSD is updated. :fire: 10 | ## Network Architecture 11 | 12 | 13 | 15 | 16 |

14 |

17 | 18 | 19 | 20 | 22 | 23 |

21 |

24 | 25 | 26 | 27 | 29 | 30 |

28 |

31 | 32 | ## Installation 33 | Our SnowFormer is built in Pytorch1.12.0, we train and test it ion Ubuntu20.04 environment (Python3.8, Cuda11.6). 34 | 35 | For installing, please follow these intructions. 36 | ``` 37 | conda create -n py38 python=3.8 38 | conda activate py38 39 | conda install pytorch=1.12 40 | pip install opencv-python tqdm tensorboardX .... 41 | ``` 42 | ## Dataset 43 | We sample the 2000 images from every desnowing dataset for the test stage, including CSD, SRRS, Snow100K, SnowCityScapes and SnowKITTI. We provide the download link of datasets, and the password is **ephe**. 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 |
DatasetCSDSRRSSnow100KSnowCityScapesSnowKITTI
LinkDownloadDownloadDownloadDownloadDownload
63 | 64 | ## Testing Stage 65 | You can download pre-trained model of CSD from [Pre-trained model](https://pan.baidu.com/s/1UvokNiZ9ZNhl8tPXg2yysQ)   **Password:ephe** and save it in model_path. 66 | 67 | Run the following commands: 68 | ```python 69 | python3 test.py --dataset_type CSD --dataset_CSD dataset_CSD --model_path model_path 70 | ``` 71 | The rusults are saved in ./out/dataset_type/ 72 | 73 | ## TODO List 74 | - [ ] Checkpoints of SRRS, Snow100K, SnowCityScapes and SnowKITTI 75 | - [ ] Train.py 76 | 77 | ## Citation 78 | ```Bibtex 79 | @article{chen2022snowformer, 80 | title={SnowFormer: Scale-aware Transformer via Context Interaction for Single Image Desnowing}, 81 | author={Chen, Sixiang and Ye, Tian and Liu, Yun and Chen, Erkang and Shi, Jun and Zhou, Jingchun}, 82 | journal={arXiv preprint arXiv:2208.09703}, 83 | year={2022} 84 | } 85 | ``` 86 | ## Contact 87 | If you have any questions, please contact the email 282542428@qq.com, ephemeral182@gmail.com or sixiangchen@hkust-gz.edu.cn. 88 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os.path import exists, join as join_paths 4 | import torch 5 | import numpy as np 6 | from torchvision.transforms import functional as FF 7 | from metrics import * 8 | import warnings 9 | from torchvision.utils import save_image,make_grid 10 | from tqdm import tqdm 11 | from torch.utils.data import DataLoader 12 | from dataloader import * 13 | 14 | from SnowFormer import * 15 | from PIL import Image 16 | warnings.filterwarnings("ignore") 17 | from PIL import Image 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--cuda', action='store_true', help='use GPU computation') 21 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation') 22 | parser.add_argument('--tile', type=int, default=256, help='Tile size, None for no tile during testing (testing as a whole)') 23 | parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles') 24 | parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') 25 | parser.add_argument('--dataset_type', type=str, default='CSD', help='CSD/SRRS/Snow100K') 26 | parser.add_argument('--dataset_CSD', type=str, default='/home/PublicDataset/CSD/Test', help='path of CSD dataset') 27 | parser.add_argument('--dataset_SRRS', type=str, default='/home/PublicDataset/SRRS/SRRS-2021/', help='path of SRRS dataset') 28 | parser.add_argument('--dataset_Snow100K', type=str, default='/home/PublicDataset/Snow100K/media/jdway/GameSSD/overlapping/test/test/', help='path of Snow100k dataset') 29 | parser.add_argument('--savepath', type=str, default='./out/', help='path of output image') 30 | parser.add_argument('--model_path', type=str, default='/mnt/csx/SnowFormer/SnowFormer/SnowFormer_CSD.pth', help='path of SnowFormer checkpoint') 31 | 32 | opt = parser.parse_args() 33 | if opt.dataset_type == 'CSD': 34 | snow_test = DataLoader(dataset=CSD_Dataset(opt.dataset_CSD,train=False,size=256,rand_inpaint=False,rand_augment=None),batch_size=1,shuffle=False,num_workers=4) 35 | if opt.dataset_type == 'SRRS': 36 | snow_test = DataLoader(dataset=SRRS_Dataset(opt.dataset_SRRS,train=False,size=256,rand_inpaint=False,rand_augment=None),batch_size=1,shuffle=False,num_workers=4) 37 | if opt.dataset_type == 'Snow100K': 38 | snow_test = DataLoader(dataset=Snow100K_Dataset(opt.dataset_Snow100K,train=False,size=256,rand_inpaint=False,rand_augment=None),batch_size=1,shuffle=False,num_workers=4) 39 | 40 | 41 | netG_1 = Transformer().cuda() 42 | 43 | if __name__ == '__main__': 44 | 45 | ssims = [] 46 | psnrs = [] 47 | rmses = [] 48 | 49 | g1ckpt1 = opt.model_path 50 | ckpt = torch.load(g1ckpt1) 51 | netG_1.load_state_dict(ckpt) 52 | 53 | savepath_dataset = os.path.join(opt.savepath,opt.dataset_type) 54 | if not os.path.exists(savepath_dataset): 55 | os.makedirs(savepath_dataset) 56 | loop = tqdm(enumerate(snow_test),total=len(snow_test)) 57 | 58 | for idx,(haze,clean,name) in loop: 59 | 60 | with torch.no_grad(): 61 | 62 | haze = haze.cuda();clean = clean.cuda() 63 | 64 | b, c, h, w = haze.size() 65 | 66 | tile = min(opt.tile, h, w) 67 | tile_overlap = opt.tile_overlap 68 | sf = opt.scale 69 | 70 | stride = tile - tile_overlap 71 | h_idx_list = list(range(0, h-tile, stride)) + [h-tile] 72 | w_idx_list = list(range(0, w-tile, stride)) + [w-tile] 73 | E1 = torch.zeros(b, c, h*sf, w*sf).type_as(haze) 74 | W1 = torch.zeros_like(E1) 75 | 76 | for h_idx in h_idx_list: 77 | for w_idx in w_idx_list: 78 | in_patch = haze[..., h_idx:h_idx+tile, w_idx:w_idx+tile] 79 | out_patch1 = netG_1(in_patch) 80 | out_patch_mask1 = torch.ones_like(out_patch1) 81 | E1[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch1) 82 | W1[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask1) 83 | dehaze = E1.div_(W1) 84 | 85 | save_image(dehaze,os.path.join(savepath_dataset,'%s.png'%(name)),normalize=False) 86 | 87 | 88 | ssim1=SSIM(dehaze,clean).item() 89 | psnr1=PSNR(dehaze,clean) 90 | 91 | ssims.append(ssim1) 92 | psnrs.append(psnr1) 93 | 94 | print('Generated images %04d of %04d' % (idx+1, len(snow_test))) 95 | print('ssim:',(ssim1)) 96 | print('psnr:',(psnr1)) 97 | 98 | ssim = np.mean(ssims) 99 | psnr = np.mean(psnrs) 100 | print('ssim_avg:',ssim) 101 | print('psnr_avg:',psnr) 102 | -------------------------------------------------------------------------------- /condconv.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.nn.modules.conv import _ConvNd 7 | from torch.nn.modules.utils import _pair 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class _routing(nn.Module): 12 | 13 | def __init__(self, in_channels, num_experts, dropout_rate): 14 | super(_routing, self).__init__() 15 | 16 | self.dropout = nn.Dropout(dropout_rate) 17 | self.fc = nn.Sequential( 18 | nn.Linear(in_channels, in_channels), 19 | nn.LeakyReLU(0.1, True), 20 | nn.Linear(in_channels, num_experts) 21 | ) 22 | 23 | def forward(self, x): 24 | x = torch.flatten(x) 25 | x = self.dropout(x) 26 | x = self.fc(x) 27 | return F.sigmoid(x) 28 | 29 | 30 | class DynamicCondConv2D(_ConvNd): 31 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 32 | padding=0, dilation=1, groups=1, 33 | bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2,rooting_channels =512): 34 | kernel_size = _pair(kernel_size) 35 | stride = _pair(stride) 36 | padding = _pair(padding) 37 | dilation = _pair(dilation) 38 | super(DynamicCondConv2D, self).__init__( 39 | in_channels, out_channels, kernel_size, stride, padding, dilation, 40 | False, _pair(0), groups, bias, padding_mode) 41 | 42 | #self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1)) 43 | self._routing_fn = _routing(rooting_channels, num_experts, dropout_rate) 44 | 45 | self.weight = Parameter(torch.Tensor( 46 | num_experts, out_channels, in_channels // groups, *kernel_size)) 47 | 48 | self.reset_parameters() 49 | 50 | def _conv_forward(self, input, weight): 51 | if self.padding_mode != 'zeros': 52 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 53 | weight, self.bias, self.stride, 54 | _pair(0), self.dilation, self.groups) 55 | return F.conv2d(input, weight, self.bias, self.stride, 56 | self.padding, self.dilation, self.groups) 57 | 58 | def forward(self, inputs_q): 59 | inputs = inputs_q[0] 60 | kernel_conditions = inputs_q[1] 61 | b, _, _, _ = inputs.size() 62 | res = [] 63 | for i,input in enumerate(inputs): 64 | input = input.unsqueeze(0) 65 | routing_weights = self._routing_fn(kernel_conditions[i]) 66 | kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0) 67 | out = self._conv_forward(input, kernels) 68 | res.append(out) 69 | return torch.cat(res, dim=0) 70 | 71 | 72 | 73 | class CondConv2D(_ConvNd): 74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 75 | padding=0, dilation=1, groups=1, 76 | bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2): 77 | kernel_size = _pair(kernel_size) 78 | stride = _pair(stride) 79 | padding = _pair(padding) 80 | dilation = _pair(dilation) 81 | super(CondConv2D, self).__init__( 82 | in_channels, out_channels, kernel_size, stride, padding, dilation, 83 | False, _pair(0), groups, bias, padding_mode) 84 | 85 | self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1)) 86 | self._routing_fn = _routing(in_channels, num_experts, dropout_rate) 87 | 88 | self.weight = Parameter(torch.Tensor( 89 | num_experts, out_channels, in_channels // groups, *kernel_size)) 90 | 91 | self.reset_parameters() 92 | 93 | def _conv_forward(self, input, weight): 94 | if self.padding_mode != 'zeros': 95 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 96 | weight, self.bias, self.stride, 97 | _pair(0), self.dilation, self.groups) 98 | return F.conv2d(input, weight, self.bias, self.stride, 99 | self.padding, self.dilation, self.groups) 100 | 101 | def forward(self, inputs): 102 | b, _, _, _ = inputs.size() 103 | res = [] 104 | for input in inputs: 105 | input = input.unsqueeze(0) 106 | pooled_inputs = self._avg_pooling(input) 107 | routing_weights = self._routing_fn(pooled_inputs) 108 | kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0) 109 | out = self._conv_forward(input, kernels) 110 | res.append(out) 111 | return torch.cat(res, dim=0) 112 | 113 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from distutils.command.clean import clean 2 | import os 3 | import torch.utils.data as data 4 | import numpy as np 5 | from PIL import Image 6 | import torch.utils.data as data 7 | import torchvision.transforms as tfs 8 | from torchvision.transforms import functional as FF 9 | import os,sys 10 | import random 11 | from PIL import Image 12 | from torchvision.utils import make_grid 13 | #from RandomMask1 import * 14 | random.seed(2) 15 | np.random.seed(2) 16 | 17 | p = 1 18 | AugDict = { 19 | 1:tfs.ColorJitter(brightness=p), #Brightness 20 | 2:tfs.ColorJitter(contrast=p), #Contrast 21 | 3:tfs.ColorJitter(saturation=p), #Saturation 22 | 4:tfs.GaussianBlur(kernel_size=5), #Gaussian Blur 23 | #5:GaussianNoise(std=1), #Gaussian Noise 24 | #5:RandomMaskwithRatio(64,patch_size=4,ratio=0.7), #Random Mask 25 | } 26 | 27 | class CSD_Dataset(data.Dataset): 28 | def __init__(self,path,train=False,size=256,format='.tif',rand_inpaint=False,rand_augment=None): 29 | super(CSD_Dataset,self).__init__() 30 | self.size=size 31 | self.rand_augment=rand_augment 32 | self.rand_inpaint=rand_inpaint 33 | self.InpaintSize = 64 34 | print('crop size',size) 35 | self.train=train 36 | self.format=format 37 | self.haze_imgs_dir=os.listdir(os.path.join(path,'Snow')) 38 | print('======>total number for training:',len(self.haze_imgs_dir)) 39 | self.haze_imgs=[os.path.join(path,'Snow',img) for img in self.haze_imgs_dir] 40 | self.clear_dir=os.path.join(path,'Gt') 41 | def __getitem__(self, index): 42 | haze=Image.open(self.haze_imgs[index]) 43 | self.format = self.haze_imgs[index].split('/')[-1].split(".")[-1] 44 | while haze.size[0]total number for training:',len(self.haze_imgs_dir)) 93 | self.haze_imgs=[os.path.join(path,'Syn',img) for img in self.haze_imgs_dir] 94 | self.clear_dir=os.path.join(path,'gt') 95 | def __getitem__(self, index): 96 | haze=Image.open(self.haze_imgs[index]) 97 | self.format = self.haze_imgs[index].split('/')[-1].split(".")[-1] 98 | while haze.size[0]total number for training:',len(self.haze_imgs_dir)) 147 | self.haze_imgs=[os.path.join(path,'synthetic',img) for img in self.haze_imgs_dir] 148 | self.clear_dir=os.path.join(path,'gt') 149 | def __getitem__(self, index): 150 | haze=Image.open(self.haze_imgs[index]) 151 | self.format = self.haze_imgs[index].split('/')[-1].split(".")[-1] 152 | while haze.size[0] b (h w) c') 15 | 16 | def to_4d(x,h,w): 17 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 18 | 19 | class BiasFree_LayerNorm(nn.Module): 20 | def __init__(self, normalized_shape): 21 | super(BiasFree_LayerNorm, self).__init__() 22 | if isinstance(normalized_shape, numbers.Integral): 23 | normalized_shape = (normalized_shape,) 24 | normalized_shape = torch.Size(normalized_shape) 25 | 26 | assert len(normalized_shape) == 1 27 | 28 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 29 | self.normalized_shape = normalized_shape 30 | 31 | def forward(self, x): 32 | sigma = x.var(-1, keepdim=True, unbiased=False) 33 | return x / torch.sqrt(sigma+1e-5) * self.weight 34 | 35 | class WithBias_LayerNorm(nn.Module): 36 | def __init__(self, normalized_shape): 37 | super(WithBias_LayerNorm, self).__init__() 38 | if isinstance(normalized_shape, numbers.Integral): 39 | normalized_shape = (normalized_shape,) 40 | normalized_shape = torch.Size(normalized_shape) 41 | 42 | assert len(normalized_shape) == 1 43 | 44 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 45 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 46 | self.normalized_shape = normalized_shape 47 | 48 | def forward(self, x): 49 | mu = x.mean(-1, keepdim=True) 50 | sigma = x.var(-1, keepdim=True, unbiased=False) 51 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 52 | 53 | 54 | class LayerNorm(nn.Module): 55 | def __init__(self, dim, LayerNorm_type): 56 | super(LayerNorm, self).__init__() 57 | if LayerNorm_type =='BiasFree': 58 | self.body = BiasFree_LayerNorm(dim) 59 | else: 60 | self.body = WithBias_LayerNorm(dim) 61 | 62 | def forward(self, x): 63 | h, w = x.shape[-2:] 64 | return to_4d(self.body(to_3d(x)), h, w) 65 | class Down(nn.Module): 66 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding): 67 | super(Down,self).__init__() 68 | self.down = nn.Sequential( 69 | nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,bias=False), 70 | ) 71 | def forward(self,x): 72 | x = self.down(x) 73 | return x 74 | 75 | class Up(nn.Module): 76 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding): 77 | super(Up,self).__init__() 78 | self.up = nn.Sequential( 79 | nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding)) 80 | ) 81 | def forward(self,x): 82 | x = self.up(x) 83 | return x 84 | class UpSample(nn.Module): 85 | def __init__(self, in_channels,out_channels,s_factor): 86 | super(UpSample, self).__init__() 87 | self.up = nn.Sequential(nn.Upsample(scale_factor=s_factor, mode='bilinear', align_corners=False), 88 | nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)) 89 | 90 | def forward(self, x): 91 | x = self.up(x) 92 | return x 93 | 94 | 95 | class DWconv(nn.Module): 96 | def __init__(self,in_channels,out_channels): 97 | super(DWconv,self).__init__() 98 | self.dwconv = nn.Conv2d(in_channels,out_channels,3,1,1,bias=False,groups=in_channels) 99 | def forward(self,x): 100 | x = self.dwconv(x) 101 | return x 102 | 103 | 104 | class ChannelAttention(nn.Module): 105 | def __init__(self,chns,factor,dynamic=False): 106 | super().__init__() 107 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 108 | if dynamic == False: 109 | self.channel_map = nn.Sequential( 110 | nn.Conv2d(chns,chns//factor,1,1,0), 111 | nn.LeakyReLU(), 112 | nn.Conv2d(chns//factor,chns,1,1,0), 113 | nn.Sigmoid() 114 | ) 115 | else: 116 | self.channel_map = nn.Sequential( 117 | CondConv2D(chns,chns//factor,1,1,0), 118 | nn.LeakyReLU(), 119 | CondConv2D(chns//factor,chns,1,1,0), 120 | nn.Sigmoid() 121 | ) 122 | def forward(self,x): 123 | avg_pool = self.avg_pool(x) 124 | map = self.channel_map(avg_pool) 125 | return x*map 126 | 127 | 128 | class LKA(nn.Module): 129 | def __init__(self, dim): 130 | super().__init__() 131 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 132 | self.act1 = nn.GELU() 133 | self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) 134 | self.act2 = nn.GELU() 135 | self.conv1 = nn.Conv2d(dim, dim, 1) 136 | 137 | 138 | def forward(self, x): 139 | u = x.clone() 140 | attn = self.conv0(x) 141 | attn = self.act1(attn) 142 | attn = self.conv_spatial(attn) 143 | attn = self.act2(attn) 144 | attn = self.conv1(attn) 145 | 146 | return u * attn 147 | 148 | class LKA_dynamic(nn.Module): 149 | def __init__(self, dim): 150 | super().__init__() 151 | self.conv0 = CondConv2D(dim,dim,5,1,2,1,dim) 152 | self.act1 = nn.GELU() 153 | self.conv_spatial = CondConv2D(dim,dim,7,1,9,3,dim) 154 | self.act2 = nn.GELU() 155 | self.conv1 = nn.Conv2d(dim, dim, 1) 156 | 157 | 158 | def forward(self, x): 159 | u = x.clone() 160 | attn = self.conv0(x) 161 | attn = self.act1(attn) 162 | attn = self.conv_spatial(attn) 163 | attn = self.act2(attn) 164 | attn = self.conv1(attn) 165 | 166 | return u * attn 167 | 168 | 169 | 170 | 171 | class Attention(nn.Module): 172 | def __init__(self, d_model,dynamic=True): 173 | super().__init__() 174 | 175 | self.conv11 = nn.Conv2d(d_model,d_model,kernel_size=3,stride=1,padding=1,groups=d_model) 176 | #self.activation = nn.GELU() 177 | if dynamic == True: 178 | self.spatial_gating_unit = LKA_dynamic(d_model) 179 | else: 180 | self.spatial_gating_unit = LKA(d_model) 181 | self.conv22 = nn.Conv2d(d_model,d_model,kernel_size=3,stride=1,padding=1,groups=d_model) 182 | 183 | def forward(self, x): 184 | x = self.conv11(x) 185 | x = self.spatial_gating_unit(x) 186 | x = self.conv22(x) 187 | return x 188 | 189 | class ConvBlock(nn.Module): 190 | def __init__(self, inp, oup, stride, expand_ratio,VAN=False,dynamic=False): 191 | super(ConvBlock, self).__init__() 192 | self.VAN = VAN 193 | hidden_dim = round(inp * expand_ratio) 194 | self.identity = stride == 1 and inp == oup 195 | self.apply(self._init_weight) 196 | 197 | if self.VAN == True: 198 | if expand_ratio == 1: 199 | self.conv = nn.Sequential( 200 | 201 | LayerNorm(hidden_dim, 'BiasFree'), 202 | Attention(hidden_dim,dynamic=dynamic), 203 | ) 204 | else: 205 | self.conv = nn.Sequential( 206 | 207 | nn.Conv2d(inp, hidden_dim, 1, 1, 0 ), 208 | LayerNorm(hidden_dim, 'BiasFree'), 209 | Attention(hidden_dim,dynamic=dynamic), 210 | nn.Conv2d(hidden_dim, oup, 1, 1, 0), 211 | 212 | ) 213 | else: 214 | if dynamic == False: 215 | if expand_ratio == 1: 216 | self.conv = nn.Sequential( 217 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 218 | LayerNorm(hidden_dim, 'BiasFree'), 219 | nn.GELU(), 220 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 221 | ChannelAttention(hidden_dim,4,dynamic=dynamic), 222 | 223 | nn.Conv2d(hidden_dim, oup, 1, 1, 0), 224 | 225 | ) 226 | else: 227 | self.conv = nn.Sequential( 228 | # pw 229 | nn.Conv2d(inp, hidden_dim, 1, 1, 0 ), 230 | 231 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 232 | LayerNorm(hidden_dim, 'BiasFree'), 233 | nn.GELU(), 234 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 235 | 236 | ChannelAttention(hidden_dim,4,dynamic=dynamic), 237 | 238 | nn.Conv2d(hidden_dim, oup, 1, 1, 0), 239 | ) 240 | else: 241 | if expand_ratio == 1: 242 | self.conv = nn.Sequential( 243 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 244 | LayerNorm(hidden_dim, 'BiasFree'), 245 | nn.GELU(), 246 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 247 | ChannelAttention(hidden_dim,4,dynamic=dynamic), 248 | # pw-linear 249 | nn.Conv2d(hidden_dim, oup, 1, 1, 0), 250 | 251 | ) 252 | else: 253 | self.conv = nn.Sequential( 254 | 255 | nn.Conv2d(inp, hidden_dim, 1, 1, 0 ), 256 | 257 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 258 | LayerNorm(hidden_dim, 'BiasFree'), 259 | nn.GELU(), 260 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim), 261 | 262 | ChannelAttention(hidden_dim,4,dynamic=dynamic), 263 | 264 | nn.Conv2d(hidden_dim, oup, 1, 1, 0), 265 | ) 266 | def _init_weight(self,m): 267 | if isinstance(m,nn.Linear): 268 | trunc_normal_(m.weight,std=0.02) 269 | if isinstance(m,nn.Linear) and m.bias is not None: 270 | nn.init.constant_(m.bias,0) 271 | elif isinstance(m,nn.LayerNorm): 272 | nn.init.constant_(m.bias,0) 273 | nn.init.constant_(m.weight,1.0) 274 | elif isinstance(m,nn.Conv2d): 275 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels 276 | fan_out //= m.groups 277 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out)) 278 | if m.bias is not None: 279 | m.bias.data.zero_() 280 | 281 | def forward(self, x): 282 | if self.identity: 283 | return x + self.conv(x) 284 | else: 285 | return self.conv(x) 286 | 287 | 288 | class Conv_block(nn.Module): 289 | def __init__(self,n,in_channel,out_channele,expand_ratio,VAN=False,dynamic=False): 290 | super(Conv_block,self).__init__() 291 | 292 | layers=[] 293 | for i in range(n): 294 | layers.append(ConvBlock(in_channel,out_channele,1 if i==0 else 1,expand_ratio,VAN=VAN,dynamic=dynamic)) 295 | in_channel = out_channele 296 | self.model = nn.Sequential(*layers) 297 | def forward(self,x): 298 | x = self.model(x) 299 | return x 300 | 301 | 302 | 303 | 304 | 305 | -------------------------------------------------------------------------------- /SnowFormer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from timm.models.layers import DropPath,to_2tuple,trunc_normal_ 5 | import math 6 | import time 7 | from base_net_snow import * 8 | import torch.nn.functional as F 9 | 10 | def _to_channel_last(x): 11 | """ 12 | Args: 13 | x: (B, C, H, W) 14 | Returns: 15 | x: (B, H, W, C) 16 | """ 17 | return x.permute(0, 2, 3, 1) 18 | 19 | 20 | def _to_channel_first(x): 21 | """ 22 | Args: 23 | x: (B, H, W, C) 24 | Returns: 25 | x: (B, C, H, W) 26 | """ 27 | return x.permute(0, 3, 1, 2) 28 | 29 | 30 | def window_partition(x, window_size): 31 | """ 32 | Args: 33 | x: (B, H, W, C) 34 | window_size: window size 35 | Returns: 36 | local window features (num_windows*B, window_size, window_size, C) 37 | """ 38 | B, H, W, C = x.shape 39 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 40 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 41 | return windows 42 | 43 | 44 | def window_reverse(windows, window_size, H, W): 45 | """ 46 | Args: 47 | windows: local window features (num_windows*B, window_size, window_size, C) 48 | window_size: Window size 49 | H: Height of image 50 | W: Width of image 51 | Returns: 52 | x: (B, H, W, C) 53 | """ 54 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 55 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 56 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 57 | return x 58 | 59 | 60 | class Mlp(nn.Module): 61 | def __init__(self, 62 | in_features, 63 | hidden_features=None, 64 | out_features=None, 65 | act_layer=nn.GELU, 66 | drop=0.): 67 | 68 | super().__init__() 69 | out_features = out_features or in_features 70 | hidden_features = hidden_features or in_features 71 | self.fc1 = nn.Linear(in_features, hidden_features) 72 | self.act = act_layer() 73 | self.fc2 = nn.Linear(hidden_features, out_features) 74 | self.drop = nn.Dropout(drop) 75 | 76 | def forward(self, x): 77 | x = self.fc1(x) 78 | x = self.act(x) 79 | x = self.drop(x) 80 | x = self.fc2(x) 81 | x = self.drop(x) 82 | return x 83 | 84 | class channel_shuffle(nn.Module): 85 | def __init__(self,groups=3): 86 | super(channel_shuffle,self).__init__() 87 | self.groups = groups 88 | 89 | def forward(self,x): 90 | B,C,H,W = x.shape 91 | assert C % self.groups == 0 92 | C_per_group = C // self.groups 93 | x = x.view(B,self.groups,C_per_group,H,W) 94 | x = x.transpose(1,2).contiguous() 95 | 96 | x = x.view(B,C,H,W) 97 | return x 98 | 99 | class overlapPatchEmbed(nn.Module): 100 | def __init__(self,img_size=224,patch_size=7,stride=4,in_channels=3,dim=768): 101 | super(overlapPatchEmbed,self).__init__() 102 | 103 | patch_size=to_2tuple(patch_size) 104 | 105 | self.patch_size = patch_size 106 | self.proj = nn.Conv2d(in_channels,dim,kernel_size=patch_size,stride=stride,padding=(patch_size[0]//2,patch_size[1]//2)) 107 | self.norm = nn.LayerNorm(dim) 108 | 109 | self.apply(self._init_weight) 110 | 111 | def _init_weight(self,m): 112 | if isinstance(m,nn.Linear): 113 | trunc_normal_(m.weight,std=0.02) 114 | if isinstance(m,nn.Linear) and m.bias is not None: 115 | nn.init.constant_(m.bias,0) 116 | elif isinstance(m,nn.LayerNorm): 117 | nn.init.constant_(m.bias,0) 118 | nn.init.constant_(m.weight,1.0) 119 | elif isinstance(m,nn.Conv2d): 120 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels 121 | fan_out //= m.groups 122 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out)) 123 | if m.bias is not None: 124 | m.bias.data.zero_() 125 | 126 | def forward(self,x): 127 | x = self.proj(x) 128 | return x 129 | 130 | 131 | class Attention(nn.Module): 132 | def __init__(self, dim, num_head=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 133 | super().__init__() 134 | assert dim % num_head == 0, f"dim {dim} should be divided by num_heads {num_head}." 135 | 136 | self.dim = dim 137 | self.num_heads = num_head 138 | head_dim = dim // num_head 139 | self.scale = qk_scale or head_dim ** -0.5 140 | 141 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 142 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 143 | 144 | self.attn_drop = nn.Dropout(attn_drop) 145 | self.proj = nn.Conv2d(dim, dim,1,1) 146 | self.proj_drop = nn.Dropout(proj_drop) 147 | 148 | self.sr_ratio = sr_ratio 149 | if sr_ratio > 1: 150 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 151 | self.norm = nn.LayerNorm(dim) 152 | self.conv = nn.Conv2d(dim,dim,3,1,1,groups=dim) 153 | self.apply(self._init_weights) 154 | 155 | def _init_weights(self, m): 156 | if isinstance(m, nn.Linear): 157 | trunc_normal_(m.weight, std=.02) 158 | if isinstance(m, nn.Linear) and m.bias is not None: 159 | nn.init.constant_(m.bias, 0) 160 | elif isinstance(m, nn.LayerNorm): 161 | nn.init.constant_(m.bias, 0) 162 | nn.init.constant_(m.weight, 1.0) 163 | elif isinstance(m, nn.Conv2d): 164 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 165 | fan_out //= m.groups 166 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 167 | if m.bias is not None: 168 | m.bias.data.zero_() 169 | 170 | def forward(self, x, H, W): 171 | 172 | B, N, C = x.shape 173 | x_conv = self.conv(x.reshape(B,H,W,C).permute(0,3,1,2)) 174 | 175 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 176 | if self.sr_ratio > 1: 177 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 178 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 179 | x_ = self.norm(x_) 180 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 181 | else: 182 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 183 | k, v = kv[0], kv[1] 184 | 185 | attn = (q @ k.transpose(-2, -1)) * self.scale 186 | attn = attn.softmax(dim=-1) 187 | attn = self.attn_drop(attn) 188 | 189 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 190 | x = self.proj(x.transpose(1,2).reshape(B,C,H,W)) 191 | x = self.proj_drop(x) 192 | x = x+x_conv 193 | return x 194 | 195 | 196 | class SimpleGate(nn.Module): 197 | def forward(self,x): 198 | x1, x2 = x.chunk(2, dim=1) 199 | return x1 * x2 200 | 201 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 202 | return nn.Conv2d( 203 | in_channels, out_channels, kernel_size, 204 | padding=(kernel_size//2), bias=bias, stride = stride) 205 | 206 | 207 | class MFFN(nn.Module): 208 | def __init__(self, dim, FFN_expand=2,norm_layer='WithBias'): 209 | super(MFFN, self).__init__() 210 | 211 | self.conv1 = nn.Conv2d(dim,dim*FFN_expand,1) 212 | self.conv33 = nn.Conv2d(dim*FFN_expand,dim*FFN_expand,3,1,1,groups=dim*FFN_expand) 213 | self.conv55 = nn.Conv2d(dim*FFN_expand,dim*FFN_expand,5,1,2,groups=dim*FFN_expand) 214 | self.sg = SimpleGate() 215 | self.conv4 = nn.Conv2d(dim,dim,1) 216 | 217 | self.apply(self._init_weights) 218 | def _init_weights(self,m): 219 | if isinstance(m,nn.Linear): 220 | trunc_normal_(m.weight,std=0.02) 221 | if isinstance(m,nn.Linear) and m.bias is not None: 222 | nn.init.constant_(m.bias,0) 223 | elif isinstance(m,nn.LayerNorm): 224 | nn.init.constant_(m.bias,0) 225 | nn.init.constant_(m.weight,1.0) 226 | elif isinstance(m,nn.Conv2d): 227 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels 228 | fan_out //= m.groups 229 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out)) 230 | if m.bias is not None: 231 | m.bias.data.zero_() 232 | def forward(self, x): 233 | x1 = self.conv1(x) 234 | x33 = self.conv33(x1) 235 | x55 = self.conv55(x1) 236 | x = x1+x33+x55 237 | x = self.sg(x) 238 | x = self.conv4(x) 239 | return x 240 | 241 | class Scale_aware_Query(nn.Module): 242 | def __init__(self, 243 | dim, 244 | out_channel, 245 | window_size, 246 | num_heads): 247 | super().__init__() 248 | self.dim = dim 249 | self.out_channel = out_channel 250 | self.window_size = window_size 251 | self.conv = nn.Conv2d(dim,out_channel,1,1,0) 252 | 253 | layers=[] 254 | for i in range(3): 255 | layers.append(CALayer(out_channel,4)) 256 | layers.append(SALayer(out_channel,4)) 257 | self.globalgen = nn.Sequential(*layers) 258 | 259 | self.num_heads = num_heads 260 | self.N = window_size * window_size 261 | self.dim_head = out_channel // self.num_heads 262 | 263 | def forward(self, x): 264 | x = self.conv(x) 265 | x = F.upsample(x,(self.window_size,self.window_size),mode="bicubic") 266 | x = self.globalgen(x) 267 | B = x.shape[0] 268 | x = x.reshape(B, 1, self.N, self.num_heads, self.dim_head).permute(0, 1, 3, 2, 4) 269 | return x 270 | 271 | class LocalContext_Interaction(nn.Module): 272 | 273 | def __init__(self, 274 | dim, 275 | num_heads, 276 | window_size, 277 | qkv_bias=True, 278 | qk_scale=None, 279 | attn_drop=0., 280 | proj_drop=0., 281 | ): 282 | """ 283 | Args: 284 | dim: feature size dimension. 285 | num_heads: number of attention head. 286 | window_size: window size. 287 | qkv_bias: bool argument for query, key, value learnable bias. 288 | qk_scale: bool argument to scaling query, key. 289 | attn_drop: attention dropout rate. 290 | proj_drop: output dropout rate. 291 | """ 292 | 293 | super().__init__() 294 | self.dim = dim 295 | window_size = (window_size,window_size) 296 | self.window_size = window_size 297 | self.num_heads = num_heads 298 | head_dim = dim // num_heads 299 | self.scale = qk_scale or head_dim ** -0.5 300 | self.relative_position_bias_table = nn.Parameter( 301 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 302 | coords_h = torch.arange(self.window_size[0]) 303 | coords_w = torch.arange(self.window_size[1]) 304 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 305 | coords_flatten = torch.flatten(coords, 1) 306 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 307 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 308 | relative_coords[:, :, 0] += self.window_size[0] - 1 309 | relative_coords[:, :, 1] += self.window_size[1] - 1 310 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 311 | relative_position_index = relative_coords.sum(-1) 312 | self.register_buffer("relative_position_index", relative_position_index) 313 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 314 | self.attn_drop = nn.Dropout(attn_drop) 315 | self.proj = nn.Linear(dim, dim) 316 | self.proj_drop = nn.Dropout(proj_drop) 317 | 318 | trunc_normal_(self.relative_position_bias_table, std=.02) 319 | self.softmax = nn.Softmax(dim=-1) 320 | 321 | def forward(self, x,q_global): 322 | B_, N, C = x.shape 323 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 324 | q, k, v = qkv[0], qkv[1], qkv[2] 325 | q = q * self.scale 326 | attn = (q @ k.transpose(-2, -1)) 327 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 328 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 329 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 330 | attn = attn + relative_position_bias.unsqueeze(0) 331 | attn = self.softmax(attn) 332 | attn = self.attn_drop(attn) 333 | 334 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 335 | x = self.proj(x) 336 | x = self.proj_drop(x) 337 | return x 338 | 339 | class GlobalContext_Interaction(nn.Module): 340 | 341 | def __init__(self, 342 | dim, 343 | num_heads, 344 | window_size, 345 | qkv_bias=True, 346 | qk_scale=None, 347 | attn_drop=0., 348 | proj_drop=0., 349 | ): 350 | """ 351 | Args: 352 | dim: feature size dimension. 353 | num_heads: number of attention head. 354 | window_size: window size. 355 | qkv_bias: bool argument for query, key, value learnable bias. 356 | qk_scale: bool argument to scaling query, key. 357 | attn_drop: attention dropout rate. 358 | proj_drop: output dropout rate. 359 | """ 360 | 361 | super().__init__() 362 | window_size = (window_size, window_size) 363 | self.window_size = window_size 364 | self.num_heads = num_heads 365 | head_dim = dim // num_heads 366 | self.scale = qk_scale or head_dim ** -0.5 367 | self.relative_position_bias_table = nn.Parameter( 368 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) 369 | coords_h = torch.arange(self.window_size[0]) 370 | coords_w = torch.arange(self.window_size[1]) 371 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 372 | coords_flatten = torch.flatten(coords, 1) 373 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 374 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 375 | relative_coords[:, :, 0] += self.window_size[0] - 1 376 | relative_coords[:, :, 1] += self.window_size[1] - 1 377 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 378 | relative_position_index = relative_coords.sum(-1) 379 | self.register_buffer("relative_position_index", relative_position_index) 380 | self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias) 381 | self.attn_drop = nn.Dropout(attn_drop) 382 | self.proj = nn.Linear(dim, dim) 383 | self.proj_drop = nn.Dropout(proj_drop) 384 | trunc_normal_(self.relative_position_bias_table, std=.02) 385 | self.softmax = nn.Softmax(dim=-1) 386 | 387 | def forward(self, x, q_global): 388 | B_, N, C = x.shape 389 | B = q_global.shape[0] 390 | kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 391 | k, v = kv[0], kv[1] 392 | q_global = q_global.repeat(1, B_ // B, 1, 1, 1) 393 | q = q_global.reshape(B_, self.num_heads, N, C // self.num_heads) 394 | q = q * self.scale 395 | attn = (q @ k.transpose(-2, -1)) 396 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 397 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) 398 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 399 | attn = attn + relative_position_bias.unsqueeze(0) 400 | attn = self.softmax(attn) 401 | attn = self.attn_drop(attn) 402 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 403 | x = self.proj(x) 404 | x = self.proj_drop(x) 405 | return x 406 | 407 | class Context_Interaction_Block(nn.Module): 408 | 409 | def __init__(self, 410 | latent_dim, 411 | dim, 412 | num_heads, 413 | window_size=8, 414 | mlp_ratio=4., 415 | qkv_bias=True, 416 | qk_scale=None, 417 | drop=0., 418 | attn_drop=0., 419 | drop_path=0., 420 | act_layer=nn.GELU, 421 | attention=LocalContext_Interaction, 422 | norm_layer=nn.LayerNorm, 423 | ): 424 | 425 | 426 | super().__init__() 427 | self.window_size = window_size 428 | self.norm1 = norm_layer(dim) 429 | 430 | self.attn = attention( 431 | dim, 432 | num_heads=num_heads, 433 | window_size=window_size, 434 | qkv_bias=qkv_bias, 435 | qk_scale=qk_scale, 436 | attn_drop=attn_drop, 437 | proj_drop=drop, 438 | ) 439 | 440 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 441 | self.norm2 = norm_layer(dim) 442 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 443 | self.layer_scale = False 444 | 445 | self.gamma1 = 1.0 446 | self.gamma2 = 1.0 447 | 448 | def forward(self, x,q_global): 449 | B,H, W,C = x.shape 450 | shortcut = x 451 | x = self.norm1(x) 452 | x_windows = window_partition(x, self.window_size) 453 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 454 | attn_windows = self.attn(x_windows,q_global) 455 | x = window_reverse(attn_windows, self.window_size, H, W) 456 | x = shortcut + self.drop_path(self.gamma1 * x) 457 | x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) 458 | return x 459 | 460 | class Context_Interaction_layer(nn.Module): 461 | def __init__(self,n,latent_dim,in_channel,head,window_size,globalatten=False): 462 | super(Context_Interaction_layer,self).__init__() 463 | 464 | #layers=[] 465 | self.globalatten = globalatten 466 | self.model = nn.ModuleList([ 467 | Context_Interaction_Block( 468 | latent_dim, 469 | in_channel, 470 | num_heads=head, 471 | window_size=window_size, 472 | attention=GlobalContext_Interaction if i%2 == 1 and self.globalatten == True else LocalContext_Interaction, 473 | ) 474 | for i in range(n)]) 475 | 476 | if self.globalatten == True: 477 | self.gen = Scale_aware_Query(latent_dim,in_channel,window_size=8,num_heads=head) 478 | def forward(self,x,latent): 479 | if self.globalatten == True: 480 | q_global = self.gen(latent) 481 | x = _to_channel_last(x) 482 | for model in self.model: 483 | x = model(x, q_global) 484 | else: 485 | x = _to_channel_last(x) 486 | for model in self.model: 487 | x = model(x,1) 488 | return x 489 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 490 | return nn.Conv2d( 491 | in_channels, out_channels, kernel_size, 492 | padding=(kernel_size//2), bias=bias, stride = stride) 493 | 494 | 495 | ########################################################################## 496 | ## Channel Attention Layer 497 | class CALayer(nn.Module): 498 | def __init__(self, channel, reduction=4, bias=False): 499 | super(CALayer, self).__init__() 500 | # global average pooling: feature --> point 501 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 502 | self.channel = channel 503 | self.reduction = reduction 504 | # feature channel downscale and upscale --> channel weight 505 | self.conv_du = nn.Sequential( 506 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 507 | nn.ReLU(inplace=True), 508 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 509 | nn.Sigmoid() 510 | ) 511 | 512 | def forward(self, x): 513 | y = self.avg_pool(x) 514 | y = self.conv_du(y) 515 | return x * y 516 | 517 | class SALayer(nn.Module): 518 | def __init__(self, channel,reduction=16): 519 | super(SALayer, self).__init__() 520 | self.pa = nn.Sequential( 521 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 522 | nn.ReLU(inplace=True), 523 | nn.Conv2d(channel // reduction, 1, 1, padding=0, bias=True), 524 | nn.Sigmoid() 525 | ) 526 | def forward(self, x): 527 | y = self.pa(x) 528 | return x * y 529 | 530 | class Refine_Block(nn.Module): 531 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 532 | super(Refine_Block, self).__init__() 533 | modules_body = [] 534 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 535 | modules_body.append(act) 536 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 537 | 538 | self.CA = CALayer(n_feat, reduction, bias=bias) 539 | self.body = nn.Sequential(*modules_body) 540 | 541 | def forward(self, x): 542 | res = self.body(x) 543 | res = self.CA(res) 544 | res += x 545 | return res 546 | 547 | class Refine(nn.Module): 548 | def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): 549 | super(Refine, self).__init__() 550 | modules_body = [] 551 | modules_body = [Refine_Block(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] 552 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 553 | self.body = nn.Sequential(*modules_body) 554 | 555 | def forward(self, x): 556 | res = self.body(x) 557 | res += x 558 | return res 559 | 560 | class HFPH(nn.Module): 561 | def __init__(self, n_feat,fusion_dim, kernel_size, reduction, act, bias, num_cab): 562 | super(HFPH, self).__init__() 563 | self.refine0 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab) 564 | self.refine1 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab) 565 | self.refine2 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab) 566 | self.refine3 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab) 567 | 568 | self.up_enc1 = UpSample(n_feat[1],fusion_dim,s_factor=2) 569 | self.up_dec1 = UpSample(n_feat[1],fusion_dim,s_factor=2) 570 | 571 | self.up_enc2 = UpSample(n_feat[2],fusion_dim,s_factor=4) 572 | self.up_dec2 = UpSample(n_feat[2],fusion_dim,s_factor=4) 573 | 574 | self.up_enc3 = UpSample(n_feat[3],fusion_dim,s_factor=8) 575 | self.up_dec3 = UpSample(n_feat[3],fusion_dim,s_factor=8) 576 | 577 | layer0=[] 578 | for i in range(2): 579 | layer0.append(CALayer(fusion_dim,16)) 580 | layer0.append(SALayer(fusion_dim,16)) 581 | self.conv_enc0 = nn.Sequential(*layer0) 582 | 583 | layer1=[] 584 | for i in range(2): 585 | layer1.append(CALayer(fusion_dim,16)) 586 | layer1.append(SALayer(fusion_dim,16)) 587 | self.conv_enc1 = nn.Sequential(*layer1) 588 | 589 | layer2=[] 590 | for i in range(2): 591 | layer2.append(CALayer(fusion_dim,16)) 592 | layer2.append(SALayer(fusion_dim,16)) 593 | self.conv_enc2 = nn.Sequential(*layer2) 594 | 595 | layer3=[] 596 | for i in range(2): 597 | layer3.append(CALayer(fusion_dim,16)) 598 | layer3.append(SALayer(fusion_dim,16)) 599 | self.conv_enc3 = nn.Sequential(*layer3) 600 | 601 | def forward(self, x, encoder_outs, decoder_outs): 602 | x = x + self.conv_enc0(encoder_outs[0] + decoder_outs[3]) 603 | x = self.refine0(x) 604 | 605 | x = x + self.conv_enc1(self.up_enc1(encoder_outs[1]) + self.up_dec1(decoder_outs[2])) 606 | x = self.refine1(x) 607 | 608 | x = x + self.conv_enc2(self.up_enc2(encoder_outs[2]) + self.up_dec2(decoder_outs[1])) 609 | x = self.refine2(x) 610 | 611 | x = x + self.conv_enc3(self.up_enc3(encoder_outs[3]) + self.up_dec3(decoder_outs[0])) 612 | x = self.refine3(x) 613 | 614 | return x 615 | 616 | class Transformer_block(nn.Module): 617 | def __init__(self,dim,num_head=8,groups=2,qkv_bias=False,qk_scale=None,attn_drop=0.,proj_drop=0.,FFN_expand=2,norm_layer='WithBias',act_layer=nn.GELU,l_drop=0.,mlp_ratio=2,drop_path=0.,sr=1): 618 | super(Transformer_block,self).__init__() 619 | self.dim=dim*2 620 | self.num_head = num_head 621 | 622 | self.norm1 = LayerNorm(self.dim//2, norm_layer) 623 | self.norm3 = LayerNorm(self.dim//2, norm_layer) 624 | 625 | self.attn_nn = Attention(dim=self.dim//2,num_head=num_head,qkv_bias=qkv_bias,qk_scale=qk_scale,attn_drop=attn_drop,proj_drop=proj_drop,sr_ratio=sr) 626 | 627 | self.ffn = MFFN(self.dim//2, FFN_expand=2,norm_layer=nn.LayerNorm) 628 | 629 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 630 | self.apply(self._init_weights) 631 | 632 | def _init_weights(self,m): 633 | if isinstance(m,nn.Linear): 634 | trunc_normal_(m.weight,std=0.02) 635 | if isinstance(m,nn.Linear) and m.bias is not None: 636 | nn.init.constant_(m.bias,0) 637 | elif isinstance(m,nn.LayerNorm): 638 | nn.init.constant_(m.bias,0) 639 | nn.init.constant_(m.weight,1.0) 640 | elif isinstance(m,nn.Conv2d): 641 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels 642 | fan_out //= m.groups 643 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out)) 644 | if m.bias is not None: 645 | m.bias.data.zero_() 646 | 647 | def forward(self,x): 648 | ind = x 649 | b,c,h,w = x.shape 650 | b,c,h,w = x.shape 651 | x = self.attn_nn(self.norm1(x).reshape(b,c,h*w).transpose(1,2),h,w) 652 | b,c,h,w = x.shape 653 | x = self.drop_path(x) 654 | x = x+self.drop_path(self.ffn(self.norm3(x))) 655 | return x 656 | 657 | class Transformer(nn.Module): 658 | def __init__(self, 659 | in_channels=3, 660 | out_cahnnels=3, 661 | transformer_blocks = 8, 662 | dim=[16,32,64,128,256], 663 | window_size = [8,8,8,8], 664 | patch_size = 64, 665 | swin_num = [4,6,7,8], 666 | head = [1,2,4,8,16], 667 | FFN_expand_=2, 668 | qkv_bias_=False, 669 | qk_sacle_=None, 670 | attn_drop_=0., 671 | proj_drop_=0., 672 | norm_layer_= 'WithBias', 673 | act_layer_=nn.GELU, 674 | l_drop_=0., 675 | drop_path_=0., 676 | sr=1, 677 | conv_num = [4,6,7,8], 678 | expand_ratio = [1,2,2,2], 679 | VAN = False, 680 | dynamic_e = False, 681 | dynamic_d = False, 682 | global_atten = True, 683 | ): 684 | super(Transformer,self).__init__() 685 | self.patch_size = patch_size 686 | 687 | self.embed = Down(in_channels,dim[0],3,1,1) 688 | self.conv0 = nn.Conv2d(dim[0],dim[4],1) 689 | self.encoder_level0 = nn.Sequential(Conv_block(conv_num[0],dim[0],dim[0],expand_ratio=expand_ratio[0],VAN=VAN,dynamic=dynamic_e)) 690 | 691 | self.down0 = Down(dim[0],dim[1],3,2,1) ## H//2,W//2 692 | self.conv1 = nn.Conv2d(dim[1],dim[4],1) 693 | self.encoder_level1 = nn.Sequential(Conv_block(conv_num[1],dim[1],dim[1],expand_ratio=expand_ratio[1],VAN=VAN,dynamic=dynamic_e)) 694 | 695 | self.down1 = Down(dim[1],dim[2],3,2,1) ## H//4,W//4 696 | self.conv2 = nn.Conv2d(dim[2],dim[4],1) 697 | self.encoder_level2 = nn.Sequential(Conv_block(conv_num[2],dim[2],dim[2],expand_ratio=expand_ratio[2],VAN=VAN,dynamic=dynamic_e)) 698 | 699 | self.down2 = Down(dim[2],dim[3],3,2,1) ## H//8,W//8 700 | self.conv3 = nn.Conv2d(dim[3],dim[4],1) 701 | self.encoder_level3 = nn.Sequential(Conv_block(conv_num[3],dim[3],dim[3],expand_ratio=expand_ratio[3],VAN=VAN,dynamic=dynamic_e)) 702 | 703 | self.down3 = Down(dim[3],dim[4],3,2,1) ## H//16,W//16 704 | 705 | self.latent = nn.Sequential(*[Transformer_block(dim=(dim[4]),num_head=head[4],qkv_bias=qkv_bias_,qk_scale=qk_sacle_,attn_drop=attn_drop_,proj_drop=proj_drop_,FFN_expand=FFN_expand_,norm_layer=norm_layer_,act_layer=act_layer_,l_drop=l_drop_,drop_path=drop_path_,sr=sr) for i in range(transformer_blocks)]) 706 | 707 | self.up3 = Up((dim[4]),dim[3],4,2,1) 708 | self.ca3 = CALayer(dim[3]*2,reduction=4) 709 | self.reduce_chan_level3 = nn.Conv2d(dim[3]*2, dim[3], kernel_size=1, bias=False) 710 | self.decoder_level3 = Context_Interaction_layer(n=swin_num[3],latent_dim=dim[4],in_channel=dim[3],head=head[3],window_size=window_size[3],globalatten=global_atten) 711 | self.up2 = Up(dim[3],dim[2],4,2,1) 712 | self.ca2 = CALayer(dim[2]*2,reduction=4) 713 | self.reduce_chan_level2 = nn.Conv2d(dim[2]*2, dim[2], kernel_size=1, bias=False) 714 | self.decoder_level2 = Context_Interaction_layer(n=swin_num[2],latent_dim=dim[4],in_channel=dim[2],head=head[2],window_size=window_size[2],globalatten=global_atten) 715 | 716 | self.up1 = Up(dim[2],dim[1],4,2,1) 717 | self.ca1 = CALayer(dim[1]*2,reduction=4) 718 | self.reduce_chan_level1 = nn.Conv2d(dim[1]*2, dim[1], kernel_size=1, bias=False) 719 | self.decoder_level1 = Context_Interaction_layer(n=swin_num[1],latent_dim=dim[4],in_channel=dim[1],head=head[1],window_size=window_size[1],globalatten=global_atten) 720 | 721 | self.up0 = Up(dim[1],dim[0],4,2,1) 722 | self.ca0 = CALayer(dim[0]*2,reduction=4) 723 | self.reduce_chan_level0 = nn.Conv2d(dim[0]*2, dim[0], kernel_size=1, bias=False) 724 | self.decoder_level0 = Context_Interaction_layer(n=swin_num[0],latent_dim=dim[4],in_channel=dim[0],head=head[0],window_size=window_size[0],globalatten=global_atten) 725 | 726 | self.refinement = HFPH(n_feat=dim,fusion_dim=dim[0],kernel_size=3,reduction=4,act=nn.GELU(),bias=True,num_cab=6) 727 | 728 | self.out2 = nn.Conv2d(dim[0],out_cahnnels,kernel_size=3,stride=1,padding=1) 729 | 730 | def check_image_size(self, x): 731 | _, _, h, w = x.size() 732 | mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size 733 | mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size 734 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 735 | return x 736 | def forward(self,x): 737 | x = self.check_image_size(x) 738 | encoder_item = [] 739 | decoder_item = [] 740 | inp_enc_level0 = self.embed(x) 741 | inp_enc_level0 = self.encoder_level0(inp_enc_level0) 742 | encoder_item.append(inp_enc_level0) 743 | 744 | inp_enc_level1 = self.down0(inp_enc_level0) 745 | inp_enc_level1 = self.encoder_level1(inp_enc_level1) 746 | encoder_item.append(inp_enc_level1) 747 | 748 | inp_enc_level2 = self.down1(inp_enc_level1) 749 | inp_enc_level2 = self.encoder_level2(inp_enc_level2) 750 | encoder_item.append(inp_enc_level2) 751 | 752 | inp_enc_level3 = self.down2(inp_enc_level2) 753 | inp_enc_level3 = self.encoder_level3(inp_enc_level3) 754 | encoder_item.append(inp_enc_level3) 755 | 756 | out_enc_level4 = self.down3(inp_enc_level3) 757 | top_0 = F.adaptive_max_pool2d(inp_enc_level0,(out_enc_level4.shape[2],out_enc_level4.shape[3])) 758 | top_1 = F.adaptive_max_pool2d(inp_enc_level1,(out_enc_level4.shape[2],out_enc_level4.shape[3])) 759 | top_2 = F.adaptive_max_pool2d(inp_enc_level2,(out_enc_level4.shape[2],out_enc_level4.shape[3])) 760 | top_3 = F.adaptive_max_pool2d(inp_enc_level3,(out_enc_level4.shape[2],out_enc_level4.shape[3])) 761 | 762 | 763 | latent = out_enc_level4+self.conv0(top_0)+self.conv1(top_1)+self.conv2(top_2)+self.conv3(top_3) 764 | 765 | latent = self.latent(latent) 766 | 767 | inp_dec_level3 = self.up3(latent) 768 | inp_dec_level3 = F.upsample(inp_dec_level3,(inp_enc_level3.shape[2],inp_enc_level3.shape[3]),mode="bicubic") 769 | inp_dec_level3 = torch.cat([inp_dec_level3, inp_enc_level3], 1) 770 | inp_dec_level3 = self.ca3(inp_dec_level3) 771 | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) 772 | 773 | out_dec_level3 = self.decoder_level3(inp_dec_level3,latent) 774 | out_dec_level3 = _to_channel_first(out_dec_level3) 775 | decoder_item.append(out_dec_level3) 776 | 777 | inp_dec_level2 = self.up2(out_dec_level3) 778 | inp_dec_level2 = F.upsample(inp_dec_level2,(inp_enc_level2.shape[2],inp_enc_level2.shape[3]),mode="bicubic") 779 | inp_dec_level2 = torch.cat([inp_dec_level2, inp_enc_level2], 1) 780 | inp_dec_level2 = self.ca2(inp_dec_level2) 781 | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) 782 | out_dec_level2 = self.decoder_level2(inp_dec_level2,latent) 783 | out_dec_level2 = _to_channel_first(out_dec_level2) 784 | decoder_item.append(out_dec_level2) 785 | 786 | inp_dec_level1 = self.up1(out_dec_level2) 787 | inp_dec_level1 = F.upsample(inp_dec_level1,(inp_enc_level1.shape[2],inp_enc_level1.shape[3]),mode="bicubic") 788 | inp_dec_level1 = torch.cat([inp_dec_level1, inp_enc_level1], 1) 789 | inp_dec_level1 = self.ca1(inp_dec_level1) 790 | inp_dec_level1 = self.reduce_chan_level1(inp_dec_level1) 791 | out_dec_level1 = self.decoder_level1(inp_dec_level1,latent) 792 | out_dec_level1 = _to_channel_first(out_dec_level1) 793 | decoder_item.append(out_dec_level1) 794 | 795 | inp_dec_level0 = self.up0(out_dec_level1) 796 | inp_dec_level0 = F.upsample(inp_dec_level0,(inp_enc_level0.shape[2],inp_enc_level0.shape[3]),mode="bicubic") 797 | inp_dec_level0 = torch.cat([inp_dec_level0, inp_enc_level0], 1) 798 | inp_dec_level0 = self.ca0(inp_dec_level0) 799 | inp_dec_level0 = self.reduce_chan_level0(inp_dec_level0) 800 | out_dec_level0 = self.decoder_level0(inp_dec_level0,latent) 801 | out_dec_level0 = _to_channel_first(out_dec_level0) 802 | decoder_item.append(out_dec_level0) 803 | 804 | out_dec_level0_refine = self.refinement(out_dec_level0,encoder_item,decoder_item) 805 | out_dec_level1 = self.out2(out_dec_level0_refine) + x 806 | 807 | return out_dec_level1 808 | # from ptflops import get_model_complexity_info 809 | 810 | # model = Transformer().cuda() 811 | # H,W=256,256 812 | # flops_t, params_t = get_model_complexity_info(model, (3, H,W), as_strings=True, print_per_layer_stat=True) 813 | 814 | # print(f"net flops:{flops_t} parameters:{params_t}") 815 | # model = nn.DataParallel(model) 816 | # x = torch.ones([1,3,H,W]).cuda() 817 | # b = model(x) 818 | # steps=25 819 | # # print(b) 820 | # time_avgs=[] 821 | # with torch.no_grad(): 822 | # for step in range(steps): 823 | 824 | # torch.cuda.synchronize() 825 | # start = time.time() 826 | # result = model(x) 827 | # torch.cuda.synchronize() 828 | # time_interval = time.time() - start 829 | # if step>5: 830 | # time_avgs.append(time_interval) 831 | # #print('run time:',time_interval) 832 | # print('avg time:',np.mean(time_avgs),'fps:',(1/np.mean(time_avgs)),' size:',H,W) 833 | --------------------------------------------------------------------------------