├── image
├── 1.png
├── 2.png
└── table.png
├── loss
├── CL1.py
└── perceptual.py
├── metrics.py
├── README.md
├── test.py
├── condconv.py
├── dataloader.py
├── base_net_snow.py
└── SnowFormer.py
/image/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ephemeral182/SnowFormer/HEAD/image/1.png
--------------------------------------------------------------------------------
/image/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ephemeral182/SnowFormer/HEAD/image/2.png
--------------------------------------------------------------------------------
/image/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ephemeral182/SnowFormer/HEAD/image/table.png
--------------------------------------------------------------------------------
/loss/CL1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | class L1_Charbonnier_loss(torch.nn.Module):
4 | """L1 Charbonnierloss."""
5 | def __init__(self):
6 | super(L1_Charbonnier_loss, self).__init__()
7 | self.eps = 1e-6
8 |
9 | def forward(self, X, Y):
10 | diff = torch.add(X, -Y)
11 | error = torch.sqrt(diff * diff + self.eps)
12 | loss = torch.mean(error)
13 | return loss
14 |
15 | class PSNRLoss(torch.nn.Module):
16 |
17 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
18 | super(PSNRLoss, self).__init__()
19 | assert reduction == 'mean'
20 | self.loss_weight = loss_weight
21 | self.scale = 10 / np.log(10)
22 | self.toY = toY
23 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
24 | self.first = True
25 |
26 | def forward(self, pred, target):
27 | assert len(pred.size()) == 4
28 | if self.toY:
29 | if self.first:
30 | self.coef = self.coef.to(pred.device)
31 | self.first = False
32 |
33 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
34 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
35 |
36 | pred, target = pred / 255., target / 255.
37 | pass
38 | assert len(pred.size()) == 4
39 |
40 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
--------------------------------------------------------------------------------
/loss/perceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 | from torchvision.models.vgg import vgg19,vgg16
7 | import torch.nn as nn
8 |
9 |
10 |
11 |
12 | class PerceptualLoss(nn.Module):
13 | def __init__(self):
14 | super(PerceptualLoss,self).__init__()
15 | self.L1 = nn.L1Loss()
16 | self.mse = nn.MSELoss()
17 | vgg = vgg19(pretrained=True).eval()
18 | self.loss_net1 = nn.Sequential(*list(vgg.features)[:1]).eval()
19 | self.loss_net3 = nn.Sequential(*list(vgg.features)[:3]).eval()
20 | self.loss_net5 = nn.Sequential(*list(vgg.features)[:5]).eval()
21 | self.loss_net9 = nn.Sequential(*list(vgg.features)[:9]).eval()
22 | self.loss_net13 = nn.Sequential(*list(vgg.features)[:13]).eval()
23 | def forward(self,x,y):
24 | loss1 = self.L1(self.loss_net1(x),self.loss_net1(y))
25 | loss3 = self.L1(self.loss_net3(x),self.loss_net3(y))
26 | loss5 = self.L1(self.loss_net5(x),self.loss_net5(y))
27 | loss9 = self.L1(self.loss_net9(x),self.loss_net9(y))
28 | loss13 = self.L1(self.loss_net13(x),self.loss_net13(y))
29 | #print(self.loss_net13(x).shape)
30 | loss = 0.2*loss1 + 0.2*loss3 + 0.2*loss5 + 0.2*loss9 + 0.2*loss13
31 | return loss
32 |
33 |
34 |
35 | class PerceptualLoss2(nn.Module):
36 | def __init__(self):
37 | super(PerceptualLoss2,self).__init__()
38 | self.L1 = nn.L1Loss()
39 | self.mse = nn.MSELoss()
40 | vgg = vgg19(pretrained=True).eval()
41 | self.loss_net1 = nn.Sequential(*list(vgg.features)[:1]).eval()
42 | self.loss_net3 = nn.Sequential(*list(vgg.features)[:3]).eval()
43 | def forward(self,x,y):
44 | loss1 = self.L1(self.loss_net1(x),self.loss_net1(y))
45 | loss3 = self.L1(self.loss_net3(x),self.loss_net3(y))
46 | loss = 0.5*loss1+0.5*loss3
47 | return loss
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 | import math
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 | from math import exp
10 | import math
11 | import numpy as np
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Variable
16 | from torchvision.transforms import ToPILImage
17 |
18 | def gaussian(window_size, sigma):
19 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
20 | return gauss / gauss.sum()
21 |
22 | def create_window(window_size, channel):
23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
25 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
26 | return window
27 |
28 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
29 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
30 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
31 | mu1_sq = mu1.pow(2)
32 | mu2_sq = mu2.pow(2)
33 | mu1_mu2 = mu1 * mu2
34 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
35 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
36 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
37 | C1 = 0.01 ** 2
38 | C2 = 0.03 ** 2
39 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
40 |
41 | if size_average:
42 | return ssim_map.mean()
43 | else:
44 | return ssim_map.mean(1).mean(1).mean(1)
45 |
46 |
47 | def SSIM(img1, img2, window_size=11, size_average=True):
48 | img1=torch.clamp(img1,min=0,max=1)
49 | img2=torch.clamp(img2,min=0,max=1)
50 | (_, channel, _, _) = img1.size()
51 | window = create_window(window_size, channel)
52 | if img1.is_cuda:
53 | window = window.cuda(img1.get_device())
54 | window = window.type_as(img1)
55 | return _ssim(img1, img2, window, window_size, channel, size_average)
56 | def PSNR(pred, gt):
57 | pred=pred.clamp(0,1).cpu().numpy()
58 | gt=gt.clamp(0,1).cpu().numpy()
59 | imdff = pred - gt
60 | rmse = math.sqrt(np.mean(imdff ** 2))
61 | if rmse == 0:
62 | return 100
63 | return 20 * math.log10( 1.0 / rmse)
64 |
65 | if __name__ == "__main__":
66 | pass
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SnowFormer: Context Interaction Transformer with Scale-awareness for Single Image Desnowing
2 | **Authors:** [Sixiang Chen](https://scholar.google.com/citations?hl=zh-CN&user=EtljKSgAAAAJ), [Tian Ye](https://scholar.google.com/citations?user=1sGXZ-wAAAAJ&hl=zh-CN), [Yun Liu](http://ai.swu.edu.cn/info/1071/1804.htm), [Erkang Chen](https://scholar.google.com/citations?user=hWo1RTsAAAAJ&hl=zh-CN)
3 |
4 | [SnowFormer: Context Interaction Transformer with Scale-awareness for Single Image Desnowing](https://arxiv.org/abs/2208.09703)
5 |
6 | > **Abstract:** *Due to various and complicated snow degradations, single image desnowing is a challenging image restoration task. As prior arts can not handle it ideally, we propose a novel transformer, SnowFormer, which explores efficient cross-attentions to build local-global context interaction across patches and surpasses existing works that employ local operators or vanilla transformers. Compared to prior desnowing methods and universal image restoration methods, SnowFormer has several benefits. Firstly, unlike the multi-head self-attention in recent image restoration Vision Transformers, SnowFormer incorporates the multi-head cross-attention mechanism to perform local-global context interaction between scale-aware snow queries and local-patch embeddings. Second, the snow queries in SnowFormer are generated by the query generator from aggregated scale-aware features, which are rich in potential clean cues, leading to superior restoration results. Third, SnowFormer outshines advanced state-of-the-art desnowing networks and the prevalent universal image restoration transformers on six synthetic and real-world datasets.*
7 |
8 | #### News
9 | - **November 22, 2022:** Checkpoint of CSD is updated. :fire:
10 | ## Network Architecture
11 |
12 |
13 |
14 |  |
15 |
16 |
17 |
18 |
19 |
20 |
21 |  |
22 |
23 |
24 |
25 |
26 |
27 |
28 |  |
29 |
30 |
31 |
32 | ## Installation
33 | Our SnowFormer is built in Pytorch1.12.0, we train and test it ion Ubuntu20.04 environment (Python3.8, Cuda11.6).
34 |
35 | For installing, please follow these intructions.
36 | ```
37 | conda create -n py38 python=3.8
38 | conda activate py38
39 | conda install pytorch=1.12
40 | pip install opencv-python tqdm tensorboardX ....
41 | ```
42 | ## Dataset
43 | We sample the 2000 images from every desnowing dataset for the test stage, including CSD, SRRS, Snow100K, SnowCityScapes and SnowKITTI. We provide the download link of datasets, and the password is **ephe**.
44 |
45 |
63 |
64 | ## Testing Stage
65 | You can download pre-trained model of CSD from [Pre-trained model](https://pan.baidu.com/s/1UvokNiZ9ZNhl8tPXg2yysQ) **Password:ephe** and save it in model_path.
66 |
67 | Run the following commands:
68 | ```python
69 | python3 test.py --dataset_type CSD --dataset_CSD dataset_CSD --model_path model_path
70 | ```
71 | The rusults are saved in ./out/dataset_type/
72 |
73 | ## TODO List
74 | - [ ] Checkpoints of SRRS, Snow100K, SnowCityScapes and SnowKITTI
75 | - [ ] Train.py
76 |
77 | ## Citation
78 | ```Bibtex
79 | @article{chen2022snowformer,
80 | title={SnowFormer: Scale-aware Transformer via Context Interaction for Single Image Desnowing},
81 | author={Chen, Sixiang and Ye, Tian and Liu, Yun and Chen, Erkang and Shi, Jun and Zhou, Jingchun},
82 | journal={arXiv preprint arXiv:2208.09703},
83 | year={2022}
84 | }
85 | ```
86 | ## Contact
87 | If you have any questions, please contact the email 282542428@qq.com, ephemeral182@gmail.com or sixiangchen@hkust-gz.edu.cn.
88 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os.path import exists, join as join_paths
4 | import torch
5 | import numpy as np
6 | from torchvision.transforms import functional as FF
7 | from metrics import *
8 | import warnings
9 | from torchvision.utils import save_image,make_grid
10 | from tqdm import tqdm
11 | from torch.utils.data import DataLoader
12 | from dataloader import *
13 |
14 | from SnowFormer import *
15 | from PIL import Image
16 | warnings.filterwarnings("ignore")
17 | from PIL import Image
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--cuda', action='store_true', help='use GPU computation')
21 | parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
22 | parser.add_argument('--tile', type=int, default=256, help='Tile size, None for no tile during testing (testing as a whole)')
23 | parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
24 | parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8')
25 | parser.add_argument('--dataset_type', type=str, default='CSD', help='CSD/SRRS/Snow100K')
26 | parser.add_argument('--dataset_CSD', type=str, default='/home/PublicDataset/CSD/Test', help='path of CSD dataset')
27 | parser.add_argument('--dataset_SRRS', type=str, default='/home/PublicDataset/SRRS/SRRS-2021/', help='path of SRRS dataset')
28 | parser.add_argument('--dataset_Snow100K', type=str, default='/home/PublicDataset/Snow100K/media/jdway/GameSSD/overlapping/test/test/', help='path of Snow100k dataset')
29 | parser.add_argument('--savepath', type=str, default='./out/', help='path of output image')
30 | parser.add_argument('--model_path', type=str, default='/mnt/csx/SnowFormer/SnowFormer/SnowFormer_CSD.pth', help='path of SnowFormer checkpoint')
31 |
32 | opt = parser.parse_args()
33 | if opt.dataset_type == 'CSD':
34 | snow_test = DataLoader(dataset=CSD_Dataset(opt.dataset_CSD,train=False,size=256,rand_inpaint=False,rand_augment=None),batch_size=1,shuffle=False,num_workers=4)
35 | if opt.dataset_type == 'SRRS':
36 | snow_test = DataLoader(dataset=SRRS_Dataset(opt.dataset_SRRS,train=False,size=256,rand_inpaint=False,rand_augment=None),batch_size=1,shuffle=False,num_workers=4)
37 | if opt.dataset_type == 'Snow100K':
38 | snow_test = DataLoader(dataset=Snow100K_Dataset(opt.dataset_Snow100K,train=False,size=256,rand_inpaint=False,rand_augment=None),batch_size=1,shuffle=False,num_workers=4)
39 |
40 |
41 | netG_1 = Transformer().cuda()
42 |
43 | if __name__ == '__main__':
44 |
45 | ssims = []
46 | psnrs = []
47 | rmses = []
48 |
49 | g1ckpt1 = opt.model_path
50 | ckpt = torch.load(g1ckpt1)
51 | netG_1.load_state_dict(ckpt)
52 |
53 | savepath_dataset = os.path.join(opt.savepath,opt.dataset_type)
54 | if not os.path.exists(savepath_dataset):
55 | os.makedirs(savepath_dataset)
56 | loop = tqdm(enumerate(snow_test),total=len(snow_test))
57 |
58 | for idx,(haze,clean,name) in loop:
59 |
60 | with torch.no_grad():
61 |
62 | haze = haze.cuda();clean = clean.cuda()
63 |
64 | b, c, h, w = haze.size()
65 |
66 | tile = min(opt.tile, h, w)
67 | tile_overlap = opt.tile_overlap
68 | sf = opt.scale
69 |
70 | stride = tile - tile_overlap
71 | h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
72 | w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
73 | E1 = torch.zeros(b, c, h*sf, w*sf).type_as(haze)
74 | W1 = torch.zeros_like(E1)
75 |
76 | for h_idx in h_idx_list:
77 | for w_idx in w_idx_list:
78 | in_patch = haze[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
79 | out_patch1 = netG_1(in_patch)
80 | out_patch_mask1 = torch.ones_like(out_patch1)
81 | E1[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch1)
82 | W1[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask1)
83 | dehaze = E1.div_(W1)
84 |
85 | save_image(dehaze,os.path.join(savepath_dataset,'%s.png'%(name)),normalize=False)
86 |
87 |
88 | ssim1=SSIM(dehaze,clean).item()
89 | psnr1=PSNR(dehaze,clean)
90 |
91 | ssims.append(ssim1)
92 | psnrs.append(psnr1)
93 |
94 | print('Generated images %04d of %04d' % (idx+1, len(snow_test)))
95 | print('ssim:',(ssim1))
96 | print('psnr:',(psnr1))
97 |
98 | ssim = np.mean(ssims)
99 | psnr = np.mean(psnrs)
100 | print('ssim_avg:',ssim)
101 | print('psnr_avg:',psnr)
102 |
--------------------------------------------------------------------------------
/condconv.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torch.nn.modules.conv import _ConvNd
7 | from torch.nn.modules.utils import _pair
8 | from torch.nn.parameter import Parameter
9 |
10 |
11 | class _routing(nn.Module):
12 |
13 | def __init__(self, in_channels, num_experts, dropout_rate):
14 | super(_routing, self).__init__()
15 |
16 | self.dropout = nn.Dropout(dropout_rate)
17 | self.fc = nn.Sequential(
18 | nn.Linear(in_channels, in_channels),
19 | nn.LeakyReLU(0.1, True),
20 | nn.Linear(in_channels, num_experts)
21 | )
22 |
23 | def forward(self, x):
24 | x = torch.flatten(x)
25 | x = self.dropout(x)
26 | x = self.fc(x)
27 | return F.sigmoid(x)
28 |
29 |
30 | class DynamicCondConv2D(_ConvNd):
31 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
32 | padding=0, dilation=1, groups=1,
33 | bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2,rooting_channels =512):
34 | kernel_size = _pair(kernel_size)
35 | stride = _pair(stride)
36 | padding = _pair(padding)
37 | dilation = _pair(dilation)
38 | super(DynamicCondConv2D, self).__init__(
39 | in_channels, out_channels, kernel_size, stride, padding, dilation,
40 | False, _pair(0), groups, bias, padding_mode)
41 |
42 | #self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
43 | self._routing_fn = _routing(rooting_channels, num_experts, dropout_rate)
44 |
45 | self.weight = Parameter(torch.Tensor(
46 | num_experts, out_channels, in_channels // groups, *kernel_size))
47 |
48 | self.reset_parameters()
49 |
50 | def _conv_forward(self, input, weight):
51 | if self.padding_mode != 'zeros':
52 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
53 | weight, self.bias, self.stride,
54 | _pair(0), self.dilation, self.groups)
55 | return F.conv2d(input, weight, self.bias, self.stride,
56 | self.padding, self.dilation, self.groups)
57 |
58 | def forward(self, inputs_q):
59 | inputs = inputs_q[0]
60 | kernel_conditions = inputs_q[1]
61 | b, _, _, _ = inputs.size()
62 | res = []
63 | for i,input in enumerate(inputs):
64 | input = input.unsqueeze(0)
65 | routing_weights = self._routing_fn(kernel_conditions[i])
66 | kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
67 | out = self._conv_forward(input, kernels)
68 | res.append(out)
69 | return torch.cat(res, dim=0)
70 |
71 |
72 |
73 | class CondConv2D(_ConvNd):
74 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
75 | padding=0, dilation=1, groups=1,
76 | bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2):
77 | kernel_size = _pair(kernel_size)
78 | stride = _pair(stride)
79 | padding = _pair(padding)
80 | dilation = _pair(dilation)
81 | super(CondConv2D, self).__init__(
82 | in_channels, out_channels, kernel_size, stride, padding, dilation,
83 | False, _pair(0), groups, bias, padding_mode)
84 |
85 | self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
86 | self._routing_fn = _routing(in_channels, num_experts, dropout_rate)
87 |
88 | self.weight = Parameter(torch.Tensor(
89 | num_experts, out_channels, in_channels // groups, *kernel_size))
90 |
91 | self.reset_parameters()
92 |
93 | def _conv_forward(self, input, weight):
94 | if self.padding_mode != 'zeros':
95 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
96 | weight, self.bias, self.stride,
97 | _pair(0), self.dilation, self.groups)
98 | return F.conv2d(input, weight, self.bias, self.stride,
99 | self.padding, self.dilation, self.groups)
100 |
101 | def forward(self, inputs):
102 | b, _, _, _ = inputs.size()
103 | res = []
104 | for input in inputs:
105 | input = input.unsqueeze(0)
106 | pooled_inputs = self._avg_pooling(input)
107 | routing_weights = self._routing_fn(pooled_inputs)
108 | kernels = torch.sum(routing_weights[: ,None, None, None, None] * self.weight, 0)
109 | out = self._conv_forward(input, kernels)
110 | res.append(out)
111 | return torch.cat(res, dim=0)
112 |
113 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from distutils.command.clean import clean
2 | import os
3 | import torch.utils.data as data
4 | import numpy as np
5 | from PIL import Image
6 | import torch.utils.data as data
7 | import torchvision.transforms as tfs
8 | from torchvision.transforms import functional as FF
9 | import os,sys
10 | import random
11 | from PIL import Image
12 | from torchvision.utils import make_grid
13 | #from RandomMask1 import *
14 | random.seed(2)
15 | np.random.seed(2)
16 |
17 | p = 1
18 | AugDict = {
19 | 1:tfs.ColorJitter(brightness=p), #Brightness
20 | 2:tfs.ColorJitter(contrast=p), #Contrast
21 | 3:tfs.ColorJitter(saturation=p), #Saturation
22 | 4:tfs.GaussianBlur(kernel_size=5), #Gaussian Blur
23 | #5:GaussianNoise(std=1), #Gaussian Noise
24 | #5:RandomMaskwithRatio(64,patch_size=4,ratio=0.7), #Random Mask
25 | }
26 |
27 | class CSD_Dataset(data.Dataset):
28 | def __init__(self,path,train=False,size=256,format='.tif',rand_inpaint=False,rand_augment=None):
29 | super(CSD_Dataset,self).__init__()
30 | self.size=size
31 | self.rand_augment=rand_augment
32 | self.rand_inpaint=rand_inpaint
33 | self.InpaintSize = 64
34 | print('crop size',size)
35 | self.train=train
36 | self.format=format
37 | self.haze_imgs_dir=os.listdir(os.path.join(path,'Snow'))
38 | print('======>total number for training:',len(self.haze_imgs_dir))
39 | self.haze_imgs=[os.path.join(path,'Snow',img) for img in self.haze_imgs_dir]
40 | self.clear_dir=os.path.join(path,'Gt')
41 | def __getitem__(self, index):
42 | haze=Image.open(self.haze_imgs[index])
43 | self.format = self.haze_imgs[index].split('/')[-1].split(".")[-1]
44 | while haze.size[0]total number for training:',len(self.haze_imgs_dir))
93 | self.haze_imgs=[os.path.join(path,'Syn',img) for img in self.haze_imgs_dir]
94 | self.clear_dir=os.path.join(path,'gt')
95 | def __getitem__(self, index):
96 | haze=Image.open(self.haze_imgs[index])
97 | self.format = self.haze_imgs[index].split('/')[-1].split(".")[-1]
98 | while haze.size[0]total number for training:',len(self.haze_imgs_dir))
147 | self.haze_imgs=[os.path.join(path,'synthetic',img) for img in self.haze_imgs_dir]
148 | self.clear_dir=os.path.join(path,'gt')
149 | def __getitem__(self, index):
150 | haze=Image.open(self.haze_imgs[index])
151 | self.format = self.haze_imgs[index].split('/')[-1].split(".")[-1]
152 | while haze.size[0] b (h w) c')
15 |
16 | def to_4d(x,h,w):
17 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
18 |
19 | class BiasFree_LayerNorm(nn.Module):
20 | def __init__(self, normalized_shape):
21 | super(BiasFree_LayerNorm, self).__init__()
22 | if isinstance(normalized_shape, numbers.Integral):
23 | normalized_shape = (normalized_shape,)
24 | normalized_shape = torch.Size(normalized_shape)
25 |
26 | assert len(normalized_shape) == 1
27 |
28 | self.weight = nn.Parameter(torch.ones(normalized_shape))
29 | self.normalized_shape = normalized_shape
30 |
31 | def forward(self, x):
32 | sigma = x.var(-1, keepdim=True, unbiased=False)
33 | return x / torch.sqrt(sigma+1e-5) * self.weight
34 |
35 | class WithBias_LayerNorm(nn.Module):
36 | def __init__(self, normalized_shape):
37 | super(WithBias_LayerNorm, self).__init__()
38 | if isinstance(normalized_shape, numbers.Integral):
39 | normalized_shape = (normalized_shape,)
40 | normalized_shape = torch.Size(normalized_shape)
41 |
42 | assert len(normalized_shape) == 1
43 |
44 | self.weight = nn.Parameter(torch.ones(normalized_shape))
45 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
46 | self.normalized_shape = normalized_shape
47 |
48 | def forward(self, x):
49 | mu = x.mean(-1, keepdim=True)
50 | sigma = x.var(-1, keepdim=True, unbiased=False)
51 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
52 |
53 |
54 | class LayerNorm(nn.Module):
55 | def __init__(self, dim, LayerNorm_type):
56 | super(LayerNorm, self).__init__()
57 | if LayerNorm_type =='BiasFree':
58 | self.body = BiasFree_LayerNorm(dim)
59 | else:
60 | self.body = WithBias_LayerNorm(dim)
61 |
62 | def forward(self, x):
63 | h, w = x.shape[-2:]
64 | return to_4d(self.body(to_3d(x)), h, w)
65 | class Down(nn.Module):
66 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding):
67 | super(Down,self).__init__()
68 | self.down = nn.Sequential(
69 | nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding,bias=False),
70 | )
71 | def forward(self,x):
72 | x = self.down(x)
73 | return x
74 |
75 | class Up(nn.Module):
76 | def __init__(self,in_channels,out_channels,kernel_size,stride,padding):
77 | super(Up,self).__init__()
78 | self.up = nn.Sequential(
79 | nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size=kernel_size,stride=stride,padding=padding))
80 | )
81 | def forward(self,x):
82 | x = self.up(x)
83 | return x
84 | class UpSample(nn.Module):
85 | def __init__(self, in_channels,out_channels,s_factor):
86 | super(UpSample, self).__init__()
87 | self.up = nn.Sequential(nn.Upsample(scale_factor=s_factor, mode='bilinear', align_corners=False),
88 | nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False))
89 |
90 | def forward(self, x):
91 | x = self.up(x)
92 | return x
93 |
94 |
95 | class DWconv(nn.Module):
96 | def __init__(self,in_channels,out_channels):
97 | super(DWconv,self).__init__()
98 | self.dwconv = nn.Conv2d(in_channels,out_channels,3,1,1,bias=False,groups=in_channels)
99 | def forward(self,x):
100 | x = self.dwconv(x)
101 | return x
102 |
103 |
104 | class ChannelAttention(nn.Module):
105 | def __init__(self,chns,factor,dynamic=False):
106 | super().__init__()
107 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
108 | if dynamic == False:
109 | self.channel_map = nn.Sequential(
110 | nn.Conv2d(chns,chns//factor,1,1,0),
111 | nn.LeakyReLU(),
112 | nn.Conv2d(chns//factor,chns,1,1,0),
113 | nn.Sigmoid()
114 | )
115 | else:
116 | self.channel_map = nn.Sequential(
117 | CondConv2D(chns,chns//factor,1,1,0),
118 | nn.LeakyReLU(),
119 | CondConv2D(chns//factor,chns,1,1,0),
120 | nn.Sigmoid()
121 | )
122 | def forward(self,x):
123 | avg_pool = self.avg_pool(x)
124 | map = self.channel_map(avg_pool)
125 | return x*map
126 |
127 |
128 | class LKA(nn.Module):
129 | def __init__(self, dim):
130 | super().__init__()
131 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
132 | self.act1 = nn.GELU()
133 | self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
134 | self.act2 = nn.GELU()
135 | self.conv1 = nn.Conv2d(dim, dim, 1)
136 |
137 |
138 | def forward(self, x):
139 | u = x.clone()
140 | attn = self.conv0(x)
141 | attn = self.act1(attn)
142 | attn = self.conv_spatial(attn)
143 | attn = self.act2(attn)
144 | attn = self.conv1(attn)
145 |
146 | return u * attn
147 |
148 | class LKA_dynamic(nn.Module):
149 | def __init__(self, dim):
150 | super().__init__()
151 | self.conv0 = CondConv2D(dim,dim,5,1,2,1,dim)
152 | self.act1 = nn.GELU()
153 | self.conv_spatial = CondConv2D(dim,dim,7,1,9,3,dim)
154 | self.act2 = nn.GELU()
155 | self.conv1 = nn.Conv2d(dim, dim, 1)
156 |
157 |
158 | def forward(self, x):
159 | u = x.clone()
160 | attn = self.conv0(x)
161 | attn = self.act1(attn)
162 | attn = self.conv_spatial(attn)
163 | attn = self.act2(attn)
164 | attn = self.conv1(attn)
165 |
166 | return u * attn
167 |
168 |
169 |
170 |
171 | class Attention(nn.Module):
172 | def __init__(self, d_model,dynamic=True):
173 | super().__init__()
174 |
175 | self.conv11 = nn.Conv2d(d_model,d_model,kernel_size=3,stride=1,padding=1,groups=d_model)
176 | #self.activation = nn.GELU()
177 | if dynamic == True:
178 | self.spatial_gating_unit = LKA_dynamic(d_model)
179 | else:
180 | self.spatial_gating_unit = LKA(d_model)
181 | self.conv22 = nn.Conv2d(d_model,d_model,kernel_size=3,stride=1,padding=1,groups=d_model)
182 |
183 | def forward(self, x):
184 | x = self.conv11(x)
185 | x = self.spatial_gating_unit(x)
186 | x = self.conv22(x)
187 | return x
188 |
189 | class ConvBlock(nn.Module):
190 | def __init__(self, inp, oup, stride, expand_ratio,VAN=False,dynamic=False):
191 | super(ConvBlock, self).__init__()
192 | self.VAN = VAN
193 | hidden_dim = round(inp * expand_ratio)
194 | self.identity = stride == 1 and inp == oup
195 | self.apply(self._init_weight)
196 |
197 | if self.VAN == True:
198 | if expand_ratio == 1:
199 | self.conv = nn.Sequential(
200 |
201 | LayerNorm(hidden_dim, 'BiasFree'),
202 | Attention(hidden_dim,dynamic=dynamic),
203 | )
204 | else:
205 | self.conv = nn.Sequential(
206 |
207 | nn.Conv2d(inp, hidden_dim, 1, 1, 0 ),
208 | LayerNorm(hidden_dim, 'BiasFree'),
209 | Attention(hidden_dim,dynamic=dynamic),
210 | nn.Conv2d(hidden_dim, oup, 1, 1, 0),
211 |
212 | )
213 | else:
214 | if dynamic == False:
215 | if expand_ratio == 1:
216 | self.conv = nn.Sequential(
217 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
218 | LayerNorm(hidden_dim, 'BiasFree'),
219 | nn.GELU(),
220 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
221 | ChannelAttention(hidden_dim,4,dynamic=dynamic),
222 |
223 | nn.Conv2d(hidden_dim, oup, 1, 1, 0),
224 |
225 | )
226 | else:
227 | self.conv = nn.Sequential(
228 | # pw
229 | nn.Conv2d(inp, hidden_dim, 1, 1, 0 ),
230 |
231 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
232 | LayerNorm(hidden_dim, 'BiasFree'),
233 | nn.GELU(),
234 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
235 |
236 | ChannelAttention(hidden_dim,4,dynamic=dynamic),
237 |
238 | nn.Conv2d(hidden_dim, oup, 1, 1, 0),
239 | )
240 | else:
241 | if expand_ratio == 1:
242 | self.conv = nn.Sequential(
243 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
244 | LayerNorm(hidden_dim, 'BiasFree'),
245 | nn.GELU(),
246 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
247 | ChannelAttention(hidden_dim,4,dynamic=dynamic),
248 | # pw-linear
249 | nn.Conv2d(hidden_dim, oup, 1, 1, 0),
250 |
251 | )
252 | else:
253 | self.conv = nn.Sequential(
254 |
255 | nn.Conv2d(inp, hidden_dim, 1, 1, 0 ),
256 |
257 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
258 | LayerNorm(hidden_dim, 'BiasFree'),
259 | nn.GELU(),
260 | CondConv2D(hidden_dim, hidden_dim, 3, stride, 1,dilation=1, groups=hidden_dim),
261 |
262 | ChannelAttention(hidden_dim,4,dynamic=dynamic),
263 |
264 | nn.Conv2d(hidden_dim, oup, 1, 1, 0),
265 | )
266 | def _init_weight(self,m):
267 | if isinstance(m,nn.Linear):
268 | trunc_normal_(m.weight,std=0.02)
269 | if isinstance(m,nn.Linear) and m.bias is not None:
270 | nn.init.constant_(m.bias,0)
271 | elif isinstance(m,nn.LayerNorm):
272 | nn.init.constant_(m.bias,0)
273 | nn.init.constant_(m.weight,1.0)
274 | elif isinstance(m,nn.Conv2d):
275 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
276 | fan_out //= m.groups
277 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out))
278 | if m.bias is not None:
279 | m.bias.data.zero_()
280 |
281 | def forward(self, x):
282 | if self.identity:
283 | return x + self.conv(x)
284 | else:
285 | return self.conv(x)
286 |
287 |
288 | class Conv_block(nn.Module):
289 | def __init__(self,n,in_channel,out_channele,expand_ratio,VAN=False,dynamic=False):
290 | super(Conv_block,self).__init__()
291 |
292 | layers=[]
293 | for i in range(n):
294 | layers.append(ConvBlock(in_channel,out_channele,1 if i==0 else 1,expand_ratio,VAN=VAN,dynamic=dynamic))
295 | in_channel = out_channele
296 | self.model = nn.Sequential(*layers)
297 | def forward(self,x):
298 | x = self.model(x)
299 | return x
300 |
301 |
302 |
303 |
304 |
305 |
--------------------------------------------------------------------------------
/SnowFormer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from timm.models.layers import DropPath,to_2tuple,trunc_normal_
5 | import math
6 | import time
7 | from base_net_snow import *
8 | import torch.nn.functional as F
9 |
10 | def _to_channel_last(x):
11 | """
12 | Args:
13 | x: (B, C, H, W)
14 | Returns:
15 | x: (B, H, W, C)
16 | """
17 | return x.permute(0, 2, 3, 1)
18 |
19 |
20 | def _to_channel_first(x):
21 | """
22 | Args:
23 | x: (B, H, W, C)
24 | Returns:
25 | x: (B, C, H, W)
26 | """
27 | return x.permute(0, 3, 1, 2)
28 |
29 |
30 | def window_partition(x, window_size):
31 | """
32 | Args:
33 | x: (B, H, W, C)
34 | window_size: window size
35 | Returns:
36 | local window features (num_windows*B, window_size, window_size, C)
37 | """
38 | B, H, W, C = x.shape
39 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
40 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
41 | return windows
42 |
43 |
44 | def window_reverse(windows, window_size, H, W):
45 | """
46 | Args:
47 | windows: local window features (num_windows*B, window_size, window_size, C)
48 | window_size: Window size
49 | H: Height of image
50 | W: Width of image
51 | Returns:
52 | x: (B, H, W, C)
53 | """
54 | B = int(windows.shape[0] / (H * W / window_size / window_size))
55 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
56 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
57 | return x
58 |
59 |
60 | class Mlp(nn.Module):
61 | def __init__(self,
62 | in_features,
63 | hidden_features=None,
64 | out_features=None,
65 | act_layer=nn.GELU,
66 | drop=0.):
67 |
68 | super().__init__()
69 | out_features = out_features or in_features
70 | hidden_features = hidden_features or in_features
71 | self.fc1 = nn.Linear(in_features, hidden_features)
72 | self.act = act_layer()
73 | self.fc2 = nn.Linear(hidden_features, out_features)
74 | self.drop = nn.Dropout(drop)
75 |
76 | def forward(self, x):
77 | x = self.fc1(x)
78 | x = self.act(x)
79 | x = self.drop(x)
80 | x = self.fc2(x)
81 | x = self.drop(x)
82 | return x
83 |
84 | class channel_shuffle(nn.Module):
85 | def __init__(self,groups=3):
86 | super(channel_shuffle,self).__init__()
87 | self.groups = groups
88 |
89 | def forward(self,x):
90 | B,C,H,W = x.shape
91 | assert C % self.groups == 0
92 | C_per_group = C // self.groups
93 | x = x.view(B,self.groups,C_per_group,H,W)
94 | x = x.transpose(1,2).contiguous()
95 |
96 | x = x.view(B,C,H,W)
97 | return x
98 |
99 | class overlapPatchEmbed(nn.Module):
100 | def __init__(self,img_size=224,patch_size=7,stride=4,in_channels=3,dim=768):
101 | super(overlapPatchEmbed,self).__init__()
102 |
103 | patch_size=to_2tuple(patch_size)
104 |
105 | self.patch_size = patch_size
106 | self.proj = nn.Conv2d(in_channels,dim,kernel_size=patch_size,stride=stride,padding=(patch_size[0]//2,patch_size[1]//2))
107 | self.norm = nn.LayerNorm(dim)
108 |
109 | self.apply(self._init_weight)
110 |
111 | def _init_weight(self,m):
112 | if isinstance(m,nn.Linear):
113 | trunc_normal_(m.weight,std=0.02)
114 | if isinstance(m,nn.Linear) and m.bias is not None:
115 | nn.init.constant_(m.bias,0)
116 | elif isinstance(m,nn.LayerNorm):
117 | nn.init.constant_(m.bias,0)
118 | nn.init.constant_(m.weight,1.0)
119 | elif isinstance(m,nn.Conv2d):
120 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
121 | fan_out //= m.groups
122 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out))
123 | if m.bias is not None:
124 | m.bias.data.zero_()
125 |
126 | def forward(self,x):
127 | x = self.proj(x)
128 | return x
129 |
130 |
131 | class Attention(nn.Module):
132 | def __init__(self, dim, num_head=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
133 | super().__init__()
134 | assert dim % num_head == 0, f"dim {dim} should be divided by num_heads {num_head}."
135 |
136 | self.dim = dim
137 | self.num_heads = num_head
138 | head_dim = dim // num_head
139 | self.scale = qk_scale or head_dim ** -0.5
140 |
141 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
142 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
143 |
144 | self.attn_drop = nn.Dropout(attn_drop)
145 | self.proj = nn.Conv2d(dim, dim,1,1)
146 | self.proj_drop = nn.Dropout(proj_drop)
147 |
148 | self.sr_ratio = sr_ratio
149 | if sr_ratio > 1:
150 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
151 | self.norm = nn.LayerNorm(dim)
152 | self.conv = nn.Conv2d(dim,dim,3,1,1,groups=dim)
153 | self.apply(self._init_weights)
154 |
155 | def _init_weights(self, m):
156 | if isinstance(m, nn.Linear):
157 | trunc_normal_(m.weight, std=.02)
158 | if isinstance(m, nn.Linear) and m.bias is not None:
159 | nn.init.constant_(m.bias, 0)
160 | elif isinstance(m, nn.LayerNorm):
161 | nn.init.constant_(m.bias, 0)
162 | nn.init.constant_(m.weight, 1.0)
163 | elif isinstance(m, nn.Conv2d):
164 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
165 | fan_out //= m.groups
166 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
167 | if m.bias is not None:
168 | m.bias.data.zero_()
169 |
170 | def forward(self, x, H, W):
171 |
172 | B, N, C = x.shape
173 | x_conv = self.conv(x.reshape(B,H,W,C).permute(0,3,1,2))
174 |
175 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
176 | if self.sr_ratio > 1:
177 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
178 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
179 | x_ = self.norm(x_)
180 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
181 | else:
182 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
183 | k, v = kv[0], kv[1]
184 |
185 | attn = (q @ k.transpose(-2, -1)) * self.scale
186 | attn = attn.softmax(dim=-1)
187 | attn = self.attn_drop(attn)
188 |
189 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
190 | x = self.proj(x.transpose(1,2).reshape(B,C,H,W))
191 | x = self.proj_drop(x)
192 | x = x+x_conv
193 | return x
194 |
195 |
196 | class SimpleGate(nn.Module):
197 | def forward(self,x):
198 | x1, x2 = x.chunk(2, dim=1)
199 | return x1 * x2
200 |
201 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
202 | return nn.Conv2d(
203 | in_channels, out_channels, kernel_size,
204 | padding=(kernel_size//2), bias=bias, stride = stride)
205 |
206 |
207 | class MFFN(nn.Module):
208 | def __init__(self, dim, FFN_expand=2,norm_layer='WithBias'):
209 | super(MFFN, self).__init__()
210 |
211 | self.conv1 = nn.Conv2d(dim,dim*FFN_expand,1)
212 | self.conv33 = nn.Conv2d(dim*FFN_expand,dim*FFN_expand,3,1,1,groups=dim*FFN_expand)
213 | self.conv55 = nn.Conv2d(dim*FFN_expand,dim*FFN_expand,5,1,2,groups=dim*FFN_expand)
214 | self.sg = SimpleGate()
215 | self.conv4 = nn.Conv2d(dim,dim,1)
216 |
217 | self.apply(self._init_weights)
218 | def _init_weights(self,m):
219 | if isinstance(m,nn.Linear):
220 | trunc_normal_(m.weight,std=0.02)
221 | if isinstance(m,nn.Linear) and m.bias is not None:
222 | nn.init.constant_(m.bias,0)
223 | elif isinstance(m,nn.LayerNorm):
224 | nn.init.constant_(m.bias,0)
225 | nn.init.constant_(m.weight,1.0)
226 | elif isinstance(m,nn.Conv2d):
227 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
228 | fan_out //= m.groups
229 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out))
230 | if m.bias is not None:
231 | m.bias.data.zero_()
232 | def forward(self, x):
233 | x1 = self.conv1(x)
234 | x33 = self.conv33(x1)
235 | x55 = self.conv55(x1)
236 | x = x1+x33+x55
237 | x = self.sg(x)
238 | x = self.conv4(x)
239 | return x
240 |
241 | class Scale_aware_Query(nn.Module):
242 | def __init__(self,
243 | dim,
244 | out_channel,
245 | window_size,
246 | num_heads):
247 | super().__init__()
248 | self.dim = dim
249 | self.out_channel = out_channel
250 | self.window_size = window_size
251 | self.conv = nn.Conv2d(dim,out_channel,1,1,0)
252 |
253 | layers=[]
254 | for i in range(3):
255 | layers.append(CALayer(out_channel,4))
256 | layers.append(SALayer(out_channel,4))
257 | self.globalgen = nn.Sequential(*layers)
258 |
259 | self.num_heads = num_heads
260 | self.N = window_size * window_size
261 | self.dim_head = out_channel // self.num_heads
262 |
263 | def forward(self, x):
264 | x = self.conv(x)
265 | x = F.upsample(x,(self.window_size,self.window_size),mode="bicubic")
266 | x = self.globalgen(x)
267 | B = x.shape[0]
268 | x = x.reshape(B, 1, self.N, self.num_heads, self.dim_head).permute(0, 1, 3, 2, 4)
269 | return x
270 |
271 | class LocalContext_Interaction(nn.Module):
272 |
273 | def __init__(self,
274 | dim,
275 | num_heads,
276 | window_size,
277 | qkv_bias=True,
278 | qk_scale=None,
279 | attn_drop=0.,
280 | proj_drop=0.,
281 | ):
282 | """
283 | Args:
284 | dim: feature size dimension.
285 | num_heads: number of attention head.
286 | window_size: window size.
287 | qkv_bias: bool argument for query, key, value learnable bias.
288 | qk_scale: bool argument to scaling query, key.
289 | attn_drop: attention dropout rate.
290 | proj_drop: output dropout rate.
291 | """
292 |
293 | super().__init__()
294 | self.dim = dim
295 | window_size = (window_size,window_size)
296 | self.window_size = window_size
297 | self.num_heads = num_heads
298 | head_dim = dim // num_heads
299 | self.scale = qk_scale or head_dim ** -0.5
300 | self.relative_position_bias_table = nn.Parameter(
301 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
302 | coords_h = torch.arange(self.window_size[0])
303 | coords_w = torch.arange(self.window_size[1])
304 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
305 | coords_flatten = torch.flatten(coords, 1)
306 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
307 | relative_coords = relative_coords.permute(1, 2, 0).contiguous()
308 | relative_coords[:, :, 0] += self.window_size[0] - 1
309 | relative_coords[:, :, 1] += self.window_size[1] - 1
310 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
311 | relative_position_index = relative_coords.sum(-1)
312 | self.register_buffer("relative_position_index", relative_position_index)
313 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
314 | self.attn_drop = nn.Dropout(attn_drop)
315 | self.proj = nn.Linear(dim, dim)
316 | self.proj_drop = nn.Dropout(proj_drop)
317 |
318 | trunc_normal_(self.relative_position_bias_table, std=.02)
319 | self.softmax = nn.Softmax(dim=-1)
320 |
321 | def forward(self, x,q_global):
322 | B_, N, C = x.shape
323 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
324 | q, k, v = qkv[0], qkv[1], qkv[2]
325 | q = q * self.scale
326 | attn = (q @ k.transpose(-2, -1))
327 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
328 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
329 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
330 | attn = attn + relative_position_bias.unsqueeze(0)
331 | attn = self.softmax(attn)
332 | attn = self.attn_drop(attn)
333 |
334 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
335 | x = self.proj(x)
336 | x = self.proj_drop(x)
337 | return x
338 |
339 | class GlobalContext_Interaction(nn.Module):
340 |
341 | def __init__(self,
342 | dim,
343 | num_heads,
344 | window_size,
345 | qkv_bias=True,
346 | qk_scale=None,
347 | attn_drop=0.,
348 | proj_drop=0.,
349 | ):
350 | """
351 | Args:
352 | dim: feature size dimension.
353 | num_heads: number of attention head.
354 | window_size: window size.
355 | qkv_bias: bool argument for query, key, value learnable bias.
356 | qk_scale: bool argument to scaling query, key.
357 | attn_drop: attention dropout rate.
358 | proj_drop: output dropout rate.
359 | """
360 |
361 | super().__init__()
362 | window_size = (window_size, window_size)
363 | self.window_size = window_size
364 | self.num_heads = num_heads
365 | head_dim = dim // num_heads
366 | self.scale = qk_scale or head_dim ** -0.5
367 | self.relative_position_bias_table = nn.Parameter(
368 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
369 | coords_h = torch.arange(self.window_size[0])
370 | coords_w = torch.arange(self.window_size[1])
371 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
372 | coords_flatten = torch.flatten(coords, 1)
373 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
374 | relative_coords = relative_coords.permute(1, 2, 0).contiguous()
375 | relative_coords[:, :, 0] += self.window_size[0] - 1
376 | relative_coords[:, :, 1] += self.window_size[1] - 1
377 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
378 | relative_position_index = relative_coords.sum(-1)
379 | self.register_buffer("relative_position_index", relative_position_index)
380 | self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
381 | self.attn_drop = nn.Dropout(attn_drop)
382 | self.proj = nn.Linear(dim, dim)
383 | self.proj_drop = nn.Dropout(proj_drop)
384 | trunc_normal_(self.relative_position_bias_table, std=.02)
385 | self.softmax = nn.Softmax(dim=-1)
386 |
387 | def forward(self, x, q_global):
388 | B_, N, C = x.shape
389 | B = q_global.shape[0]
390 | kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
391 | k, v = kv[0], kv[1]
392 | q_global = q_global.repeat(1, B_ // B, 1, 1, 1)
393 | q = q_global.reshape(B_, self.num_heads, N, C // self.num_heads)
394 | q = q * self.scale
395 | attn = (q @ k.transpose(-2, -1))
396 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
397 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
398 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
399 | attn = attn + relative_position_bias.unsqueeze(0)
400 | attn = self.softmax(attn)
401 | attn = self.attn_drop(attn)
402 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
403 | x = self.proj(x)
404 | x = self.proj_drop(x)
405 | return x
406 |
407 | class Context_Interaction_Block(nn.Module):
408 |
409 | def __init__(self,
410 | latent_dim,
411 | dim,
412 | num_heads,
413 | window_size=8,
414 | mlp_ratio=4.,
415 | qkv_bias=True,
416 | qk_scale=None,
417 | drop=0.,
418 | attn_drop=0.,
419 | drop_path=0.,
420 | act_layer=nn.GELU,
421 | attention=LocalContext_Interaction,
422 | norm_layer=nn.LayerNorm,
423 | ):
424 |
425 |
426 | super().__init__()
427 | self.window_size = window_size
428 | self.norm1 = norm_layer(dim)
429 |
430 | self.attn = attention(
431 | dim,
432 | num_heads=num_heads,
433 | window_size=window_size,
434 | qkv_bias=qkv_bias,
435 | qk_scale=qk_scale,
436 | attn_drop=attn_drop,
437 | proj_drop=drop,
438 | )
439 |
440 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
441 | self.norm2 = norm_layer(dim)
442 | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
443 | self.layer_scale = False
444 |
445 | self.gamma1 = 1.0
446 | self.gamma2 = 1.0
447 |
448 | def forward(self, x,q_global):
449 | B,H, W,C = x.shape
450 | shortcut = x
451 | x = self.norm1(x)
452 | x_windows = window_partition(x, self.window_size)
453 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
454 | attn_windows = self.attn(x_windows,q_global)
455 | x = window_reverse(attn_windows, self.window_size, H, W)
456 | x = shortcut + self.drop_path(self.gamma1 * x)
457 | x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
458 | return x
459 |
460 | class Context_Interaction_layer(nn.Module):
461 | def __init__(self,n,latent_dim,in_channel,head,window_size,globalatten=False):
462 | super(Context_Interaction_layer,self).__init__()
463 |
464 | #layers=[]
465 | self.globalatten = globalatten
466 | self.model = nn.ModuleList([
467 | Context_Interaction_Block(
468 | latent_dim,
469 | in_channel,
470 | num_heads=head,
471 | window_size=window_size,
472 | attention=GlobalContext_Interaction if i%2 == 1 and self.globalatten == True else LocalContext_Interaction,
473 | )
474 | for i in range(n)])
475 |
476 | if self.globalatten == True:
477 | self.gen = Scale_aware_Query(latent_dim,in_channel,window_size=8,num_heads=head)
478 | def forward(self,x,latent):
479 | if self.globalatten == True:
480 | q_global = self.gen(latent)
481 | x = _to_channel_last(x)
482 | for model in self.model:
483 | x = model(x, q_global)
484 | else:
485 | x = _to_channel_last(x)
486 | for model in self.model:
487 | x = model(x,1)
488 | return x
489 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
490 | return nn.Conv2d(
491 | in_channels, out_channels, kernel_size,
492 | padding=(kernel_size//2), bias=bias, stride = stride)
493 |
494 |
495 | ##########################################################################
496 | ## Channel Attention Layer
497 | class CALayer(nn.Module):
498 | def __init__(self, channel, reduction=4, bias=False):
499 | super(CALayer, self).__init__()
500 | # global average pooling: feature --> point
501 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
502 | self.channel = channel
503 | self.reduction = reduction
504 | # feature channel downscale and upscale --> channel weight
505 | self.conv_du = nn.Sequential(
506 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
507 | nn.ReLU(inplace=True),
508 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
509 | nn.Sigmoid()
510 | )
511 |
512 | def forward(self, x):
513 | y = self.avg_pool(x)
514 | y = self.conv_du(y)
515 | return x * y
516 |
517 | class SALayer(nn.Module):
518 | def __init__(self, channel,reduction=16):
519 | super(SALayer, self).__init__()
520 | self.pa = nn.Sequential(
521 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
522 | nn.ReLU(inplace=True),
523 | nn.Conv2d(channel // reduction, 1, 1, padding=0, bias=True),
524 | nn.Sigmoid()
525 | )
526 | def forward(self, x):
527 | y = self.pa(x)
528 | return x * y
529 |
530 | class Refine_Block(nn.Module):
531 | def __init__(self, n_feat, kernel_size, reduction, bias, act):
532 | super(Refine_Block, self).__init__()
533 | modules_body = []
534 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
535 | modules_body.append(act)
536 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
537 |
538 | self.CA = CALayer(n_feat, reduction, bias=bias)
539 | self.body = nn.Sequential(*modules_body)
540 |
541 | def forward(self, x):
542 | res = self.body(x)
543 | res = self.CA(res)
544 | res += x
545 | return res
546 |
547 | class Refine(nn.Module):
548 | def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab):
549 | super(Refine, self).__init__()
550 | modules_body = []
551 | modules_body = [Refine_Block(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)]
552 | modules_body.append(conv(n_feat, n_feat, kernel_size))
553 | self.body = nn.Sequential(*modules_body)
554 |
555 | def forward(self, x):
556 | res = self.body(x)
557 | res += x
558 | return res
559 |
560 | class HFPH(nn.Module):
561 | def __init__(self, n_feat,fusion_dim, kernel_size, reduction, act, bias, num_cab):
562 | super(HFPH, self).__init__()
563 | self.refine0 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab)
564 | self.refine1 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab)
565 | self.refine2 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab)
566 | self.refine3 = Refine(fusion_dim, kernel_size, reduction, act, bias, num_cab)
567 |
568 | self.up_enc1 = UpSample(n_feat[1],fusion_dim,s_factor=2)
569 | self.up_dec1 = UpSample(n_feat[1],fusion_dim,s_factor=2)
570 |
571 | self.up_enc2 = UpSample(n_feat[2],fusion_dim,s_factor=4)
572 | self.up_dec2 = UpSample(n_feat[2],fusion_dim,s_factor=4)
573 |
574 | self.up_enc3 = UpSample(n_feat[3],fusion_dim,s_factor=8)
575 | self.up_dec3 = UpSample(n_feat[3],fusion_dim,s_factor=8)
576 |
577 | layer0=[]
578 | for i in range(2):
579 | layer0.append(CALayer(fusion_dim,16))
580 | layer0.append(SALayer(fusion_dim,16))
581 | self.conv_enc0 = nn.Sequential(*layer0)
582 |
583 | layer1=[]
584 | for i in range(2):
585 | layer1.append(CALayer(fusion_dim,16))
586 | layer1.append(SALayer(fusion_dim,16))
587 | self.conv_enc1 = nn.Sequential(*layer1)
588 |
589 | layer2=[]
590 | for i in range(2):
591 | layer2.append(CALayer(fusion_dim,16))
592 | layer2.append(SALayer(fusion_dim,16))
593 | self.conv_enc2 = nn.Sequential(*layer2)
594 |
595 | layer3=[]
596 | for i in range(2):
597 | layer3.append(CALayer(fusion_dim,16))
598 | layer3.append(SALayer(fusion_dim,16))
599 | self.conv_enc3 = nn.Sequential(*layer3)
600 |
601 | def forward(self, x, encoder_outs, decoder_outs):
602 | x = x + self.conv_enc0(encoder_outs[0] + decoder_outs[3])
603 | x = self.refine0(x)
604 |
605 | x = x + self.conv_enc1(self.up_enc1(encoder_outs[1]) + self.up_dec1(decoder_outs[2]))
606 | x = self.refine1(x)
607 |
608 | x = x + self.conv_enc2(self.up_enc2(encoder_outs[2]) + self.up_dec2(decoder_outs[1]))
609 | x = self.refine2(x)
610 |
611 | x = x + self.conv_enc3(self.up_enc3(encoder_outs[3]) + self.up_dec3(decoder_outs[0]))
612 | x = self.refine3(x)
613 |
614 | return x
615 |
616 | class Transformer_block(nn.Module):
617 | def __init__(self,dim,num_head=8,groups=2,qkv_bias=False,qk_scale=None,attn_drop=0.,proj_drop=0.,FFN_expand=2,norm_layer='WithBias',act_layer=nn.GELU,l_drop=0.,mlp_ratio=2,drop_path=0.,sr=1):
618 | super(Transformer_block,self).__init__()
619 | self.dim=dim*2
620 | self.num_head = num_head
621 |
622 | self.norm1 = LayerNorm(self.dim//2, norm_layer)
623 | self.norm3 = LayerNorm(self.dim//2, norm_layer)
624 |
625 | self.attn_nn = Attention(dim=self.dim//2,num_head=num_head,qkv_bias=qkv_bias,qk_scale=qk_scale,attn_drop=attn_drop,proj_drop=proj_drop,sr_ratio=sr)
626 |
627 | self.ffn = MFFN(self.dim//2, FFN_expand=2,norm_layer=nn.LayerNorm)
628 |
629 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
630 | self.apply(self._init_weights)
631 |
632 | def _init_weights(self,m):
633 | if isinstance(m,nn.Linear):
634 | trunc_normal_(m.weight,std=0.02)
635 | if isinstance(m,nn.Linear) and m.bias is not None:
636 | nn.init.constant_(m.bias,0)
637 | elif isinstance(m,nn.LayerNorm):
638 | nn.init.constant_(m.bias,0)
639 | nn.init.constant_(m.weight,1.0)
640 | elif isinstance(m,nn.Conv2d):
641 | fan_out = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
642 | fan_out //= m.groups
643 | m.weight.data.normal_(0,math.sqrt(2.0/fan_out))
644 | if m.bias is not None:
645 | m.bias.data.zero_()
646 |
647 | def forward(self,x):
648 | ind = x
649 | b,c,h,w = x.shape
650 | b,c,h,w = x.shape
651 | x = self.attn_nn(self.norm1(x).reshape(b,c,h*w).transpose(1,2),h,w)
652 | b,c,h,w = x.shape
653 | x = self.drop_path(x)
654 | x = x+self.drop_path(self.ffn(self.norm3(x)))
655 | return x
656 |
657 | class Transformer(nn.Module):
658 | def __init__(self,
659 | in_channels=3,
660 | out_cahnnels=3,
661 | transformer_blocks = 8,
662 | dim=[16,32,64,128,256],
663 | window_size = [8,8,8,8],
664 | patch_size = 64,
665 | swin_num = [4,6,7,8],
666 | head = [1,2,4,8,16],
667 | FFN_expand_=2,
668 | qkv_bias_=False,
669 | qk_sacle_=None,
670 | attn_drop_=0.,
671 | proj_drop_=0.,
672 | norm_layer_= 'WithBias',
673 | act_layer_=nn.GELU,
674 | l_drop_=0.,
675 | drop_path_=0.,
676 | sr=1,
677 | conv_num = [4,6,7,8],
678 | expand_ratio = [1,2,2,2],
679 | VAN = False,
680 | dynamic_e = False,
681 | dynamic_d = False,
682 | global_atten = True,
683 | ):
684 | super(Transformer,self).__init__()
685 | self.patch_size = patch_size
686 |
687 | self.embed = Down(in_channels,dim[0],3,1,1)
688 | self.conv0 = nn.Conv2d(dim[0],dim[4],1)
689 | self.encoder_level0 = nn.Sequential(Conv_block(conv_num[0],dim[0],dim[0],expand_ratio=expand_ratio[0],VAN=VAN,dynamic=dynamic_e))
690 |
691 | self.down0 = Down(dim[0],dim[1],3,2,1) ## H//2,W//2
692 | self.conv1 = nn.Conv2d(dim[1],dim[4],1)
693 | self.encoder_level1 = nn.Sequential(Conv_block(conv_num[1],dim[1],dim[1],expand_ratio=expand_ratio[1],VAN=VAN,dynamic=dynamic_e))
694 |
695 | self.down1 = Down(dim[1],dim[2],3,2,1) ## H//4,W//4
696 | self.conv2 = nn.Conv2d(dim[2],dim[4],1)
697 | self.encoder_level2 = nn.Sequential(Conv_block(conv_num[2],dim[2],dim[2],expand_ratio=expand_ratio[2],VAN=VAN,dynamic=dynamic_e))
698 |
699 | self.down2 = Down(dim[2],dim[3],3,2,1) ## H//8,W//8
700 | self.conv3 = nn.Conv2d(dim[3],dim[4],1)
701 | self.encoder_level3 = nn.Sequential(Conv_block(conv_num[3],dim[3],dim[3],expand_ratio=expand_ratio[3],VAN=VAN,dynamic=dynamic_e))
702 |
703 | self.down3 = Down(dim[3],dim[4],3,2,1) ## H//16,W//16
704 |
705 | self.latent = nn.Sequential(*[Transformer_block(dim=(dim[4]),num_head=head[4],qkv_bias=qkv_bias_,qk_scale=qk_sacle_,attn_drop=attn_drop_,proj_drop=proj_drop_,FFN_expand=FFN_expand_,norm_layer=norm_layer_,act_layer=act_layer_,l_drop=l_drop_,drop_path=drop_path_,sr=sr) for i in range(transformer_blocks)])
706 |
707 | self.up3 = Up((dim[4]),dim[3],4,2,1)
708 | self.ca3 = CALayer(dim[3]*2,reduction=4)
709 | self.reduce_chan_level3 = nn.Conv2d(dim[3]*2, dim[3], kernel_size=1, bias=False)
710 | self.decoder_level3 = Context_Interaction_layer(n=swin_num[3],latent_dim=dim[4],in_channel=dim[3],head=head[3],window_size=window_size[3],globalatten=global_atten)
711 | self.up2 = Up(dim[3],dim[2],4,2,1)
712 | self.ca2 = CALayer(dim[2]*2,reduction=4)
713 | self.reduce_chan_level2 = nn.Conv2d(dim[2]*2, dim[2], kernel_size=1, bias=False)
714 | self.decoder_level2 = Context_Interaction_layer(n=swin_num[2],latent_dim=dim[4],in_channel=dim[2],head=head[2],window_size=window_size[2],globalatten=global_atten)
715 |
716 | self.up1 = Up(dim[2],dim[1],4,2,1)
717 | self.ca1 = CALayer(dim[1]*2,reduction=4)
718 | self.reduce_chan_level1 = nn.Conv2d(dim[1]*2, dim[1], kernel_size=1, bias=False)
719 | self.decoder_level1 = Context_Interaction_layer(n=swin_num[1],latent_dim=dim[4],in_channel=dim[1],head=head[1],window_size=window_size[1],globalatten=global_atten)
720 |
721 | self.up0 = Up(dim[1],dim[0],4,2,1)
722 | self.ca0 = CALayer(dim[0]*2,reduction=4)
723 | self.reduce_chan_level0 = nn.Conv2d(dim[0]*2, dim[0], kernel_size=1, bias=False)
724 | self.decoder_level0 = Context_Interaction_layer(n=swin_num[0],latent_dim=dim[4],in_channel=dim[0],head=head[0],window_size=window_size[0],globalatten=global_atten)
725 |
726 | self.refinement = HFPH(n_feat=dim,fusion_dim=dim[0],kernel_size=3,reduction=4,act=nn.GELU(),bias=True,num_cab=6)
727 |
728 | self.out2 = nn.Conv2d(dim[0],out_cahnnels,kernel_size=3,stride=1,padding=1)
729 |
730 | def check_image_size(self, x):
731 | _, _, h, w = x.size()
732 | mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
733 | mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
734 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
735 | return x
736 | def forward(self,x):
737 | x = self.check_image_size(x)
738 | encoder_item = []
739 | decoder_item = []
740 | inp_enc_level0 = self.embed(x)
741 | inp_enc_level0 = self.encoder_level0(inp_enc_level0)
742 | encoder_item.append(inp_enc_level0)
743 |
744 | inp_enc_level1 = self.down0(inp_enc_level0)
745 | inp_enc_level1 = self.encoder_level1(inp_enc_level1)
746 | encoder_item.append(inp_enc_level1)
747 |
748 | inp_enc_level2 = self.down1(inp_enc_level1)
749 | inp_enc_level2 = self.encoder_level2(inp_enc_level2)
750 | encoder_item.append(inp_enc_level2)
751 |
752 | inp_enc_level3 = self.down2(inp_enc_level2)
753 | inp_enc_level3 = self.encoder_level3(inp_enc_level3)
754 | encoder_item.append(inp_enc_level3)
755 |
756 | out_enc_level4 = self.down3(inp_enc_level3)
757 | top_0 = F.adaptive_max_pool2d(inp_enc_level0,(out_enc_level4.shape[2],out_enc_level4.shape[3]))
758 | top_1 = F.adaptive_max_pool2d(inp_enc_level1,(out_enc_level4.shape[2],out_enc_level4.shape[3]))
759 | top_2 = F.adaptive_max_pool2d(inp_enc_level2,(out_enc_level4.shape[2],out_enc_level4.shape[3]))
760 | top_3 = F.adaptive_max_pool2d(inp_enc_level3,(out_enc_level4.shape[2],out_enc_level4.shape[3]))
761 |
762 |
763 | latent = out_enc_level4+self.conv0(top_0)+self.conv1(top_1)+self.conv2(top_2)+self.conv3(top_3)
764 |
765 | latent = self.latent(latent)
766 |
767 | inp_dec_level3 = self.up3(latent)
768 | inp_dec_level3 = F.upsample(inp_dec_level3,(inp_enc_level3.shape[2],inp_enc_level3.shape[3]),mode="bicubic")
769 | inp_dec_level3 = torch.cat([inp_dec_level3, inp_enc_level3], 1)
770 | inp_dec_level3 = self.ca3(inp_dec_level3)
771 | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
772 |
773 | out_dec_level3 = self.decoder_level3(inp_dec_level3,latent)
774 | out_dec_level3 = _to_channel_first(out_dec_level3)
775 | decoder_item.append(out_dec_level3)
776 |
777 | inp_dec_level2 = self.up2(out_dec_level3)
778 | inp_dec_level2 = F.upsample(inp_dec_level2,(inp_enc_level2.shape[2],inp_enc_level2.shape[3]),mode="bicubic")
779 | inp_dec_level2 = torch.cat([inp_dec_level2, inp_enc_level2], 1)
780 | inp_dec_level2 = self.ca2(inp_dec_level2)
781 | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
782 | out_dec_level2 = self.decoder_level2(inp_dec_level2,latent)
783 | out_dec_level2 = _to_channel_first(out_dec_level2)
784 | decoder_item.append(out_dec_level2)
785 |
786 | inp_dec_level1 = self.up1(out_dec_level2)
787 | inp_dec_level1 = F.upsample(inp_dec_level1,(inp_enc_level1.shape[2],inp_enc_level1.shape[3]),mode="bicubic")
788 | inp_dec_level1 = torch.cat([inp_dec_level1, inp_enc_level1], 1)
789 | inp_dec_level1 = self.ca1(inp_dec_level1)
790 | inp_dec_level1 = self.reduce_chan_level1(inp_dec_level1)
791 | out_dec_level1 = self.decoder_level1(inp_dec_level1,latent)
792 | out_dec_level1 = _to_channel_first(out_dec_level1)
793 | decoder_item.append(out_dec_level1)
794 |
795 | inp_dec_level0 = self.up0(out_dec_level1)
796 | inp_dec_level0 = F.upsample(inp_dec_level0,(inp_enc_level0.shape[2],inp_enc_level0.shape[3]),mode="bicubic")
797 | inp_dec_level0 = torch.cat([inp_dec_level0, inp_enc_level0], 1)
798 | inp_dec_level0 = self.ca0(inp_dec_level0)
799 | inp_dec_level0 = self.reduce_chan_level0(inp_dec_level0)
800 | out_dec_level0 = self.decoder_level0(inp_dec_level0,latent)
801 | out_dec_level0 = _to_channel_first(out_dec_level0)
802 | decoder_item.append(out_dec_level0)
803 |
804 | out_dec_level0_refine = self.refinement(out_dec_level0,encoder_item,decoder_item)
805 | out_dec_level1 = self.out2(out_dec_level0_refine) + x
806 |
807 | return out_dec_level1
808 | # from ptflops import get_model_complexity_info
809 |
810 | # model = Transformer().cuda()
811 | # H,W=256,256
812 | # flops_t, params_t = get_model_complexity_info(model, (3, H,W), as_strings=True, print_per_layer_stat=True)
813 |
814 | # print(f"net flops:{flops_t} parameters:{params_t}")
815 | # model = nn.DataParallel(model)
816 | # x = torch.ones([1,3,H,W]).cuda()
817 | # b = model(x)
818 | # steps=25
819 | # # print(b)
820 | # time_avgs=[]
821 | # with torch.no_grad():
822 | # for step in range(steps):
823 |
824 | # torch.cuda.synchronize()
825 | # start = time.time()
826 | # result = model(x)
827 | # torch.cuda.synchronize()
828 | # time_interval = time.time() - start
829 | # if step>5:
830 | # time_avgs.append(time_interval)
831 | # #print('run time:',time_interval)
832 | # print('avg time:',np.mean(time_avgs),'fps:',(1/np.mean(time_avgs)),' size:',H,W)
833 |
--------------------------------------------------------------------------------