├── LAM ├── SaliencyModel │ ├── __init__.py │ ├── attributes.py │ ├── BackProp.py │ └── utils.py ├── ModelZoo │ ├── NN │ │ ├── MPNCOV │ │ │ ├── __init__.py │ │ │ └── python │ │ │ │ ├── __init__.py │ │ │ │ └── MPNCOV.py │ │ ├── fsrcnn.py │ │ ├── edsr.py │ │ ├── rnan.py │ │ ├── rcan.py │ │ ├── rrdbnet.py │ │ ├── __init__.py │ │ ├── drln_ops.py │ │ ├── drln.py │ │ ├── common.py │ │ ├── man.py │ │ └── san.py │ ├── CARN │ │ ├── __init__.py │ │ ├── carn_m.py │ │ ├── carn.py │ │ └── ops.py │ ├── __init__.py │ └── utils.py ├── README.md ├── test_images │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ ├── 9.png │ ├── a.png │ ├── b.png │ ├── c.png │ ├── d.png │ └── e.png ├── test_MAN_diff.py ├── test_MAN.py └── cal_metrix.py ├── images ├── Visual_Results │ ├── D0850 │ │ ├── temp │ │ ├── HR.png │ │ ├── MAN.png │ │ ├── EDSR.png │ │ ├── MAN-Tiny.png │ │ ├── EDSR-Base.png │ │ └── MAN-Light.png │ ├── U004 │ │ ├── HR.png │ │ ├── EDSR.png │ │ ├── MAN.png │ │ ├── EDSR-Base.png │ │ ├── MAN-Light.png │ │ └── MAN-Tiny.png │ ├── U012 │ │ ├── HR.png │ │ ├── EDSR.png │ │ ├── MAN.png │ │ ├── EDSR-Base.png │ │ ├── MAN-Light.png │ │ └── MAN-Tiny.png │ └── U044 │ │ ├── HR.png │ │ ├── EDSR.png │ │ ├── MAN.png │ │ ├── EDSR-Base.png │ │ ├── MAN-Light.png │ │ └── MAN-Tiny.png ├── MAN_arch.png ├── MAN_details.png └── man_ntire24.pdf ├── test.py ├── train.py ├── archs ├── __init__.py └── MAN_arch.py ├── options ├── test_MAN.yml └── trian_MAN.yml ├── README.md └── LICENSE /LAM/SaliencyModel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/MPNCOV/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/Visual_Results/D0850/temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/MPNCOV/python/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LAM/README.md: -------------------------------------------------------------------------------- 1 | # LAM 2 | Run `test_MAN.py` and `test_MAN_diff.py`. 3 | -------------------------------------------------------------------------------- /images/MAN_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/MAN_arch.png -------------------------------------------------------------------------------- /LAM/test_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/1.png -------------------------------------------------------------------------------- /LAM/test_images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/2.png -------------------------------------------------------------------------------- /LAM/test_images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/3.png -------------------------------------------------------------------------------- /LAM/test_images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/4.png -------------------------------------------------------------------------------- /LAM/test_images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/5.png -------------------------------------------------------------------------------- /LAM/test_images/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/6.png -------------------------------------------------------------------------------- /LAM/test_images/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/7.png -------------------------------------------------------------------------------- /LAM/test_images/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/8.png -------------------------------------------------------------------------------- /LAM/test_images/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/9.png -------------------------------------------------------------------------------- /LAM/test_images/a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/a.png -------------------------------------------------------------------------------- /LAM/test_images/b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/b.png -------------------------------------------------------------------------------- /LAM/test_images/c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/c.png -------------------------------------------------------------------------------- /LAM/test_images/d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/d.png -------------------------------------------------------------------------------- /LAM/test_images/e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/e.png -------------------------------------------------------------------------------- /images/MAN_details.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/MAN_details.png -------------------------------------------------------------------------------- /images/man_ntire24.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/man_ntire24.pdf -------------------------------------------------------------------------------- /images/Visual_Results/U004/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/MAN-Light.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/MAN-Light.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/MAN-Light.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/MAN-Light.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os.path as osp 3 | 4 | import archs 5 | 6 | from basicsr.test import test_pipeline 7 | 8 | if __name__ == '__main__': 9 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 10 | test_pipeline(root_path) 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os.path as osp 3 | 4 | import archs 5 | from basicsr.train import train_pipeline 6 | 7 | if __name__ == '__main__': 8 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 9 | train_pipeline(root_path) 10 | -------------------------------------------------------------------------------- /archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules for registry 7 | # scan all the files that end with '_arch.py' under the archs folder 8 | arch_folder = osp.dirname(osp.abspath(__file__)) 9 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 10 | # import all the arch modules 11 | _arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] 12 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | from .carn import Net as CARN 6 | from .carn_m import Net as CARNM 7 | from ModelZoo import MODEL_DIR 8 | 9 | 10 | def load_carn(): 11 | carn_model = CARN(2) 12 | state_dict = torch.load(os.path.join(MODEL_DIR, 'carn.pth'), map_location=torch.device('cpu')) 13 | new_state_dict = OrderedDict() 14 | for k, v in state_dict.items(): 15 | name = k 16 | # name = k[7:] # remove "module." 17 | new_state_dict[name] = v 18 | 19 | carn_model.load_state_dict(new_state_dict) 20 | return carn_model 21 | 22 | def load_carnm(): 23 | carn_model = CARNM(2) 24 | state_dict = torch.load(os.path.join(MODEL_DIR, 'carn_m.pth'), map_location=torch.device('cpu')) 25 | new_state_dict = OrderedDict() 26 | for k, v in state_dict.items(): 27 | name = k 28 | # name = k[7:] # remove "module." 29 | new_state_dict[name] = v 30 | 31 | carn_model.load_state_dict(new_state_dict) 32 | return carn_model -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/fsrcnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | 4 | 5 | class FSRCNN(nn.Module): 6 | def __init__(self, scale_factor, num_channels=3, d=56, s=12, m=4): 7 | super(FSRCNN, self).__init__() 8 | self.first_part = nn.Sequential( 9 | nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2), 10 | nn.PReLU(d) 11 | ) 12 | self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)] 13 | for _ in range(m): 14 | self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)]) 15 | self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)]) 16 | self.mid_part = nn.Sequential(*self.mid_part) 17 | self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2, 18 | output_padding=scale_factor-1) 19 | 20 | self._initialize_weights() 21 | 22 | def _initialize_weights(self): 23 | for m in self.first_part: 24 | if isinstance(m, nn.Conv2d): 25 | nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) 26 | nn.init.zeros_(m.bias.data) 27 | for m in self.mid_part: 28 | if isinstance(m, nn.Conv2d): 29 | nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) 30 | nn.init.zeros_(m.bias.data) 31 | nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001) 32 | nn.init.zeros_(self.last_part.bias.data) 33 | 34 | def forward(self, x): 35 | x = self.first_part(x) 36 | x = self.mid_part(x) 37 | x = self.last_part(x) 38 | return x -------------------------------------------------------------------------------- /options/test_MAN.yml: -------------------------------------------------------------------------------- 1 | name: MANx2 2 | model_type: SRModel 3 | scale: 2 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: # the 1st test dataset 9 | name: Set5 10 | type: PairedImageDataset 11 | dataroot_gt: ./datasets/Set5/GTmod2 12 | dataroot_lq: ./datasets/Set5/LRbicx2 13 | io_backend: 14 | type: disk 15 | 16 | # test_2: # the 2nd test dataset 17 | # name: Set14 18 | # type: PairedImageDataset 19 | # dataroot_gt: ./datasets/Set14/GTmod2 20 | # dataroot_lq: ./datasets/Set14/LRbicx2 21 | # io_backend: 22 | # type: disk 23 | 24 | # test_3: 25 | # name: Urban100 26 | # type: PairedImageDataset 27 | # dataroot_gt: ./datasets/urban100/GTmod2 28 | # dataroot_lq: ./datasets/urban100/LRbicx2 29 | # io_backend: 30 | # type: disk 31 | 32 | # test_4: 33 | # name: BSDS100 34 | # type: PairedImageDataset 35 | # dataroot_gt: ./datasets/BSDS100/GTmod2 36 | # dataroot_lq: ./datasets/BSDS100/LRbicx2 37 | # io_backend: 38 | # type: disk 39 | 40 | # test_5: 41 | # name: Manga109 42 | # type: PairedImageDataset 43 | # dataroot_gt: ./datasets/manga109/GTmod2 44 | # dataroot_lq: ./datasets/manga109/LRbicx2 45 | # io_backend: 46 | # type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: MAN 51 | scale: 2 #or 3/4 52 | n_resblocks: 36 # 5 for MAN-tiny; 24 for MAN-light; 36 for MAN 53 | n_resgroups: 1 54 | n_feats: 180 # 48 for MAN-tiny; 60 for MAN-light; 180 for MAN 55 | 56 | 57 | # path 58 | path: 59 | pretrain_network_g: ./experiments/pretrained_models/HAT_SRx2.pth 60 | strict_load_g: true 61 | param_key_g: 'params_ema' # only for MAN, for MAN-T and MAN-L, using ~ 62 | 63 | # validation settings 64 | val: 65 | save_img: true 66 | suffix: ~ # add suffix to saved images, if None, use exp name 67 | 68 | metrics: 69 | psnr: # metric name, can be arbitrary 70 | type: calculate_psnr 71 | crop_border: 2 72 | test_y_channel: true 73 | ssim: 74 | type: calculate_ssim 75 | crop_border: 2 76 | test_y_channel: true 77 | -------------------------------------------------------------------------------- /LAM/SaliencyModel/attributes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import torch.nn.functional as F 4 | 5 | 6 | def _tensor_size(t): 7 | return t.size()[1] * t.size()[2] * t.size()[3] 8 | 9 | def reduce_func(method): 10 | """ 11 | 12 | :param method: ['mean', 'sum', 'max', 'min', 'count', 'std'] 13 | :return: 14 | """ 15 | if method == 'sum': 16 | return torch.sum 17 | elif method == 'mean': 18 | return torch.mean 19 | elif method == 'count': 20 | return lambda x: sum(x.size()) 21 | else: 22 | raise NotImplementedError() 23 | 24 | 25 | def attr_id(tensor, h, w, window=8, reduce='sum'): 26 | """ 27 | :param tensor: B, C, H, W tensor 28 | :param h: h position 29 | :param w: w position 30 | :param window: size of window 31 | :param reduce: reduce method, ['mean', 'sum', 'max', 'min'] 32 | :return: 33 | """ 34 | crop = tensor[:, :, h: h + window, w: w + window] 35 | return reduce_func(reduce)(crop) 36 | 37 | 38 | def attr_grad(tensor, h, w, window=8, reduce='sum'): 39 | """ 40 | :param tensor: B, C, H, W tensor 41 | :param h: h position 42 | :param w: w position 43 | :param window: size of window 44 | :param reduce: reduce method, ['mean', 'sum', 'max', 'min'] 45 | :return: 46 | """ 47 | h_x = tensor.size()[2] 48 | w_x = tensor.size()[3] 49 | h_grad = torch.pow(tensor[:, :, :h_x - 1, :] - tensor[:, :, 1:, :], 2) 50 | w_grad = torch.pow(tensor[:, :, :, :w_x - 1] - tensor[:, :, :, 1:], 2) 51 | grad = torch.pow(h_grad[:, :, :, :-1] + w_grad[:, :, :-1, :], 1 / 2) 52 | crop = grad[:, :, h: h + window, w: w + window] 53 | return reduce_func(reduce)(crop) 54 | 55 | 56 | # gabor_filter = cv2.getGaborKernel((21, 21), 10.0, -np.pi/4, 8.0, 1, 0, ktype=cv2.CV_32F) 57 | 58 | def attr_gabor_generator(gabor_filter): 59 | filter = torch.from_numpy(gabor_filter).view((1, 1,) + gabor_filter.shape).repeat(1,3,1,1) 60 | def attr_gabor(tensor, h, w, window=8, reduce='sum'): 61 | after_filter = F.conv2d(tensor, filter, bias=None) 62 | crop = after_filter[:, :, h: h + window, w: w + window] 63 | return reduce_func(reduce)(crop) 64 | return attr_gabor 65 | 66 | 67 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/carn_m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..CARN import ops 4 | 5 | 6 | class Block(nn.Module): 7 | def __init__(self, 8 | in_channels, out_channels, 9 | group=1): 10 | super(Block, self).__init__() 11 | 12 | self.b1 = ops.EResidualBlock(64, 64, group=group) 13 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 14 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 15 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 16 | 17 | def forward(self, x): 18 | c0 = o0 = x 19 | 20 | b1 = self.b1(o0) 21 | c1 = torch.cat([c0, b1], dim=1) 22 | o1 = self.c1(c1) 23 | 24 | b2 = self.b1(o1) 25 | c2 = torch.cat([c1, b2], dim=1) 26 | o2 = self.c2(c2) 27 | 28 | b3 = self.b1(o2) 29 | c3 = torch.cat([c2, b3], dim=1) 30 | o3 = self.c3(c3) 31 | 32 | return o3 33 | 34 | 35 | class Net(nn.Module): 36 | def __init__(self, scale): 37 | super(Net, self).__init__() 38 | 39 | multi_scale = True 40 | group = 1 41 | 42 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 43 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 44 | 45 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 46 | 47 | self.b1 = Block(64, 64, group=group) 48 | self.b2 = Block(64, 64, group=group) 49 | self.b3 = Block(64, 64, group=group) 50 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 51 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 52 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 53 | 54 | self.upsample = ops.UpsampleBlock(64, scale=scale, 55 | multi_scale=multi_scale, 56 | group=group) 57 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 58 | 59 | def forward(self, x, scale): 60 | x = self.sub_mean(x) 61 | x = self.entry(x) 62 | c0 = o0 = x 63 | 64 | b1 = self.b1(o0) 65 | c1 = torch.cat([c0, b1], dim=1) 66 | o1 = self.c1(c1) 67 | 68 | b2 = self.b2(o1) 69 | c2 = torch.cat([c1, b2], dim=1) 70 | o2 = self.c2(c2) 71 | 72 | b3 = self.b3(o2) 73 | c3 = torch.cat([c2, b3], dim=1) 74 | o3 = self.c3(c3) 75 | 76 | out = self.upsample(o3, scale=scale) 77 | 78 | out = self.exit(out) 79 | out = self.add_mean(out) 80 | 81 | return out -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/edsr.py: -------------------------------------------------------------------------------- 1 | from ..NN import common 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def make_model(args, parent=False): 7 | return EDSR(args) 8 | 9 | 10 | class EDSR(nn.Module): 11 | def __init__(self, num_channels=3, factor=4, width=64, depth=16, kernel_size=3,res_scale=1.0, conv=common.default_conv): 12 | super(EDSR, self).__init__() 13 | 14 | n_resblock = depth 15 | n_feats = width 16 | kernel_size = kernel_size 17 | scale = factor 18 | act = nn.ReLU(True) 19 | 20 | rgb_mean = (0.4488, 0.4371, 0.4040) 21 | rgb_std = (1.0, 1.0, 1.0) 22 | self.sub_mean = common.MeanShift(255.0, rgb_mean, rgb_std) 23 | 24 | # define head module 25 | m_head = [conv(num_channels, n_feats, kernel_size)] 26 | 27 | # define body module 28 | m_body = [ 29 | common.ResBlock( 30 | conv, n_feats, kernel_size, act=act, res_scale=res_scale 31 | ) for _ in range(n_resblock) 32 | ] 33 | m_body.append(conv(n_feats, n_feats, kernel_size)) 34 | 35 | # define tail module 36 | m_tail = [ 37 | common.Upsampler(conv, scale, n_feats, act=False), 38 | conv(n_feats, num_channels, kernel_size) 39 | ] 40 | 41 | self.add_mean = common.MeanShift(255.0, rgb_mean, rgb_std, 1) 42 | 43 | self.head = nn.Sequential(*m_head) 44 | self.body = nn.Sequential(*m_body) 45 | self.tail = nn.Sequential(*m_tail) 46 | 47 | def forward(self, x): 48 | x = x*255 49 | x = self.sub_mean(x) 50 | x = self.head(x) 51 | 52 | res = self.body(x) 53 | res += x 54 | 55 | x = self.tail(res) 56 | x = self.add_mean(x) 57 | x = x/255 58 | return x 59 | 60 | def load_state_dict(self, state_dict, strict=True): 61 | own_state = self.state_dict() 62 | for name, param in state_dict.items(): 63 | if name in own_state: 64 | if isinstance(param, nn.Parameter): 65 | param = param.data 66 | try: 67 | own_state[name].copy_(param) 68 | except Exception: 69 | if name.find('tail') == -1: 70 | raise RuntimeError('While copying the parameter named {}, ' 71 | 'whose dimensions in the model are {} and ' 72 | 'whose dimensions in the checkpoint are {}.' 73 | .format(name, own_state[name].size(), param.size())) 74 | elif strict: 75 | if name.find('tail') == -1: 76 | raise KeyError('unexpected key "{}" in state_dict' 77 | .format(name)) -------------------------------------------------------------------------------- /LAM/test_MAN_diff.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch, cv2, os, sys, numpy as np, matplotlib.pyplot as plt 3 | from PIL import Image 4 | from ModelZoo.utils import load_as_tensor, Tensor2PIL, PIL2Tensor, _add_batch_one 5 | from ModelZoo import get_model, load_model, print_network 6 | from SaliencyModel.utils import vis_saliency, vis_saliency_kde, click_select_position, grad_abs_norm, grad_norm, prepare_images, make_pil_grid, blend_input,plot_diff_of_attrs_kde 7 | from SaliencyModel.utils import cv2_to_pil, pil_to_cv2, gini 8 | from SaliencyModel.attributes import attr_grad 9 | from SaliencyModel.BackProp import I_gradient, attribution_objective, Path_gradient 10 | from SaliencyModel.BackProp import saliency_map_PG as saliency_map 11 | from SaliencyModel.BackProp import GaussianBlurPath 12 | from SaliencyModel.utils import grad_norm, IG_baseline, interpolation, isotropic_gaussian_kernel 13 | from cal_metrix import calculate_psnr, calculate_ssim, bgr2ycbcr, tensor2img 14 | 15 | 16 | def LAM(image_path='test_images/3.png',w=90, h=120): 17 | window_size = 16 # Define windoes_size of D 18 | img_lr, img_hr = prepare_images(image_path) # Change this image name 19 | tensor_lr = PIL2Tensor(img_lr)[:3] ; tensor_hr = PIL2Tensor(img_hr)[:3] 20 | cv2_lr = np.moveaxis(tensor_lr.numpy(), 0, 2) ; cv2_hr = np.moveaxis(tensor_hr.numpy(), 0, 2) 21 | draw_img = pil_to_cv2(img_hr) 22 | cv2.rectangle(draw_img, (w, h), (w + window_size, h + window_size), (0, 0, 255), 2) 23 | position_pil = cv2_to_pil(draw_img) 24 | plt.imshow(position_pil) 25 | 26 | 27 | model = load_model('EDSR@Large') 28 | 29 | sigma = 1.2 ; fold = 50 ; l = 9 ; alpha = 0.5 30 | attr_objective = attribution_objective(attr_grad, h, w, window=window_size) 31 | gaus_blur_path_func = GaussianBlurPath(sigma, fold, l) 32 | interpolated_grad_numpy, result_numpy, interpolated_numpy = Path_gradient(tensor_lr.numpy(), model, attr_objective, gaus_blur_path_func, cuda=True) 33 | grad_numpy, result = saliency_map(interpolated_grad_numpy, result_numpy) 34 | abs_normed_grad_numpy = grad_abs_norm(grad_numpy) 35 | B = abs_normed_grad_numpy 36 | 37 | 38 | model = load_model('MAN@Light') 39 | 40 | attr_objective = attribution_objective(attr_grad, h, w, window=window_size) 41 | gaus_blur_path_func = GaussianBlurPath(sigma, fold, l) 42 | interpolated_grad_numpy, result_numpy, interpolated_numpy = Path_gradient(tensor_lr.numpy(), model, attr_objective, gaus_blur_path_func, cuda=True) 43 | grad_numpy, result = saliency_map(interpolated_grad_numpy, result_numpy) 44 | abs_normed_grad_numpy = grad_abs_norm(grad_numpy) 45 | A = abs_normed_grad_numpy 46 | 47 | 48 | res = plot_diff_of_attrs_kde(A,B) 49 | 50 | 51 | plt.axis('off') 52 | 53 | plt.imshow(res) 54 | plt.savefig('./lam_results/{}/diff_LE.png'.format(image_path[-5]), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 55 | 56 | 57 | 58 | 59 | 60 | image_path='test_images/e.png' 61 | w=120 62 | h=100 63 | LAM(image_path,w,h) -------------------------------------------------------------------------------- /options/trian_MAN.yml: -------------------------------------------------------------------------------- 1 | name: MAN_SR 2 | model_type: SRModel 3 | scale: 2 # 2/3/4/8 4 | num_gpu: 8 # or 4 5 | manual_seed: 10 6 | 7 | datasets: 8 | train: 9 | name: DF2K 10 | type: PairedImageDataset 11 | dataroot_gt: datasets/DF2K/DF2K_HR_sub 12 | dataroot_lq: datasets/DF2K/DF2K_bicx2_sub 13 | meta_info_file: hat/data/meta_info/meta_info_DF2Ksub_GT.txt 14 | io_backend: 15 | type: disk 16 | 17 | gt_size: 128 #scale*48 or scale*64 18 | use_hflip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 6 24 | batch_size_per_gpu: 4 25 | dataset_enlarge_ratio: 1 26 | prefetch_mode: ~ 27 | 28 | val_1: 29 | name: Set5 30 | type: PairedImageDataset 31 | dataroot_gt: ./datasets/Set5/GTmod2 32 | dataroot_lq: ./datasets/Set5/LRbicx2 33 | io_backend: 34 | type: disk 35 | 36 | val_2: 37 | name: Set14 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/Set14/GTmod2 40 | dataroot_lq: ./datasets/Set14/LRbicx2 41 | io_backend: 42 | type: disk 43 | 44 | # val_3: 45 | # name: Urban100 46 | # type: PairedImageDataset 47 | # dataroot_gt: ./datasets/urban100/GTmod2 48 | # dataroot_lq: ./datasets/urban100/LRbicx2 49 | # io_backend: 50 | # type: disk 51 | 52 | 53 | # network structures 54 | network_g: 55 | type: MAN 56 | scale: 2 #or 3/4 57 | n_resblocks: 36 # 5 for MAN-tiny; 24 for MAN-light; 36 for MAN 58 | n_resgroups: 1 59 | n_feats: 180 # 48 for MAN-tiny; 60 for MAN-light; 180 for MAN 60 | 61 | # path 62 | path: 63 | pretrain_network_g: ~ 64 | strict_load_g: true 65 | resume_state: ~ 66 | 67 | # training settings 68 | train: 69 | ema_decay: 0.999 70 | optim_g: 71 | type: Adam 72 | lr: !!float 5e-4 73 | weight_decay: 0 74 | betas: [0.9, 0.99] 75 | 76 | scheduler: 77 | type: MultiStepLR 78 | milestones: [800000, 1200000, 140000, 1500000] 79 | gamma: 0.5 80 | 81 | #type: CosineAnnealingRestartLR 82 | #periods: [1600000] 83 | #restart_weights: [1] 84 | #eta_min: !!float 1e-7 85 | 86 | total_iter: 1600000 87 | warmup_iter: -1 # no warm up 88 | 89 | # losses 90 | pixel_opt: 91 | type: L1Loss 92 | loss_weight: 1.0 93 | reduction: mean 94 | 95 | # validation settings 96 | val: 97 | val_freq: !!float 5e3 98 | save_img: false 99 | pbar: False 100 | 101 | metrics: 102 | psnr: 103 | type: calculate_psnr 104 | crop_border: 2 # 2/3/4 105 | test_y_channel: true 106 | better: higher # the higher, the better. Default: higher 107 | ssim: 108 | type: calculate_ssim 109 | crop_border: 2 # 2/3/4 110 | test_y_channel: true 111 | better: higher # the higher, the better. Default: higher 112 | 113 | # logging settings 114 | logger: 115 | print_freq: 100 116 | save_checkpoint_freq: !!float 5e3 117 | use_tb_logger: true 118 | wandb: 119 | project: ~ 120 | resume_id: ~ 121 | 122 | # dist training settings 123 | dist_params: 124 | backend: nccl 125 | port: 29500 126 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/carn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..CARN import ops 4 | 5 | 6 | class Block(nn.Module): 7 | def __init__(self, 8 | in_channels, out_channels, 9 | group=1): 10 | super(Block, self).__init__() 11 | 12 | self.b1 = ops.ResidualBlock(64, 64) 13 | self.b2 = ops.ResidualBlock(64, 64) 14 | self.b3 = ops.ResidualBlock(64, 64) 15 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 16 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 17 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 18 | 19 | def forward(self, x): 20 | c0 = o0 = x 21 | 22 | b1 = self.b1(o0) 23 | c1 = torch.cat([c0, b1], dim=1) 24 | o1 = self.c1(c1) 25 | 26 | b2 = self.b2(o1) 27 | c2 = torch.cat([c1, b2], dim=1) 28 | o2 = self.c2(c2) 29 | 30 | b3 = self.b3(o2) 31 | c3 = torch.cat([c2, b3], dim=1) 32 | o3 = self.c3(c3) 33 | 34 | return o3 35 | 36 | 37 | class Net(nn.Module): 38 | def __init__(self, scale): 39 | super(Net, self).__init__() 40 | 41 | multi_scale = True 42 | group = 1 43 | 44 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 45 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 46 | 47 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 48 | 49 | self.b1 = Block(64, 64) 50 | self.b2 = Block(64, 64) 51 | self.b3 = Block(64, 64) 52 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 53 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 54 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 55 | 56 | self.upsample = ops.UpsampleBlock(64, scale=scale, 57 | multi_scale=multi_scale, 58 | group=group) 59 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 60 | 61 | def forward(self, x, scale=4): 62 | x = self.sub_mean(x) 63 | x = self.entry(x) 64 | c0 = o0 = x 65 | 66 | b1 = self.b1(o0) 67 | c1 = torch.cat([c0, b1], dim=1) 68 | o1 = self.c1(c1) 69 | 70 | b2 = self.b2(o1) 71 | c2 = torch.cat([c1, b2], dim=1) 72 | o2 = self.c2(c2) 73 | 74 | b3 = self.b3(o2) 75 | c3 = torch.cat([c2, b3], dim=1) 76 | o3 = self.c3(c3) 77 | 78 | out = self.upsample(o3, scale=scale) 79 | 80 | out = self.exit(out) 81 | out = self.add_mean(out) 82 | 83 | return out 84 | 85 | 86 | class CARNet(nn.Module): 87 | def __init__(self, factor=4, num_channels=3): 88 | super(CARNet, self).__init__() 89 | 90 | multi_scale = False 91 | group = 1 92 | 93 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 94 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 95 | 96 | self.entry = nn.Conv2d(num_channels, 64, 3, 1, 1) 97 | 98 | self.b1 = Block(64, 64) 99 | self.b2 = Block(64, 64) 100 | self.b3 = Block(64, 64) 101 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 102 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 103 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 104 | 105 | self.upsample = ops.UpsampleBlock(64, scale=factor, 106 | multi_scale=multi_scale, 107 | group=group) 108 | self.exit = nn.Conv2d(64, num_channels, 3, 1, 1) 109 | 110 | def forward(self, x): 111 | x = self.sub_mean(x) 112 | x = self.entry(x) 113 | c0 = o0 = x 114 | 115 | b1 = self.b1(o0) 116 | c1 = torch.cat([c0, b1], dim=1) 117 | o1 = self.c1(c1) 118 | 119 | b2 = self.b2(o1) 120 | c2 = torch.cat([c1, b2], dim=1) 121 | o2 = self.c2(c2) 122 | 123 | b3 = self.b3(o2) 124 | c3 = torch.cat([c2, b3], dim=1) 125 | o3 = self.c3(c3) 126 | 127 | out = self.upsample(o3) 128 | 129 | out = self.exit(out) 130 | out = self.add_mean(out) 131 | 132 | return out -------------------------------------------------------------------------------- /LAM/ModelZoo/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | MODEL_DIR = 'ModelZoo/models' 5 | 6 | 7 | NN_LIST = [ 8 | 'RCAN', 9 | 'CARN', 10 | 'RRDBNet', 11 | 'RNAN', 12 | 'SAN', 13 | 'MAN', 14 | 'EDSR' 15 | ] 16 | 17 | 18 | MODEL_LIST = { 19 | 'RCAN': { 20 | 'Base': 'RCAN.pt', 21 | }, 22 | 'CARN': { 23 | 'Base': 'CARN_7400.pth', 24 | }, 25 | 'RRDBNet': { 26 | 'Base': 'RRDBNet_PSNR_SRx4_DF2K_official-150ff491.pth', 27 | }, 28 | 'SAN': { 29 | 'Base': 'SAN_BI4X.pt', 30 | }, 31 | 'RNAN': { 32 | 'Base': 'RNAN_SR_F64G10P48BIX4.pt', 33 | }, 34 | 'EDSR': { 35 | 'Base': 'edsr_baseline_x4-6b446fab.pt', 36 | 'Large': 'edsr_x4-4f62e9ef.pt', 37 | }, 38 | 'MAN': { 39 | 'Base': 'MANx4_DF2K.pth', 40 | 'Light': 'MAN-Light-x4.pth', 41 | 'Tiny': 'MAN-Tiny-x4.pth', 42 | }, 43 | } 44 | 45 | def print_network(model, model_name): 46 | num_params = 0 47 | for param in model.parameters(): 48 | num_params += param.numel() 49 | print('Network [%s] was created. Total number of parameters: %.1f kelo. ' 50 | 'To see the architecture, do print(network).' 51 | % (model_name, num_params / 1000)) 52 | 53 | 54 | def get_model(model_name, training_name='Base', factor=4, num_channels=3): 55 | """ 56 | All the models are defaulted to be X4 models, the Channels is defaulted to be RGB 3 channels. 57 | :param model_name: 58 | :param factor: 59 | :param num_channels: 60 | :return: 61 | """ 62 | print(f'Getting SR Network {model_name}') 63 | if model_name.split('-')[0] in NN_LIST: 64 | 65 | if model_name == 'RCAN': 66 | from .NN.rcan import RCAN 67 | net = RCAN(factor=factor, num_channels=num_channels) 68 | 69 | elif model_name == 'CARN': 70 | from .CARN.carn import CARNet 71 | net = CARNet(factor=factor, num_channels=num_channels) 72 | 73 | elif model_name == 'RRDBNet': 74 | from .NN.rrdbnet import RRDBNet 75 | net = RRDBNet(num_in_ch=num_channels, num_out_ch=num_channels) 76 | 77 | elif model_name == 'SAN': 78 | from .NN.san import SAN 79 | net = SAN(factor=factor, num_channels=num_channels) 80 | 81 | elif model_name == 'RNAN': 82 | from .NN.rnan import RNAN 83 | net = RNAN(factor=factor, num_channels=num_channels) 84 | 85 | elif model_name == 'EDSR': 86 | from .NN.edsr import EDSR 87 | if training_name == 'Base': 88 | net = EDSR(factor=factor, width=64, depth=16) 89 | else: 90 | net = EDSR(factor=factor, width=256, depth=32, res_scale=0.1) 91 | 92 | elif model_name == 'MAN': 93 | from .NN.man import MAN 94 | if training_name == 'Base': 95 | net = MAN(n_resblocks=36, n_feats=180, scale=factor) 96 | elif training_name == 'Tiny': 97 | net = MAN(n_resblocks=5, n_feats=48, scale=factor) 98 | else: 99 | net = MAN(n_resblocks=24, n_feats=60, scale=factor) 100 | else: 101 | raise NotImplementedError() 102 | 103 | print_network(net, model_name) 104 | return net 105 | else: 106 | raise NotImplementedError() 107 | 108 | 109 | def load_model(model_loading_name): 110 | """ 111 | :param model_loading_name: model_name-training_name 112 | :return: 113 | """ 114 | splitting = model_loading_name.split('@') 115 | if len(splitting) == 1: 116 | model_name = splitting[0] 117 | training_name = 'Base' 118 | elif len(splitting) == 2: 119 | model_name = splitting[0] 120 | training_name = splitting[1] 121 | else: 122 | raise NotImplementedError() 123 | assert model_name in NN_LIST or model_name in MODEL_LIST.keys(), 'check your model name before @' 124 | net = get_model(model_name,training_name) 125 | state_dict_path = os.path.join(MODEL_DIR, MODEL_LIST[model_name][training_name]) 126 | print(f'Loading model {state_dict_path} for {model_name} network.') 127 | state_dict = torch.load(state_dict_path, map_location='cpu') 128 | if model_loading_name =='MAN@Base': 129 | state_dict = state_dict['params_ema'] 130 | net.load_state_dict(state_dict) 131 | return net 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | def init_weights(modules): 9 | pass 10 | 11 | 12 | class MeanShift(nn.Module): 13 | def __init__(self, mean_rgb, sub): 14 | super(MeanShift, self).__init__() 15 | 16 | sign = -1 if sub else 1 17 | r = mean_rgb[0] * sign 18 | g = mean_rgb[1] * sign 19 | b = mean_rgb[2] * sign 20 | 21 | self.shifter = nn.Conv2d(3, 3, 1, 1, 0) 22 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 23 | self.shifter.bias.data = torch.Tensor([r, g, b]) 24 | 25 | # Freeze the mean shift layer 26 | for params in self.shifter.parameters(): 27 | params.requires_grad = False 28 | 29 | def forward(self, x): 30 | x = self.shifter(x) 31 | return x 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | def __init__(self, 36 | in_channels, out_channels, 37 | ksize=3, stride=1, pad=1): 38 | super(BasicBlock, self).__init__() 39 | 40 | self.body = nn.Sequential( 41 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad), 42 | nn.ReLU(inplace=True) 43 | ) 44 | 45 | init_weights(self.modules) 46 | 47 | def forward(self, x): 48 | out = self.body(x) 49 | return out 50 | 51 | 52 | class ResidualBlock(nn.Module): 53 | def __init__(self, 54 | in_channels, out_channels): 55 | super(ResidualBlock, self).__init__() 56 | 57 | self.body = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 61 | ) 62 | 63 | init_weights(self.modules) 64 | 65 | def forward(self, x): 66 | out = self.body(x) 67 | out = F.relu(out + x) 68 | return out 69 | 70 | 71 | class EResidualBlock(nn.Module): 72 | def __init__(self, 73 | in_channels, out_channels, 74 | group=1): 75 | super(EResidualBlock, self).__init__() 76 | 77 | self.body = nn.Sequential( 78 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 83 | ) 84 | 85 | init_weights(self.modules) 86 | 87 | def forward(self, x): 88 | out = self.body(x) 89 | out = F.relu(out + x) 90 | return out 91 | 92 | 93 | class UpsampleBlock(nn.Module): 94 | def __init__(self, 95 | n_channels, scale, multi_scale, 96 | group=1): 97 | super(UpsampleBlock, self).__init__() 98 | 99 | if multi_scale: 100 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) 101 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) 102 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) 103 | else: 104 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group) 105 | 106 | self.multi_scale = multi_scale 107 | 108 | def forward(self, x, scale=None): 109 | if self.multi_scale: 110 | if scale == 2: 111 | return self.up2(x) 112 | elif scale == 3: 113 | return self.up3(x) 114 | elif scale == 4: 115 | return self.up4(x) 116 | else: 117 | return self.up(x) 118 | 119 | 120 | class _UpsampleBlock(nn.Module): 121 | def __init__(self, 122 | n_channels, scale, 123 | group=1): 124 | super(_UpsampleBlock, self).__init__() 125 | 126 | modules = [] 127 | if scale == 2 or scale == 4 or scale == 8: 128 | for _ in range(int(math.log(scale, 2))): 129 | modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 130 | modules += [nn.PixelShuffle(2)] 131 | elif scale == 3: 132 | modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 133 | modules += [nn.PixelShuffle(3)] 134 | 135 | self.body = nn.Sequential(*modules) 136 | init_weights(self.modules) 137 | 138 | def forward(self, x): 139 | out = self.body(x) 140 | return out -------------------------------------------------------------------------------- /LAM/test_MAN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch, cv2, os, sys, numpy as np, matplotlib.pyplot as plt 3 | from PIL import Image 4 | from ModelZoo.utils import load_as_tensor, Tensor2PIL, PIL2Tensor, _add_batch_one 5 | from ModelZoo import get_model, load_model, print_network 6 | from SaliencyModel.utils import vis_saliency, vis_saliency_kde, click_select_position, grad_abs_norm, grad_norm, prepare_images, make_pil_grid, blend_input 7 | from SaliencyModel.utils import cv2_to_pil, pil_to_cv2, gini 8 | from SaliencyModel.attributes import attr_grad 9 | from SaliencyModel.BackProp import I_gradient, attribution_objective, Path_gradient 10 | from SaliencyModel.BackProp import saliency_map_PG as saliency_map 11 | from SaliencyModel.BackProp import GaussianBlurPath 12 | from SaliencyModel.utils import grad_norm, IG_baseline, interpolation, isotropic_gaussian_kernel 13 | from cal_metrix import calculate_psnr, calculate_ssim, bgr2ycbcr, tensor2img 14 | 15 | #LAM: Interpreting Super-Resolution Networks with Local Attribution Maps 16 | 17 | def LAM(model_names='MAN@Base', image_path='test_images/3.png',w=90, h=120): 18 | model = load_model(model_names) 19 | window_size = 16 # Define windoes_size of D 20 | img_lr, img_hr = prepare_images(image_path) # Change this image name 21 | tensor_lr = PIL2Tensor(img_lr)[:3] ; tensor_hr = PIL2Tensor(img_hr)[:3] 22 | cv2_lr = np.moveaxis(tensor_lr.numpy(), 0, 2) ; cv2_hr = np.moveaxis(tensor_hr.numpy(), 0, 2) 23 | 24 | plt.imshow(cv2_hr) 25 | 26 | draw_img = pil_to_cv2(img_hr) 27 | cv2.rectangle(draw_img, (w, h), (w + window_size, h + window_size), (0, 0, 255), 2) 28 | position_pil = cv2_to_pil(draw_img) 29 | plt.imshow(position_pil) 30 | 31 | sigma = 1.2 ; fold = 50 ; l = 9 ; alpha = 0.5 32 | attr_objective = attribution_objective(attr_grad, h, w, window=window_size) 33 | gaus_blur_path_func = GaussianBlurPath(sigma, fold, l) 34 | interpolated_grad_numpy, result_numpy, interpolated_numpy = Path_gradient(tensor_lr.numpy(), model, attr_objective, gaus_blur_path_func, cuda=True) 35 | grad_numpy, result = saliency_map(interpolated_grad_numpy, result_numpy) 36 | abs_normed_grad_numpy = grad_abs_norm(grad_numpy) 37 | saliency_image_abs = vis_saliency(abs_normed_grad_numpy, zoomin=4) 38 | saliency_image_kde = vis_saliency_kde(abs_normed_grad_numpy) 39 | blend_abs_and_input = cv2_to_pil(pil_to_cv2(saliency_image_abs) * (1.0 - alpha) + pil_to_cv2(img_lr.resize(img_hr.size)) * alpha) 40 | blend_kde_and_input = cv2_to_pil(pil_to_cv2(saliency_image_kde) * (1.0 - alpha) + pil_to_cv2(img_lr.resize(img_hr.size)) * alpha) 41 | pil = make_pil_grid( 42 | [position_pil, 43 | saliency_image_abs, 44 | blend_abs_and_input, 45 | blend_kde_and_input, 46 | Tensor2PIL(torch.clamp(result, min=0., max=1.))] 47 | ) 48 | 49 | plt.axis('off') 50 | 51 | gini_index = gini(abs_normed_grad_numpy) 52 | diffusion_index = (1 - gini_index) * 100 53 | 54 | plt.imshow(img_lr.resize(img_hr.size)) 55 | plt.savefig('./lam_results/{}/0lr.png'.format(image_path[-5]), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 56 | 57 | plt.imshow(position_pil) 58 | plt.savefig('./lam_results/{}/1hr.png'.format(image_path[-5]), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 59 | plt.imshow(saliency_image_abs) 60 | plt.savefig('./lam_results/{}/2abs_{}.png'.format(image_path[-5],model_names), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 61 | plt.imshow(blend_kde_and_input) 62 | plt.savefig('./lam_results/{}/3kde_{}.png'.format(image_path[-5],model_names), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 63 | plt.imshow(Tensor2PIL(torch.clamp(result, min=0., max=1.))) 64 | plt.savefig('./lam_results/{}/4sr_{}.png'.format(image_path[-5],model_names), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 65 | 66 | plt.imshow(pil) 67 | 68 | im_GT = tensor2img(tensor_hr) 69 | im_Gen = tensor2img(result) 70 | im_GT_in = bgr2ycbcr(im_GT) 71 | im_Gen_in = bgr2ycbcr(im_Gen) 72 | 73 | crop_border = 4 74 | 75 | if im_GT_in.ndim == 3: 76 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 77 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 78 | elif im_GT_in.ndim == 2: 79 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 80 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 81 | 82 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 83 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 84 | 85 | print('DI: {:.5f} PSNR: {:.3f}dB SSIM: {:.5f}'.format(diffusion_index,PSNR,SSIM)) 86 | 87 | model_names='EDSR@Large' 88 | image_path='test_images/e.png' 89 | w=120 90 | h=100 91 | LAM(model_names,image_path,w,h) -------------------------------------------------------------------------------- /LAM/ModelZoo/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torchvision 4 | import torch 5 | 6 | 7 | IMG_EXTENSIONS = ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm'] 8 | 9 | 10 | def mkdir(path): 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | 14 | 15 | def pil_loader(path, mode='RGB'): 16 | """ 17 | open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 18 | :param path: image path 19 | :return: PIL.Image 20 | """ 21 | assert _is_image_file(path), "%s is not an image" % path 22 | with open(path, 'rb') as f: 23 | with Image.open(f) as img: 24 | return img.convert(mode) 25 | 26 | 27 | def calculate_RF(model): 28 | layers = getLayers(model) 29 | r = 1 30 | for layer in layers[::-1]: 31 | if isinstance(layer, torch.nn.Conv2d): 32 | kernel = layer.kernel_size[0] 33 | padding = layer.padding[0] 34 | stride = layer.stride[0] 35 | r = stride * r + (kernel - stride) 36 | return r 37 | 38 | 39 | def getLayers(model): 40 | """ 41 | get each layer's name and its module 42 | :param model: 43 | :return: each layer's name and its module 44 | """ 45 | layers = [] 46 | 47 | def unfoldLayer(model): 48 | """ 49 | unfold each layer 50 | :param model: the given model or a single layer 51 | :param root: root name 52 | :return: 53 | """ 54 | 55 | # get all layers of the model 56 | layer_list = list(model.named_children()) 57 | for item in layer_list: 58 | module = item[1] 59 | sublayer = list(module.named_children()) 60 | sublayer_num = len(sublayer) 61 | 62 | # if current layer contains sublayers, add current layer name on its sublayers 63 | if sublayer_num == 0: 64 | layers.append(module) 65 | # if current layer contains sublayers, unfold them 66 | elif isinstance(module, torch.nn.Module): 67 | unfoldLayer(module) 68 | 69 | unfoldLayer(model) 70 | return layers 71 | 72 | 73 | def load_as_tensor(path, mode='RGB'): 74 | """ 75 | Load image to tensor 76 | :param path: image path 77 | :param mode: 'Y' returns 1 channel tensor, 'RGB' returns 3 channels, 'RGBA' returns 4 channels, 'YCbCr' returns 3 channels 78 | :return: 3D tensor 79 | """ 80 | if mode != 'Y': 81 | return PIL2Tensor(pil_loader(path, mode=mode)) 82 | else: 83 | return PIL2Tensor(pil_loader(path, mode='YCbCr'))[:1] 84 | 85 | 86 | def PIL2Tensor(pil_image): 87 | return torchvision.transforms.functional.to_tensor(pil_image) 88 | 89 | 90 | def Tensor2PIL(tensor_image, mode='RGB'): 91 | if len(tensor_image.size()) == 4 and tensor_image.size()[0] == 1: 92 | tensor_image = tensor_image.view(tensor_image.size()[1:]) 93 | return torchvision.transforms.functional.to_pil_image(tensor_image.detach(), mode=mode) 94 | 95 | 96 | def _is_image_file(filename): 97 | """ 98 | judge if the file is an image file 99 | :param filename: path 100 | :return: bool of judgement 101 | """ 102 | filename_lower = filename.lower() 103 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 104 | 105 | 106 | def image_files(path): 107 | """ 108 | return list of images in the path 109 | :param path: path to Data Folder, absolute path 110 | :return: 1D list of image files absolute path 111 | """ 112 | abs_path = os.path.abspath(path) 113 | image_files = os.listdir(abs_path) 114 | for i in range(len(image_files)): 115 | if (not os.path.isdir(image_files[i])) and (_is_image_file(image_files[i])): 116 | image_files[i] = os.path.join(abs_path, image_files[i]) 117 | return image_files 118 | 119 | 120 | def split_to_batches(l, n): 121 | for i in range(0, len(l), n): 122 | yield l[i:i + n] 123 | 124 | 125 | def _sigmoid_to_tanh(x): 126 | """ 127 | range [0, 1] to range [-1, 1] 128 | :param x: tensor type 129 | :return: tensor 130 | """ 131 | return (x - 0.5) * 2. 132 | 133 | 134 | def _tanh_to_sigmoid(x): 135 | """ 136 | range [-1, 1] to range [0, 1] 137 | :param x: 138 | :return: 139 | """ 140 | return x * 0.5 + 0.5 141 | 142 | 143 | def _add_batch_one(tensor): 144 | """ 145 | Return a tensor with size (1, ) + tensor.size 146 | :param tensor: 2D or 3D tensor 147 | :return: 3D or 4D tensor 148 | """ 149 | return tensor.view((1, ) + tensor.size()) 150 | 151 | 152 | def _remove_batch(tensor): 153 | """ 154 | Return a tensor with size tensor.size()[1:] 155 | :param tensor: 3D or 4D tensor 156 | :return: 2D or 3D tensor 157 | """ 158 | return tensor.view(tensor.size()[1:]) 159 | 160 | def mod_crop(tensor, scale=4): 161 | B, C, H, W = tensor.shape 162 | return tensor[:, :, :H-H % scale, :W-W % scale] 163 | 164 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/rnan.py: -------------------------------------------------------------------------------- 1 | from ..NN import common 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def make_model(args, parent=False): 7 | return RNAN(args) 8 | 9 | 10 | ### RNAN 11 | class _ResGroup(nn.Module): 12 | def __init__(self, conv, n_feats, kernel_size, act, res_scale): 13 | super(_ResGroup, self).__init__() 14 | modules_body = [] 15 | modules_body.append( 16 | common.ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), 17 | res_scale=1)) 18 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 19 | self.body = nn.Sequential(*modules_body) 20 | 21 | def forward(self, x): 22 | res = self.body(x) 23 | return res 24 | 25 | 26 | class _NLResGroup(nn.Module): 27 | def __init__(self, conv, n_feats, kernel_size, act, res_scale): 28 | super(_NLResGroup, self).__init__() 29 | modules_body = [] 30 | modules_body.append( 31 | common.NLResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), 32 | res_scale=1)) 33 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 34 | self.body = nn.Sequential(*modules_body) 35 | 36 | def forward(self, x): 37 | res = self.body(x) 38 | return res 39 | 40 | 41 | class RNAN(nn.Module): 42 | def __init__(self, factor=4, num_channels=3, conv=common.default_conv): 43 | super(RNAN, self).__init__() 44 | 45 | n_resgroup = 10 46 | n_resblock = 16 47 | n_feats = 64 48 | kernel_size = 3 49 | reduction = 16 50 | scale = factor 51 | act = nn.ReLU(True) 52 | 53 | # RGB mean for DIV2K 1-800 54 | rgb_mean = (0.4488, 0.4371, 0.4040) 55 | rgb_std = (1.0, 1.0, 1.0) 56 | self.sub_mean = common.MeanShift(1.0, rgb_mean, rgb_std) 57 | 58 | # define head module 59 | modules_head = [conv(num_channels, n_feats, kernel_size)] 60 | 61 | # define body module 62 | modules_body_nl_low = [ 63 | _NLResGroup( 64 | conv, n_feats, kernel_size, act=act, res_scale=1.)] 65 | modules_body = [ 66 | _ResGroup( 67 | conv, n_feats, kernel_size, act=act, res_scale=1.) \ 68 | for _ in range(n_resgroup - 2)] 69 | modules_body_nl_high = [ 70 | _NLResGroup( 71 | conv, n_feats, kernel_size, act=act, res_scale=1.)] 72 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 73 | 74 | # define tail module 75 | modules_tail = [ 76 | common.Upsampler(conv, scale, n_feats, act=False), 77 | conv(n_feats, num_channels, kernel_size)] 78 | 79 | self.add_mean = common.MeanShift(1.0, rgb_mean, rgb_std, 1) 80 | 81 | self.head = nn.Sequential(*modules_head) 82 | self.body_nl_low = nn.Sequential(*modules_body_nl_low) 83 | self.body = nn.Sequential(*modules_body) 84 | self.body_nl_high = nn.Sequential(*modules_body_nl_high) 85 | self.tail = nn.Sequential(*modules_tail) 86 | 87 | def forward(self, x): 88 | 89 | x = self.sub_mean(x * 255.) 90 | feats_shallow = self.head(x) 91 | 92 | res = self.body_nl_low(feats_shallow) 93 | res = self.body(res) 94 | res = self.body_nl_high(res) 95 | res += feats_shallow 96 | 97 | res_main = self.tail(res) 98 | 99 | res_main = self.add_mean(res_main) 100 | 101 | return res_main / 255. 102 | 103 | def load_state_dict(self, state_dict, strict=False): 104 | own_state = self.state_dict() 105 | for name, param in state_dict.items(): 106 | if name in own_state: 107 | if isinstance(param, nn.Parameter): 108 | param = param.data 109 | try: 110 | own_state[name].copy_(param) 111 | except Exception: 112 | if name.find('tail') >= 0: 113 | print('Replace pre-trained upsampler to new one...') 114 | else: 115 | raise RuntimeError('While copying the parameter named {}, ' 116 | 'whose dimensions in the model are {} and ' 117 | 'whose dimensions in the checkpoint are {}.' 118 | .format(name, own_state[name].size(), param.size())) 119 | elif strict: 120 | if name.find('tail') == -1: 121 | raise KeyError('unexpected key "{}" in state_dict' 122 | .format(name)) 123 | 124 | if strict: 125 | missing = set(own_state.keys()) - set(state_dict.keys()) 126 | if len(missing) > 0: 127 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/rcan.py: -------------------------------------------------------------------------------- 1 | from ..NN import common 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def make_model(args, parent=False): 7 | return RCAN(args) 8 | 9 | 10 | ## Channel Attention (CA) Layer 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.conv_du = nn.Sequential( 16 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | y = self.avg_pool(x) 24 | y = self.conv_du(y) 25 | return x * y 26 | 27 | 28 | ## Residual Channel Attention Block (RCAB) 29 | class RCAB(nn.Module): 30 | def __init__( 31 | self, conv, n_feat, kernel_size, reduction, 32 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 33 | 34 | super(RCAB, self).__init__() 35 | modules_body = [] 36 | for i in range(2): 37 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 38 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 39 | if i == 0: modules_body.append(act) 40 | modules_body.append(CALayer(n_feat, reduction)) 41 | self.body = nn.Sequential(*modules_body) 42 | self.res_scale = res_scale 43 | 44 | def forward(self, x): 45 | res = self.body(x) 46 | # res = self.body(x).mul(self.res_scale) 47 | res += x 48 | return res 49 | 50 | 51 | ## Residual Group (RG) 52 | class ResidualGroup(nn.Module): 53 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 54 | super(ResidualGroup, self).__init__() 55 | modules_body = [] 56 | modules_body = [ 57 | RCAB( 58 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 59 | for _ in range(n_resblocks)] 60 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 61 | self.body = nn.Sequential(*modules_body) 62 | 63 | def forward(self, x): 64 | res = self.body(x) 65 | res += x 66 | return res 67 | 68 | 69 | class RCAN(nn.Module): 70 | def __init__(self, factor=4, num_channels=3, conv=common.default_conv): 71 | super(RCAN, self).__init__() 72 | 73 | n_resgroups = 10 74 | n_resblocks = 20 75 | n_feats = 64 76 | kernel_size = 3 77 | reduction = 16 78 | scale = factor 79 | act = nn.ReLU(True) 80 | 81 | # RGB mean for DIV2K 1-800 82 | # rgb_mean = (0.4488, 0.4371, 0.4040) 83 | # RGB mean for DIVFlickr2K 1-3450 84 | rgb_mean = (0.4690, 0.4490, 0.4036) 85 | # rgb_mean = (0.4488, 0.4371, 0.4040) 86 | rgb_std = (1.0, 1.0, 1.0) 87 | self.sub_mean = common.MeanShift(255., rgb_mean, rgb_std) 88 | 89 | # define head module 90 | modules_head = [conv(num_channels, n_feats, kernel_size)] 91 | 92 | # define body module 93 | modules_body = [ 94 | ResidualGroup( 95 | conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) for _ in range(n_resgroups)] 96 | 97 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 98 | 99 | # define tail module 100 | modules_tail = [ 101 | common.Upsampler(conv, scale, n_feats, act=False), 102 | conv(n_feats, num_channels, kernel_size)] 103 | 104 | self.add_mean = common.MeanShift(255., rgb_mean, rgb_std, 1) 105 | 106 | self.head = nn.Sequential(*modules_head) 107 | self.body = nn.Sequential(*modules_body) 108 | self.tail = nn.Sequential(*modules_tail) 109 | 110 | def forward(self, x): 111 | x = self.sub_mean(x * 255) 112 | x = self.head(x) 113 | 114 | res = self.body(x) 115 | res += x 116 | 117 | x = self.tail(res) 118 | x = self.add_mean(x) 119 | 120 | return x / 255 121 | 122 | def load_state_dict(self, state_dict, strict=False): 123 | own_state = self.state_dict() 124 | for name, param in state_dict.items(): 125 | if name in own_state: 126 | if isinstance(param, nn.Parameter): 127 | param = param.data 128 | try: 129 | own_state[name].copy_(param) 130 | except Exception: 131 | if name.find('tail') >= 0: 132 | print('Replace pre-trained upsampler to new one...') 133 | else: 134 | raise RuntimeError('While copying the parameter named {}, ' 135 | 'whose dimensions in the model are {} and ' 136 | 'whose dimensions in the checkpoint are {}.' 137 | .format(name, own_state[name].size(), param.size())) 138 | elif strict: 139 | if name.find('tail') == -1: 140 | raise KeyError('unexpected key "{}" in state_dict' 141 | .format(name)) 142 | 143 | if strict: 144 | missing = set(own_state.keys()) - set(state_dict.keys()) 145 | if len(missing) > 0: 146 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/MPNCOV/python/MPNCOV.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @file: MPNCOV.py 3 | @author: Jiangtao Xie 4 | @author: Peihua Li 5 | 6 | Copyright (C) 2018 Peihua Li and Jiangtao Xie 7 | 8 | All rights reserved. 9 | ''' 10 | import torch 11 | import numpy as np 12 | from torch.autograd import Function 13 | 14 | class Covpool(Function): 15 | @staticmethod 16 | def forward(ctx, input): 17 | x = input 18 | batchSize = x.data.shape[0] 19 | dim = x.data.shape[1] 20 | h = x.data.shape[2] 21 | w = x.data.shape[3] 22 | M = h*w 23 | x = x.reshape(batchSize,dim,M) 24 | I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) 25 | I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) 26 | y = x.bmm(I_hat).bmm(x.transpose(1,2)) 27 | ctx.save_for_backward(input,I_hat) 28 | return y 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | input,I_hat = ctx.saved_tensors 32 | x = input 33 | batchSize = x.data.shape[0] 34 | dim = x.data.shape[1] 35 | h = x.data.shape[2] 36 | w = x.data.shape[3] 37 | M = h*w 38 | x = x.reshape(batchSize,dim,M) 39 | grad_input = grad_output + grad_output.transpose(1,2) 40 | grad_input = grad_input.bmm(x).bmm(I_hat) 41 | grad_input = grad_input.reshape(batchSize,dim,h,w) 42 | return grad_input 43 | 44 | class Sqrtm(Function): 45 | @staticmethod 46 | def forward(ctx, input, iterN): 47 | x = input 48 | batchSize = x.data.shape[0] 49 | dim = x.data.shape[1] 50 | dtype = x.dtype 51 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 52 | normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) 53 | A = x.div(normA.view(batchSize,1,1).expand_as(x)) 54 | Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device) 55 | Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1) 56 | if iterN < 2: 57 | ZY = 0.5*(I3 - A) 58 | Y[:,0,:,:] = A.bmm(ZY) 59 | else: 60 | ZY = 0.5*(I3 - A) 61 | Y[:,0,:,:] = A.bmm(ZY) 62 | Z[:,0,:,:] = ZY 63 | for i in range(1, iterN-1): 64 | ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) 65 | Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) 66 | Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) 67 | ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) 68 | y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 69 | ctx.save_for_backward(input, A, ZY, normA, Y, Z) 70 | ctx.iterN = iterN 71 | return y 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | input, A, ZY, normA, Y, Z = ctx.saved_tensors 75 | iterN = ctx.iterN 76 | x = input 77 | batchSize = x.data.shape[0] 78 | dim = x.data.shape[1] 79 | dtype = x.dtype 80 | der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 81 | der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) 82 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 83 | if iterN < 2: 84 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace)) 85 | else: 86 | dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - 87 | Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) 88 | dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) 89 | for i in range(iterN-3, -1, -1): 90 | YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) 91 | ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) 92 | dldY_ = 0.5*(dldY.bmm(YZ) - 93 | Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 94 | ZY.bmm(dldY)) 95 | dldZ_ = 0.5*(YZ.bmm(dldZ) - 96 | Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - 97 | dldZ.bmm(ZY)) 98 | dldY = dldY_ 99 | dldZ = dldZ_ 100 | der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) 101 | grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) 102 | grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) 103 | for i in range(batchSize): 104 | grad_input[i,:,:] += (der_postComAux[i] \ 105 | - grad_aux[i] / (normA[i] * normA[i])) \ 106 | *torch.ones(dim,device = x.device).diag() 107 | return grad_input, None 108 | 109 | class Triuvec(Function): 110 | @staticmethod 111 | def forward(ctx, input): 112 | x = input 113 | batchSize = x.data.shape[0] 114 | dim = x.data.shape[1] 115 | dtype = x.dtype 116 | x = x.reshape(batchSize, dim*dim) 117 | I = torch.ones(dim,dim).triu().t().reshape(dim*dim) 118 | index = I.nonzero() 119 | y = torch.zeros(batchSize,dim*(dim+1)/2,device = x.device) 120 | for i in range(batchSize): 121 | y[i, :] = x[i, index].t() 122 | ctx.save_for_backward(input,index) 123 | return y 124 | @staticmethod 125 | def backward(ctx, grad_output): 126 | input,index = ctx.saved_tensors 127 | x = input 128 | batchSize = x.data.shape[0] 129 | dim = x.data.shape[1] 130 | dtype = x.dtype 131 | grad_input = torch.zeros(batchSize,dim,dim,device = x.device,requires_grad=False) 132 | grad_input = grad_input.reshape(batchSize,dim*dim) 133 | for i in range(batchSize): 134 | grad_input[i,index] = grad_output[i,:].reshape(index.size(),1) 135 | grad_input = grad_input.reshape(batchSize,dim,dim) 136 | return grad_input 137 | 138 | def CovpoolLayer(var): 139 | return Covpool.apply(var) 140 | 141 | def SqrtmLayer(var, iterN): 142 | return Sqrtm.apply(var, iterN) 143 | 144 | def TriuvecLayer(var): 145 | return Triuvec.apply(var) 146 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/rrdbnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn import init as init 5 | 6 | def make_layer(basic_block, num_basic_block, **kwarg): 7 | """Make layers by stacking the same blocks. 8 | Args: 9 | basic_block (nn.module): nn.module class for basic block. 10 | num_basic_block (int): number of blocks. 11 | Returns: 12 | nn.Sequential: Stacked blocks in nn.Sequential. 13 | """ 14 | layers = [] 15 | for _ in range(num_basic_block): 16 | layers.append(basic_block(**kwarg)) 17 | return nn.Sequential(*layers) 18 | 19 | 20 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 21 | """Initialize network weights. 22 | Args: 23 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 24 | scale (float): Scale initialized weights, especially for residual 25 | blocks. Default: 1. 26 | bias_fill (float): The value to fill bias. Default: 0 27 | kwargs (dict): Other arguments for initialization function. 28 | """ 29 | if not isinstance(module_list, list): 30 | module_list = [module_list] 31 | for module in module_list: 32 | for m in module.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | init.kaiming_normal_(m.weight, **kwargs) 35 | m.weight.data *= scale 36 | if m.bias is not None: 37 | m.bias.data.fill_(bias_fill) 38 | elif isinstance(m, nn.Linear): 39 | init.kaiming_normal_(m.weight, **kwargs) 40 | m.weight.data *= scale 41 | if m.bias is not None: 42 | m.bias.data.fill_(bias_fill) 43 | elif isinstance(m, _BatchNorm): 44 | init.constant_(m.weight, 1) 45 | if m.bias is not None: 46 | m.bias.data.fill_(bias_fill) 47 | 48 | 49 | 50 | class ResidualDenseBlock(nn.Module): 51 | """Residual Dense Block. 52 | Used in RRDB block in ESRGAN. 53 | Args: 54 | num_feat (int): Channel number of intermediate features. 55 | num_grow_ch (int): Channels for each growth. 56 | """ 57 | 58 | def __init__(self, num_feat=64, num_grow_ch=32): 59 | super(ResidualDenseBlock, self).__init__() 60 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 61 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 62 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 63 | 1) 64 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 65 | 1) 66 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 67 | 68 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 69 | 70 | # initialization 71 | default_init_weights( 72 | [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 73 | 74 | def forward(self, x): 75 | x1 = self.lrelu(self.conv1(x)) 76 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 77 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 78 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 79 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 80 | # Emperically, we use 0.2 to scale the residual for better performance 81 | return x5 * 0.2 + x 82 | 83 | 84 | class RRDB(nn.Module): 85 | """Residual in Residual Dense Block. 86 | Used in RRDB-Net in ESRGAN. 87 | Args: 88 | num_feat (int): Channel number of intermediate features. 89 | num_grow_ch (int): Channels for each growth. 90 | """ 91 | 92 | def __init__(self, num_feat, num_grow_ch=32): 93 | super(RRDB, self).__init__() 94 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 95 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 96 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 97 | 98 | def forward(self, x): 99 | out = self.rdb1(x) 100 | out = self.rdb2(out) 101 | out = self.rdb3(out) 102 | # Emperically, we use 0.2 to scale the residual for better performance 103 | return out * 0.2 + x 104 | 105 | 106 | class RRDBNet(nn.Module): 107 | """Networks consisting of Residual in Residual Dense Block, which is used 108 | in ESRGAN. 109 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 110 | Currently, it supports x4 upsampling scale factor. 111 | Args: 112 | num_in_ch (int): Channel number of inputs. 113 | num_out_ch (int): Channel number of outputs. 114 | num_feat (int): Channel number of intermediate features. 115 | Default: 64 116 | num_block (int): Block number in the trunk network. Defaults: 23 117 | num_grow_ch (int): Channels for each growth. Default: 32. 118 | """ 119 | 120 | def __init__(self, 121 | num_in_ch=3, 122 | num_out_ch=3, 123 | num_feat=64, 124 | num_block=23, 125 | num_grow_ch=32): 126 | super(RRDBNet, self).__init__() 127 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 128 | self.body = make_layer( 129 | RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 130 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 131 | # upsample 132 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 133 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 134 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 135 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 136 | 137 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 138 | 139 | def forward(self, x): 140 | feat = self.conv_first(x) 141 | body_feat = self.conv_body(self.body(feat)) 142 | feat = feat + body_feat 143 | # upsample 144 | feat = self.lrelu( 145 | self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 146 | feat = self.lrelu( 147 | self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 148 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 149 | return out 150 | -------------------------------------------------------------------------------- /LAM/cal_metrix.py: -------------------------------------------------------------------------------- 1 | ''' 2 | calculate the PSNR and SSIM. 3 | same as MATLAB's results 4 | ''' 5 | import os 6 | import math 7 | import numpy as np 8 | import cv2 9 | import glob 10 | from torchvision.utils import make_grid 11 | 12 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 13 | ''' 14 | Converts a torch Tensor into an image Numpy array 15 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 16 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 17 | ''' 18 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 19 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 20 | n_dim = tensor.dim() 21 | if n_dim == 4: 22 | n_img = len(tensor) 23 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 24 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 25 | elif n_dim == 3: 26 | img_np = tensor.numpy() 27 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 28 | elif n_dim == 2: 29 | img_np = tensor.numpy() 30 | else: 31 | raise TypeError( 32 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 33 | if out_type == np.uint8: 34 | img_np = (img_np * 255.0).round() 35 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 36 | return img_np.astype(out_type) 37 | 38 | def main(): 39 | # Configurations 40 | 41 | # GT - Ground-truth; 42 | # Gen: Generated / Restored / Recovered images 43 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5' 44 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5' 45 | 46 | crop_border = 4 47 | suffix = '' # suffix for Gen images 48 | test_Y = False # True: test Y channel only; False: test RGB channels 49 | 50 | PSNR_all = [] 51 | SSIM_all = [] 52 | img_list = sorted(glob.glob(folder_GT + '/*')) 53 | 54 | if test_Y: 55 | print('Testing Y channel.') 56 | else: 57 | print('Testing RGB channels.') 58 | 59 | for i, img_path in enumerate(img_list): 60 | base_name = os.path.splitext(os.path.basename(img_path))[0] 61 | im_GT = cv2.imread(img_path) / 255. 62 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255. 63 | 64 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 65 | im_GT_in = bgr2ycbcr(im_GT) 66 | im_Gen_in = bgr2ycbcr(im_Gen) 67 | else: 68 | im_GT_in = im_GT 69 | im_Gen_in = im_Gen 70 | 71 | # crop borders 72 | if im_GT_in.ndim == 3: 73 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 74 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 75 | elif im_GT_in.ndim == 2: 76 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 77 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 78 | else: 79 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 80 | 81 | # calculate PSNR and SSIM 82 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 83 | 84 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 85 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format( 86 | i + 1, base_name, PSNR, SSIM)) 87 | PSNR_all.append(PSNR) 88 | SSIM_all.append(SSIM) 89 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format( 90 | sum(PSNR_all) / len(PSNR_all), 91 | sum(SSIM_all) / len(SSIM_all))) 92 | 93 | 94 | def calculate_psnr(img1, img2): 95 | # img1 and img2 have range [0, 255] 96 | img1 = img1.astype(np.float64) 97 | img2 = img2.astype(np.float64) 98 | mse = np.mean((img1 - img2)**2) 99 | if mse == 0: 100 | return float('inf') 101 | return 20 * math.log10(255.0 / math.sqrt(mse)) 102 | 103 | 104 | def ssim(img1, img2): 105 | C1 = (0.01 * 255)**2 106 | C2 = (0.03 * 255)**2 107 | 108 | img1 = img1.astype(np.float64) 109 | img2 = img2.astype(np.float64) 110 | kernel = cv2.getGaussianKernel(11, 1.5) 111 | window = np.outer(kernel, kernel.transpose()) 112 | 113 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 114 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 115 | mu1_sq = mu1**2 116 | mu2_sq = mu2**2 117 | mu1_mu2 = mu1 * mu2 118 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 119 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 120 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 121 | 122 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 123 | (sigma1_sq + sigma2_sq + C2)) 124 | return ssim_map.mean() 125 | 126 | 127 | def calculate_ssim(img1, img2): 128 | '''calculate SSIM 129 | the same outputs as MATLAB's 130 | img1, img2: [0, 255] 131 | ''' 132 | if not img1.shape == img2.shape: 133 | raise ValueError('Input images must have the same dimensions.') 134 | if img1.ndim == 2: 135 | return ssim(img1, img2) 136 | elif img1.ndim == 3: 137 | if img1.shape[2] == 3: 138 | ssims = [] 139 | for i in range(3): 140 | ssims.append(ssim(img1, img2)) 141 | return np.array(ssims).mean() 142 | elif img1.shape[2] == 1: 143 | return ssim(np.squeeze(img1), np.squeeze(img2)) 144 | else: 145 | raise ValueError('Wrong input image dimensions.') 146 | 147 | 148 | def bgr2ycbcr(img, only_y=True): 149 | '''same as matlab rgb2ycbcr 150 | only_y: only return Y channel 151 | Input: 152 | uint8, [0, 255] 153 | float, [0, 1] 154 | ''' 155 | in_img_type = img.dtype 156 | img.astype(np.float32) 157 | if in_img_type != np.uint8: 158 | img *= 255. 159 | # convert 160 | if only_y: 161 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 162 | else: 163 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 164 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 165 | if in_img_type == np.uint8: 166 | rlt = rlt.round() 167 | else: 168 | rlt /= 255. 169 | return rlt.astype(in_img_type) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /LAM/SaliencyModel/BackProp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | from ModelZoo.utils import _add_batch_one, _remove_batch 5 | from SaliencyModel.utils import grad_norm, IG_baseline, interpolation, isotropic_gaussian_kernel 6 | 7 | 8 | def attribution_objective(attr_func, h, w, window=16): 9 | def calculate_objective(image): 10 | return attr_func(image, h, w, window=window) 11 | return calculate_objective 12 | 13 | 14 | def saliency_map_gradient(numpy_image, model, attr_func): 15 | img_tensor = torch.from_numpy(numpy_image) 16 | img_tensor.requires_grad_(True) 17 | result = model(_add_batch_one(img_tensor)) 18 | target = attr_func(result) 19 | target.backward() 20 | return img_tensor.grad.numpy(), result 21 | 22 | 23 | def I_gradient(numpy_image, baseline_image, model, attr_objective, fold, interp='linear'): 24 | interpolated = interpolation(numpy_image, baseline_image, fold, mode=interp).astype(np.float32) 25 | grad_list = np.zeros_like(interpolated, dtype=np.float32) 26 | result_list = [] 27 | for i in range(fold): 28 | img_tensor = torch.from_numpy(interpolated[i]) 29 | img_tensor.requires_grad_(True) 30 | result = model(_add_batch_one(img_tensor)) 31 | target = attr_objective(result) 32 | target.backward() 33 | grad = img_tensor.grad.numpy() 34 | grad_list[i] = grad 35 | result_list.append(result) 36 | results_numpy = np.asarray(result_list) 37 | return grad_list, results_numpy, interpolated 38 | 39 | 40 | def GaussianBlurPath(sigma, fold, l=5): 41 | def path_interpolation_func(cv_numpy_image): 42 | h, w, c = cv_numpy_image.shape 43 | kernel_interpolation = np.zeros((fold + 1, l, l)) 44 | image_interpolation = np.zeros((fold, h, w, c)) 45 | lambda_derivative_interpolation = np.zeros((fold, h, w, c)) 46 | sigma_interpolation = np.linspace(sigma, 0, fold + 1) 47 | for i in range(fold + 1): 48 | kernel_interpolation[i] = isotropic_gaussian_kernel(l, sigma_interpolation[i]) 49 | for i in range(fold): 50 | image_interpolation[i] = cv2.filter2D(cv_numpy_image, -1, kernel_interpolation[i + 1]) 51 | lambda_derivative_interpolation[i] = cv2.filter2D(cv_numpy_image, -1, (kernel_interpolation[i + 1] - kernel_interpolation[i]) * fold) 52 | return np.moveaxis(image_interpolation, 3, 1).astype(np.float32), \ 53 | np.moveaxis(lambda_derivative_interpolation, 3, 1).astype(np.float32) 54 | return path_interpolation_func 55 | 56 | 57 | def GaussianLinearPath(sigma, fold, l=5): 58 | def path_interpolation_func(cv_numpy_image): 59 | kernel = isotropic_gaussian_kernel(l, sigma) 60 | baseline_image = cv2.filter2D(cv_numpy_image, -1, kernel) 61 | image_interpolation = interpolation(cv_numpy_image, baseline_image, fold, mode='linear').astype(np.float32) 62 | lambda_derivative_interpolation = np.repeat(np.expand_dims(cv_numpy_image - baseline_image, axis=0), fold, axis=0) 63 | return np.moveaxis(image_interpolation, 3, 1).astype(np.float32), \ 64 | np.moveaxis(lambda_derivative_interpolation, 3, 1).astype(np.float32) 65 | return path_interpolation_func 66 | 67 | 68 | def LinearPath(fold): 69 | def path_interpolation_func(cv_numpy_image): 70 | baseline_image = np.zeros_like(cv_numpy_image) 71 | image_interpolation = interpolation(cv_numpy_image, baseline_image, fold, mode='linear').astype(np.float32) 72 | lambda_derivative_interpolation = np.repeat(np.expand_dims(cv_numpy_image - baseline_image, axis=0), fold, axis=0) 73 | return np.moveaxis(image_interpolation, 3, 1).astype(np.float32), \ 74 | np.moveaxis(lambda_derivative_interpolation, 3, 1).astype(np.float32) 75 | return path_interpolation_func 76 | 77 | 78 | def Path_gradient(numpy_image, model, attr_objective, path_interpolation_func, cuda=False): 79 | """ 80 | :param path_interpolation_func: 81 | return \lambda(\alpha) and d\lambda(\alpha)/d\alpha, for \alpha\in[0, 1] 82 | This function return pil_numpy_images 83 | :return: 84 | """ 85 | if cuda: 86 | model = model.cuda() 87 | cv_numpy_image = np.moveaxis(numpy_image, 0, 2) 88 | image_interpolation, lambda_derivative_interpolation = path_interpolation_func(cv_numpy_image) 89 | grad_accumulate_list = np.zeros_like(image_interpolation) 90 | result_list = [] 91 | for i in range(image_interpolation.shape[0]): 92 | img_tensor = torch.from_numpy(image_interpolation[i]) 93 | img_tensor.requires_grad_(True) 94 | if cuda: 95 | result = model(_add_batch_one(img_tensor).cuda()) 96 | target = attr_objective(result) 97 | target.backward() 98 | grad = img_tensor.grad.cpu().numpy() 99 | if np.any(np.isnan(grad)): 100 | grad[np.isnan(grad)] = 0.0 101 | result = result.cpu().detach() 102 | else: 103 | result = model(_add_batch_one(img_tensor)) 104 | target = attr_objective(result) 105 | target.backward() 106 | grad = img_tensor.grad.numpy() 107 | if np.any(np.isnan(grad)): 108 | grad[np.isnan(grad)] = 0.0 109 | 110 | grad_accumulate_list[i] = grad * lambda_derivative_interpolation[i] 111 | result_list.append(result) 112 | results_numpy = np.asarray(result_list) 113 | return grad_accumulate_list, results_numpy, image_interpolation 114 | 115 | 116 | def saliency_map_PG(grad_list, result_list): 117 | final_grad = grad_list.mean(axis=0) 118 | return final_grad, result_list[-1] 119 | 120 | 121 | def saliency_map_P_gradient( 122 | numpy_image, model, attr_objective, path_interpolation_func): 123 | grad_list, result_list, _ = Path_gradient(numpy_image, model, attr_objective, path_interpolation_func) 124 | final_grad = grad_list.mean(axis=0) 125 | return final_grad, result_list[-1] 126 | 127 | 128 | def saliency_map_I_gradient( 129 | numpy_image, model, attr_objective, baseline='gaus', fold=10, interp='linear'): 130 | """ 131 | :param numpy_image: RGB C x H x W 132 | :param model: 133 | :param attr_func: 134 | :param h: 135 | :param w: 136 | :param window: 137 | :param baseline: 138 | :return: 139 | """ 140 | numpy_baseline = np.moveaxis(IG_baseline(np.moveaxis(numpy_image, 0, 2) * 255., mode=baseline) / 255., 2, 0) 141 | grad_list, result_list, _ = I_gradient(numpy_image, numpy_baseline, model, attr_objective, fold, interp='linear') 142 | final_grad = grad_list.mean(axis=0) * (numpy_image - numpy_baseline) 143 | return final_grad, result_list[-1] 144 | 145 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp, cpu): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = 4 15 | self.idx_scale = 0 16 | self.self_ensemble = True 17 | self.chop = True 18 | self.precision = 'single' 19 | self.cpu = cpu 20 | self.device = torch.device('cpu' if args.cpu else 'cuda') 21 | self.n_GPUs = args.n_GPUs 22 | self.save_models = args.save_models 23 | 24 | module = import_module('model.' + args.model.lower()) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | if args.print_model: print(self.model) 38 | 39 | def forward(self, x, idx_scale): 40 | self.idx_scale = idx_scale 41 | target = self.get_model() 42 | if hasattr(target, 'set_scale'): 43 | target.set_scale(idx_scale) 44 | 45 | if self.self_ensemble and not self.training: 46 | if self.chop: 47 | forward_function = self.forward_chop 48 | else: 49 | forward_function = self.model.forward 50 | 51 | return self.forward_x8(x, forward_function) 52 | elif self.chop and not self.training: 53 | return self.forward_chop(x) 54 | else: 55 | return self.model(x) 56 | 57 | def get_model(self): 58 | if self.n_GPUs == 1: 59 | return self.model 60 | else: 61 | return self.model.module 62 | 63 | def state_dict(self, **kwargs): 64 | target = self.get_model() 65 | return target.state_dict(**kwargs) 66 | 67 | def save(self, apath, epoch, is_best=False): 68 | target = self.get_model() 69 | torch.save( 70 | target.state_dict(), 71 | os.path.join(apath, 'model', 'model_latest.pt') 72 | ) 73 | if is_best: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_best.pt') 77 | ) 78 | 79 | if self.save_models: 80 | torch.save( 81 | target.state_dict(), 82 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 83 | ) 84 | 85 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 86 | if cpu: 87 | kwargs = {'map_location': lambda storage, loc: storage} 88 | else: 89 | kwargs = {} 90 | 91 | if resume == -1: 92 | self.get_model().load_state_dict( 93 | torch.load( 94 | os.path.join(apath, 'model', 'model_latest.pt'), 95 | **kwargs 96 | ), 97 | strict=False 98 | ) 99 | elif resume == 0: 100 | if pre_train != '.': 101 | print('Loading model from {}'.format(pre_train)) 102 | self.get_model().load_state_dict( 103 | torch.load(pre_train, **kwargs), 104 | strict=False 105 | ) 106 | else: 107 | self.get_model().load_state_dict( 108 | torch.load( 109 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 110 | **kwargs 111 | ), 112 | strict=False 113 | ) 114 | 115 | # shave = 10, min_size=160000 116 | def forward_chop(self, x, shave=10, min_size=160000): 117 | scale = self.scale[self.idx_scale] 118 | n_GPUs = min(self.n_GPUs, 4) 119 | b, c, h, w = x.size() 120 | h_half, w_half = h // 2, w // 2 121 | h_size, w_size = h_half + shave, w_half + shave 122 | lr_list = [ 123 | x[:, :, 0:h_size, 0:w_size], 124 | x[:, :, 0:h_size, (w - w_size):w], 125 | x[:, :, (h - h_size):h, 0:w_size], 126 | x[:, :, (h - h_size):h, (w - w_size):w]] 127 | 128 | if w_size * h_size < min_size: 129 | sr_list = [] 130 | for i in range(0, 4, n_GPUs): 131 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 132 | sr_batch = self.model(lr_batch) 133 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 134 | else: 135 | sr_list = [ 136 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 137 | for patch in lr_list 138 | ] 139 | 140 | h, w = scale * h, scale * w 141 | h_half, w_half = scale * h_half, scale * w_half 142 | h_size, w_size = scale * h_size, scale * w_size 143 | shave *= scale 144 | 145 | output = x.new(b, c, h, w) 146 | output[:, :, 0:h_half, 0:w_half] \ 147 | = sr_list[0][:, :, 0:h_half, 0:w_half] 148 | output[:, :, 0:h_half, w_half:w] \ 149 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 150 | output[:, :, h_half:h, 0:w_half] \ 151 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 152 | output[:, :, h_half:h, w_half:w] \ 153 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 154 | 155 | return output 156 | 157 | def forward_x8(self, x, forward_function): 158 | def _transform(v, op): 159 | if self.precision != 'single': v = v.float() 160 | 161 | v2np = v.data.cpu().numpy() 162 | if op == 'v': 163 | tfnp = v2np[:, :, :, ::-1].copy() 164 | elif op == 'h': 165 | tfnp = v2np[:, :, ::-1, :].copy() 166 | elif op == 't': 167 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 168 | 169 | ret = torch.Tensor(tfnp).to(self.device) 170 | if self.precision == 'half': ret = ret.half() 171 | 172 | return ret 173 | 174 | lr_list = [x] 175 | for tf in 'v', 'h', 't': 176 | lr_list.extend([_transform(t, tf) for t in lr_list]) 177 | 178 | sr_list = [forward_function(aug) for aug in lr_list] 179 | for i in range(len(sr_list)): 180 | if i > 3: 181 | sr_list[i] = _transform(sr_list[i], 't') 182 | if i % 4 > 1: 183 | sr_list[i] = _transform(sr_list[i], 'h') 184 | if (i % 4) % 2 == 1: 185 | sr_list[i] = _transform(sr_list[i], 'v') 186 | 187 | output_cat = torch.cat(sr_list, dim=0) 188 | output = output_cat.mean(dim=0, keepdim=True) 189 | 190 | return output -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/drln_ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | def init_weights(modules): 9 | pass 10 | 11 | 12 | class MeanShift(nn.Module): 13 | def __init__(self, mean_rgb, sub): 14 | super(MeanShift, self).__init__() 15 | 16 | sign = -1 if sub else 1 17 | r = mean_rgb[0] * sign 18 | g = mean_rgb[1] * sign 19 | b = mean_rgb[2] * sign 20 | 21 | self.shifter = nn.Conv2d(3, 3, 1, 1, 0) 22 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 23 | self.shifter.bias.data = torch.Tensor([r, g, b]) 24 | 25 | # Freeze the mean shift layer 26 | for params in self.shifter.parameters(): 27 | params.requires_grad = False 28 | 29 | def forward(self, x): 30 | x = self.shifter(x) 31 | return x 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | def __init__(self, 36 | in_channels, out_channels, 37 | ksize=3, stride=1, pad=1, dilation=1): 38 | super(BasicBlock, self).__init__() 39 | 40 | self.body = nn.Sequential( 41 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad, dilation), 42 | nn.ReLU(inplace=True) 43 | ) 44 | 45 | init_weights(self.modules) 46 | 47 | def forward(self, x): 48 | out = self.body(x) 49 | return out 50 | 51 | 52 | class GBasicBlock(nn.Module): 53 | def __init__(self, 54 | in_channels, out_channels, 55 | ksize=3, stride=1, pad=1, dilation=1): 56 | super(GBasicBlock, self).__init__() 57 | 58 | self.body = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad, dilation, groups=4), 60 | nn.ReLU(inplace=True) 61 | ) 62 | 63 | init_weights(self.modules) 64 | 65 | def forward(self, x): 66 | out = self.body(x) 67 | return out 68 | 69 | 70 | class BasicBlockSig(nn.Module): 71 | def __init__(self, 72 | in_channels, out_channels, 73 | ksize=3, stride=1, pad=1): 74 | super(BasicBlockSig, self).__init__() 75 | 76 | self.body = nn.Sequential( 77 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad), 78 | nn.Sigmoid() 79 | ) 80 | 81 | init_weights(self.modules) 82 | 83 | def forward(self, x): 84 | out = self.body(x) 85 | return out 86 | 87 | 88 | class GBasicBlockSig(nn.Module): 89 | def __init__(self, 90 | in_channels, out_channels, 91 | ksize=3, stride=1, pad=1): 92 | super(GBasicBlockSig, self).__init__() 93 | 94 | self.body = nn.Sequential( 95 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad, groups=4), 96 | nn.Sigmoid() 97 | ) 98 | 99 | init_weights(self.modules) 100 | 101 | def forward(self, x): 102 | out = self.body(x) 103 | return out 104 | 105 | 106 | class ResidualBlock(nn.Module): 107 | def __init__(self, 108 | in_channels, out_channels): 109 | super(ResidualBlock, self).__init__() 110 | 111 | self.body = nn.Sequential( 112 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 115 | ) 116 | 117 | init_weights(self.modules) 118 | 119 | def forward(self, x): 120 | out = self.body(x) 121 | out = F.relu(out + x) 122 | return out 123 | 124 | 125 | class GResidualBlock(nn.Module): 126 | def __init__(self, 127 | in_channels, out_channels): 128 | super(GResidualBlock, self).__init__() 129 | 130 | self.body = nn.Sequential( 131 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=4), 132 | nn.ReLU(inplace=True), 133 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 134 | ) 135 | 136 | init_weights(self.modules) 137 | 138 | def forward(self, x): 139 | out = self.body(x) 140 | out = F.relu(out + x) 141 | return out 142 | 143 | 144 | class EResidualBlock(nn.Module): 145 | def __init__(self, 146 | in_channels, out_channels, 147 | group=1): 148 | super(EResidualBlock, self).__init__() 149 | 150 | self.body = nn.Sequential( 151 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), 152 | nn.ReLU(inplace=True), 153 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 156 | ) 157 | 158 | init_weights(self.modules) 159 | 160 | def forward(self, x): 161 | out = self.body(x) 162 | out = F.relu(out + x) 163 | return out 164 | 165 | 166 | class ConvertBlock(nn.Module): 167 | def __init__(self, 168 | in_channels, out_channels, 169 | blocks): 170 | super(ConvertBlock, self).__init__() 171 | 172 | self.body = nn.Sequential( 173 | nn.Conv2d(in_channels * blocks, out_channels * blocks // 2, 3, 1, 1), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(out_channels * blocks // 2, out_channels * blocks // 4, 3, 1, 1), 176 | nn.ReLU(inplace=True), 177 | nn.Conv2d(out_channels * blocks // 4, out_channels, 3, 1, 1), 178 | ) 179 | 180 | init_weights(self.modules) 181 | 182 | def forward(self, x): 183 | out = self.body(x) 184 | # out = F.relu(out + x) 185 | return out 186 | 187 | 188 | class UpsampleBlock(nn.Module): 189 | def __init__(self, 190 | n_channels, scale, multi_scale, 191 | group=1): 192 | super(UpsampleBlock, self).__init__() 193 | 194 | if multi_scale: 195 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) 196 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) 197 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) 198 | else: 199 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group) 200 | 201 | self.multi_scale = multi_scale 202 | 203 | def forward(self, x, scale): 204 | if self.multi_scale: 205 | if scale == 2: 206 | return self.up2(x) 207 | elif scale == 3: 208 | return self.up3(x) 209 | elif scale == 4: 210 | return self.up4(x) 211 | else: 212 | return self.up(x) 213 | 214 | 215 | class _UpsampleBlock(nn.Module): 216 | def __init__(self, 217 | n_channels, scale, 218 | group=1): 219 | super(_UpsampleBlock, self).__init__() 220 | 221 | modules = [] 222 | if scale == 2 or scale == 4 or scale == 8: 223 | for _ in range(int(math.log(scale, 2))): 224 | modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 225 | modules += [nn.PixelShuffle(2)] 226 | elif scale == 3: 227 | modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 228 | modules += [nn.PixelShuffle(3)] 229 | 230 | self.body = nn.Sequential(*modules) 231 | init_weights(self.modules) 232 | 233 | def forward(self, x): 234 | out = self.body(x) 235 | return out -------------------------------------------------------------------------------- /LAM/SaliencyModel/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy import stats 4 | import matplotlib.pyplot as plt 5 | import matplotlib as mpl 6 | from PIL import Image 7 | import cv2 8 | 9 | 10 | def cv2_to_pil(img): 11 | image = Image.fromarray(cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)) 12 | return image 13 | 14 | 15 | def pil_to_cv2(img): 16 | image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 17 | return image 18 | 19 | 20 | def make_pil_grid(pil_image_list): 21 | sizex, sizey = pil_image_list[0].size 22 | for img in pil_image_list: 23 | assert sizex == img.size[0] and sizey == img.size[1], 'check image size' 24 | 25 | target = Image.new('RGB', (sizex * len(pil_image_list), sizey)) 26 | left = 0 27 | right = sizex 28 | for i in range(len(pil_image_list)): 29 | target.paste(pil_image_list[i], (left, 0, right, sizey)) 30 | left += sizex 31 | right += sizex 32 | return target 33 | 34 | 35 | def blend_input(map, input): 36 | return Image.blend(map, input, 0.4) 37 | 38 | 39 | def count_saliency_pixels(map, threshold=0.95): 40 | sum_threshold = map.reshape(-1).sum() * threshold 41 | cum_sum = -np.cumsum(np.sort(-map.reshape(-1))) 42 | return len(cum_sum[cum_sum < sum_threshold]) 43 | 44 | 45 | def plot_diff_of_attrs_kde(A, B, zoomin=4, blend=0.5): 46 | grad_flat = A.reshape((-1)) 47 | datapoint_y, datapoint_x = np.mgrid[0:A.shape[0]:1, 0:A.shape[1]:1] 48 | Y, X = np.mgrid[0:A.shape[0]:1, 0:A.shape[1]:1] 49 | positions = np.vstack([X.ravel(), Y.ravel()]) 50 | pixels = np.vstack([datapoint_x.ravel(), datapoint_y.ravel()]) 51 | kernel = stats.gaussian_kde(pixels, weights=grad_flat) 52 | Za = np.reshape(kernel(positions).T, A.shape) 53 | Za = Za / Za.max() 54 | 55 | grad_flat = B.reshape((-1)) 56 | datapoint_y, datapoint_x = np.mgrid[0:B.shape[0]:1, 0:B.shape[1]:1] 57 | Y, X = np.mgrid[0:B.shape[0]:1, 0:B.shape[1]:1] 58 | positions = np.vstack([X.ravel(), Y.ravel()]) 59 | pixels = np.vstack([datapoint_x.ravel(), datapoint_y.ravel()]) 60 | kernel = stats.gaussian_kde(pixels, weights=grad_flat) 61 | Zb = np.reshape(kernel(positions).T, B.shape) 62 | Zb = Zb / Zb.max() 63 | 64 | diff = Za - Zb 65 | diff_norm = diff / diff.max() 66 | vis = Zb - blend*diff_norm 67 | 68 | cmap = plt.get_cmap('seismic') 69 | # cmap = plt.get_cmap('Purples') 70 | map_color = (255 * cmap(vis * 0.5 + 0.5)).astype(np.uint8) 71 | # map_color = (255 * cmap(Z)).astype(np.uint8) 72 | Img = Image.fromarray(map_color) 73 | s1, s2 = Img.size 74 | return Img.resize((s1 * zoomin, s2 * zoomin), Image.BICUBIC) 75 | 76 | 77 | def vis_saliency_kde(map, zoomin=4): 78 | grad_flat = map.reshape((-1)) 79 | datapoint_y, datapoint_x = np.mgrid[0:map.shape[0]:1, 0:map.shape[1]:1] 80 | Y, X = np.mgrid[0:map.shape[0]:1, 0:map.shape[1]:1] 81 | positions = np.vstack([X.ravel(), Y.ravel()]) 82 | pixels = np.vstack([datapoint_x.ravel(), datapoint_y.ravel()]) 83 | kernel = stats.gaussian_kde(pixels, weights=grad_flat) 84 | Z = np.reshape(kernel(positions).T, map.shape) 85 | Z = Z / Z.max() 86 | cmap = plt.get_cmap('seismic') 87 | # cmap = plt.get_cmap('Purples') 88 | map_color = (255 * cmap(Z * 0.5 + 0.5)).astype(np.uint8) 89 | # map_color = (255 * cmap(Z)).astype(np.uint8) 90 | Img = Image.fromarray(map_color) 91 | s1, s2 = Img.size 92 | return Img.resize((s1 * zoomin, s2 * zoomin), Image.BICUBIC) 93 | 94 | 95 | def vis_saliency(map, zoomin=4): 96 | """ 97 | :param map: the saliency map, 2D, norm to [0, 1] 98 | :param zoomin: the resize factor, nn upsample 99 | :return: 100 | """ 101 | cmap = plt.get_cmap('seismic') 102 | # cmap = plt.get_cmap('Purples') 103 | map_color = (255 * cmap(map * 0.5 + 0.5)).astype(np.uint8) 104 | # map_color = (255 * cmap(map)).astype(np.uint8) 105 | Img = Image.fromarray(map_color) 106 | s1, s2 = Img.size 107 | Img = Img.resize((s1 * zoomin, s2 * zoomin), Image.NEAREST) 108 | return Img.convert('RGB') 109 | 110 | 111 | def click_select_position(pil_img, window_size=16): 112 | """ 113 | 114 | :param pil_img: 115 | :param window_size: 116 | :return: w, h 117 | """ 118 | cv2_img = pil_to_cv2(pil_img) 119 | position = [-1, -1] 120 | def mouse(event, x, y, flags, param): 121 | """""" 122 | if event == cv2.EVENT_LBUTTONDOWN: 123 | xy = "%d, %d" % (x, y) 124 | position[0] = x 125 | position[1] = y 126 | draw_img = cv2_img.copy() 127 | cv2.rectangle(draw_img, (x, y), (x + window_size, y + window_size), (0,0,255), 2) 128 | cv2.putText(draw_img, xy, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), thickness = 1) 129 | cv2.imshow("image", draw_img) 130 | 131 | cv2.namedWindow("image") 132 | cv2.imshow("image", cv2_img) 133 | cv2.resizeWindow("image", 800, 600) 134 | cv2.setMouseCallback("image", mouse) 135 | 136 | cv2.waitKey(0) 137 | cv2.destroyAllWindows() 138 | return_img = cv2_img.copy() 139 | cv2.rectangle(return_img, (position[0], position[1]), (position[0] + window_size, position[1] + window_size), (0, 0, 255), 2) 140 | return position[0], position[1], cv2_to_pil(return_img) 141 | 142 | 143 | def prepare_images(hr_path, scale=4): 144 | hr_pil = Image.open(hr_path) 145 | sizex, sizey = hr_pil.size 146 | hr_pil = hr_pil.crop((0, 0, sizex - sizex % scale, sizey - sizey % scale)) 147 | sizex, sizey = hr_pil.size 148 | lr_pil = hr_pil.resize((sizex // scale, sizey // scale), Image.BICUBIC) 149 | return lr_pil, hr_pil 150 | 151 | 152 | def grad_abs_norm(grad): 153 | """ 154 | 155 | :param grad: numpy array 156 | :return: 157 | """ 158 | grad_2d = np.abs(grad.sum(axis=0)) 159 | grad_max = grad_2d.max() 160 | grad_norm = grad_2d / grad_max 161 | return grad_norm 162 | 163 | 164 | def grad_norm(grad): 165 | """ 166 | 167 | :param grad: numpy array 168 | :return: 169 | """ 170 | grad_2d = grad.sum(axis=0) 171 | grad_max = max(grad_2d.max(), abs(grad_2d.min())) 172 | grad_norm = grad_2d / grad_max 173 | return grad_norm 174 | 175 | 176 | def grad_abs_norm_singlechannel(grad): 177 | """ 178 | 179 | :param grad: numpy array 180 | :return: 181 | """ 182 | grad_2d = np.abs(grad) 183 | grad_max = grad_2d.max() 184 | grad_norm = grad_2d / grad_max 185 | return grad_norm 186 | 187 | 188 | def IG_baseline(numpy_image, mode='gaus'): 189 | """ 190 | :param numpy_image: cv2 image 191 | :param mode: 192 | :return: 193 | """ 194 | if mode == 'l1': 195 | raise NotImplementedError() 196 | elif mode == 'gaus': 197 | ablated = cv2.GaussianBlur(numpy_image, (7, 7), 0) 198 | elif mode == 'bif': 199 | ablated = cv2.bilateralFilter(numpy_image, 15, 90, 90) 200 | elif mode == 'mean': 201 | ablated = cv2.medianBlur(numpy_image, 5) 202 | else: 203 | ablated = cv2.GaussianBlur(numpy_image, (7, 7), 0) 204 | return ablated 205 | 206 | 207 | def interpolation(x, x_prime, fold, mode='linear'): 208 | diff = x - x_prime 209 | l = np.linspace(0, 1, fold).reshape((fold, 1, 1, 1)) 210 | interp_list = l * diff + x_prime 211 | return interp_list 212 | 213 | 214 | def isotropic_gaussian_kernel(l, sigma, epsilon=1e-5): 215 | ax = np.arange(-l // 2 + 1., l // 2 + 1.) 216 | xx, yy = np.meshgrid(ax, ax) 217 | kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * (sigma + epsilon) ** 2)) 218 | return kernel / np.sum(kernel) 219 | 220 | 221 | def gini(array): 222 | """Calculate the Gini coefficient of a numpy array.""" 223 | # based on bottom eq: 224 | # http://www.statsdirect.com/help/generatedimages/equations/equation154.svg 225 | # from: 226 | # http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm 227 | # All values are treated equally, arrays must be 1d: 228 | array = array.flatten() 229 | if np.amin(array) < 0: 230 | # Values cannot be negative: 231 | array -= np.amin(array) 232 | # Values cannot be 0: 233 | array += 0.0000001 234 | # Values must be sorted: 235 | array = np.sort(array) 236 | # Index per array element: 237 | index = np.arange(1,array.shape[0]+1) 238 | # Number of array elements: 239 | n = array.shape[0] 240 | # Gini coefficient: 241 | return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array))) 242 | 243 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ##
Multi-scale Attention Network for Single Image Super-Resolution
2 | 3 |
4 | 5 | [Yan Wang](https://scholar.google.com/citations?user=SXIehvoAAAAJ&hl=en), [Yusen Li](https://scholar.google.com/citations?user=4EJ9aekAAAAJ&hl=en&oi=ao), Gang Wang, Xiaoguang Liu 6 |
7 | 8 |

Nankai University

9 | 10 |

11 | 12 | 13 | 14 | 15 | 16 | 17 |

18 | 19 | **Overview:** To unleash the potential of ConvNet in super-resolution, we propose a multi-scale attention network (MAN), by coupling a classical multi-scale mechanism with emerging large kernel attention. In particular, we proposed multi-scale large kernel attention (MLKA) and gated spatial attention unit (GSAU). Experimental results illustrate that our MAN can perform on par with SwinIR and achieve varied trade-offs between state-of-the-art performance and computations. 20 | 21 | 22 | 23 | This repository contains [PyTorch](https://pytorch.org/) implementation for ***MAN*** (CVPRW 2024). 24 |
Table of contents 25 |

26 | 27 | 1. [Requirements](#%EF%B8%8F-requirements) 28 | 2. [Datasets](#-datasets) 29 | 3. [Implementary Details](#-implementary-details) 30 | 4. [Train and Test](#%EF%B8%8F-train-and-test) 31 | 5. [Results and Models](#-results-and-models) 32 | 6. [Acknowledgments](#-acknowledgments) 33 | 7. [Citation](#-citation) 34 | 35 |

36 | 37 |
38 | 39 | 40 | --- 41 | 42 | ⚙️ Requirements 43 | --- 44 | - [PyTorch >= 1.8](https://pytorch.org/) 45 | - [BasicSR >= 1.3.5](https://github.com/xinntao/BasicSR-examples/blob/master/README.md) 46 | 47 | 48 | 🎈 Datasets 49 | --- 50 | 51 | *Training*: [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [DF2K](https://openmmlab.medium.com/awesome-datasets-for-super-resolution-introduction-and-pre-processing-55f8501f8b18). 52 | 53 | *Testing*: Set5, Set14, BSD100, Urban100, Manga109 ([Google Drive](https://drive.google.com/file/d/1SbdbpUZwWYDIEhvxQQaRsokySkcYJ8dq/view?usp=sharing)/[Baidu Netdisk](https://pan.baidu.com/s/1zfmkFK3liwNpW4NtPnWbrw?pwd=nbjl)). 54 | 55 | *Preparing*: Please refer to the [Dataset Preparation](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md) of BasicSR. 56 | 57 | 🔎 Implementary Details 58 | --- 59 | 60 | *Network architecture*: Group number (n_resgroups): *1 for simplicity*, MAB number (n_resblocks): *5/24/36*, channel width (n_feats): *48/60/180* for *tiny/light/base MAN*. 61 |

62 |

63 | Overview of the proposed MAN constituted of three components: the shallow feature extraction module (SF), the deep feature extraction module (DF) based on 64 | multiple multi-scale attention blocks (MAB), and the high-quality image reconstruction module. 65 | 66 |   67 | 68 | *Component details:* Three multi-scale decomposition modes are utilized in MLKA. The 7×7 depth-wise convolution is used in the GSAU. 69 |

70 |

71 | Details of Multi-scale Large Kernel Attention (MLKA), Gated Spatial Attention Unit (GSAU), and Large Kernel Attention Tail (LKAT). 72 |   73 | 74 | ▶️ Train and Test 75 | --- 76 | 77 | The [BasicSR](https://github.com/XPixelGroup/BasicSR) framework is utilized to train our MAN, also testing. 78 | 79 | #### Training with the example option 80 | 81 | ``` 82 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 83 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/trian_MAN.yml --launcher pytorch 84 | ``` 85 | #### Testing with the example option 86 | 87 | ``` 88 | python test.py -opt options/test_MAN.yml 89 | ``` 90 | 91 | The training/testing results will be saved in the `./experiments` and `./results` folders, respectively. 92 | 93 | 📊 Results and Models 94 | --- 95 | 96 | Pretrained models available at [Google Drive](https://drive.google.com/drive/folders/1sARYFkVeTIFVCa2EnZg9TjZvirDvUNOL?usp=sharing) and [Baidu Netdisk](https://pan.baidu.com/s/15CTY-mgdTuOc1I8mzIA4Ug?pwd=mans) (pwd: **mans** for all links). 97 | 98 | |HR (x4)|MAN-tiny|[EDSR-base+](https://github.com/sanghyun-son/EDSR-PyTorch)|MAN-light|[EDSR+](https://github.com/sanghyun-son/EDSR-PyTorch)|MAN| 99 | | :----- | :-----: | :-----: | :-----: | :-----: | :-----: | 100 | | | ||||| 101 | | | ||||| 102 | | | ||||| 103 | | | ||||| 104 | |**Params/FLOPs**| 150K/8G|1518K/114G|840K/47G|43090K/2895G|8712K/495G| 105 | 106 | Results of our MAN-tiny/light/base models. Set5 validation set is used below to show the general performance. The visual results of five testsets are provided in the last column. 107 | 108 | | Methods | Params | FLOPs |PSNR/SSIM (x2)|PSNR/SSIM (x3)|PSNR/SSIM (x4)|Results| 109 | |:---------|:---------:|:--------:|:------:|:------:|:------:|:--------:| 110 | | MAN-tiny | 150K | 8.4G | 37.91/0.9603 | 34.23/0.9258 | 32.07/0.8930 | [x2](https://pan.baidu.com/s/1mYkGvAlz0bSZuCVubkpsmg?pwd=mans)/[x3](https://pan.baidu.com/s/1RP5gGu-QPXTkH1NPH7axag?pwd=mans)/[x4](https://pan.baidu.com/s/1u22su2bT4Pq_idVxAnqWdw?pwd=mans) | 111 | | MAN-light| 840K | 47.1G | 38.18/0.9612 | 34.65/0.9292 | 32.50/0.8988 | [x2](https://pan.baidu.com/s/1AVuPa7bsbb3qMQqMSM-IJQ?pwd=mans)/[x3](https://pan.baidu.com/s/1TRL7-Y23JddVOpEhH0ObEQ?pwd=mans)/[x4](https://pan.baidu.com/s/1T2bPZcjFRxAgMxGWtPv-Lw?pwd=mans) | 112 | | MAN+ | 8712K | 495G | 38.44/0.9623 | 34.97/0.9315 | 32.87/0.9030 | [x2](https://pan.baidu.com/s/1pTb3Fob_7MOxMKIdopI0hQ?pwd=mans)/[x3](https://pan.baidu.com/s/1L3HEtcraU8Y9VY-HpCZdfg?pwd=mans)/[x4](https://pan.baidu.com/s/1FCNqht9zi9HecG3ExRdeWQ?pwd=mans) | 113 | 114 | 💖 Acknowledgments 115 | --- 116 | 117 | We would thank [VAN](https://github.com/Visual-Attention-Network/VAN-Classification) and [BasicSR](https://github.com/XPixelGroup/BasicSR) for their enlightening work! 118 | 119 | 🎓 Citation 120 | --- 121 | ``` 122 | @inproceedings{wang2024multi, 123 | title={Multi-scale Attention Network for Single Image Super-Resolution}, 124 | author={Wang, Yan and Li, Yusen and Wang, Gang and Liu, Xiaoguang}, 125 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 126 | year={2024} 127 | } 128 | ``` 129 | 130 | or 131 | 132 | ``` 133 | @article{wang2022multi, 134 | title={Multi-scale Attention Network for Single Image Super-Resolution}, 135 | author={Wang, Yan and Li, Yusen and Wang, Gang and Liu, Xiaoguang}, 136 | journal={arXiv preprint arXiv:2209.14145}, 137 | year={2022} 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/drln.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..NN import drln_ops as ops 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_model(args, parent=False): 8 | return DRLN(args) 9 | 10 | 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | 17 | self.c1 = ops.BasicBlock(channel, channel // reduction, 3, 1, 3, 3) 18 | self.c2 = ops.BasicBlock(channel, channel // reduction, 3, 1, 5, 5) 19 | self.c3 = ops.BasicBlock(channel, channel // reduction, 3, 1, 7, 7) 20 | self.c4 = ops.BasicBlockSig((channel // reduction) * 3, channel, 3, 1, 1) 21 | 22 | def forward(self, x): 23 | y = self.avg_pool(x) 24 | c1 = self.c1(y) 25 | c2 = self.c2(y) 26 | c3 = self.c3(y) 27 | c_out = torch.cat([c1, c2, c3], dim=1) 28 | y = self.c4(c_out) 29 | return x * y 30 | 31 | 32 | class Block(nn.Module): 33 | def __init__(self, in_channels, out_channels, group=1): 34 | super(Block, self).__init__() 35 | 36 | self.r1 = ops.ResidualBlock(in_channels, out_channels) 37 | self.r2 = ops.ResidualBlock(in_channels * 2, out_channels * 2) 38 | self.r3 = ops.ResidualBlock(in_channels * 4, out_channels * 4) 39 | self.g = ops.BasicBlock(in_channels * 8, out_channels, 1, 1, 0) 40 | self.ca = CALayer(in_channels) 41 | 42 | def forward(self, x): 43 | c0 = x 44 | 45 | r1 = self.r1(c0) 46 | c1 = torch.cat([c0, r1], dim=1) 47 | 48 | r2 = self.r2(c1) 49 | c2 = torch.cat([c1, r2], dim=1) 50 | 51 | r3 = self.r3(c2) 52 | c3 = torch.cat([c2, r3], dim=1) 53 | 54 | g = self.g(c3) 55 | out = self.ca(g) 56 | return out 57 | 58 | 59 | class DRLN(nn.Module): 60 | def __init__(self): 61 | super(DRLN, self).__init__() 62 | 63 | # n_resgroups = args.n_resgroups 64 | # n_resblocks = args.n_resblocks 65 | # n_feats = args.n_feats 66 | # kernel_size = 3 67 | # reduction = args.reduction 68 | # scale = args.scale[0] 69 | # act = nn.ReLU(True) 70 | 71 | self.scale = 4 72 | chs = 64 73 | 74 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 75 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 76 | 77 | self.head = nn.Conv2d(3, chs, 3, 1, 1) 78 | 79 | self.b1 = Block(chs, chs) 80 | self.b2 = Block(chs, chs) 81 | self.b3 = Block(chs, chs) 82 | self.b4 = Block(chs, chs) 83 | self.b5 = Block(chs, chs) 84 | self.b6 = Block(chs, chs) 85 | self.b7 = Block(chs, chs) 86 | self.b8 = Block(chs, chs) 87 | self.b9 = Block(chs, chs) 88 | self.b10 = Block(chs, chs) 89 | self.b11 = Block(chs, chs) 90 | self.b12 = Block(chs, chs) 91 | self.b13 = Block(chs, chs) 92 | self.b14 = Block(chs, chs) 93 | self.b15 = Block(chs, chs) 94 | self.b16 = Block(chs, chs) 95 | self.b17 = Block(chs, chs) 96 | self.b18 = Block(chs, chs) 97 | self.b19 = Block(chs, chs) 98 | self.b20 = Block(chs, chs) 99 | 100 | self.c1 = ops.BasicBlock(chs * 2, chs, 3, 1, 1) 101 | self.c2 = ops.BasicBlock(chs * 3, chs, 3, 1, 1) 102 | self.c3 = ops.BasicBlock(chs * 4, chs, 3, 1, 1) 103 | self.c4 = ops.BasicBlock(chs * 2, chs, 3, 1, 1) 104 | self.c5 = ops.BasicBlock(chs * 3, chs, 3, 1, 1) 105 | self.c6 = ops.BasicBlock(chs * 4, chs, 3, 1, 1) 106 | self.c7 = ops.BasicBlock(chs * 2, chs, 3, 1, 1) 107 | self.c8 = ops.BasicBlock(chs * 3, chs, 3, 1, 1) 108 | self.c9 = ops.BasicBlock(chs * 4, chs, 3, 1, 1) 109 | self.c10 = ops.BasicBlock(chs * 2, chs, 3, 1, 1) 110 | self.c11 = ops.BasicBlock(chs * 3, chs, 3, 1, 1) 111 | self.c12 = ops.BasicBlock(chs * 4, chs, 3, 1, 1) 112 | self.c13 = ops.BasicBlock(chs * 2, chs, 3, 1, 1) 113 | self.c14 = ops.BasicBlock(chs * 3, chs, 3, 1, 1) 114 | self.c15 = ops.BasicBlock(chs * 4, chs, 3, 1, 1) 115 | self.c16 = ops.BasicBlock(chs * 5, chs, 3, 1, 1) 116 | self.c17 = ops.BasicBlock(chs * 2, chs, 3, 1, 1) 117 | self.c18 = ops.BasicBlock(chs * 3, chs, 3, 1, 1) 118 | self.c19 = ops.BasicBlock(chs * 4, chs, 3, 1, 1) 119 | self.c20 = ops.BasicBlock(chs * 5, chs, 3, 1, 1) 120 | 121 | self.upsample = ops.UpsampleBlock(chs, self.scale, multi_scale=False) 122 | # self.convert = ops.ConvertBlock(chs, chs, 20) 123 | self.tail = nn.Conv2d(chs, 3, 3, 1, 1) 124 | 125 | def forward(self, x): 126 | x = self.sub_mean(x * 255.) 127 | x = self.head(x) 128 | c0 = o0 = x 129 | 130 | b1 = self.b1(o0) 131 | c1 = torch.cat([c0, b1], dim=1) 132 | o1 = self.c1(c1) 133 | 134 | b2 = self.b2(o1) 135 | c2 = torch.cat([c1, b2], dim=1) 136 | o2 = self.c2(c2) 137 | 138 | b3 = self.b3(o2) 139 | c3 = torch.cat([c2, b3], dim=1) 140 | o3 = self.c3(c3) 141 | a1 = o3 + c0 142 | 143 | b4 = self.b4(a1) 144 | c4 = torch.cat([o3, b4], dim=1) 145 | o4 = self.c4(c4) 146 | 147 | b5 = self.b5(a1) 148 | c5 = torch.cat([c4, b5], dim=1) 149 | o5 = self.c5(c5) 150 | 151 | b6 = self.b6(o5) 152 | c6 = torch.cat([c5, b6], dim=1) 153 | o6 = self.c6(c6) 154 | a2 = o6 + a1 155 | 156 | b7 = self.b7(a2) 157 | c7 = torch.cat([o6, b7], dim=1) 158 | o7 = self.c7(c7) 159 | 160 | b8 = self.b8(o7) 161 | c8 = torch.cat([c7, b8], dim=1) 162 | o8 = self.c8(c8) 163 | 164 | b9 = self.b9(o8) 165 | c9 = torch.cat([c8, b9], dim=1) 166 | o9 = self.c9(c9) 167 | a3 = o9 + a2 168 | 169 | b10 = self.b10(a3) 170 | c10 = torch.cat([o9, b10], dim=1) 171 | o10 = self.c10(c10) 172 | 173 | b11 = self.b11(o10) 174 | c11 = torch.cat([c10, b11], dim=1) 175 | o11 = self.c11(c11) 176 | 177 | b12 = self.b12(o11) 178 | c12 = torch.cat([c11, b12], dim=1) 179 | o12 = self.c12(c12) 180 | a4 = o12 + a3 181 | 182 | b13 = self.b13(a4) 183 | c13 = torch.cat([o12, b13], dim=1) 184 | o13 = self.c13(c13) 185 | 186 | b14 = self.b14(o13) 187 | c14 = torch.cat([c13, b14], dim=1) 188 | o14 = self.c14(c14) 189 | 190 | b15 = self.b15(o14) 191 | c15 = torch.cat([c14, b15], dim=1) 192 | o15 = self.c15(c15) 193 | 194 | b16 = self.b16(o15) 195 | c16 = torch.cat([c15, b16], dim=1) 196 | o16 = self.c16(c16) 197 | a5 = o16 + a4 198 | 199 | b17 = self.b17(a5) 200 | c17 = torch.cat([o16, b17], dim=1) 201 | o17 = self.c17(c17) 202 | 203 | b18 = self.b18(o17) 204 | c18 = torch.cat([c17, b18], dim=1) 205 | o18 = self.c18(c18) 206 | 207 | b19 = self.b19(o18) 208 | c19 = torch.cat([c18, b19], dim=1) 209 | o19 = self.c19(c19) 210 | 211 | b20 = self.b20(o19) 212 | c20 = torch.cat([c19, b20], dim=1) 213 | o20 = self.c20(c20) 214 | a6 = o20 + a5 215 | 216 | # c_out = torch.cat([b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20], dim=1) 217 | 218 | # b = self.convert(c_out) 219 | b_out = a6 + x 220 | out = self.upsample(b_out, scale=self.scale) 221 | 222 | out = self.tail(out) 223 | f_out = self.add_mean(out) 224 | 225 | return f_out / 255. 226 | 227 | def load_state_dict(self, state_dict, strict=False): 228 | own_state = self.state_dict() 229 | for name, param in state_dict.items(): 230 | if name in own_state: 231 | if isinstance(param, nn.Parameter): 232 | param = param.data 233 | try: 234 | own_state[name].copy_(param) 235 | except Exception: 236 | if name.find('tail') >= 0 or name.find('upsample') >= 0: 237 | print('Replace pre-trained upsampler to new one...') 238 | else: 239 | raise RuntimeError('While copying the parameter named {}, ' 240 | 'whose dimensions in the model are {} and ' 241 | 'whose dimensions in the checkpoint are {}.' 242 | .format(name, own_state[name].size(), param.size())) 243 | elif strict: 244 | if name.find('tail') == -1: 245 | raise KeyError('unexpected key "{}" in state_dict' 246 | .format(name)) 247 | 248 | if strict: 249 | missing = set(own_state.keys()) - set(state_dict.keys()) 250 | if len(missing) > 0: 251 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 252 | 253 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright MAN Authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | 11 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 12 | return nn.Conv2d( 13 | in_channels, out_channels, kernel_size, 14 | padding=(kernel_size//2), bias=bias) 15 | 16 | 17 | class MeanShift(nn.Conv2d): 18 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 19 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 20 | std = torch.Tensor(rgb_std) 21 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 22 | self.weight.data.div_(std.view(3, 1, 1, 1)) 23 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 24 | self.bias.data.div_(std) 25 | self.requires_grad = False 26 | 27 | 28 | class DownBlock(nn.Module): 29 | def __init__(self, opt, scale, nFeat=None, in_channels=None, out_channels=None): 30 | super(DownBlock, self).__init__() 31 | negval = opt.negval 32 | 33 | if nFeat is None: 34 | nFeat = opt.n_feats 35 | 36 | if in_channels is None: 37 | in_channels = opt.n_colors 38 | 39 | if out_channels is None: 40 | out_channels = opt.n_colors 41 | 42 | dual_block = [ 43 | nn.Sequential( 44 | nn.Conv2d(in_channels, nFeat, kernel_size=3, stride=2, padding=1, bias=False), 45 | nn.LeakyReLU(negative_slope=negval, inplace=True) 46 | ) 47 | ] 48 | 49 | for _ in range(1, int(np.log2(scale))): 50 | dual_block.append( 51 | nn.Sequential( 52 | nn.Conv2d(nFeat, nFeat, kernel_size=3, stride=2, padding=1, bias=False), 53 | nn.LeakyReLU(negative_slope=negval, inplace=True) 54 | ) 55 | ) 56 | 57 | dual_block.append(nn.Conv2d(nFeat, out_channels, kernel_size=3, stride=1, padding=1, bias=False)) 58 | 59 | self.dual_module = nn.Sequential(*dual_block) 60 | 61 | def forward(self, x): 62 | x = self.dual_module(x) 63 | return x 64 | 65 | 66 | class BasicBlock(nn.Sequential): 67 | def __init__( 68 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 69 | bn=True, act=nn.ReLU(True)): 70 | 71 | m = [nn.Conv2d( 72 | in_channels, out_channels, kernel_size, 73 | padding=(kernel_size//2), stride=stride, bias=bias) 74 | ] 75 | if bn: m.append(nn.BatchNorm2d(out_channels)) 76 | if act is not None: m.append(act) 77 | super(BasicBlock, self).__init__(*m) 78 | 79 | 80 | class ResBlock(nn.Module): 81 | def __init__( 82 | self, conv, n_feat, kernel_size, 83 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 84 | 85 | super(ResBlock, self).__init__() 86 | m = [] 87 | for i in range(2): 88 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 89 | if bn: m.append(nn.BatchNorm2d(n_feat)) 90 | if i == 0: m.append(act) 91 | 92 | self.body = nn.Sequential(*m) 93 | self.res_scale = res_scale 94 | 95 | def forward(self, x): 96 | res = self.body(x).mul(self.res_scale) 97 | res += x 98 | 99 | return res 100 | 101 | 102 | class Upsampler(nn.Sequential): 103 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 104 | 105 | m = [] 106 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 107 | for _ in range(int(math.log(scale, 2))): 108 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 109 | m.append(nn.PixelShuffle(2)) 110 | if bn: m.append(nn.BatchNorm2d(n_feat)) 111 | if act: m.append(act()) 112 | elif scale == 3: 113 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 114 | m.append(nn.PixelShuffle(3)) 115 | if bn: m.append(nn.BatchNorm2d(n_feat)) 116 | if act: m.append(act()) 117 | else: 118 | raise NotImplementedError 119 | 120 | super(Upsampler, self).__init__(*m) 121 | 122 | 123 | ## add SELayer 124 | class SELayer(nn.Module): 125 | def __init__(self, channel, reduction=16): 126 | super(SELayer, self).__init__() 127 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 128 | self.conv_du = nn.Sequential( 129 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 130 | nn.ReLU(inplace=True), 131 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 132 | nn.Sigmoid() 133 | ) 134 | 135 | def forward(self, x): 136 | y = self.avg_pool(x) 137 | y = self.conv_du(y) 138 | return x * y 139 | 140 | 141 | ## add SEResBlock 142 | class SEResBlock(nn.Module): 143 | def __init__( 144 | self, conv, n_feat, kernel_size, reduction, 145 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 146 | 147 | super(SEResBlock, self).__init__() 148 | modules_body = [] 149 | for i in range(2): 150 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 151 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 152 | if i == 0: modules_body.append(act) 153 | modules_body.append(SELayer(n_feat, reduction)) 154 | self.body = nn.Sequential(*modules_body) 155 | self.res_scale = res_scale 156 | 157 | def forward(self, x): 158 | res = self.body(x) 159 | #res = self.body(x).mul(self.res_scale) 160 | res += x 161 | 162 | return res 163 | 164 | 165 | ## Channel Attention (CA) Layer 166 | class CALayer(nn.Module): 167 | def __init__(self, channel, reduction=16): 168 | super(CALayer, self).__init__() 169 | # global average pooling: feature --> point 170 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 171 | # feature channel downscale and upscale --> channel weight 172 | self.conv_du = nn.Sequential( 173 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 176 | nn.Sigmoid() 177 | ) 178 | 179 | def forward(self, x): 180 | y = self.avg_pool(x) 181 | y = self.conv_du(y) 182 | return x * y 183 | 184 | 185 | ## Residual Channel Attention Block (RCAB) 186 | class RCAB(nn.Module): 187 | def __init__(self, conv, n_feat, kernel_size, reduction=16, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 188 | super(RCAB, self).__init__() 189 | modules_body = [] 190 | for i in range(2): 191 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 192 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 193 | if i == 0: modules_body.append(act) 194 | modules_body.append(CALayer(n_feat, reduction)) 195 | self.body = nn.Sequential(*modules_body) 196 | self.res_scale = res_scale 197 | 198 | def forward(self, x): 199 | res = self.body(x) 200 | res += x 201 | return res 202 | 203 | 204 | # add NonLocalBlock2D 205 | # reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py 206 | class NonLocalBlock2D(nn.Module): 207 | def __init__(self, in_channels, inter_channels): 208 | super(NonLocalBlock2D, self).__init__() 209 | 210 | self.in_channels = in_channels 211 | self.inter_channels = inter_channels 212 | 213 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 214 | padding=0) 215 | 216 | self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, 217 | padding=0) 218 | # for pytorch 0.3.1 219 | # nn.init.constant(self.W.weight, 0) 220 | # nn.init.constant(self.W.bias, 0) 221 | # for pytorch 0.4.0 222 | nn.init.constant_(self.W.weight, 0) 223 | nn.init.constant_(self.W.bias, 0) 224 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 225 | padding=0) 226 | 227 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 228 | padding=0) 229 | 230 | def forward(self, x): 231 | batch_size = x.size(0) 232 | 233 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 234 | 235 | g_x = g_x.permute(0, 2, 1) 236 | 237 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 238 | 239 | theta_x = theta_x.permute(0, 2, 1) 240 | 241 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 242 | 243 | f = torch.matmul(theta_x, phi_x) 244 | 245 | f_div_C = F.softmax(f, dim=1) 246 | 247 | y = torch.matmul(f_div_C, g_x) 248 | 249 | y = y.permute(0, 2, 1).contiguous() 250 | 251 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 252 | W_y = self.W(y) 253 | z = W_y + x 254 | 255 | return z 256 | 257 | 258 | ## define trunk branch 259 | class TrunkBranch(nn.Module): 260 | def __init__( 261 | self, conv, n_feat, kernel_size, 262 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 263 | super(TrunkBranch, self).__init__() 264 | modules_body = [] 265 | for i in range(2): 266 | modules_body.append( 267 | ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 268 | self.body = nn.Sequential(*modules_body) 269 | 270 | def forward(self, x): 271 | tx = self.body(x) 272 | 273 | return tx 274 | 275 | 276 | ## define mask branch 277 | class MaskBranchDownUp(nn.Module): 278 | def __init__( 279 | self, conv, n_feat, kernel_size, 280 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 281 | super(MaskBranchDownUp, self).__init__() 282 | 283 | MB_RB1 = [] 284 | MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 285 | 286 | MB_Down = [] 287 | MB_Down.append(nn.Conv2d(n_feat, n_feat, 3, stride=2, padding=1)) 288 | 289 | MB_RB2 = [] 290 | for i in range(2): 291 | MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 292 | 293 | MB_Up = [] 294 | MB_Up.append(nn.ConvTranspose2d(n_feat, n_feat, 6, stride=2, padding=2)) 295 | 296 | MB_RB3 = [] 297 | MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 298 | 299 | MB_1x1conv = [] 300 | MB_1x1conv.append(nn.Conv2d(n_feat, n_feat, 1, padding=0, bias=True)) 301 | 302 | MB_sigmoid = [] 303 | MB_sigmoid.append(nn.Sigmoid()) 304 | 305 | self.MB_RB1 = nn.Sequential(*MB_RB1) 306 | self.MB_Down = nn.Sequential(*MB_Down) 307 | self.MB_RB2 = nn.Sequential(*MB_RB2) 308 | self.MB_Up = nn.Sequential(*MB_Up) 309 | self.MB_RB3 = nn.Sequential(*MB_RB3) 310 | self.MB_1x1conv = nn.Sequential(*MB_1x1conv) 311 | self.MB_sigmoid = nn.Sequential(*MB_sigmoid) 312 | 313 | def forward(self, x): 314 | x_RB1 = self.MB_RB1(x) 315 | x_Down = self.MB_Down(x_RB1) 316 | x_RB2 = self.MB_RB2(x_Down) 317 | x_Up = self.MB_Up(x_RB2) 318 | x_preRB3 = x_RB1 + x_Up 319 | x_RB3 = self.MB_RB3(x_preRB3) 320 | x_1x1 = self.MB_1x1conv(x_RB3) 321 | mx = self.MB_sigmoid(x_1x1) 322 | 323 | return mx 324 | 325 | 326 | ## define nonlocal mask branch 327 | class NLMaskBranchDownUp(nn.Module): 328 | def __init__( 329 | self, conv, n_feat, kernel_size, 330 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 331 | super(NLMaskBranchDownUp, self).__init__() 332 | 333 | MB_RB1 = [] 334 | MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2)) 335 | MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 336 | 337 | MB_Down = [] 338 | MB_Down.append(nn.Conv2d(n_feat, n_feat, 3, stride=2, padding=1)) 339 | 340 | MB_RB2 = [] 341 | for i in range(2): 342 | MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 343 | 344 | MB_Up = [] 345 | MB_Up.append(nn.ConvTranspose2d(n_feat, n_feat, 6, stride=2, padding=2)) 346 | 347 | MB_RB3 = [] 348 | MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 349 | 350 | MB_1x1conv = [] 351 | MB_1x1conv.append(nn.Conv2d(n_feat, n_feat, 1, padding=0, bias=True)) 352 | 353 | MB_sigmoid = [] 354 | MB_sigmoid.append(nn.Sigmoid()) 355 | 356 | self.MB_RB1 = nn.Sequential(*MB_RB1) 357 | self.MB_Down = nn.Sequential(*MB_Down) 358 | self.MB_RB2 = nn.Sequential(*MB_RB2) 359 | self.MB_Up = nn.Sequential(*MB_Up) 360 | self.MB_RB3 = nn.Sequential(*MB_RB3) 361 | self.MB_1x1conv = nn.Sequential(*MB_1x1conv) 362 | self.MB_sigmoid = nn.Sequential(*MB_sigmoid) 363 | 364 | def forward(self, x): 365 | x_RB1 = self.MB_RB1(x) 366 | x_Down = self.MB_Down(x_RB1) 367 | x_RB2 = self.MB_RB2(x_Down) 368 | x_Up = self.MB_Up(x_RB2) 369 | x_preRB3 = x_RB1 + x_Up 370 | x_RB3 = self.MB_RB3(x_preRB3) 371 | x_1x1 = self.MB_1x1conv(x_RB3) 372 | mx = self.MB_sigmoid(x_1x1) 373 | 374 | return mx 375 | 376 | 377 | ## define residual attention module 378 | class ResAttModuleDownUpPlus(nn.Module): 379 | def __init__( 380 | self, conv, n_feat, kernel_size, 381 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 382 | super(ResAttModuleDownUpPlus, self).__init__() 383 | RA_RB1 = [] 384 | RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 385 | RA_TB = [] 386 | RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 387 | RA_MB = [] 388 | RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 389 | RA_tail = [] 390 | for i in range(2): 391 | RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 392 | 393 | self.RA_RB1 = nn.Sequential(*RA_RB1) 394 | self.RA_TB = nn.Sequential(*RA_TB) 395 | self.RA_MB = nn.Sequential(*RA_MB) 396 | self.RA_tail = nn.Sequential(*RA_tail) 397 | 398 | def forward(self, input): 399 | RA_RB1_x = self.RA_RB1(input) 400 | tx = self.RA_TB(RA_RB1_x) 401 | mx = self.RA_MB(RA_RB1_x) 402 | txmx = tx * mx 403 | hx = txmx + RA_RB1_x 404 | hx = self.RA_tail(hx) 405 | 406 | return hx 407 | 408 | 409 | ## define nonlocal residual attention module 410 | class NLResAttModuleDownUpPlus(nn.Module): 411 | def __init__( 412 | self, conv, n_feat, kernel_size, 413 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 414 | super(NLResAttModuleDownUpPlus, self).__init__() 415 | RA_RB1 = [] 416 | RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 417 | RA_TB = [] 418 | RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 419 | RA_MB = [] 420 | RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 421 | RA_tail = [] 422 | for i in range(2): 423 | RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1)) 424 | 425 | self.RA_RB1 = nn.Sequential(*RA_RB1) 426 | self.RA_TB = nn.Sequential(*RA_TB) 427 | self.RA_MB = nn.Sequential(*RA_MB) 428 | self.RA_tail = nn.Sequential(*RA_tail) 429 | 430 | def forward(self, input): 431 | RA_RB1_x = self.RA_RB1(input) 432 | tx = self.RA_TB(RA_RB1_x) 433 | mx = self.RA_MB(RA_RB1_x) 434 | txmx = tx * mx 435 | hx = txmx + RA_RB1_x 436 | hx = self.RA_tail(hx) 437 | 438 | return hx -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/man.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from basicsr.utils.registry import ARCH_REGISTRY 9 | 10 | #LKA from VAN (https://github.com/Visual-Attention-Network) 11 | class LKA(nn.Module): 12 | def __init__(self, dim): 13 | super().__init__() 14 | self.conv0 = nn.Conv2d(dim, dim, 7, padding=7//2, groups=dim) 15 | self.conv_spatial = nn.Conv2d(dim, dim, 9, stride=1, padding=((9//2)*4), groups=dim, dilation=4) 16 | self.conv1 = nn.Conv2d(dim, dim, 1) 17 | 18 | def forward(self, x): 19 | u = x.clone() 20 | attn = self.conv0(x) 21 | attn = self.conv_spatial(attn) 22 | attn = self.conv1(attn) 23 | 24 | return u * attn 25 | 26 | class Attention(nn.Module): 27 | def __init__(self, n_feats): 28 | super().__init__() 29 | 30 | self.norm = LayerNorm(n_feats, data_format='channels_first') 31 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 32 | 33 | self.proj_1 = nn.Conv2d(n_feats, n_feats, 1) 34 | self.spatial_gating_unit = LKA(n_feats) 35 | self.proj_2 = nn.Conv2d(n_feats, n_feats, 1) 36 | 37 | def forward(self, x): 38 | shorcut = x.clone() 39 | x = self.proj_1(self.norm(x)) 40 | x = self.spatial_gating_unit(x) 41 | x = self.proj_2(x) 42 | x = x*self.scale + shorcut 43 | return x 44 | #---------------------------------------------------------------------------------------------------------------- 45 | 46 | class MLP(nn.Module): 47 | def __init__(self, n_feats): 48 | super().__init__() 49 | 50 | self.norm = LayerNorm(n_feats, data_format='channels_first') 51 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 52 | 53 | i_feats = 2*n_feats 54 | 55 | self.fc1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 56 | self.act = nn.GELU() 57 | self.fc2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0) 58 | 59 | def forward(self, x): 60 | shortcut = x.clone() 61 | x = self.norm(x) 62 | x = self.fc1(x) 63 | x = self.act(x) 64 | x = self.fc2(x) 65 | 66 | return x*self.scale + shortcut 67 | 68 | 69 | class CFF(nn.Module): 70 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'): 71 | super().__init__() 72 | i_feats =n_feats*2 73 | 74 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 75 | self.DWConv1 = nn.Sequential( 76 | nn.Conv2d(i_feats, i_feats, 7, 1, 7//2, groups= n_feats), 77 | nn.GELU()) 78 | self.Conv2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0) 79 | 80 | self.norm = LayerNorm(n_feats, data_format='channels_first') 81 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 82 | 83 | def forward(self, x): 84 | shortcut = x.clone() 85 | 86 | #Ghost Expand 87 | x = self.Conv1(self.norm(x)) 88 | x = self.DWConv1(x) 89 | x = self.Conv2(x) 90 | 91 | return x*self.scale + shortcut 92 | 93 | 94 | class SimpleGate(nn.Module): 95 | def __init__(self, n_feats): 96 | super().__init__() 97 | i_feats =n_feats*2 98 | 99 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 100 | #self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats) 101 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 102 | 103 | self.norm = LayerNorm(n_feats, data_format='channels_first') 104 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 105 | 106 | def forward(self, x): 107 | shortcut = x.clone() 108 | 109 | #Ghost Expand 110 | x = self.Conv1(self.norm(x)) 111 | a, x = torch.chunk(x, 2, dim=1) 112 | x = x*a #self.DWConv1(a) 113 | x = self.Conv2(x) 114 | 115 | return x*self.scale + shortcut 116 | #----------------------------------------------------------------------------------------------------------------- 117 | #RCAN-style 118 | class RCBv6(nn.Module): 119 | def __init__( 120 | self, n_feats, k, lk=7, res_scale=1.0, style ='X', act = nn.SiLU(), deploy=False): 121 | super().__init__() 122 | self.LKA = nn.Sequential( 123 | nn.Conv2d(n_feats, n_feats, 5, 1, lk//2, groups= n_feats), 124 | nn.Conv2d(n_feats, n_feats, 7, stride=1, padding=9, groups=n_feats, dilation=3), 125 | nn.Conv2d(n_feats, n_feats, 1, 1, 0), 126 | nn.Sigmoid()) 127 | 128 | #self.LFE2 = LFEv3(n_feats, attn ='CA') 129 | 130 | self.LFE = nn.Sequential( 131 | nn.Conv2d(n_feats,n_feats,3,1,1), 132 | nn.GELU(), 133 | nn.Conv2d(n_feats,n_feats,3,1,1)) 134 | 135 | 136 | def forward(self, x, pre_attn=None, RAA=None): 137 | shortcut = x.clone() 138 | x = self.LFE(x) 139 | 140 | x = self.LKA(x)*x 141 | 142 | return x + shortcut 143 | 144 | #----------------------------------------------------------------------------------------------------------------- 145 | 146 | 147 | class MLKA_Ablation(nn.Module): 148 | def __init__(self, n_feats, k=2, squeeze_factor=15): 149 | super().__init__() 150 | i_feats = 2*n_feats 151 | 152 | self.n_feats= n_feats 153 | self.i_feats = i_feats 154 | 155 | self.norm = LayerNorm(n_feats, data_format='channels_first') 156 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 157 | 158 | k = 2 159 | 160 | #Multiscale Large Kernel Attention 161 | self.LKA7 = nn.Sequential( 162 | nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k), 163 | nn.Conv2d(n_feats//k, n_feats//k, 9, stride=1, padding=(9//2)*4, groups=n_feats//k, dilation=4), 164 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0)) 165 | self.LKA5 = nn.Sequential( 166 | nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k), 167 | nn.Conv2d(n_feats//k, n_feats//k, 7, stride=1, padding=(7//2)*3, groups=n_feats//k, dilation=3), 168 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0)) 169 | '''self.LKA3 = nn.Sequential( 170 | nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k), 171 | nn.Conv2d(n_feats//k, n_feats//k, 5, stride=1, padding=(5//2)*2, groups=n_feats//k, dilation=2), 172 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))''' 173 | 174 | #self.X3 = nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k) 175 | self.X5 = nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k) 176 | self.X7 = nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k) 177 | 178 | self.proj_first = nn.Sequential( 179 | nn.Conv2d(n_feats, i_feats, 1, 1, 0)) 180 | 181 | self.proj_last = nn.Sequential( 182 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 183 | 184 | 185 | def forward(self, x, pre_attn=None, RAA=None): 186 | shortcut = x.clone() 187 | 188 | x = self.norm(x) 189 | 190 | x = self.proj_first(x) 191 | 192 | a, x = torch.chunk(x, 2, dim=1) 193 | 194 | #u_1, u_2, u_3= torch.chunk(u, 3, dim=1) 195 | a_1, a_2 = torch.chunk(a, 2, dim=1) 196 | 197 | a = torch.cat([self.LKA7(a_1)*self.X7(a_1), self.LKA5(a_2)*self.X5(a_2)], dim=1) 198 | 199 | x = self.proj_last(x*a)*self.scale + shortcut 200 | 201 | return x 202 | #----------------------------------------------------------------------------------------------------------------- 203 | 204 | 205 | class LayerNorm(nn.Module): 206 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 207 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 208 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 209 | with shape (batch_size, channels, height, width). 210 | """ 211 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 212 | super().__init__() 213 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 214 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 215 | self.eps = eps 216 | self.data_format = data_format 217 | if self.data_format not in ["channels_last", "channels_first"]: 218 | raise NotImplementedError 219 | self.normalized_shape = (normalized_shape, ) 220 | 221 | def forward(self, x): 222 | if self.data_format == "channels_last": 223 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 224 | elif self.data_format == "channels_first": 225 | u = x.mean(1, keepdim=True) 226 | s = (x - u).pow(2).mean(1, keepdim=True) 227 | x = (x - u) / torch.sqrt(s + self.eps) 228 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 229 | return x 230 | 231 | 232 | 233 | class SGAB(nn.Module): 234 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'): 235 | super().__init__() 236 | i_feats =n_feats*2 237 | 238 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 239 | self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats) 240 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 241 | 242 | self.norm = LayerNorm(n_feats, data_format='channels_first') 243 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 244 | 245 | def forward(self, x): 246 | shortcut = x.clone() 247 | 248 | #Ghost Expand 249 | x = self.Conv1(self.norm(x)) 250 | a, x = torch.chunk(x, 2, dim=1) 251 | x = x*self.DWConv1(a) 252 | x = self.Conv2(x) 253 | 254 | return x*self.scale + shortcut 255 | 256 | class GroupGLKA(nn.Module): 257 | def __init__(self, n_feats, k=2, squeeze_factor=15): 258 | super().__init__() 259 | i_feats = 2*n_feats 260 | 261 | self.n_feats= n_feats 262 | self.i_feats = i_feats 263 | 264 | self.norm = LayerNorm(n_feats, data_format='channels_first') 265 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 266 | 267 | #Multiscale Large Kernel Attention 268 | self.LKA7 = nn.Sequential( 269 | nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3), 270 | nn.Conv2d(n_feats//3, n_feats//3, 9, stride=1, padding=(9//2)*4, groups=n_feats//3, dilation=4), 271 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0)) 272 | self.LKA5 = nn.Sequential( 273 | nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3), 274 | nn.Conv2d(n_feats//3, n_feats//3, 7, stride=1, padding=(7//2)*3, groups=n_feats//3, dilation=3), 275 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0)) 276 | self.LKA3 = nn.Sequential( 277 | nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3), 278 | nn.Conv2d(n_feats//3, n_feats//3, 5, stride=1, padding=(5//2)*2, groups=n_feats//3, dilation=2), 279 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0)) 280 | 281 | self.X3 = nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3) 282 | self.X5 = nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3) 283 | self.X7 = nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3) 284 | 285 | self.proj_first = nn.Sequential( 286 | nn.Conv2d(n_feats, i_feats, 1, 1, 0)) 287 | 288 | self.proj_last = nn.Sequential( 289 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 290 | 291 | 292 | def forward(self, x, pre_attn=None, RAA=None): 293 | shortcut = x.clone() 294 | 295 | x = self.norm(x) 296 | 297 | x = self.proj_first(x) 298 | 299 | a, x = torch.chunk(x, 2, dim=1) 300 | 301 | a_1, a_2, a_3= torch.chunk(a, 3, dim=1) 302 | 303 | a = torch.cat([self.LKA3(a_1)*self.X3(a_1), self.LKA5(a_2)*self.X5(a_2), self.LKA7(a_3)*self.X7(a_3)], dim=1) 304 | 305 | x = self.proj_last(x*a)*self.scale + shortcut 306 | 307 | return x 308 | 309 | 310 | # MAB 311 | class MAB(nn.Module): 312 | def __init__( 313 | self, n_feats): 314 | super().__init__() 315 | 316 | self.LKA = GroupGLKA(n_feats) 317 | 318 | self.LFE = SGAB(n_feats) 319 | 320 | def forward(self, x, pre_attn=None, RAA=None): 321 | #large kernel attention 322 | x = self.LKA(x) 323 | 324 | #local feature extraction 325 | x = self.LFE(x) 326 | 327 | return x 328 | 329 | class LKAT(nn.Module): 330 | def __init__(self, n_feats): 331 | super().__init__() 332 | 333 | #self.norm = LayerNorm(n_feats, data_format='channels_first') 334 | #self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 335 | 336 | self.conv0 = nn.Sequential( 337 | nn.Conv2d(n_feats, n_feats, 1, 1, 0), 338 | nn.GELU()) 339 | 340 | self.att = nn.Sequential( 341 | nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats), 342 | nn.Conv2d(n_feats, n_feats, 9, stride=1, padding=(9//2)*3, groups=n_feats, dilation=3), 343 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 344 | 345 | self.conv1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 346 | 347 | def forward(self, x): 348 | x = self.conv0(x) 349 | x = x*self.att(x) 350 | x = self.conv1(x) 351 | return x 352 | 353 | class ResGroup(nn.Module): 354 | def __init__(self, n_resblocks, n_feats, res_scale=1.0): 355 | super(ResGroup, self).__init__() 356 | self.body = nn.ModuleList([ 357 | MAB(n_feats) \ 358 | for _ in range(n_resblocks)]) 359 | 360 | self.body_t = LKAT(n_feats) 361 | 362 | def forward(self, x): 363 | res = x.clone() 364 | 365 | for i, block in enumerate(self.body): 366 | res = block(res) 367 | 368 | x = self.body_t(res) + x 369 | 370 | return x 371 | 372 | class MeanShift(nn.Conv2d): 373 | def __init__( 374 | self, rgb_range, 375 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 376 | 377 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 378 | std = torch.Tensor(rgb_std) 379 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 380 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 381 | for p in self.parameters(): 382 | p.requires_grad = False 383 | 384 | 385 | class MAN(nn.Module): 386 | def __init__(self, n_resblocks=36, n_resgroups=1, n_colors=3, n_feats=180, scale=2, res_scale = 1.0): 387 | super(MAN, self).__init__() 388 | 389 | #res_scale = res_scale 390 | self.n_resgroups = n_resgroups 391 | 392 | self.sub_mean = MeanShift(1.0) 393 | self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 394 | 395 | # define body module 396 | self.body = nn.ModuleList([ 397 | ResGroup( 398 | n_resblocks, n_feats, res_scale=res_scale) 399 | for i in range(n_resgroups)]) 400 | 401 | if self.n_resgroups > 1: 402 | self.body_t = nn.Conv2d(n_feats, n_feats, 3, 1, 1) 403 | 404 | # define tail module 405 | self.tail = nn.Sequential( 406 | nn.Conv2d(n_feats, n_colors*(scale**2), 3, 1, 1), 407 | nn.PixelShuffle(scale) 408 | ) 409 | self.add_mean = MeanShift(1.0, sign=1) 410 | 411 | def forward(self, x): 412 | x = self.sub_mean(x) 413 | x = self.head(x) 414 | res = x 415 | for i in self.body: 416 | res = i(res) 417 | if self.n_resgroups>1: 418 | res = self.body_t(res) + x 419 | x = self.tail(res) 420 | x = self.add_mean(x) 421 | return x 422 | 423 | def visual_feature(self, x): 424 | fea = [] 425 | x = self.head(x) 426 | res = x 427 | 428 | for i in self.body: 429 | temp = res 430 | res = i(res) 431 | fea.append(res) 432 | 433 | res = self.body_t(res) + x 434 | 435 | x = self.tail(res) 436 | return x, fea 437 | 438 | def load_state_dict(self, state_dict, strict=False): 439 | own_state = self.state_dict() 440 | for name, param in state_dict.items(): 441 | if name in own_state: 442 | if isinstance(param, nn.Parameter): 443 | param = param.data 444 | try: 445 | own_state[name].copy_(param) 446 | except Exception: 447 | if name.find('tail') >= 0: 448 | print('Replace pre-trained upsampler to new one...') 449 | else: 450 | raise RuntimeError('While copying the parameter named {}, ' 451 | 'whose dimensions in the model are {} and ' 452 | 'whose dimensions in the checkpoint are {}.' 453 | .format(name, own_state[name].size(), param.size())) 454 | elif strict: 455 | if name.find('tail') == -1: 456 | raise KeyError('unexpected key "{}" in state_dict' 457 | .format(name)) 458 | 459 | if strict: 460 | missing = set(own_state.keys()) - set(state_dict.keys()) 461 | if len(missing) > 0: 462 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) -------------------------------------------------------------------------------- /archs/MAN_arch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from basicsr.utils.registry import ARCH_REGISTRY 9 | 10 | #LKA from VAN (https://github.com/Visual-Attention-Network) 11 | class LKA(nn.Module): 12 | def __init__(self, dim): 13 | super().__init__() 14 | self.conv0 = nn.Conv2d(dim, dim, 7, padding=7//2, groups=dim) 15 | self.conv_spatial = nn.Conv2d(dim, dim, 9, stride=1, padding=((9//2)*4), groups=dim, dilation=4) 16 | self.conv1 = nn.Conv2d(dim, dim, 1) 17 | 18 | def forward(self, x): 19 | u = x.clone() 20 | attn = self.conv0(x) 21 | attn = self.conv_spatial(attn) 22 | attn = self.conv1(attn) 23 | 24 | return u * attn 25 | 26 | class Attention(nn.Module): 27 | def __init__(self, n_feats): 28 | super().__init__() 29 | 30 | self.norm = LayerNorm(n_feats, data_format='channels_first') 31 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 32 | 33 | self.proj_1 = nn.Conv2d(n_feats, n_feats, 1) 34 | self.spatial_gating_unit = LKA(n_feats) 35 | self.proj_2 = nn.Conv2d(n_feats, n_feats, 1) 36 | 37 | def forward(self, x): 38 | shorcut = x.clone() 39 | x = self.proj_1(self.norm(x)) 40 | x = self.spatial_gating_unit(x) 41 | x = self.proj_2(x) 42 | x = x*self.scale + shorcut 43 | return x 44 | #---------------------------------------------------------------------------------------------------------------- 45 | 46 | class MLP(nn.Module): 47 | def __init__(self, n_feats): 48 | super().__init__() 49 | 50 | self.norm = LayerNorm(n_feats, data_format='channels_first') 51 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 52 | 53 | i_feats = 2*n_feats 54 | 55 | self.fc1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 56 | self.act = nn.GELU() 57 | self.fc2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0) 58 | 59 | def forward(self, x): 60 | shortcut = x.clone() 61 | x = self.norm(x) 62 | x = self.fc1(x) 63 | x = self.act(x) 64 | x = self.fc2(x) 65 | 66 | return x*self.scale + shortcut 67 | 68 | 69 | class CFF(nn.Module): 70 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'): 71 | super().__init__() 72 | i_feats =n_feats*2 73 | 74 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 75 | self.DWConv1 = nn.Sequential( 76 | nn.Conv2d(i_feats, i_feats, 7, 1, 7//2, groups= n_feats), 77 | nn.GELU()) 78 | self.Conv2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0) 79 | 80 | self.norm = LayerNorm(n_feats, data_format='channels_first') 81 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 82 | 83 | def forward(self, x): 84 | shortcut = x.clone() 85 | 86 | #Ghost Expand 87 | x = self.Conv1(self.norm(x)) 88 | x = self.DWConv1(x) 89 | x = self.Conv2(x) 90 | 91 | return x*self.scale + shortcut 92 | 93 | 94 | class SimpleGate(nn.Module): 95 | def __init__(self, n_feats): 96 | super().__init__() 97 | i_feats =n_feats*2 98 | 99 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 100 | #self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats) 101 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 102 | 103 | self.norm = LayerNorm(n_feats, data_format='channels_first') 104 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 105 | 106 | def forward(self, x): 107 | shortcut = x.clone() 108 | 109 | #Ghost Expand 110 | x = self.Conv1(self.norm(x)) 111 | a, x = torch.chunk(x, 2, dim=1) 112 | x = x*a #self.DWConv1(a) 113 | x = self.Conv2(x) 114 | 115 | return x*self.scale + shortcut 116 | #----------------------------------------------------------------------------------------------------------------- 117 | #RCAN-style 118 | class RCBv6(nn.Module): 119 | def __init__( 120 | self, n_feats, k, lk=7, res_scale=1.0, style ='X', act = nn.SiLU(), deploy=False): 121 | super().__init__() 122 | self.LKA = nn.Sequential( 123 | nn.Conv2d(n_feats, n_feats, 5, 1, lk//2, groups= n_feats), 124 | nn.Conv2d(n_feats, n_feats, 7, stride=1, padding=9, groups=n_feats, dilation=3), 125 | nn.Conv2d(n_feats, n_feats, 1, 1, 0), 126 | nn.Sigmoid()) 127 | 128 | #self.LFE2 = LFEv3(n_feats, attn ='CA') 129 | 130 | self.LFE = nn.Sequential( 131 | nn.Conv2d(n_feats,n_feats,3,1,1), 132 | nn.GELU(), 133 | nn.Conv2d(n_feats,n_feats,3,1,1)) 134 | 135 | 136 | def forward(self, x, pre_attn=None, RAA=None): 137 | shortcut = x.clone() 138 | x = self.LFE(x) 139 | 140 | x = self.LKA(x)*x 141 | 142 | return x + shortcut 143 | 144 | #----------------------------------------------------------------------------------------------------------------- 145 | 146 | 147 | class MLKA_Ablation(nn.Module): 148 | def __init__(self, n_feats, k=2, squeeze_factor=15): 149 | super().__init__() 150 | i_feats = 2*n_feats 151 | 152 | self.n_feats= n_feats 153 | self.i_feats = i_feats 154 | 155 | self.norm = LayerNorm(n_feats, data_format='channels_first') 156 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 157 | 158 | k = 2 159 | 160 | #Multiscale Large Kernel Attention 161 | self.LKA7 = nn.Sequential( 162 | nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k), 163 | nn.Conv2d(n_feats//k, n_feats//k, 9, stride=1, padding=(9//2)*4, groups=n_feats//k, dilation=4), 164 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0)) 165 | self.LKA5 = nn.Sequential( 166 | nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k), 167 | nn.Conv2d(n_feats//k, n_feats//k, 7, stride=1, padding=(7//2)*3, groups=n_feats//k, dilation=3), 168 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0)) 169 | '''self.LKA3 = nn.Sequential( 170 | nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k), 171 | nn.Conv2d(n_feats//k, n_feats//k, 5, stride=1, padding=(5//2)*2, groups=n_feats//k, dilation=2), 172 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))''' 173 | 174 | #self.X3 = nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k) 175 | self.X5 = nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k) 176 | self.X7 = nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k) 177 | 178 | self.proj_first = nn.Sequential( 179 | nn.Conv2d(n_feats, i_feats, 1, 1, 0)) 180 | 181 | self.proj_last = nn.Sequential( 182 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 183 | 184 | 185 | def forward(self, x, pre_attn=None, RAA=None): 186 | shortcut = x.clone() 187 | 188 | x = self.norm(x) 189 | 190 | x = self.proj_first(x) 191 | 192 | a, x = torch.chunk(x, 2, dim=1) 193 | 194 | #u_1, u_2, u_3= torch.chunk(u, 3, dim=1) 195 | a_1, a_2 = torch.chunk(a, 2, dim=1) 196 | 197 | a = torch.cat([self.LKA7(a_1)*self.X7(a_1), self.LKA5(a_2)*self.X5(a_2)], dim=1) 198 | 199 | x = self.proj_last(x*a)*self.scale + shortcut 200 | 201 | return x 202 | #----------------------------------------------------------------------------------------------------------------- 203 | 204 | 205 | class LayerNorm(nn.Module): 206 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 207 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 208 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 209 | with shape (batch_size, channels, height, width). 210 | """ 211 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 212 | super().__init__() 213 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 214 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 215 | self.eps = eps 216 | self.data_format = data_format 217 | if self.data_format not in ["channels_last", "channels_first"]: 218 | raise NotImplementedError 219 | self.normalized_shape = (normalized_shape, ) 220 | 221 | def forward(self, x): 222 | if self.data_format == "channels_last": 223 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 224 | elif self.data_format == "channels_first": 225 | u = x.mean(1, keepdim=True) 226 | s = (x - u).pow(2).mean(1, keepdim=True) 227 | x = (x - u) / torch.sqrt(s + self.eps) 228 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 229 | return x 230 | 231 | 232 | 233 | class SGAB(nn.Module): 234 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'): 235 | super().__init__() 236 | i_feats =n_feats*2 237 | 238 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0) 239 | self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats) 240 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 241 | 242 | self.norm = LayerNorm(n_feats, data_format='channels_first') 243 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 244 | 245 | def forward(self, x): 246 | shortcut = x.clone() 247 | 248 | #Ghost Expand 249 | x = self.Conv1(self.norm(x)) 250 | a, x = torch.chunk(x, 2, dim=1) 251 | x = x*self.DWConv1(a) 252 | x = self.Conv2(x) 253 | 254 | return x*self.scale + shortcut 255 | 256 | class GroupGLKA(nn.Module): 257 | def __init__(self, n_feats, k=2, squeeze_factor=15): 258 | super().__init__() 259 | i_feats = 2*n_feats 260 | 261 | self.n_feats= n_feats 262 | self.i_feats = i_feats 263 | 264 | self.norm = LayerNorm(n_feats, data_format='channels_first') 265 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 266 | 267 | #Multiscale Large Kernel Attention 268 | self.LKA7 = nn.Sequential( 269 | nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3), 270 | nn.Conv2d(n_feats//3, n_feats//3, 9, stride=1, padding=(9//2)*4, groups=n_feats//3, dilation=4), 271 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0)) 272 | self.LKA5 = nn.Sequential( 273 | nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3), 274 | nn.Conv2d(n_feats//3, n_feats//3, 7, stride=1, padding=(7//2)*3, groups=n_feats//3, dilation=3), 275 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0)) 276 | self.LKA3 = nn.Sequential( 277 | nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3), 278 | nn.Conv2d(n_feats//3, n_feats//3, 5, stride=1, padding=(5//2)*2, groups=n_feats//3, dilation=2), 279 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0)) 280 | 281 | self.X3 = nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3) 282 | self.X5 = nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3) 283 | self.X7 = nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3) 284 | 285 | self.proj_first = nn.Sequential( 286 | nn.Conv2d(n_feats, i_feats, 1, 1, 0)) 287 | 288 | self.proj_last = nn.Sequential( 289 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 290 | 291 | 292 | def forward(self, x, pre_attn=None, RAA=None): 293 | shortcut = x.clone() 294 | 295 | x = self.norm(x) 296 | 297 | x = self.proj_first(x) 298 | 299 | a, x = torch.chunk(x, 2, dim=1) 300 | 301 | a_1, a_2, a_3= torch.chunk(a, 3, dim=1) 302 | 303 | a = torch.cat([self.LKA3(a_1)*self.X3(a_1), self.LKA5(a_2)*self.X5(a_2), self.LKA7(a_3)*self.X7(a_3)], dim=1) 304 | 305 | x = self.proj_last(x*a)*self.scale + shortcut 306 | 307 | return x 308 | 309 | 310 | # MAB 311 | class MAB(nn.Module): 312 | def __init__( 313 | self, n_feats): 314 | super().__init__() 315 | 316 | self.LKA = GroupGLKA(n_feats) 317 | 318 | self.LFE = SGAB(n_feats) 319 | 320 | def forward(self, x, pre_attn=None, RAA=None): 321 | #large kernel attention 322 | x = self.LKA(x) 323 | 324 | #local feature extraction 325 | x = self.LFE(x) 326 | 327 | return x 328 | 329 | class LKAT(nn.Module): 330 | def __init__(self, n_feats): 331 | super().__init__() 332 | 333 | #self.norm = LayerNorm(n_feats, data_format='channels_first') 334 | #self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 335 | 336 | self.conv0 = nn.Sequential( 337 | nn.Conv2d(n_feats, n_feats, 1, 1, 0), 338 | nn.GELU()) 339 | 340 | self.att = nn.Sequential( 341 | nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats), 342 | nn.Conv2d(n_feats, n_feats, 9, stride=1, padding=(9//2)*3, groups=n_feats, dilation=3), 343 | nn.Conv2d(n_feats, n_feats, 1, 1, 0)) 344 | 345 | self.conv1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 346 | 347 | def forward(self, x): 348 | x = self.conv0(x) 349 | x = x*self.att(x) 350 | x = self.conv1(x) 351 | return x 352 | 353 | class ResGroup(nn.Module): 354 | def __init__(self, n_resblocks, n_feats, res_scale=1.0): 355 | super(ResGroup, self).__init__() 356 | self.body = nn.ModuleList([ 357 | MAB(n_feats) \ 358 | for _ in range(n_resblocks)]) 359 | 360 | self.body_t = LKAT(n_feats) 361 | 362 | def forward(self, x): 363 | res = x.clone() 364 | 365 | for i, block in enumerate(self.body): 366 | res = block(res) 367 | 368 | x = self.body_t(res) + x 369 | 370 | return x 371 | 372 | class MeanShift(nn.Conv2d): 373 | def __init__( 374 | self, rgb_range, 375 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 376 | 377 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 378 | std = torch.Tensor(rgb_std) 379 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 380 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 381 | for p in self.parameters(): 382 | p.requires_grad = False 383 | 384 | 385 | @ARCH_REGISTRY.register() 386 | class MAN(nn.Module): 387 | def __init__(self, n_resblocks=36, n_resgroups=1, n_colors=3, n_feats=180, scale=2, res_scale = 1.0): 388 | super(MAN, self).__init__() 389 | 390 | #res_scale = res_scale 391 | self.n_resgroups = n_resgroups 392 | 393 | self.sub_mean = MeanShift(1.0) 394 | self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1) 395 | 396 | # define body module 397 | self.body = nn.ModuleList([ 398 | ResGroup( 399 | n_resblocks, n_feats, res_scale=res_scale) 400 | for i in range(n_resgroups)]) 401 | 402 | if self.n_resgroups > 1: 403 | self.body_t = nn.Conv2d(n_feats, n_feats, 3, 1, 1) 404 | 405 | # define tail module 406 | self.tail = nn.Sequential( 407 | nn.Conv2d(n_feats, n_colors*(scale**2), 3, 1, 1), 408 | nn.PixelShuffle(scale) 409 | ) 410 | self.add_mean = MeanShift(1.0, sign=1) 411 | 412 | def forward(self, x): 413 | x = self.sub_mean(x) 414 | x = self.head(x) 415 | res = x 416 | for i in self.body: 417 | res = i(res) 418 | if self.n_resgroups>1: 419 | res = self.body_t(res) + x 420 | x = self.tail(res) 421 | x = self.add_mean(x) 422 | return x 423 | 424 | def visual_feature(self, x): 425 | fea = [] 426 | x = self.head(x) 427 | res = x 428 | 429 | for i in self.body: 430 | temp = res 431 | res = i(res) 432 | fea.append(res) 433 | 434 | res = self.body_t(res) + x 435 | 436 | x = self.tail(res) 437 | return x, fea 438 | 439 | def load_state_dict(self, state_dict, strict=False): 440 | own_state = self.state_dict() 441 | for name, param in state_dict.items(): 442 | if name in own_state: 443 | if isinstance(param, nn.Parameter): 444 | param = param.data 445 | try: 446 | own_state[name].copy_(param) 447 | except Exception: 448 | if name.find('tail') >= 0: 449 | print('Replace pre-trained upsampler to new one...') 450 | else: 451 | raise RuntimeError('While copying the parameter named {}, ' 452 | 'whose dimensions in the model are {} and ' 453 | 'whose dimensions in the checkpoint are {}.' 454 | .format(name, own_state[name].size(), param.size())) 455 | elif strict: 456 | if name.find('tail') == -1: 457 | raise KeyError('unexpected key "{}" in state_dict' 458 | .format(name)) 459 | 460 | if strict: 461 | missing = set(own_state.keys()) - set(state_dict.keys()) 462 | if len(missing) > 0: 463 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 464 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/san.py: -------------------------------------------------------------------------------- 1 | from ..NN import common 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from ..NN.MPNCOV.python import MPNCOV 6 | 7 | 8 | # 9 | def make_model(args, parent=False): 10 | return SAN(args) 11 | 12 | 13 | ## non_local module 14 | class _NonLocalBlockND(nn.Module): 15 | def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', 16 | sub_sample=True, bn_layer=True): 17 | super(_NonLocalBlockND, self).__init__() 18 | assert dimension in [1, 2, 3] 19 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation'] 20 | 21 | # print('Dimension: %d, mode: %s' % (dimension, mode)) 22 | 23 | self.mode = mode 24 | self.dimension = dimension 25 | self.sub_sample = sub_sample 26 | 27 | self.in_channels = in_channels 28 | self.inter_channels = inter_channels 29 | 30 | if self.inter_channels is None: 31 | self.inter_channels = in_channels // 2 32 | if self.inter_channels == 0: 33 | self.inter_channels = 1 34 | 35 | if dimension == 3: 36 | conv_nd = nn.Conv3d 37 | max_pool = nn.MaxPool3d 38 | bn = nn.BatchNorm3d 39 | elif dimension == 2: 40 | conv_nd = nn.Conv2d 41 | max_pool = nn.MaxPool2d 42 | sub_sample = nn.Upsample 43 | bn = nn.BatchNorm2d 44 | else: 45 | conv_nd = nn.Conv1d 46 | max_pool = nn.MaxPool1d 47 | bn = nn.BatchNorm1d 48 | 49 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 50 | kernel_size=1, stride=1, padding=0) 51 | 52 | if bn_layer: 53 | self.W = nn.Sequential( 54 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 55 | kernel_size=1, stride=1, padding=0), 56 | bn(self.in_channels) 57 | ) 58 | nn.init.constant_(self.W[1].weight, 0) 59 | nn.init.constant_(self.W[1].bias, 0) 60 | else: 61 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 62 | kernel_size=1, stride=1, padding=0) 63 | nn.init.constant_(self.W.weight, 0) 64 | nn.init.constant_(self.W.bias, 0) 65 | 66 | self.theta = None 67 | self.phi = None 68 | self.concat_project = None 69 | # self.fc = nn.Linear(64,2304,bias=True) 70 | # self.sub_bilinear = nn.Upsample(size=(48,48),mode='bilinear') 71 | # self.sub_maxpool = nn.AdaptiveMaxPool2d(output_size=(48,48)) 72 | if mode in ['embedded_gaussian', 'dot_product', 'concatenation']: 73 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 74 | kernel_size=1, stride=1, padding=0) 75 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 76 | kernel_size=1, stride=1, padding=0) 77 | 78 | if mode == 'embedded_gaussian': 79 | self.operation_function = self._embedded_gaussian 80 | elif mode == 'dot_product': 81 | self.operation_function = self._dot_product 82 | elif mode == 'concatenation': 83 | self.operation_function = self._concatenation 84 | self.concat_project = nn.Sequential( 85 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False), 86 | nn.ReLU() 87 | ) 88 | elif mode == 'gaussian': 89 | self.operation_function = self._gaussian 90 | 91 | if sub_sample: 92 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 93 | if self.phi is None: 94 | self.phi = max_pool(kernel_size=2) 95 | else: 96 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 97 | 98 | def forward(self, x): 99 | ''' 100 | :param x: (b, c, t, h, w) 101 | :return: 102 | ''' 103 | 104 | output = self.operation_function(x) 105 | return output 106 | 107 | def _embedded_gaussian(self, x): 108 | batch_size, C, H, W = x.shape 109 | 110 | # x_sub = self.sub_bilinear(x) # bilinear downsample 111 | # x_sub = self.sub_maxpool(x) # maxpool downsample 112 | 113 | ## 114 | # g_x = x.view(batch_size, self.inter_channels, -1) 115 | # g_x = g_x.permute(0, 2, 1) 116 | # 117 | # # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) 118 | # # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 119 | # # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) 120 | # theta_x = x.view(batch_size, self.inter_channels, -1) 121 | # theta_x = theta_x.permute(0, 2, 1) 122 | # fc = self.fc(theta_x) 123 | # # phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 124 | # # f = torch.matmul(theta_x, phi_x) 125 | # # return f 126 | # # f_div_C = F.softmax(fc, dim=-1) 127 | # return fc 128 | 129 | ## 130 | # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) 131 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 132 | g_x = g_x.permute(0, 2, 1) 133 | 134 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) 135 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) 136 | # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) 137 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 138 | theta_x = theta_x.permute(0, 2, 1) 139 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 140 | f = torch.matmul(theta_x, phi_x) 141 | # return f 142 | f_div_C = F.softmax(f, dim=-1) 143 | # return f_div_C 144 | # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) 145 | y = torch.matmul(f_div_C, g_x) 146 | y = y.permute(0, 2, 1).contiguous() 147 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 148 | W_y = self.W(y) 149 | z = W_y + x 150 | 151 | return z 152 | 153 | def _gaussian(self, x): 154 | batch_size = x.size(0) 155 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 156 | g_x = g_x.permute(0, 2, 1) 157 | 158 | theta_x = x.view(batch_size, self.in_channels, -1) 159 | theta_x = theta_x.permute(0, 2, 1) 160 | 161 | if self.sub_sample: 162 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1) 163 | else: 164 | phi_x = x.view(batch_size, self.in_channels, -1) 165 | 166 | f = torch.matmul(theta_x, phi_x) 167 | f_div_C = F.softmax(f, dim=-1) 168 | 169 | y = torch.matmul(f_div_C, g_x) 170 | y = y.permute(0, 2, 1).contiguous() 171 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 172 | W_y = self.W(y) 173 | z = W_y + x 174 | 175 | return z 176 | 177 | def _dot_product(self, x): 178 | batch_size = x.size(0) 179 | 180 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 181 | g_x = g_x.permute(0, 2, 1) 182 | 183 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 184 | theta_x = theta_x.permute(0, 2, 1) 185 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 186 | f = torch.matmul(theta_x, phi_x) 187 | N = f.size(-1) 188 | f_div_C = f / N 189 | 190 | y = torch.matmul(f_div_C, g_x) 191 | y = y.permute(0, 2, 1).contiguous() 192 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 193 | W_y = self.W(y) 194 | z = W_y + x 195 | 196 | return z 197 | 198 | def _concatenation(self, x): 199 | batch_size = x.size(0) 200 | 201 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 202 | g_x = g_x.permute(0, 2, 1) 203 | 204 | # (b, c, N, 1) 205 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1) 206 | # (b, c, 1, N) 207 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1) 208 | 209 | h = theta_x.size(2) 210 | w = phi_x.size(3) 211 | theta_x = theta_x.repeat(1, 1, 1, w) 212 | phi_x = phi_x.repeat(1, 1, h, 1) 213 | 214 | concat_feature = torch.cat([theta_x, phi_x], dim=1) 215 | f = self.concat_project(concat_feature) 216 | b, _, h, w = f.size() 217 | f = f.view(b, h, w) 218 | 219 | N = f.size(-1) 220 | f_div_C = f / N 221 | 222 | y = torch.matmul(f_div_C, g_x) 223 | y = y.permute(0, 2, 1).contiguous() 224 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 225 | W_y = self.W(y) 226 | z = W_y + x 227 | 228 | return z 229 | 230 | 231 | class NONLocalBlock1D(_NonLocalBlockND): 232 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 233 | super(NONLocalBlock1D, self).__init__(in_channels, 234 | inter_channels=inter_channels, 235 | dimension=1, mode=mode, 236 | sub_sample=sub_sample, 237 | bn_layer=bn_layer) 238 | 239 | 240 | class NONLocalBlock2D(_NonLocalBlockND): 241 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True): 242 | super(NONLocalBlock2D, self).__init__(in_channels, 243 | inter_channels=inter_channels, 244 | dimension=2, mode=mode, 245 | sub_sample=sub_sample, 246 | bn_layer=bn_layer) 247 | 248 | 249 | ## Channel Attention (CA) Layer 250 | class CALayer(nn.Module): 251 | def __init__(self, channel, reduction=8): 252 | super(CALayer, self).__init__() 253 | # global average pooling: feature --> point 254 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 255 | self.max_pool = nn.AdaptiveMaxPool2d(1) 256 | # feature channel downscale and upscale --> channel weight 257 | self.conv_du = nn.Sequential( 258 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 259 | nn.ReLU(inplace=True), 260 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 261 | # nn.Sigmoid() 262 | # nn.BatchNorm2d(channel) 263 | ) 264 | 265 | def forward(self, x): 266 | _, _, h, w = x.shape 267 | y_ave = self.avg_pool(x) 268 | # y_max = self.max_pool(x) 269 | y_ave = self.conv_du(y_ave) 270 | # y_max = self.conv_du(y_max) 271 | # y = y_ave + y_max 272 | # expand y to C*H*W 273 | # expand_y = y.expand(-1,-1,h,w) 274 | return y_ave 275 | 276 | 277 | ## second-order Channel attention (SOCA) 278 | class SOCA(nn.Module): 279 | def __init__(self, channel, reduction=8): 280 | super(SOCA, self).__init__() 281 | # global average pooling: feature --> point 282 | # self.avg_pool = nn.AdaptiveAvgPool2d(1) 283 | # self.max_pool = nn.AdaptiveMaxPool2d(1) 284 | self.max_pool = nn.MaxPool2d(kernel_size=2) 285 | 286 | # feature channel downscale and upscale --> channel weight 287 | self.conv_du = nn.Sequential( 288 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 289 | nn.ReLU(inplace=True), 290 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 291 | nn.Sigmoid() 292 | # nn.BatchNorm2d(channel) 293 | ) 294 | 295 | def forward(self, x): 296 | batch_size, C, h, w = x.shape # x: NxCxHxW 297 | N = int(h * w) 298 | min_h = min(h, w) 299 | h1 = 1000 300 | w1 = 1000 301 | if h < h1 and w < w1: 302 | x_sub = x 303 | elif h < h1 and w > w1: 304 | # H = (h - h1) // 2 305 | W = (w - w1) // 2 306 | x_sub = x[:, :, :, W:(W + w1)] 307 | elif w < w1 and h > h1: 308 | H = (h - h1) // 2 309 | # W = (w - w1) // 2 310 | x_sub = x[:, :, H:H + h1, :] 311 | else: 312 | H = (h - h1) // 2 313 | W = (w - w1) // 2 314 | x_sub = x[:, :, H:(H + h1), W:(W + w1)] 315 | # subsample 316 | # subsample_scale = 2 317 | # subsample = nn.Upsample(size=(h // subsample_scale, w // subsample_scale), mode='nearest') 318 | # x_sub = subsample(x) 319 | # max_pool = nn.MaxPool2d(kernel_size=2) 320 | # max_pool = nn.AvgPool2d(kernel_size=2) 321 | # x_sub = self.max_pool(x) 322 | ## 323 | ## MPN-COV 324 | cov_mat = MPNCOV.CovpoolLayer(x_sub) # Global Covariance pooling layer 325 | cov_mat_sqrt = MPNCOV.SqrtmLayer(cov_mat, 326 | 5) # Matrix square root layer( including pre-norm,Newton-Schulz iter. and post-com. with 5 iteration) 327 | ## 328 | cov_mat_sum = torch.mean(cov_mat_sqrt, 1) 329 | cov_mat_sum = cov_mat_sum.view(batch_size, C, 1, 1) 330 | # y_ave = self.avg_pool(x) 331 | # y_max = self.max_pool(x) 332 | y_cov = self.conv_du(cov_mat_sum) 333 | # y_max = self.conv_du(y_max) 334 | # y = y_ave + y_max 335 | # expand y to C*H*W 336 | # expand_y = y.expand(-1,-1,h,w) 337 | return y_cov * x 338 | 339 | 340 | ## self-attention+ channel attention module 341 | class Nonlocal_CA(nn.Module): 342 | def __init__(self, in_feat=64, inter_feat=32, reduction=8, sub_sample=False, bn_layer=True): 343 | super(Nonlocal_CA, self).__init__() 344 | # second-order channel attention 345 | self.soca = SOCA(in_feat, reduction=reduction) 346 | # nonlocal module 347 | self.non_local = _NonLocalBlockND(in_channels=in_feat, inter_channels=inter_feat, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer) 348 | 349 | self.sigmoid = nn.Sigmoid() 350 | 351 | def forward(self, x): 352 | ## divide feature map into 4 part 353 | batch_size, C, H, W = x.shape 354 | H1 = int(H / 2) 355 | W1 = int(W / 2) 356 | nonlocal_feat = torch.zeros_like(x) 357 | 358 | feat_sub_lu = x[:, :, :H1, :W1] 359 | feat_sub_ld = x[:, :, H1:, :W1] 360 | feat_sub_ru = x[:, :, :H1, W1:] 361 | feat_sub_rd = x[:, :, H1:, W1:] 362 | 363 | nonlocal_lu = self.non_local(feat_sub_lu) 364 | nonlocal_ld = self.non_local(feat_sub_ld) 365 | nonlocal_ru = self.non_local(feat_sub_ru) 366 | nonlocal_rd = self.non_local(feat_sub_rd) 367 | nonlocal_feat[:, :, :H1, :W1] = nonlocal_lu 368 | nonlocal_feat[:, :, H1:, :W1] = nonlocal_ld 369 | nonlocal_feat[:, :, :H1, W1:] = nonlocal_ru 370 | nonlocal_feat[:, :, H1:, W1:] = nonlocal_rd 371 | 372 | return nonlocal_feat 373 | 374 | 375 | ## Residual Block (RB) 376 | class RB(nn.Module): 377 | def __init__(self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(inplace=True), 378 | res_scale=1, dilation=2): 379 | super(RB, self).__init__() 380 | modules_body = [] 381 | 382 | # self.gamma1 = nn.Parameter(torch.ones(1)) 383 | self.gamma1 = 1.0 384 | # self.salayer = SALayer(n_feat, reduction=reduction, dilation=dilation) 385 | # self.salayer = SALayer2(n_feat, reduction=reduction, dilation=dilation) 386 | 387 | self.conv_first = nn.Sequential(conv(n_feat, n_feat, kernel_size, bias=bias), 388 | act, 389 | conv(n_feat, n_feat, kernel_size, bias=bias) 390 | ) 391 | 392 | self.res_scale = res_scale 393 | 394 | def forward(self, x): 395 | y = self.conv_first(x) 396 | y = y + x 397 | 398 | return y 399 | 400 | 401 | ## Local-source Residual Attention Group (LSRARG) 402 | class LSRAG(nn.Module): 403 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 404 | super(LSRAG, self).__init__() 405 | ## 406 | self.rcab = nn.ModuleList([RB(conv, n_feat, kernel_size, reduction, \ 407 | bias=True, bn=False, act=nn.ReLU(inplace=True), res_scale=1) for _ in 408 | range(n_resblocks)]) 409 | self.soca = (SOCA(n_feat, reduction=reduction)) 410 | self.conv_last = (conv(n_feat, n_feat, kernel_size)) 411 | self.n_resblocks = n_resblocks 412 | ## 413 | # modules_body = [] 414 | self.gamma = nn.Parameter(torch.zeros(1)) 415 | # self.gamma = 0.2 416 | # for i in range(n_resblocks): 417 | # modules_body.append(RCAB(conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(inplace=True), res_scale=1)) 418 | # modules_body.append(SOCA(n_feat,reduction=reduction)) 419 | # # modules_body.append(Nonlocal_CA(in_feat=n_feat, inter_feat=n_feat//8, reduction =reduction, sub_sample=False, bn_layer=False)) 420 | # modules_body.append(conv(n_feat, n_feat, kernel_size)) 421 | # self.body = nn.Sequential(*modules_body) 422 | ## 423 | 424 | def make_layer(self, block, num_of_layer): 425 | layers = [] 426 | for _ in range(num_of_layer): 427 | layers.append(block) 428 | return nn.ModuleList(layers) 429 | # return nn.Sequential(*layers) 430 | 431 | def forward(self, x): 432 | residual = x 433 | # batch_size,C,H,W = x.shape 434 | # y_pre = self.body(x) 435 | # y_pre = y_pre + x 436 | # return y_pre 437 | 438 | ## share-source skip connection 439 | 440 | for i, l in enumerate(self.rcab): 441 | # x = l(x) + self.gamma*residual 442 | x = l(x) 443 | x = self.soca(x) 444 | x = self.conv_last(x) 445 | 446 | x = x + residual 447 | 448 | return x 449 | ## 450 | 451 | 452 | ## Second-order Channel Attention Network (SAN) 453 | class SAN(nn.Module): 454 | def __init__(self, factor=4, num_channels=3, conv=common.default_conv): 455 | super(SAN, self).__init__() 456 | n_resgroups = 20 457 | n_resblocks = 10 458 | n_feats = 64 459 | kernel_size = 3 460 | reduction = 16 461 | scale = factor 462 | act = nn.ReLU(inplace=True) 463 | 464 | # RGB mean for DIV2K 465 | rgb_mean = (0.4488, 0.4371, 0.4040) 466 | rgb_std = (1.0, 1.0, 1.0) 467 | self.sub_mean = common.MeanShift(1.0, rgb_mean, rgb_std) 468 | # self.soca= SOCA(n_feats, reduction=reduction) 469 | 470 | # define head module 471 | modules_head = [conv(num_channels, n_feats, kernel_size)] 472 | 473 | # define body module 474 | ## share-source skip connection 475 | 476 | ## 477 | self.gamma = nn.Parameter(torch.zeros(1)) 478 | # self.gamma = 0.2 479 | self.n_resgroups = n_resgroups 480 | self.RG = nn.ModuleList([LSRAG(conv, n_feats, kernel_size, reduction, \ 481 | act=act, res_scale=1.0, n_resblocks=n_resblocks) for _ in 482 | range(n_resgroups)]) 483 | self.conv_last = conv(n_feats, n_feats, kernel_size) 484 | 485 | # modules_body = [ 486 | # ResidualGroup( 487 | # conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 488 | # for _ in range(n_resgroups)] 489 | # modules_body.append(conv(n_feats, n_feats, kernel_size)) 490 | 491 | # define tail module 492 | modules_tail = [ 493 | common.Upsampler(conv, scale, n_feats, act=False), 494 | conv(n_feats, num_channels, kernel_size)] 495 | 496 | self.add_mean = common.MeanShift(1.0, rgb_mean, rgb_std, 1) 497 | self.non_local = Nonlocal_CA(in_feat=n_feats, inter_feat=n_feats // 8, reduction=8, sub_sample=False, 498 | bn_layer=False) 499 | 500 | self.head = nn.Sequential(*modules_head) 501 | # self.body = nn.Sequential(*modules_body) 502 | self.tail = nn.Sequential(*modules_tail) 503 | 504 | def make_layer(self, block, num_of_layer): 505 | layers = [] 506 | for _ in range(num_of_layer): 507 | layers.append(block) 508 | 509 | return nn.ModuleList(layers) 510 | # return nn.Sequential(*layers) 511 | 512 | def forward(self, x): 513 | x = self.sub_mean(x * 255.) 514 | x = self.head(x) 515 | 516 | ## add nonlocal 517 | xx = self.non_local(x) 518 | 519 | # share-source skip connection 520 | residual = xx 521 | 522 | # res = self.RG(xx) 523 | # res = res + xx 524 | ## share-source residual gruop 525 | for i, l in enumerate(self.RG): 526 | xx = l(xx) + self.gamma * residual 527 | # xx = self.gamma*xx + residual 528 | # body part 529 | # res = self.body(xx) 530 | ## 531 | ## add nonlocal 532 | res = self.non_local(xx) 533 | ## 534 | # res = self.soca(res) 535 | # res += x 536 | res = res + x 537 | 538 | x = self.tail(res) 539 | x = self.add_mean(x) 540 | 541 | return x / 255. 542 | 543 | def load_state_dict(self, state_dict, strict=False): 544 | own_state = self.state_dict() 545 | for name, param in state_dict.items(): 546 | if name in own_state: 547 | if isinstance(param, nn.Parameter): 548 | param = param.data 549 | try: 550 | own_state[name].copy_(param) 551 | except Exception: 552 | if name.find('tail') >= 0: 553 | print('Replace pre-trained upsampler to new one...') 554 | else: 555 | raise RuntimeError('While copying the parameter named {}, ' 556 | 'whose dimensions in the model are {} and ' 557 | 'whose dimensions in the checkpoint are {}.' 558 | .format(name, own_state[name].size(), param.size())) 559 | elif strict: 560 | if name.find('tail') == -1: 561 | raise KeyError('unexpected key "{}" in state_dict' 562 | .format(name)) 563 | 564 | if strict: 565 | missing = set(own_state.keys()) - set(state_dict.keys()) 566 | if len(missing) > 0: 567 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) --------------------------------------------------------------------------------