├── LAM ├── SaliencyModel │ ├── __init__.py │ ├── attributes.py │ ├── BackProp.py │ └── utils.py ├── ModelZoo │ ├── NN │ │ ├── MPNCOV │ │ │ ├── __init__.py │ │ │ └── python │ │ │ │ ├── __init__.py │ │ │ │ └── MPNCOV.py │ │ ├── fsrcnn.py │ │ ├── edsr.py │ │ ├── rnan.py │ │ ├── rcan.py │ │ ├── rrdbnet.py │ │ ├── __init__.py │ │ ├── drln_ops.py │ │ ├── drln.py │ │ ├── common.py │ │ ├── man.py │ │ └── san.py │ ├── CARN │ │ ├── __init__.py │ │ ├── carn_m.py │ │ ├── carn.py │ │ └── ops.py │ ├── __init__.py │ └── utils.py ├── README.md ├── test_images │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ ├── 9.png │ ├── a.png │ ├── b.png │ ├── c.png │ ├── d.png │ └── e.png ├── test_MAN_diff.py ├── test_MAN.py └── cal_metrix.py ├── images ├── Visual_Results │ ├── D0850 │ │ ├── temp │ │ ├── HR.png │ │ ├── MAN.png │ │ ├── EDSR.png │ │ ├── MAN-Tiny.png │ │ ├── EDSR-Base.png │ │ └── MAN-Light.png │ ├── U004 │ │ ├── HR.png │ │ ├── EDSR.png │ │ ├── MAN.png │ │ ├── EDSR-Base.png │ │ ├── MAN-Light.png │ │ └── MAN-Tiny.png │ ├── U012 │ │ ├── HR.png │ │ ├── EDSR.png │ │ ├── MAN.png │ │ ├── EDSR-Base.png │ │ ├── MAN-Light.png │ │ └── MAN-Tiny.png │ └── U044 │ │ ├── HR.png │ │ ├── EDSR.png │ │ ├── MAN.png │ │ ├── EDSR-Base.png │ │ ├── MAN-Light.png │ │ └── MAN-Tiny.png ├── MAN_arch.png ├── MAN_details.png └── man_ntire24.pdf ├── test.py ├── train.py ├── archs ├── __init__.py └── MAN_arch.py ├── options ├── test_MAN.yml └── trian_MAN.yml ├── README.md └── LICENSE /LAM/SaliencyModel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/MPNCOV/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/Visual_Results/D0850/temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/MPNCOV/python/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LAM/README.md: -------------------------------------------------------------------------------- 1 | # LAM 2 | Run `test_MAN.py` and `test_MAN_diff.py`. 3 | -------------------------------------------------------------------------------- /images/MAN_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/MAN_arch.png -------------------------------------------------------------------------------- /LAM/test_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/1.png -------------------------------------------------------------------------------- /LAM/test_images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/2.png -------------------------------------------------------------------------------- /LAM/test_images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/3.png -------------------------------------------------------------------------------- /LAM/test_images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/4.png -------------------------------------------------------------------------------- /LAM/test_images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/5.png -------------------------------------------------------------------------------- /LAM/test_images/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/6.png -------------------------------------------------------------------------------- /LAM/test_images/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/7.png -------------------------------------------------------------------------------- /LAM/test_images/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/8.png -------------------------------------------------------------------------------- /LAM/test_images/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/9.png -------------------------------------------------------------------------------- /LAM/test_images/a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/a.png -------------------------------------------------------------------------------- /LAM/test_images/b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/b.png -------------------------------------------------------------------------------- /LAM/test_images/c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/c.png -------------------------------------------------------------------------------- /LAM/test_images/d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/d.png -------------------------------------------------------------------------------- /LAM/test_images/e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/LAM/test_images/e.png -------------------------------------------------------------------------------- /images/MAN_details.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/MAN_details.png -------------------------------------------------------------------------------- /images/man_ntire24.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/man_ntire24.pdf -------------------------------------------------------------------------------- /images/Visual_Results/U004/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/HR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/HR.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/MAN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/MAN.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/EDSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/EDSR.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/MAN-Light.png -------------------------------------------------------------------------------- /images/Visual_Results/U004/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U004/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/MAN-Light.png -------------------------------------------------------------------------------- /images/Visual_Results/U012/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U012/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/MAN-Light.png -------------------------------------------------------------------------------- /images/Visual_Results/U044/MAN-Tiny.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/U044/MAN-Tiny.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/EDSR-Base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/EDSR-Base.png -------------------------------------------------------------------------------- /images/Visual_Results/D0850/MAN-Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/icandle/MAN/HEAD/images/Visual_Results/D0850/MAN-Light.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os.path as osp 3 | 4 | import archs 5 | 6 | from basicsr.test import test_pipeline 7 | 8 | if __name__ == '__main__': 9 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 10 | test_pipeline(root_path) 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os.path as osp 3 | 4 | import archs 5 | from basicsr.train import train_pipeline 6 | 7 | if __name__ == '__main__': 8 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 9 | train_pipeline(root_path) 10 | -------------------------------------------------------------------------------- /archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules for registry 7 | # scan all the files that end with '_arch.py' under the archs folder 8 | arch_folder = osp.dirname(osp.abspath(__file__)) 9 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 10 | # import all the arch modules 11 | _arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] 12 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | from .carn import Net as CARN 6 | from .carn_m import Net as CARNM 7 | from ModelZoo import MODEL_DIR 8 | 9 | 10 | def load_carn(): 11 | carn_model = CARN(2) 12 | state_dict = torch.load(os.path.join(MODEL_DIR, 'carn.pth'), map_location=torch.device('cpu')) 13 | new_state_dict = OrderedDict() 14 | for k, v in state_dict.items(): 15 | name = k 16 | # name = k[7:] # remove "module." 17 | new_state_dict[name] = v 18 | 19 | carn_model.load_state_dict(new_state_dict) 20 | return carn_model 21 | 22 | def load_carnm(): 23 | carn_model = CARNM(2) 24 | state_dict = torch.load(os.path.join(MODEL_DIR, 'carn_m.pth'), map_location=torch.device('cpu')) 25 | new_state_dict = OrderedDict() 26 | for k, v in state_dict.items(): 27 | name = k 28 | # name = k[7:] # remove "module." 29 | new_state_dict[name] = v 30 | 31 | carn_model.load_state_dict(new_state_dict) 32 | return carn_model -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/fsrcnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | 4 | 5 | class FSRCNN(nn.Module): 6 | def __init__(self, scale_factor, num_channels=3, d=56, s=12, m=4): 7 | super(FSRCNN, self).__init__() 8 | self.first_part = nn.Sequential( 9 | nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2), 10 | nn.PReLU(d) 11 | ) 12 | self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)] 13 | for _ in range(m): 14 | self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)]) 15 | self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)]) 16 | self.mid_part = nn.Sequential(*self.mid_part) 17 | self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2, 18 | output_padding=scale_factor-1) 19 | 20 | self._initialize_weights() 21 | 22 | def _initialize_weights(self): 23 | for m in self.first_part: 24 | if isinstance(m, nn.Conv2d): 25 | nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) 26 | nn.init.zeros_(m.bias.data) 27 | for m in self.mid_part: 28 | if isinstance(m, nn.Conv2d): 29 | nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) 30 | nn.init.zeros_(m.bias.data) 31 | nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001) 32 | nn.init.zeros_(self.last_part.bias.data) 33 | 34 | def forward(self, x): 35 | x = self.first_part(x) 36 | x = self.mid_part(x) 37 | x = self.last_part(x) 38 | return x -------------------------------------------------------------------------------- /options/test_MAN.yml: -------------------------------------------------------------------------------- 1 | name: MANx2 2 | model_type: SRModel 3 | scale: 2 4 | num_gpu: 1 # set num_gpu: 0 for cpu mode 5 | manual_seed: 0 6 | 7 | datasets: 8 | test_1: # the 1st test dataset 9 | name: Set5 10 | type: PairedImageDataset 11 | dataroot_gt: ./datasets/Set5/GTmod2 12 | dataroot_lq: ./datasets/Set5/LRbicx2 13 | io_backend: 14 | type: disk 15 | 16 | # test_2: # the 2nd test dataset 17 | # name: Set14 18 | # type: PairedImageDataset 19 | # dataroot_gt: ./datasets/Set14/GTmod2 20 | # dataroot_lq: ./datasets/Set14/LRbicx2 21 | # io_backend: 22 | # type: disk 23 | 24 | # test_3: 25 | # name: Urban100 26 | # type: PairedImageDataset 27 | # dataroot_gt: ./datasets/urban100/GTmod2 28 | # dataroot_lq: ./datasets/urban100/LRbicx2 29 | # io_backend: 30 | # type: disk 31 | 32 | # test_4: 33 | # name: BSDS100 34 | # type: PairedImageDataset 35 | # dataroot_gt: ./datasets/BSDS100/GTmod2 36 | # dataroot_lq: ./datasets/BSDS100/LRbicx2 37 | # io_backend: 38 | # type: disk 39 | 40 | # test_5: 41 | # name: Manga109 42 | # type: PairedImageDataset 43 | # dataroot_gt: ./datasets/manga109/GTmod2 44 | # dataroot_lq: ./datasets/manga109/LRbicx2 45 | # io_backend: 46 | # type: disk 47 | 48 | # network structures 49 | network_g: 50 | type: MAN 51 | scale: 2 #or 3/4 52 | n_resblocks: 36 # 5 for MAN-tiny; 24 for MAN-light; 36 for MAN 53 | n_resgroups: 1 54 | n_feats: 180 # 48 for MAN-tiny; 60 for MAN-light; 180 for MAN 55 | 56 | 57 | # path 58 | path: 59 | pretrain_network_g: ./experiments/pretrained_models/HAT_SRx2.pth 60 | strict_load_g: true 61 | param_key_g: 'params_ema' # only for MAN, for MAN-T and MAN-L, using ~ 62 | 63 | # validation settings 64 | val: 65 | save_img: true 66 | suffix: ~ # add suffix to saved images, if None, use exp name 67 | 68 | metrics: 69 | psnr: # metric name, can be arbitrary 70 | type: calculate_psnr 71 | crop_border: 2 72 | test_y_channel: true 73 | ssim: 74 | type: calculate_ssim 75 | crop_border: 2 76 | test_y_channel: true 77 | -------------------------------------------------------------------------------- /LAM/SaliencyModel/attributes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import torch.nn.functional as F 4 | 5 | 6 | def _tensor_size(t): 7 | return t.size()[1] * t.size()[2] * t.size()[3] 8 | 9 | def reduce_func(method): 10 | """ 11 | 12 | :param method: ['mean', 'sum', 'max', 'min', 'count', 'std'] 13 | :return: 14 | """ 15 | if method == 'sum': 16 | return torch.sum 17 | elif method == 'mean': 18 | return torch.mean 19 | elif method == 'count': 20 | return lambda x: sum(x.size()) 21 | else: 22 | raise NotImplementedError() 23 | 24 | 25 | def attr_id(tensor, h, w, window=8, reduce='sum'): 26 | """ 27 | :param tensor: B, C, H, W tensor 28 | :param h: h position 29 | :param w: w position 30 | :param window: size of window 31 | :param reduce: reduce method, ['mean', 'sum', 'max', 'min'] 32 | :return: 33 | """ 34 | crop = tensor[:, :, h: h + window, w: w + window] 35 | return reduce_func(reduce)(crop) 36 | 37 | 38 | def attr_grad(tensor, h, w, window=8, reduce='sum'): 39 | """ 40 | :param tensor: B, C, H, W tensor 41 | :param h: h position 42 | :param w: w position 43 | :param window: size of window 44 | :param reduce: reduce method, ['mean', 'sum', 'max', 'min'] 45 | :return: 46 | """ 47 | h_x = tensor.size()[2] 48 | w_x = tensor.size()[3] 49 | h_grad = torch.pow(tensor[:, :, :h_x - 1, :] - tensor[:, :, 1:, :], 2) 50 | w_grad = torch.pow(tensor[:, :, :, :w_x - 1] - tensor[:, :, :, 1:], 2) 51 | grad = torch.pow(h_grad[:, :, :, :-1] + w_grad[:, :, :-1, :], 1 / 2) 52 | crop = grad[:, :, h: h + window, w: w + window] 53 | return reduce_func(reduce)(crop) 54 | 55 | 56 | # gabor_filter = cv2.getGaborKernel((21, 21), 10.0, -np.pi/4, 8.0, 1, 0, ktype=cv2.CV_32F) 57 | 58 | def attr_gabor_generator(gabor_filter): 59 | filter = torch.from_numpy(gabor_filter).view((1, 1,) + gabor_filter.shape).repeat(1,3,1,1) 60 | def attr_gabor(tensor, h, w, window=8, reduce='sum'): 61 | after_filter = F.conv2d(tensor, filter, bias=None) 62 | crop = after_filter[:, :, h: h + window, w: w + window] 63 | return reduce_func(reduce)(crop) 64 | return attr_gabor 65 | 66 | 67 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/carn_m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..CARN import ops 4 | 5 | 6 | class Block(nn.Module): 7 | def __init__(self, 8 | in_channels, out_channels, 9 | group=1): 10 | super(Block, self).__init__() 11 | 12 | self.b1 = ops.EResidualBlock(64, 64, group=group) 13 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 14 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 15 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 16 | 17 | def forward(self, x): 18 | c0 = o0 = x 19 | 20 | b1 = self.b1(o0) 21 | c1 = torch.cat([c0, b1], dim=1) 22 | o1 = self.c1(c1) 23 | 24 | b2 = self.b1(o1) 25 | c2 = torch.cat([c1, b2], dim=1) 26 | o2 = self.c2(c2) 27 | 28 | b3 = self.b1(o2) 29 | c3 = torch.cat([c2, b3], dim=1) 30 | o3 = self.c3(c3) 31 | 32 | return o3 33 | 34 | 35 | class Net(nn.Module): 36 | def __init__(self, scale): 37 | super(Net, self).__init__() 38 | 39 | multi_scale = True 40 | group = 1 41 | 42 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 43 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 44 | 45 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 46 | 47 | self.b1 = Block(64, 64, group=group) 48 | self.b2 = Block(64, 64, group=group) 49 | self.b3 = Block(64, 64, group=group) 50 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 51 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 52 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 53 | 54 | self.upsample = ops.UpsampleBlock(64, scale=scale, 55 | multi_scale=multi_scale, 56 | group=group) 57 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 58 | 59 | def forward(self, x, scale): 60 | x = self.sub_mean(x) 61 | x = self.entry(x) 62 | c0 = o0 = x 63 | 64 | b1 = self.b1(o0) 65 | c1 = torch.cat([c0, b1], dim=1) 66 | o1 = self.c1(c1) 67 | 68 | b2 = self.b2(o1) 69 | c2 = torch.cat([c1, b2], dim=1) 70 | o2 = self.c2(c2) 71 | 72 | b3 = self.b3(o2) 73 | c3 = torch.cat([c2, b3], dim=1) 74 | o3 = self.c3(c3) 75 | 76 | out = self.upsample(o3, scale=scale) 77 | 78 | out = self.exit(out) 79 | out = self.add_mean(out) 80 | 81 | return out -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/edsr.py: -------------------------------------------------------------------------------- 1 | from ..NN import common 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def make_model(args, parent=False): 7 | return EDSR(args) 8 | 9 | 10 | class EDSR(nn.Module): 11 | def __init__(self, num_channels=3, factor=4, width=64, depth=16, kernel_size=3,res_scale=1.0, conv=common.default_conv): 12 | super(EDSR, self).__init__() 13 | 14 | n_resblock = depth 15 | n_feats = width 16 | kernel_size = kernel_size 17 | scale = factor 18 | act = nn.ReLU(True) 19 | 20 | rgb_mean = (0.4488, 0.4371, 0.4040) 21 | rgb_std = (1.0, 1.0, 1.0) 22 | self.sub_mean = common.MeanShift(255.0, rgb_mean, rgb_std) 23 | 24 | # define head module 25 | m_head = [conv(num_channels, n_feats, kernel_size)] 26 | 27 | # define body module 28 | m_body = [ 29 | common.ResBlock( 30 | conv, n_feats, kernel_size, act=act, res_scale=res_scale 31 | ) for _ in range(n_resblock) 32 | ] 33 | m_body.append(conv(n_feats, n_feats, kernel_size)) 34 | 35 | # define tail module 36 | m_tail = [ 37 | common.Upsampler(conv, scale, n_feats, act=False), 38 | conv(n_feats, num_channels, kernel_size) 39 | ] 40 | 41 | self.add_mean = common.MeanShift(255.0, rgb_mean, rgb_std, 1) 42 | 43 | self.head = nn.Sequential(*m_head) 44 | self.body = nn.Sequential(*m_body) 45 | self.tail = nn.Sequential(*m_tail) 46 | 47 | def forward(self, x): 48 | x = x*255 49 | x = self.sub_mean(x) 50 | x = self.head(x) 51 | 52 | res = self.body(x) 53 | res += x 54 | 55 | x = self.tail(res) 56 | x = self.add_mean(x) 57 | x = x/255 58 | return x 59 | 60 | def load_state_dict(self, state_dict, strict=True): 61 | own_state = self.state_dict() 62 | for name, param in state_dict.items(): 63 | if name in own_state: 64 | if isinstance(param, nn.Parameter): 65 | param = param.data 66 | try: 67 | own_state[name].copy_(param) 68 | except Exception: 69 | if name.find('tail') == -1: 70 | raise RuntimeError('While copying the parameter named {}, ' 71 | 'whose dimensions in the model are {} and ' 72 | 'whose dimensions in the checkpoint are {}.' 73 | .format(name, own_state[name].size(), param.size())) 74 | elif strict: 75 | if name.find('tail') == -1: 76 | raise KeyError('unexpected key "{}" in state_dict' 77 | .format(name)) -------------------------------------------------------------------------------- /LAM/test_MAN_diff.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch, cv2, os, sys, numpy as np, matplotlib.pyplot as plt 3 | from PIL import Image 4 | from ModelZoo.utils import load_as_tensor, Tensor2PIL, PIL2Tensor, _add_batch_one 5 | from ModelZoo import get_model, load_model, print_network 6 | from SaliencyModel.utils import vis_saliency, vis_saliency_kde, click_select_position, grad_abs_norm, grad_norm, prepare_images, make_pil_grid, blend_input,plot_diff_of_attrs_kde 7 | from SaliencyModel.utils import cv2_to_pil, pil_to_cv2, gini 8 | from SaliencyModel.attributes import attr_grad 9 | from SaliencyModel.BackProp import I_gradient, attribution_objective, Path_gradient 10 | from SaliencyModel.BackProp import saliency_map_PG as saliency_map 11 | from SaliencyModel.BackProp import GaussianBlurPath 12 | from SaliencyModel.utils import grad_norm, IG_baseline, interpolation, isotropic_gaussian_kernel 13 | from cal_metrix import calculate_psnr, calculate_ssim, bgr2ycbcr, tensor2img 14 | 15 | 16 | def LAM(image_path='test_images/3.png',w=90, h=120): 17 | window_size = 16 # Define windoes_size of D 18 | img_lr, img_hr = prepare_images(image_path) # Change this image name 19 | tensor_lr = PIL2Tensor(img_lr)[:3] ; tensor_hr = PIL2Tensor(img_hr)[:3] 20 | cv2_lr = np.moveaxis(tensor_lr.numpy(), 0, 2) ; cv2_hr = np.moveaxis(tensor_hr.numpy(), 0, 2) 21 | draw_img = pil_to_cv2(img_hr) 22 | cv2.rectangle(draw_img, (w, h), (w + window_size, h + window_size), (0, 0, 255), 2) 23 | position_pil = cv2_to_pil(draw_img) 24 | plt.imshow(position_pil) 25 | 26 | 27 | model = load_model('EDSR@Large') 28 | 29 | sigma = 1.2 ; fold = 50 ; l = 9 ; alpha = 0.5 30 | attr_objective = attribution_objective(attr_grad, h, w, window=window_size) 31 | gaus_blur_path_func = GaussianBlurPath(sigma, fold, l) 32 | interpolated_grad_numpy, result_numpy, interpolated_numpy = Path_gradient(tensor_lr.numpy(), model, attr_objective, gaus_blur_path_func, cuda=True) 33 | grad_numpy, result = saliency_map(interpolated_grad_numpy, result_numpy) 34 | abs_normed_grad_numpy = grad_abs_norm(grad_numpy) 35 | B = abs_normed_grad_numpy 36 | 37 | 38 | model = load_model('MAN@Light') 39 | 40 | attr_objective = attribution_objective(attr_grad, h, w, window=window_size) 41 | gaus_blur_path_func = GaussianBlurPath(sigma, fold, l) 42 | interpolated_grad_numpy, result_numpy, interpolated_numpy = Path_gradient(tensor_lr.numpy(), model, attr_objective, gaus_blur_path_func, cuda=True) 43 | grad_numpy, result = saliency_map(interpolated_grad_numpy, result_numpy) 44 | abs_normed_grad_numpy = grad_abs_norm(grad_numpy) 45 | A = abs_normed_grad_numpy 46 | 47 | 48 | res = plot_diff_of_attrs_kde(A,B) 49 | 50 | 51 | plt.axis('off') 52 | 53 | plt.imshow(res) 54 | plt.savefig('./lam_results/{}/diff_LE.png'.format(image_path[-5]), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 55 | 56 | 57 | 58 | 59 | 60 | image_path='test_images/e.png' 61 | w=120 62 | h=100 63 | LAM(image_path,w,h) -------------------------------------------------------------------------------- /options/trian_MAN.yml: -------------------------------------------------------------------------------- 1 | name: MAN_SR 2 | model_type: SRModel 3 | scale: 2 # 2/3/4/8 4 | num_gpu: 8 # or 4 5 | manual_seed: 10 6 | 7 | datasets: 8 | train: 9 | name: DF2K 10 | type: PairedImageDataset 11 | dataroot_gt: datasets/DF2K/DF2K_HR_sub 12 | dataroot_lq: datasets/DF2K/DF2K_bicx2_sub 13 | meta_info_file: hat/data/meta_info/meta_info_DF2Ksub_GT.txt 14 | io_backend: 15 | type: disk 16 | 17 | gt_size: 128 #scale*48 or scale*64 18 | use_hflip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 6 24 | batch_size_per_gpu: 4 25 | dataset_enlarge_ratio: 1 26 | prefetch_mode: ~ 27 | 28 | val_1: 29 | name: Set5 30 | type: PairedImageDataset 31 | dataroot_gt: ./datasets/Set5/GTmod2 32 | dataroot_lq: ./datasets/Set5/LRbicx2 33 | io_backend: 34 | type: disk 35 | 36 | val_2: 37 | name: Set14 38 | type: PairedImageDataset 39 | dataroot_gt: ./datasets/Set14/GTmod2 40 | dataroot_lq: ./datasets/Set14/LRbicx2 41 | io_backend: 42 | type: disk 43 | 44 | # val_3: 45 | # name: Urban100 46 | # type: PairedImageDataset 47 | # dataroot_gt: ./datasets/urban100/GTmod2 48 | # dataroot_lq: ./datasets/urban100/LRbicx2 49 | # io_backend: 50 | # type: disk 51 | 52 | 53 | # network structures 54 | network_g: 55 | type: MAN 56 | scale: 2 #or 3/4 57 | n_resblocks: 36 # 5 for MAN-tiny; 24 for MAN-light; 36 for MAN 58 | n_resgroups: 1 59 | n_feats: 180 # 48 for MAN-tiny; 60 for MAN-light; 180 for MAN 60 | 61 | # path 62 | path: 63 | pretrain_network_g: ~ 64 | strict_load_g: true 65 | resume_state: ~ 66 | 67 | # training settings 68 | train: 69 | ema_decay: 0.999 70 | optim_g: 71 | type: Adam 72 | lr: !!float 5e-4 73 | weight_decay: 0 74 | betas: [0.9, 0.99] 75 | 76 | scheduler: 77 | type: MultiStepLR 78 | milestones: [800000, 1200000, 140000, 1500000] 79 | gamma: 0.5 80 | 81 | #type: CosineAnnealingRestartLR 82 | #periods: [1600000] 83 | #restart_weights: [1] 84 | #eta_min: !!float 1e-7 85 | 86 | total_iter: 1600000 87 | warmup_iter: -1 # no warm up 88 | 89 | # losses 90 | pixel_opt: 91 | type: L1Loss 92 | loss_weight: 1.0 93 | reduction: mean 94 | 95 | # validation settings 96 | val: 97 | val_freq: !!float 5e3 98 | save_img: false 99 | pbar: False 100 | 101 | metrics: 102 | psnr: 103 | type: calculate_psnr 104 | crop_border: 2 # 2/3/4 105 | test_y_channel: true 106 | better: higher # the higher, the better. Default: higher 107 | ssim: 108 | type: calculate_ssim 109 | crop_border: 2 # 2/3/4 110 | test_y_channel: true 111 | better: higher # the higher, the better. Default: higher 112 | 113 | # logging settings 114 | logger: 115 | print_freq: 100 116 | save_checkpoint_freq: !!float 5e3 117 | use_tb_logger: true 118 | wandb: 119 | project: ~ 120 | resume_id: ~ 121 | 122 | # dist training settings 123 | dist_params: 124 | backend: nccl 125 | port: 29500 126 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/carn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..CARN import ops 4 | 5 | 6 | class Block(nn.Module): 7 | def __init__(self, 8 | in_channels, out_channels, 9 | group=1): 10 | super(Block, self).__init__() 11 | 12 | self.b1 = ops.ResidualBlock(64, 64) 13 | self.b2 = ops.ResidualBlock(64, 64) 14 | self.b3 = ops.ResidualBlock(64, 64) 15 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 16 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 17 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 18 | 19 | def forward(self, x): 20 | c0 = o0 = x 21 | 22 | b1 = self.b1(o0) 23 | c1 = torch.cat([c0, b1], dim=1) 24 | o1 = self.c1(c1) 25 | 26 | b2 = self.b2(o1) 27 | c2 = torch.cat([c1, b2], dim=1) 28 | o2 = self.c2(c2) 29 | 30 | b3 = self.b3(o2) 31 | c3 = torch.cat([c2, b3], dim=1) 32 | o3 = self.c3(c3) 33 | 34 | return o3 35 | 36 | 37 | class Net(nn.Module): 38 | def __init__(self, scale): 39 | super(Net, self).__init__() 40 | 41 | multi_scale = True 42 | group = 1 43 | 44 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 45 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 46 | 47 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 48 | 49 | self.b1 = Block(64, 64) 50 | self.b2 = Block(64, 64) 51 | self.b3 = Block(64, 64) 52 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 53 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 54 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 55 | 56 | self.upsample = ops.UpsampleBlock(64, scale=scale, 57 | multi_scale=multi_scale, 58 | group=group) 59 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 60 | 61 | def forward(self, x, scale=4): 62 | x = self.sub_mean(x) 63 | x = self.entry(x) 64 | c0 = o0 = x 65 | 66 | b1 = self.b1(o0) 67 | c1 = torch.cat([c0, b1], dim=1) 68 | o1 = self.c1(c1) 69 | 70 | b2 = self.b2(o1) 71 | c2 = torch.cat([c1, b2], dim=1) 72 | o2 = self.c2(c2) 73 | 74 | b3 = self.b3(o2) 75 | c3 = torch.cat([c2, b3], dim=1) 76 | o3 = self.c3(c3) 77 | 78 | out = self.upsample(o3, scale=scale) 79 | 80 | out = self.exit(out) 81 | out = self.add_mean(out) 82 | 83 | return out 84 | 85 | 86 | class CARNet(nn.Module): 87 | def __init__(self, factor=4, num_channels=3): 88 | super(CARNet, self).__init__() 89 | 90 | multi_scale = False 91 | group = 1 92 | 93 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 94 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 95 | 96 | self.entry = nn.Conv2d(num_channels, 64, 3, 1, 1) 97 | 98 | self.b1 = Block(64, 64) 99 | self.b2 = Block(64, 64) 100 | self.b3 = Block(64, 64) 101 | self.c1 = ops.BasicBlock(64 * 2, 64, 1, 1, 0) 102 | self.c2 = ops.BasicBlock(64 * 3, 64, 1, 1, 0) 103 | self.c3 = ops.BasicBlock(64 * 4, 64, 1, 1, 0) 104 | 105 | self.upsample = ops.UpsampleBlock(64, scale=factor, 106 | multi_scale=multi_scale, 107 | group=group) 108 | self.exit = nn.Conv2d(64, num_channels, 3, 1, 1) 109 | 110 | def forward(self, x): 111 | x = self.sub_mean(x) 112 | x = self.entry(x) 113 | c0 = o0 = x 114 | 115 | b1 = self.b1(o0) 116 | c1 = torch.cat([c0, b1], dim=1) 117 | o1 = self.c1(c1) 118 | 119 | b2 = self.b2(o1) 120 | c2 = torch.cat([c1, b2], dim=1) 121 | o2 = self.c2(c2) 122 | 123 | b3 = self.b3(o2) 124 | c3 = torch.cat([c2, b3], dim=1) 125 | o3 = self.c3(c3) 126 | 127 | out = self.upsample(o3) 128 | 129 | out = self.exit(out) 130 | out = self.add_mean(out) 131 | 132 | return out -------------------------------------------------------------------------------- /LAM/ModelZoo/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | MODEL_DIR = 'ModelZoo/models' 5 | 6 | 7 | NN_LIST = [ 8 | 'RCAN', 9 | 'CARN', 10 | 'RRDBNet', 11 | 'RNAN', 12 | 'SAN', 13 | 'MAN', 14 | 'EDSR' 15 | ] 16 | 17 | 18 | MODEL_LIST = { 19 | 'RCAN': { 20 | 'Base': 'RCAN.pt', 21 | }, 22 | 'CARN': { 23 | 'Base': 'CARN_7400.pth', 24 | }, 25 | 'RRDBNet': { 26 | 'Base': 'RRDBNet_PSNR_SRx4_DF2K_official-150ff491.pth', 27 | }, 28 | 'SAN': { 29 | 'Base': 'SAN_BI4X.pt', 30 | }, 31 | 'RNAN': { 32 | 'Base': 'RNAN_SR_F64G10P48BIX4.pt', 33 | }, 34 | 'EDSR': { 35 | 'Base': 'edsr_baseline_x4-6b446fab.pt', 36 | 'Large': 'edsr_x4-4f62e9ef.pt', 37 | }, 38 | 'MAN': { 39 | 'Base': 'MANx4_DF2K.pth', 40 | 'Light': 'MAN-Light-x4.pth', 41 | 'Tiny': 'MAN-Tiny-x4.pth', 42 | }, 43 | } 44 | 45 | def print_network(model, model_name): 46 | num_params = 0 47 | for param in model.parameters(): 48 | num_params += param.numel() 49 | print('Network [%s] was created. Total number of parameters: %.1f kelo. ' 50 | 'To see the architecture, do print(network).' 51 | % (model_name, num_params / 1000)) 52 | 53 | 54 | def get_model(model_name, training_name='Base', factor=4, num_channels=3): 55 | """ 56 | All the models are defaulted to be X4 models, the Channels is defaulted to be RGB 3 channels. 57 | :param model_name: 58 | :param factor: 59 | :param num_channels: 60 | :return: 61 | """ 62 | print(f'Getting SR Network {model_name}') 63 | if model_name.split('-')[0] in NN_LIST: 64 | 65 | if model_name == 'RCAN': 66 | from .NN.rcan import RCAN 67 | net = RCAN(factor=factor, num_channels=num_channels) 68 | 69 | elif model_name == 'CARN': 70 | from .CARN.carn import CARNet 71 | net = CARNet(factor=factor, num_channels=num_channels) 72 | 73 | elif model_name == 'RRDBNet': 74 | from .NN.rrdbnet import RRDBNet 75 | net = RRDBNet(num_in_ch=num_channels, num_out_ch=num_channels) 76 | 77 | elif model_name == 'SAN': 78 | from .NN.san import SAN 79 | net = SAN(factor=factor, num_channels=num_channels) 80 | 81 | elif model_name == 'RNAN': 82 | from .NN.rnan import RNAN 83 | net = RNAN(factor=factor, num_channels=num_channels) 84 | 85 | elif model_name == 'EDSR': 86 | from .NN.edsr import EDSR 87 | if training_name == 'Base': 88 | net = EDSR(factor=factor, width=64, depth=16) 89 | else: 90 | net = EDSR(factor=factor, width=256, depth=32, res_scale=0.1) 91 | 92 | elif model_name == 'MAN': 93 | from .NN.man import MAN 94 | if training_name == 'Base': 95 | net = MAN(n_resblocks=36, n_feats=180, scale=factor) 96 | elif training_name == 'Tiny': 97 | net = MAN(n_resblocks=5, n_feats=48, scale=factor) 98 | else: 99 | net = MAN(n_resblocks=24, n_feats=60, scale=factor) 100 | else: 101 | raise NotImplementedError() 102 | 103 | print_network(net, model_name) 104 | return net 105 | else: 106 | raise NotImplementedError() 107 | 108 | 109 | def load_model(model_loading_name): 110 | """ 111 | :param model_loading_name: model_name-training_name 112 | :return: 113 | """ 114 | splitting = model_loading_name.split('@') 115 | if len(splitting) == 1: 116 | model_name = splitting[0] 117 | training_name = 'Base' 118 | elif len(splitting) == 2: 119 | model_name = splitting[0] 120 | training_name = splitting[1] 121 | else: 122 | raise NotImplementedError() 123 | assert model_name in NN_LIST or model_name in MODEL_LIST.keys(), 'check your model name before @' 124 | net = get_model(model_name,training_name) 125 | state_dict_path = os.path.join(MODEL_DIR, MODEL_LIST[model_name][training_name]) 126 | print(f'Loading model {state_dict_path} for {model_name} network.') 127 | state_dict = torch.load(state_dict_path, map_location='cpu') 128 | if model_loading_name =='MAN@Base': 129 | state_dict = state_dict['params_ema'] 130 | net.load_state_dict(state_dict) 131 | return net 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /LAM/ModelZoo/CARN/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | def init_weights(modules): 9 | pass 10 | 11 | 12 | class MeanShift(nn.Module): 13 | def __init__(self, mean_rgb, sub): 14 | super(MeanShift, self).__init__() 15 | 16 | sign = -1 if sub else 1 17 | r = mean_rgb[0] * sign 18 | g = mean_rgb[1] * sign 19 | b = mean_rgb[2] * sign 20 | 21 | self.shifter = nn.Conv2d(3, 3, 1, 1, 0) 22 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 23 | self.shifter.bias.data = torch.Tensor([r, g, b]) 24 | 25 | # Freeze the mean shift layer 26 | for params in self.shifter.parameters(): 27 | params.requires_grad = False 28 | 29 | def forward(self, x): 30 | x = self.shifter(x) 31 | return x 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | def __init__(self, 36 | in_channels, out_channels, 37 | ksize=3, stride=1, pad=1): 38 | super(BasicBlock, self).__init__() 39 | 40 | self.body = nn.Sequential( 41 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad), 42 | nn.ReLU(inplace=True) 43 | ) 44 | 45 | init_weights(self.modules) 46 | 47 | def forward(self, x): 48 | out = self.body(x) 49 | return out 50 | 51 | 52 | class ResidualBlock(nn.Module): 53 | def __init__(self, 54 | in_channels, out_channels): 55 | super(ResidualBlock, self).__init__() 56 | 57 | self.body = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 61 | ) 62 | 63 | init_weights(self.modules) 64 | 65 | def forward(self, x): 66 | out = self.body(x) 67 | out = F.relu(out + x) 68 | return out 69 | 70 | 71 | class EResidualBlock(nn.Module): 72 | def __init__(self, 73 | in_channels, out_channels, 74 | group=1): 75 | super(EResidualBlock, self).__init__() 76 | 77 | self.body = nn.Sequential( 78 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 83 | ) 84 | 85 | init_weights(self.modules) 86 | 87 | def forward(self, x): 88 | out = self.body(x) 89 | out = F.relu(out + x) 90 | return out 91 | 92 | 93 | class UpsampleBlock(nn.Module): 94 | def __init__(self, 95 | n_channels, scale, multi_scale, 96 | group=1): 97 | super(UpsampleBlock, self).__init__() 98 | 99 | if multi_scale: 100 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) 101 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) 102 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) 103 | else: 104 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group) 105 | 106 | self.multi_scale = multi_scale 107 | 108 | def forward(self, x, scale=None): 109 | if self.multi_scale: 110 | if scale == 2: 111 | return self.up2(x) 112 | elif scale == 3: 113 | return self.up3(x) 114 | elif scale == 4: 115 | return self.up4(x) 116 | else: 117 | return self.up(x) 118 | 119 | 120 | class _UpsampleBlock(nn.Module): 121 | def __init__(self, 122 | n_channels, scale, 123 | group=1): 124 | super(_UpsampleBlock, self).__init__() 125 | 126 | modules = [] 127 | if scale == 2 or scale == 4 or scale == 8: 128 | for _ in range(int(math.log(scale, 2))): 129 | modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 130 | modules += [nn.PixelShuffle(2)] 131 | elif scale == 3: 132 | modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 133 | modules += [nn.PixelShuffle(3)] 134 | 135 | self.body = nn.Sequential(*modules) 136 | init_weights(self.modules) 137 | 138 | def forward(self, x): 139 | out = self.body(x) 140 | return out -------------------------------------------------------------------------------- /LAM/test_MAN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch, cv2, os, sys, numpy as np, matplotlib.pyplot as plt 3 | from PIL import Image 4 | from ModelZoo.utils import load_as_tensor, Tensor2PIL, PIL2Tensor, _add_batch_one 5 | from ModelZoo import get_model, load_model, print_network 6 | from SaliencyModel.utils import vis_saliency, vis_saliency_kde, click_select_position, grad_abs_norm, grad_norm, prepare_images, make_pil_grid, blend_input 7 | from SaliencyModel.utils import cv2_to_pil, pil_to_cv2, gini 8 | from SaliencyModel.attributes import attr_grad 9 | from SaliencyModel.BackProp import I_gradient, attribution_objective, Path_gradient 10 | from SaliencyModel.BackProp import saliency_map_PG as saliency_map 11 | from SaliencyModel.BackProp import GaussianBlurPath 12 | from SaliencyModel.utils import grad_norm, IG_baseline, interpolation, isotropic_gaussian_kernel 13 | from cal_metrix import calculate_psnr, calculate_ssim, bgr2ycbcr, tensor2img 14 | 15 | #LAM: Interpreting Super-Resolution Networks with Local Attribution Maps 16 | 17 | def LAM(model_names='MAN@Base', image_path='test_images/3.png',w=90, h=120): 18 | model = load_model(model_names) 19 | window_size = 16 # Define windoes_size of D 20 | img_lr, img_hr = prepare_images(image_path) # Change this image name 21 | tensor_lr = PIL2Tensor(img_lr)[:3] ; tensor_hr = PIL2Tensor(img_hr)[:3] 22 | cv2_lr = np.moveaxis(tensor_lr.numpy(), 0, 2) ; cv2_hr = np.moveaxis(tensor_hr.numpy(), 0, 2) 23 | 24 | plt.imshow(cv2_hr) 25 | 26 | draw_img = pil_to_cv2(img_hr) 27 | cv2.rectangle(draw_img, (w, h), (w + window_size, h + window_size), (0, 0, 255), 2) 28 | position_pil = cv2_to_pil(draw_img) 29 | plt.imshow(position_pil) 30 | 31 | sigma = 1.2 ; fold = 50 ; l = 9 ; alpha = 0.5 32 | attr_objective = attribution_objective(attr_grad, h, w, window=window_size) 33 | gaus_blur_path_func = GaussianBlurPath(sigma, fold, l) 34 | interpolated_grad_numpy, result_numpy, interpolated_numpy = Path_gradient(tensor_lr.numpy(), model, attr_objective, gaus_blur_path_func, cuda=True) 35 | grad_numpy, result = saliency_map(interpolated_grad_numpy, result_numpy) 36 | abs_normed_grad_numpy = grad_abs_norm(grad_numpy) 37 | saliency_image_abs = vis_saliency(abs_normed_grad_numpy, zoomin=4) 38 | saliency_image_kde = vis_saliency_kde(abs_normed_grad_numpy) 39 | blend_abs_and_input = cv2_to_pil(pil_to_cv2(saliency_image_abs) * (1.0 - alpha) + pil_to_cv2(img_lr.resize(img_hr.size)) * alpha) 40 | blend_kde_and_input = cv2_to_pil(pil_to_cv2(saliency_image_kde) * (1.0 - alpha) + pil_to_cv2(img_lr.resize(img_hr.size)) * alpha) 41 | pil = make_pil_grid( 42 | [position_pil, 43 | saliency_image_abs, 44 | blend_abs_and_input, 45 | blend_kde_and_input, 46 | Tensor2PIL(torch.clamp(result, min=0., max=1.))] 47 | ) 48 | 49 | plt.axis('off') 50 | 51 | gini_index = gini(abs_normed_grad_numpy) 52 | diffusion_index = (1 - gini_index) * 100 53 | 54 | plt.imshow(img_lr.resize(img_hr.size)) 55 | plt.savefig('./lam_results/{}/0lr.png'.format(image_path[-5]), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 56 | 57 | plt.imshow(position_pil) 58 | plt.savefig('./lam_results/{}/1hr.png'.format(image_path[-5]), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 59 | plt.imshow(saliency_image_abs) 60 | plt.savefig('./lam_results/{}/2abs_{}.png'.format(image_path[-5],model_names), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 61 | plt.imshow(blend_kde_and_input) 62 | plt.savefig('./lam_results/{}/3kde_{}.png'.format(image_path[-5],model_names), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 63 | plt.imshow(Tensor2PIL(torch.clamp(result, min=0., max=1.))) 64 | plt.savefig('./lam_results/{}/4sr_{}.png'.format(image_path[-5],model_names), dpi=300, bbox_inches = 'tight',pad_inches=0.0) 65 | 66 | plt.imshow(pil) 67 | 68 | im_GT = tensor2img(tensor_hr) 69 | im_Gen = tensor2img(result) 70 | im_GT_in = bgr2ycbcr(im_GT) 71 | im_Gen_in = bgr2ycbcr(im_Gen) 72 | 73 | crop_border = 4 74 | 75 | if im_GT_in.ndim == 3: 76 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 77 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 78 | elif im_GT_in.ndim == 2: 79 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 80 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 81 | 82 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 83 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 84 | 85 | print('DI: {:.5f} PSNR: {:.3f}dB SSIM: {:.5f}'.format(diffusion_index,PSNR,SSIM)) 86 | 87 | model_names='EDSR@Large' 88 | image_path='test_images/e.png' 89 | w=120 90 | h=100 91 | LAM(model_names,image_path,w,h) -------------------------------------------------------------------------------- /LAM/ModelZoo/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torchvision 4 | import torch 5 | 6 | 7 | IMG_EXTENSIONS = ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm'] 8 | 9 | 10 | def mkdir(path): 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | 14 | 15 | def pil_loader(path, mode='RGB'): 16 | """ 17 | open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 18 | :param path: image path 19 | :return: PIL.Image 20 | """ 21 | assert _is_image_file(path), "%s is not an image" % path 22 | with open(path, 'rb') as f: 23 | with Image.open(f) as img: 24 | return img.convert(mode) 25 | 26 | 27 | def calculate_RF(model): 28 | layers = getLayers(model) 29 | r = 1 30 | for layer in layers[::-1]: 31 | if isinstance(layer, torch.nn.Conv2d): 32 | kernel = layer.kernel_size[0] 33 | padding = layer.padding[0] 34 | stride = layer.stride[0] 35 | r = stride * r + (kernel - stride) 36 | return r 37 | 38 | 39 | def getLayers(model): 40 | """ 41 | get each layer's name and its module 42 | :param model: 43 | :return: each layer's name and its module 44 | """ 45 | layers = [] 46 | 47 | def unfoldLayer(model): 48 | """ 49 | unfold each layer 50 | :param model: the given model or a single layer 51 | :param root: root name 52 | :return: 53 | """ 54 | 55 | # get all layers of the model 56 | layer_list = list(model.named_children()) 57 | for item in layer_list: 58 | module = item[1] 59 | sublayer = list(module.named_children()) 60 | sublayer_num = len(sublayer) 61 | 62 | # if current layer contains sublayers, add current layer name on its sublayers 63 | if sublayer_num == 0: 64 | layers.append(module) 65 | # if current layer contains sublayers, unfold them 66 | elif isinstance(module, torch.nn.Module): 67 | unfoldLayer(module) 68 | 69 | unfoldLayer(model) 70 | return layers 71 | 72 | 73 | def load_as_tensor(path, mode='RGB'): 74 | """ 75 | Load image to tensor 76 | :param path: image path 77 | :param mode: 'Y' returns 1 channel tensor, 'RGB' returns 3 channels, 'RGBA' returns 4 channels, 'YCbCr' returns 3 channels 78 | :return: 3D tensor 79 | """ 80 | if mode != 'Y': 81 | return PIL2Tensor(pil_loader(path, mode=mode)) 82 | else: 83 | return PIL2Tensor(pil_loader(path, mode='YCbCr'))[:1] 84 | 85 | 86 | def PIL2Tensor(pil_image): 87 | return torchvision.transforms.functional.to_tensor(pil_image) 88 | 89 | 90 | def Tensor2PIL(tensor_image, mode='RGB'): 91 | if len(tensor_image.size()) == 4 and tensor_image.size()[0] == 1: 92 | tensor_image = tensor_image.view(tensor_image.size()[1:]) 93 | return torchvision.transforms.functional.to_pil_image(tensor_image.detach(), mode=mode) 94 | 95 | 96 | def _is_image_file(filename): 97 | """ 98 | judge if the file is an image file 99 | :param filename: path 100 | :return: bool of judgement 101 | """ 102 | filename_lower = filename.lower() 103 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 104 | 105 | 106 | def image_files(path): 107 | """ 108 | return list of images in the path 109 | :param path: path to Data Folder, absolute path 110 | :return: 1D list of image files absolute path 111 | """ 112 | abs_path = os.path.abspath(path) 113 | image_files = os.listdir(abs_path) 114 | for i in range(len(image_files)): 115 | if (not os.path.isdir(image_files[i])) and (_is_image_file(image_files[i])): 116 | image_files[i] = os.path.join(abs_path, image_files[i]) 117 | return image_files 118 | 119 | 120 | def split_to_batches(l, n): 121 | for i in range(0, len(l), n): 122 | yield l[i:i + n] 123 | 124 | 125 | def _sigmoid_to_tanh(x): 126 | """ 127 | range [0, 1] to range [-1, 1] 128 | :param x: tensor type 129 | :return: tensor 130 | """ 131 | return (x - 0.5) * 2. 132 | 133 | 134 | def _tanh_to_sigmoid(x): 135 | """ 136 | range [-1, 1] to range [0, 1] 137 | :param x: 138 | :return: 139 | """ 140 | return x * 0.5 + 0.5 141 | 142 | 143 | def _add_batch_one(tensor): 144 | """ 145 | Return a tensor with size (1, ) + tensor.size 146 | :param tensor: 2D or 3D tensor 147 | :return: 3D or 4D tensor 148 | """ 149 | return tensor.view((1, ) + tensor.size()) 150 | 151 | 152 | def _remove_batch(tensor): 153 | """ 154 | Return a tensor with size tensor.size()[1:] 155 | :param tensor: 3D or 4D tensor 156 | :return: 2D or 3D tensor 157 | """ 158 | return tensor.view(tensor.size()[1:]) 159 | 160 | def mod_crop(tensor, scale=4): 161 | B, C, H, W = tensor.shape 162 | return tensor[:, :, :H-H % scale, :W-W % scale] 163 | 164 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/rnan.py: -------------------------------------------------------------------------------- 1 | from ..NN import common 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def make_model(args, parent=False): 7 | return RNAN(args) 8 | 9 | 10 | ### RNAN 11 | class _ResGroup(nn.Module): 12 | def __init__(self, conv, n_feats, kernel_size, act, res_scale): 13 | super(_ResGroup, self).__init__() 14 | modules_body = [] 15 | modules_body.append( 16 | common.ResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), 17 | res_scale=1)) 18 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 19 | self.body = nn.Sequential(*modules_body) 20 | 21 | def forward(self, x): 22 | res = self.body(x) 23 | return res 24 | 25 | 26 | class _NLResGroup(nn.Module): 27 | def __init__(self, conv, n_feats, kernel_size, act, res_scale): 28 | super(_NLResGroup, self).__init__() 29 | modules_body = [] 30 | modules_body.append( 31 | common.NLResAttModuleDownUpPlus(conv, n_feats, kernel_size, bias=True, bn=False, act=nn.ReLU(True), 32 | res_scale=1)) 33 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 34 | self.body = nn.Sequential(*modules_body) 35 | 36 | def forward(self, x): 37 | res = self.body(x) 38 | return res 39 | 40 | 41 | class RNAN(nn.Module): 42 | def __init__(self, factor=4, num_channels=3, conv=common.default_conv): 43 | super(RNAN, self).__init__() 44 | 45 | n_resgroup = 10 46 | n_resblock = 16 47 | n_feats = 64 48 | kernel_size = 3 49 | reduction = 16 50 | scale = factor 51 | act = nn.ReLU(True) 52 | 53 | # RGB mean for DIV2K 1-800 54 | rgb_mean = (0.4488, 0.4371, 0.4040) 55 | rgb_std = (1.0, 1.0, 1.0) 56 | self.sub_mean = common.MeanShift(1.0, rgb_mean, rgb_std) 57 | 58 | # define head module 59 | modules_head = [conv(num_channels, n_feats, kernel_size)] 60 | 61 | # define body module 62 | modules_body_nl_low = [ 63 | _NLResGroup( 64 | conv, n_feats, kernel_size, act=act, res_scale=1.)] 65 | modules_body = [ 66 | _ResGroup( 67 | conv, n_feats, kernel_size, act=act, res_scale=1.) \ 68 | for _ in range(n_resgroup - 2)] 69 | modules_body_nl_high = [ 70 | _NLResGroup( 71 | conv, n_feats, kernel_size, act=act, res_scale=1.)] 72 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 73 | 74 | # define tail module 75 | modules_tail = [ 76 | common.Upsampler(conv, scale, n_feats, act=False), 77 | conv(n_feats, num_channels, kernel_size)] 78 | 79 | self.add_mean = common.MeanShift(1.0, rgb_mean, rgb_std, 1) 80 | 81 | self.head = nn.Sequential(*modules_head) 82 | self.body_nl_low = nn.Sequential(*modules_body_nl_low) 83 | self.body = nn.Sequential(*modules_body) 84 | self.body_nl_high = nn.Sequential(*modules_body_nl_high) 85 | self.tail = nn.Sequential(*modules_tail) 86 | 87 | def forward(self, x): 88 | 89 | x = self.sub_mean(x * 255.) 90 | feats_shallow = self.head(x) 91 | 92 | res = self.body_nl_low(feats_shallow) 93 | res = self.body(res) 94 | res = self.body_nl_high(res) 95 | res += feats_shallow 96 | 97 | res_main = self.tail(res) 98 | 99 | res_main = self.add_mean(res_main) 100 | 101 | return res_main / 255. 102 | 103 | def load_state_dict(self, state_dict, strict=False): 104 | own_state = self.state_dict() 105 | for name, param in state_dict.items(): 106 | if name in own_state: 107 | if isinstance(param, nn.Parameter): 108 | param = param.data 109 | try: 110 | own_state[name].copy_(param) 111 | except Exception: 112 | if name.find('tail') >= 0: 113 | print('Replace pre-trained upsampler to new one...') 114 | else: 115 | raise RuntimeError('While copying the parameter named {}, ' 116 | 'whose dimensions in the model are {} and ' 117 | 'whose dimensions in the checkpoint are {}.' 118 | .format(name, own_state[name].size(), param.size())) 119 | elif strict: 120 | if name.find('tail') == -1: 121 | raise KeyError('unexpected key "{}" in state_dict' 122 | .format(name)) 123 | 124 | if strict: 125 | missing = set(own_state.keys()) - set(state_dict.keys()) 126 | if len(missing) > 0: 127 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/rcan.py: -------------------------------------------------------------------------------- 1 | from ..NN import common 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def make_model(args, parent=False): 7 | return RCAN(args) 8 | 9 | 10 | ## Channel Attention (CA) Layer 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | self.conv_du = nn.Sequential( 16 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | y = self.avg_pool(x) 24 | y = self.conv_du(y) 25 | return x * y 26 | 27 | 28 | ## Residual Channel Attention Block (RCAB) 29 | class RCAB(nn.Module): 30 | def __init__( 31 | self, conv, n_feat, kernel_size, reduction, 32 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 33 | 34 | super(RCAB, self).__init__() 35 | modules_body = [] 36 | for i in range(2): 37 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 38 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 39 | if i == 0: modules_body.append(act) 40 | modules_body.append(CALayer(n_feat, reduction)) 41 | self.body = nn.Sequential(*modules_body) 42 | self.res_scale = res_scale 43 | 44 | def forward(self, x): 45 | res = self.body(x) 46 | # res = self.body(x).mul(self.res_scale) 47 | res += x 48 | return res 49 | 50 | 51 | ## Residual Group (RG) 52 | class ResidualGroup(nn.Module): 53 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 54 | super(ResidualGroup, self).__init__() 55 | modules_body = [] 56 | modules_body = [ 57 | RCAB( 58 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 59 | for _ in range(n_resblocks)] 60 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 61 | self.body = nn.Sequential(*modules_body) 62 | 63 | def forward(self, x): 64 | res = self.body(x) 65 | res += x 66 | return res 67 | 68 | 69 | class RCAN(nn.Module): 70 | def __init__(self, factor=4, num_channels=3, conv=common.default_conv): 71 | super(RCAN, self).__init__() 72 | 73 | n_resgroups = 10 74 | n_resblocks = 20 75 | n_feats = 64 76 | kernel_size = 3 77 | reduction = 16 78 | scale = factor 79 | act = nn.ReLU(True) 80 | 81 | # RGB mean for DIV2K 1-800 82 | # rgb_mean = (0.4488, 0.4371, 0.4040) 83 | # RGB mean for DIVFlickr2K 1-3450 84 | rgb_mean = (0.4690, 0.4490, 0.4036) 85 | # rgb_mean = (0.4488, 0.4371, 0.4040) 86 | rgb_std = (1.0, 1.0, 1.0) 87 | self.sub_mean = common.MeanShift(255., rgb_mean, rgb_std) 88 | 89 | # define head module 90 | modules_head = [conv(num_channels, n_feats, kernel_size)] 91 | 92 | # define body module 93 | modules_body = [ 94 | ResidualGroup( 95 | conv, n_feats, kernel_size, reduction, act=act, res_scale=1, n_resblocks=n_resblocks) for _ in range(n_resgroups)] 96 | 97 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 98 | 99 | # define tail module 100 | modules_tail = [ 101 | common.Upsampler(conv, scale, n_feats, act=False), 102 | conv(n_feats, num_channels, kernel_size)] 103 | 104 | self.add_mean = common.MeanShift(255., rgb_mean, rgb_std, 1) 105 | 106 | self.head = nn.Sequential(*modules_head) 107 | self.body = nn.Sequential(*modules_body) 108 | self.tail = nn.Sequential(*modules_tail) 109 | 110 | def forward(self, x): 111 | x = self.sub_mean(x * 255) 112 | x = self.head(x) 113 | 114 | res = self.body(x) 115 | res += x 116 | 117 | x = self.tail(res) 118 | x = self.add_mean(x) 119 | 120 | return x / 255 121 | 122 | def load_state_dict(self, state_dict, strict=False): 123 | own_state = self.state_dict() 124 | for name, param in state_dict.items(): 125 | if name in own_state: 126 | if isinstance(param, nn.Parameter): 127 | param = param.data 128 | try: 129 | own_state[name].copy_(param) 130 | except Exception: 131 | if name.find('tail') >= 0: 132 | print('Replace pre-trained upsampler to new one...') 133 | else: 134 | raise RuntimeError('While copying the parameter named {}, ' 135 | 'whose dimensions in the model are {} and ' 136 | 'whose dimensions in the checkpoint are {}.' 137 | .format(name, own_state[name].size(), param.size())) 138 | elif strict: 139 | if name.find('tail') == -1: 140 | raise KeyError('unexpected key "{}" in state_dict' 141 | .format(name)) 142 | 143 | if strict: 144 | missing = set(own_state.keys()) - set(state_dict.keys()) 145 | if len(missing) > 0: 146 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/MPNCOV/python/MPNCOV.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @file: MPNCOV.py 3 | @author: Jiangtao Xie 4 | @author: Peihua Li 5 | 6 | Copyright (C) 2018 Peihua Li and Jiangtao Xie 7 | 8 | All rights reserved. 9 | ''' 10 | import torch 11 | import numpy as np 12 | from torch.autograd import Function 13 | 14 | class Covpool(Function): 15 | @staticmethod 16 | def forward(ctx, input): 17 | x = input 18 | batchSize = x.data.shape[0] 19 | dim = x.data.shape[1] 20 | h = x.data.shape[2] 21 | w = x.data.shape[3] 22 | M = h*w 23 | x = x.reshape(batchSize,dim,M) 24 | I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) 25 | I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) 26 | y = x.bmm(I_hat).bmm(x.transpose(1,2)) 27 | ctx.save_for_backward(input,I_hat) 28 | return y 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | input,I_hat = ctx.saved_tensors 32 | x = input 33 | batchSize = x.data.shape[0] 34 | dim = x.data.shape[1] 35 | h = x.data.shape[2] 36 | w = x.data.shape[3] 37 | M = h*w 38 | x = x.reshape(batchSize,dim,M) 39 | grad_input = grad_output + grad_output.transpose(1,2) 40 | grad_input = grad_input.bmm(x).bmm(I_hat) 41 | grad_input = grad_input.reshape(batchSize,dim,h,w) 42 | return grad_input 43 | 44 | class Sqrtm(Function): 45 | @staticmethod 46 | def forward(ctx, input, iterN): 47 | x = input 48 | batchSize = x.data.shape[0] 49 | dim = x.data.shape[1] 50 | dtype = x.dtype 51 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 52 | normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) 53 | A = x.div(normA.view(batchSize,1,1).expand_as(x)) 54 | Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device) 55 | Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1) 56 | if iterN < 2: 57 | ZY = 0.5*(I3 - A) 58 | Y[:,0,:,:] = A.bmm(ZY) 59 | else: 60 | ZY = 0.5*(I3 - A) 61 | Y[:,0,:,:] = A.bmm(ZY) 62 | Z[:,0,:,:] = ZY 63 | for i in range(1, iterN-1): 64 | ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) 65 | Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) 66 | Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) 67 | ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) 68 | y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 69 | ctx.save_for_backward(input, A, ZY, normA, Y, Z) 70 | ctx.iterN = iterN 71 | return y 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | input, A, ZY, normA, Y, Z = ctx.saved_tensors 75 | iterN = ctx.iterN 76 | x = input 77 | batchSize = x.data.shape[0] 78 | dim = x.data.shape[1] 79 | dtype = x.dtype 80 | der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) 81 | der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) 82 | I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 83 | if iterN < 2: 84 | der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace)) 85 | else: 86 | dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - 87 | Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) 88 | dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) 89 | for i in range(iterN-3, -1, -1): 90 | YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) 91 | ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) 92 | dldY_ = 0.5*(dldY.bmm(YZ) - 93 | Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - 94 | ZY.bmm(dldY)) 95 | dldZ_ = 0.5*(YZ.bmm(dldZ) - 96 | Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - 97 | dldZ.bmm(ZY)) 98 | dldY = dldY_ 99 | dldZ = dldZ_ 100 | der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) 101 | grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) 102 | grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) 103 | for i in range(batchSize): 104 | grad_input[i,:,:] += (der_postComAux[i] \ 105 | - grad_aux[i] / (normA[i] * normA[i])) \ 106 | *torch.ones(dim,device = x.device).diag() 107 | return grad_input, None 108 | 109 | class Triuvec(Function): 110 | @staticmethod 111 | def forward(ctx, input): 112 | x = input 113 | batchSize = x.data.shape[0] 114 | dim = x.data.shape[1] 115 | dtype = x.dtype 116 | x = x.reshape(batchSize, dim*dim) 117 | I = torch.ones(dim,dim).triu().t().reshape(dim*dim) 118 | index = I.nonzero() 119 | y = torch.zeros(batchSize,dim*(dim+1)/2,device = x.device) 120 | for i in range(batchSize): 121 | y[i, :] = x[i, index].t() 122 | ctx.save_for_backward(input,index) 123 | return y 124 | @staticmethod 125 | def backward(ctx, grad_output): 126 | input,index = ctx.saved_tensors 127 | x = input 128 | batchSize = x.data.shape[0] 129 | dim = x.data.shape[1] 130 | dtype = x.dtype 131 | grad_input = torch.zeros(batchSize,dim,dim,device = x.device,requires_grad=False) 132 | grad_input = grad_input.reshape(batchSize,dim*dim) 133 | for i in range(batchSize): 134 | grad_input[i,index] = grad_output[i,:].reshape(index.size(),1) 135 | grad_input = grad_input.reshape(batchSize,dim,dim) 136 | return grad_input 137 | 138 | def CovpoolLayer(var): 139 | return Covpool.apply(var) 140 | 141 | def SqrtmLayer(var, iterN): 142 | return Sqrtm.apply(var, iterN) 143 | 144 | def TriuvecLayer(var): 145 | return Triuvec.apply(var) 146 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/rrdbnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn import init as init 5 | 6 | def make_layer(basic_block, num_basic_block, **kwarg): 7 | """Make layers by stacking the same blocks. 8 | Args: 9 | basic_block (nn.module): nn.module class for basic block. 10 | num_basic_block (int): number of blocks. 11 | Returns: 12 | nn.Sequential: Stacked blocks in nn.Sequential. 13 | """ 14 | layers = [] 15 | for _ in range(num_basic_block): 16 | layers.append(basic_block(**kwarg)) 17 | return nn.Sequential(*layers) 18 | 19 | 20 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 21 | """Initialize network weights. 22 | Args: 23 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 24 | scale (float): Scale initialized weights, especially for residual 25 | blocks. Default: 1. 26 | bias_fill (float): The value to fill bias. Default: 0 27 | kwargs (dict): Other arguments for initialization function. 28 | """ 29 | if not isinstance(module_list, list): 30 | module_list = [module_list] 31 | for module in module_list: 32 | for m in module.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | init.kaiming_normal_(m.weight, **kwargs) 35 | m.weight.data *= scale 36 | if m.bias is not None: 37 | m.bias.data.fill_(bias_fill) 38 | elif isinstance(m, nn.Linear): 39 | init.kaiming_normal_(m.weight, **kwargs) 40 | m.weight.data *= scale 41 | if m.bias is not None: 42 | m.bias.data.fill_(bias_fill) 43 | elif isinstance(m, _BatchNorm): 44 | init.constant_(m.weight, 1) 45 | if m.bias is not None: 46 | m.bias.data.fill_(bias_fill) 47 | 48 | 49 | 50 | class ResidualDenseBlock(nn.Module): 51 | """Residual Dense Block. 52 | Used in RRDB block in ESRGAN. 53 | Args: 54 | num_feat (int): Channel number of intermediate features. 55 | num_grow_ch (int): Channels for each growth. 56 | """ 57 | 58 | def __init__(self, num_feat=64, num_grow_ch=32): 59 | super(ResidualDenseBlock, self).__init__() 60 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 61 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 62 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 63 | 1) 64 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 65 | 1) 66 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 67 | 68 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 69 | 70 | # initialization 71 | default_init_weights( 72 | [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 73 | 74 | def forward(self, x): 75 | x1 = self.lrelu(self.conv1(x)) 76 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 77 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 78 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 79 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 80 | # Emperically, we use 0.2 to scale the residual for better performance 81 | return x5 * 0.2 + x 82 | 83 | 84 | class RRDB(nn.Module): 85 | """Residual in Residual Dense Block. 86 | Used in RRDB-Net in ESRGAN. 87 | Args: 88 | num_feat (int): Channel number of intermediate features. 89 | num_grow_ch (int): Channels for each growth. 90 | """ 91 | 92 | def __init__(self, num_feat, num_grow_ch=32): 93 | super(RRDB, self).__init__() 94 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 95 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 96 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 97 | 98 | def forward(self, x): 99 | out = self.rdb1(x) 100 | out = self.rdb2(out) 101 | out = self.rdb3(out) 102 | # Emperically, we use 0.2 to scale the residual for better performance 103 | return out * 0.2 + x 104 | 105 | 106 | class RRDBNet(nn.Module): 107 | """Networks consisting of Residual in Residual Dense Block, which is used 108 | in ESRGAN. 109 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 110 | Currently, it supports x4 upsampling scale factor. 111 | Args: 112 | num_in_ch (int): Channel number of inputs. 113 | num_out_ch (int): Channel number of outputs. 114 | num_feat (int): Channel number of intermediate features. 115 | Default: 64 116 | num_block (int): Block number in the trunk network. Defaults: 23 117 | num_grow_ch (int): Channels for each growth. Default: 32. 118 | """ 119 | 120 | def __init__(self, 121 | num_in_ch=3, 122 | num_out_ch=3, 123 | num_feat=64, 124 | num_block=23, 125 | num_grow_ch=32): 126 | super(RRDBNet, self).__init__() 127 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 128 | self.body = make_layer( 129 | RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 130 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 131 | # upsample 132 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 133 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 134 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 135 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 136 | 137 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 138 | 139 | def forward(self, x): 140 | feat = self.conv_first(x) 141 | body_feat = self.conv_body(self.body(feat)) 142 | feat = feat + body_feat 143 | # upsample 144 | feat = self.lrelu( 145 | self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 146 | feat = self.lrelu( 147 | self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 148 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 149 | return out 150 | -------------------------------------------------------------------------------- /LAM/cal_metrix.py: -------------------------------------------------------------------------------- 1 | ''' 2 | calculate the PSNR and SSIM. 3 | same as MATLAB's results 4 | ''' 5 | import os 6 | import math 7 | import numpy as np 8 | import cv2 9 | import glob 10 | from torchvision.utils import make_grid 11 | 12 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 13 | ''' 14 | Converts a torch Tensor into an image Numpy array 15 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 16 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 17 | ''' 18 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 19 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 20 | n_dim = tensor.dim() 21 | if n_dim == 4: 22 | n_img = len(tensor) 23 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 24 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 25 | elif n_dim == 3: 26 | img_np = tensor.numpy() 27 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 28 | elif n_dim == 2: 29 | img_np = tensor.numpy() 30 | else: 31 | raise TypeError( 32 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 33 | if out_type == np.uint8: 34 | img_np = (img_np * 255.0).round() 35 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 36 | return img_np.astype(out_type) 37 | 38 | def main(): 39 | # Configurations 40 | 41 | # GT - Ground-truth; 42 | # Gen: Generated / Restored / Recovered images 43 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5' 44 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5' 45 | 46 | crop_border = 4 47 | suffix = '' # suffix for Gen images 48 | test_Y = False # True: test Y channel only; False: test RGB channels 49 | 50 | PSNR_all = [] 51 | SSIM_all = [] 52 | img_list = sorted(glob.glob(folder_GT + '/*')) 53 | 54 | if test_Y: 55 | print('Testing Y channel.') 56 | else: 57 | print('Testing RGB channels.') 58 | 59 | for i, img_path in enumerate(img_list): 60 | base_name = os.path.splitext(os.path.basename(img_path))[0] 61 | im_GT = cv2.imread(img_path) / 255. 62 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255. 63 | 64 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 65 | im_GT_in = bgr2ycbcr(im_GT) 66 | im_Gen_in = bgr2ycbcr(im_Gen) 67 | else: 68 | im_GT_in = im_GT 69 | im_Gen_in = im_Gen 70 | 71 | # crop borders 72 | if im_GT_in.ndim == 3: 73 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 74 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 75 | elif im_GT_in.ndim == 2: 76 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 77 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 78 | else: 79 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 80 | 81 | # calculate PSNR and SSIM 82 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 83 | 84 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 85 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format( 86 | i + 1, base_name, PSNR, SSIM)) 87 | PSNR_all.append(PSNR) 88 | SSIM_all.append(SSIM) 89 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format( 90 | sum(PSNR_all) / len(PSNR_all), 91 | sum(SSIM_all) / len(SSIM_all))) 92 | 93 | 94 | def calculate_psnr(img1, img2): 95 | # img1 and img2 have range [0, 255] 96 | img1 = img1.astype(np.float64) 97 | img2 = img2.astype(np.float64) 98 | mse = np.mean((img1 - img2)**2) 99 | if mse == 0: 100 | return float('inf') 101 | return 20 * math.log10(255.0 / math.sqrt(mse)) 102 | 103 | 104 | def ssim(img1, img2): 105 | C1 = (0.01 * 255)**2 106 | C2 = (0.03 * 255)**2 107 | 108 | img1 = img1.astype(np.float64) 109 | img2 = img2.astype(np.float64) 110 | kernel = cv2.getGaussianKernel(11, 1.5) 111 | window = np.outer(kernel, kernel.transpose()) 112 | 113 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 114 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 115 | mu1_sq = mu1**2 116 | mu2_sq = mu2**2 117 | mu1_mu2 = mu1 * mu2 118 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 119 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 120 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 121 | 122 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 123 | (sigma1_sq + sigma2_sq + C2)) 124 | return ssim_map.mean() 125 | 126 | 127 | def calculate_ssim(img1, img2): 128 | '''calculate SSIM 129 | the same outputs as MATLAB's 130 | img1, img2: [0, 255] 131 | ''' 132 | if not img1.shape == img2.shape: 133 | raise ValueError('Input images must have the same dimensions.') 134 | if img1.ndim == 2: 135 | return ssim(img1, img2) 136 | elif img1.ndim == 3: 137 | if img1.shape[2] == 3: 138 | ssims = [] 139 | for i in range(3): 140 | ssims.append(ssim(img1, img2)) 141 | return np.array(ssims).mean() 142 | elif img1.shape[2] == 1: 143 | return ssim(np.squeeze(img1), np.squeeze(img2)) 144 | else: 145 | raise ValueError('Wrong input image dimensions.') 146 | 147 | 148 | def bgr2ycbcr(img, only_y=True): 149 | '''same as matlab rgb2ycbcr 150 | only_y: only return Y channel 151 | Input: 152 | uint8, [0, 255] 153 | float, [0, 1] 154 | ''' 155 | in_img_type = img.dtype 156 | img.astype(np.float32) 157 | if in_img_type != np.uint8: 158 | img *= 255. 159 | # convert 160 | if only_y: 161 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 162 | else: 163 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 164 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 165 | if in_img_type == np.uint8: 166 | rlt = rlt.round() 167 | else: 168 | rlt /= 255. 169 | return rlt.astype(in_img_type) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /LAM/SaliencyModel/BackProp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | from ModelZoo.utils import _add_batch_one, _remove_batch 5 | from SaliencyModel.utils import grad_norm, IG_baseline, interpolation, isotropic_gaussian_kernel 6 | 7 | 8 | def attribution_objective(attr_func, h, w, window=16): 9 | def calculate_objective(image): 10 | return attr_func(image, h, w, window=window) 11 | return calculate_objective 12 | 13 | 14 | def saliency_map_gradient(numpy_image, model, attr_func): 15 | img_tensor = torch.from_numpy(numpy_image) 16 | img_tensor.requires_grad_(True) 17 | result = model(_add_batch_one(img_tensor)) 18 | target = attr_func(result) 19 | target.backward() 20 | return img_tensor.grad.numpy(), result 21 | 22 | 23 | def I_gradient(numpy_image, baseline_image, model, attr_objective, fold, interp='linear'): 24 | interpolated = interpolation(numpy_image, baseline_image, fold, mode=interp).astype(np.float32) 25 | grad_list = np.zeros_like(interpolated, dtype=np.float32) 26 | result_list = [] 27 | for i in range(fold): 28 | img_tensor = torch.from_numpy(interpolated[i]) 29 | img_tensor.requires_grad_(True) 30 | result = model(_add_batch_one(img_tensor)) 31 | target = attr_objective(result) 32 | target.backward() 33 | grad = img_tensor.grad.numpy() 34 | grad_list[i] = grad 35 | result_list.append(result) 36 | results_numpy = np.asarray(result_list) 37 | return grad_list, results_numpy, interpolated 38 | 39 | 40 | def GaussianBlurPath(sigma, fold, l=5): 41 | def path_interpolation_func(cv_numpy_image): 42 | h, w, c = cv_numpy_image.shape 43 | kernel_interpolation = np.zeros((fold + 1, l, l)) 44 | image_interpolation = np.zeros((fold, h, w, c)) 45 | lambda_derivative_interpolation = np.zeros((fold, h, w, c)) 46 | sigma_interpolation = np.linspace(sigma, 0, fold + 1) 47 | for i in range(fold + 1): 48 | kernel_interpolation[i] = isotropic_gaussian_kernel(l, sigma_interpolation[i]) 49 | for i in range(fold): 50 | image_interpolation[i] = cv2.filter2D(cv_numpy_image, -1, kernel_interpolation[i + 1]) 51 | lambda_derivative_interpolation[i] = cv2.filter2D(cv_numpy_image, -1, (kernel_interpolation[i + 1] - kernel_interpolation[i]) * fold) 52 | return np.moveaxis(image_interpolation, 3, 1).astype(np.float32), \ 53 | np.moveaxis(lambda_derivative_interpolation, 3, 1).astype(np.float32) 54 | return path_interpolation_func 55 | 56 | 57 | def GaussianLinearPath(sigma, fold, l=5): 58 | def path_interpolation_func(cv_numpy_image): 59 | kernel = isotropic_gaussian_kernel(l, sigma) 60 | baseline_image = cv2.filter2D(cv_numpy_image, -1, kernel) 61 | image_interpolation = interpolation(cv_numpy_image, baseline_image, fold, mode='linear').astype(np.float32) 62 | lambda_derivative_interpolation = np.repeat(np.expand_dims(cv_numpy_image - baseline_image, axis=0), fold, axis=0) 63 | return np.moveaxis(image_interpolation, 3, 1).astype(np.float32), \ 64 | np.moveaxis(lambda_derivative_interpolation, 3, 1).astype(np.float32) 65 | return path_interpolation_func 66 | 67 | 68 | def LinearPath(fold): 69 | def path_interpolation_func(cv_numpy_image): 70 | baseline_image = np.zeros_like(cv_numpy_image) 71 | image_interpolation = interpolation(cv_numpy_image, baseline_image, fold, mode='linear').astype(np.float32) 72 | lambda_derivative_interpolation = np.repeat(np.expand_dims(cv_numpy_image - baseline_image, axis=0), fold, axis=0) 73 | return np.moveaxis(image_interpolation, 3, 1).astype(np.float32), \ 74 | np.moveaxis(lambda_derivative_interpolation, 3, 1).astype(np.float32) 75 | return path_interpolation_func 76 | 77 | 78 | def Path_gradient(numpy_image, model, attr_objective, path_interpolation_func, cuda=False): 79 | """ 80 | :param path_interpolation_func: 81 | return \lambda(\alpha) and d\lambda(\alpha)/d\alpha, for \alpha\in[0, 1] 82 | This function return pil_numpy_images 83 | :return: 84 | """ 85 | if cuda: 86 | model = model.cuda() 87 | cv_numpy_image = np.moveaxis(numpy_image, 0, 2) 88 | image_interpolation, lambda_derivative_interpolation = path_interpolation_func(cv_numpy_image) 89 | grad_accumulate_list = np.zeros_like(image_interpolation) 90 | result_list = [] 91 | for i in range(image_interpolation.shape[0]): 92 | img_tensor = torch.from_numpy(image_interpolation[i]) 93 | img_tensor.requires_grad_(True) 94 | if cuda: 95 | result = model(_add_batch_one(img_tensor).cuda()) 96 | target = attr_objective(result) 97 | target.backward() 98 | grad = img_tensor.grad.cpu().numpy() 99 | if np.any(np.isnan(grad)): 100 | grad[np.isnan(grad)] = 0.0 101 | result = result.cpu().detach() 102 | else: 103 | result = model(_add_batch_one(img_tensor)) 104 | target = attr_objective(result) 105 | target.backward() 106 | grad = img_tensor.grad.numpy() 107 | if np.any(np.isnan(grad)): 108 | grad[np.isnan(grad)] = 0.0 109 | 110 | grad_accumulate_list[i] = grad * lambda_derivative_interpolation[i] 111 | result_list.append(result) 112 | results_numpy = np.asarray(result_list) 113 | return grad_accumulate_list, results_numpy, image_interpolation 114 | 115 | 116 | def saliency_map_PG(grad_list, result_list): 117 | final_grad = grad_list.mean(axis=0) 118 | return final_grad, result_list[-1] 119 | 120 | 121 | def saliency_map_P_gradient( 122 | numpy_image, model, attr_objective, path_interpolation_func): 123 | grad_list, result_list, _ = Path_gradient(numpy_image, model, attr_objective, path_interpolation_func) 124 | final_grad = grad_list.mean(axis=0) 125 | return final_grad, result_list[-1] 126 | 127 | 128 | def saliency_map_I_gradient( 129 | numpy_image, model, attr_objective, baseline='gaus', fold=10, interp='linear'): 130 | """ 131 | :param numpy_image: RGB C x H x W 132 | :param model: 133 | :param attr_func: 134 | :param h: 135 | :param w: 136 | :param window: 137 | :param baseline: 138 | :return: 139 | """ 140 | numpy_baseline = np.moveaxis(IG_baseline(np.moveaxis(numpy_image, 0, 2) * 255., mode=baseline) / 255., 2, 0) 141 | grad_list, result_list, _ = I_gradient(numpy_image, numpy_baseline, model, attr_objective, fold, interp='linear') 142 | final_grad = grad_list.mean(axis=0) * (numpy_image - numpy_baseline) 143 | return final_grad, result_list[-1] 144 | 145 | -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp, cpu): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = 4 15 | self.idx_scale = 0 16 | self.self_ensemble = True 17 | self.chop = True 18 | self.precision = 'single' 19 | self.cpu = cpu 20 | self.device = torch.device('cpu' if args.cpu else 'cuda') 21 | self.n_GPUs = args.n_GPUs 22 | self.save_models = args.save_models 23 | 24 | module = import_module('model.' + args.model.lower()) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | if args.print_model: print(self.model) 38 | 39 | def forward(self, x, idx_scale): 40 | self.idx_scale = idx_scale 41 | target = self.get_model() 42 | if hasattr(target, 'set_scale'): 43 | target.set_scale(idx_scale) 44 | 45 | if self.self_ensemble and not self.training: 46 | if self.chop: 47 | forward_function = self.forward_chop 48 | else: 49 | forward_function = self.model.forward 50 | 51 | return self.forward_x8(x, forward_function) 52 | elif self.chop and not self.training: 53 | return self.forward_chop(x) 54 | else: 55 | return self.model(x) 56 | 57 | def get_model(self): 58 | if self.n_GPUs == 1: 59 | return self.model 60 | else: 61 | return self.model.module 62 | 63 | def state_dict(self, **kwargs): 64 | target = self.get_model() 65 | return target.state_dict(**kwargs) 66 | 67 | def save(self, apath, epoch, is_best=False): 68 | target = self.get_model() 69 | torch.save( 70 | target.state_dict(), 71 | os.path.join(apath, 'model', 'model_latest.pt') 72 | ) 73 | if is_best: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_best.pt') 77 | ) 78 | 79 | if self.save_models: 80 | torch.save( 81 | target.state_dict(), 82 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 83 | ) 84 | 85 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 86 | if cpu: 87 | kwargs = {'map_location': lambda storage, loc: storage} 88 | else: 89 | kwargs = {} 90 | 91 | if resume == -1: 92 | self.get_model().load_state_dict( 93 | torch.load( 94 | os.path.join(apath, 'model', 'model_latest.pt'), 95 | **kwargs 96 | ), 97 | strict=False 98 | ) 99 | elif resume == 0: 100 | if pre_train != '.': 101 | print('Loading model from {}'.format(pre_train)) 102 | self.get_model().load_state_dict( 103 | torch.load(pre_train, **kwargs), 104 | strict=False 105 | ) 106 | else: 107 | self.get_model().load_state_dict( 108 | torch.load( 109 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 110 | **kwargs 111 | ), 112 | strict=False 113 | ) 114 | 115 | # shave = 10, min_size=160000 116 | def forward_chop(self, x, shave=10, min_size=160000): 117 | scale = self.scale[self.idx_scale] 118 | n_GPUs = min(self.n_GPUs, 4) 119 | b, c, h, w = x.size() 120 | h_half, w_half = h // 2, w // 2 121 | h_size, w_size = h_half + shave, w_half + shave 122 | lr_list = [ 123 | x[:, :, 0:h_size, 0:w_size], 124 | x[:, :, 0:h_size, (w - w_size):w], 125 | x[:, :, (h - h_size):h, 0:w_size], 126 | x[:, :, (h - h_size):h, (w - w_size):w]] 127 | 128 | if w_size * h_size < min_size: 129 | sr_list = [] 130 | for i in range(0, 4, n_GPUs): 131 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 132 | sr_batch = self.model(lr_batch) 133 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 134 | else: 135 | sr_list = [ 136 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 137 | for patch in lr_list 138 | ] 139 | 140 | h, w = scale * h, scale * w 141 | h_half, w_half = scale * h_half, scale * w_half 142 | h_size, w_size = scale * h_size, scale * w_size 143 | shave *= scale 144 | 145 | output = x.new(b, c, h, w) 146 | output[:, :, 0:h_half, 0:w_half] \ 147 | = sr_list[0][:, :, 0:h_half, 0:w_half] 148 | output[:, :, 0:h_half, w_half:w] \ 149 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 150 | output[:, :, h_half:h, 0:w_half] \ 151 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 152 | output[:, :, h_half:h, w_half:w] \ 153 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 154 | 155 | return output 156 | 157 | def forward_x8(self, x, forward_function): 158 | def _transform(v, op): 159 | if self.precision != 'single': v = v.float() 160 | 161 | v2np = v.data.cpu().numpy() 162 | if op == 'v': 163 | tfnp = v2np[:, :, :, ::-1].copy() 164 | elif op == 'h': 165 | tfnp = v2np[:, :, ::-1, :].copy() 166 | elif op == 't': 167 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 168 | 169 | ret = torch.Tensor(tfnp).to(self.device) 170 | if self.precision == 'half': ret = ret.half() 171 | 172 | return ret 173 | 174 | lr_list = [x] 175 | for tf in 'v', 'h', 't': 176 | lr_list.extend([_transform(t, tf) for t in lr_list]) 177 | 178 | sr_list = [forward_function(aug) for aug in lr_list] 179 | for i in range(len(sr_list)): 180 | if i > 3: 181 | sr_list[i] = _transform(sr_list[i], 't') 182 | if i % 4 > 1: 183 | sr_list[i] = _transform(sr_list[i], 'h') 184 | if (i % 4) % 2 == 1: 185 | sr_list[i] = _transform(sr_list[i], 'v') 186 | 187 | output_cat = torch.cat(sr_list, dim=0) 188 | output = output_cat.mean(dim=0, keepdim=True) 189 | 190 | return output -------------------------------------------------------------------------------- /LAM/ModelZoo/NN/drln_ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | 8 | def init_weights(modules): 9 | pass 10 | 11 | 12 | class MeanShift(nn.Module): 13 | def __init__(self, mean_rgb, sub): 14 | super(MeanShift, self).__init__() 15 | 16 | sign = -1 if sub else 1 17 | r = mean_rgb[0] * sign 18 | g = mean_rgb[1] * sign 19 | b = mean_rgb[2] * sign 20 | 21 | self.shifter = nn.Conv2d(3, 3, 1, 1, 0) 22 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 23 | self.shifter.bias.data = torch.Tensor([r, g, b]) 24 | 25 | # Freeze the mean shift layer 26 | for params in self.shifter.parameters(): 27 | params.requires_grad = False 28 | 29 | def forward(self, x): 30 | x = self.shifter(x) 31 | return x 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | def __init__(self, 36 | in_channels, out_channels, 37 | ksize=3, stride=1, pad=1, dilation=1): 38 | super(BasicBlock, self).__init__() 39 | 40 | self.body = nn.Sequential( 41 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad, dilation), 42 | nn.ReLU(inplace=True) 43 | ) 44 | 45 | init_weights(self.modules) 46 | 47 | def forward(self, x): 48 | out = self.body(x) 49 | return out 50 | 51 | 52 | class GBasicBlock(nn.Module): 53 | def __init__(self, 54 | in_channels, out_channels, 55 | ksize=3, stride=1, pad=1, dilation=1): 56 | super(GBasicBlock, self).__init__() 57 | 58 | self.body = nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad, dilation, groups=4), 60 | nn.ReLU(inplace=True) 61 | ) 62 | 63 | init_weights(self.modules) 64 | 65 | def forward(self, x): 66 | out = self.body(x) 67 | return out 68 | 69 | 70 | class BasicBlockSig(nn.Module): 71 | def __init__(self, 72 | in_channels, out_channels, 73 | ksize=3, stride=1, pad=1): 74 | super(BasicBlockSig, self).__init__() 75 | 76 | self.body = nn.Sequential( 77 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad), 78 | nn.Sigmoid() 79 | ) 80 | 81 | init_weights(self.modules) 82 | 83 | def forward(self, x): 84 | out = self.body(x) 85 | return out 86 | 87 | 88 | class GBasicBlockSig(nn.Module): 89 | def __init__(self, 90 | in_channels, out_channels, 91 | ksize=3, stride=1, pad=1): 92 | super(GBasicBlockSig, self).__init__() 93 | 94 | self.body = nn.Sequential( 95 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad, groups=4), 96 | nn.Sigmoid() 97 | ) 98 | 99 | init_weights(self.modules) 100 | 101 | def forward(self, x): 102 | out = self.body(x) 103 | return out 104 | 105 | 106 | class ResidualBlock(nn.Module): 107 | def __init__(self, 108 | in_channels, out_channels): 109 | super(ResidualBlock, self).__init__() 110 | 111 | self.body = nn.Sequential( 112 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 115 | ) 116 | 117 | init_weights(self.modules) 118 | 119 | def forward(self, x): 120 | out = self.body(x) 121 | out = F.relu(out + x) 122 | return out 123 | 124 | 125 | class GResidualBlock(nn.Module): 126 | def __init__(self, 127 | in_channels, out_channels): 128 | super(GResidualBlock, self).__init__() 129 | 130 | self.body = nn.Sequential( 131 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=4), 132 | nn.ReLU(inplace=True), 133 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 134 | ) 135 | 136 | init_weights(self.modules) 137 | 138 | def forward(self, x): 139 | out = self.body(x) 140 | out = F.relu(out + x) 141 | return out 142 | 143 | 144 | class EResidualBlock(nn.Module): 145 | def __init__(self, 146 | in_channels, out_channels, 147 | group=1): 148 | super(EResidualBlock, self).__init__() 149 | 150 | self.body = nn.Sequential( 151 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), 152 | nn.ReLU(inplace=True), 153 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 156 | ) 157 | 158 | init_weights(self.modules) 159 | 160 | def forward(self, x): 161 | out = self.body(x) 162 | out = F.relu(out + x) 163 | return out 164 | 165 | 166 | class ConvertBlock(nn.Module): 167 | def __init__(self, 168 | in_channels, out_channels, 169 | blocks): 170 | super(ConvertBlock, self).__init__() 171 | 172 | self.body = nn.Sequential( 173 | nn.Conv2d(in_channels * blocks, out_channels * blocks // 2, 3, 1, 1), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(out_channels * blocks // 2, out_channels * blocks // 4, 3, 1, 1), 176 | nn.ReLU(inplace=True), 177 | nn.Conv2d(out_channels * blocks // 4, out_channels, 3, 1, 1), 178 | ) 179 | 180 | init_weights(self.modules) 181 | 182 | def forward(self, x): 183 | out = self.body(x) 184 | # out = F.relu(out + x) 185 | return out 186 | 187 | 188 | class UpsampleBlock(nn.Module): 189 | def __init__(self, 190 | n_channels, scale, multi_scale, 191 | group=1): 192 | super(UpsampleBlock, self).__init__() 193 | 194 | if multi_scale: 195 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) 196 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) 197 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) 198 | else: 199 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group) 200 | 201 | self.multi_scale = multi_scale 202 | 203 | def forward(self, x, scale): 204 | if self.multi_scale: 205 | if scale == 2: 206 | return self.up2(x) 207 | elif scale == 3: 208 | return self.up3(x) 209 | elif scale == 4: 210 | return self.up4(x) 211 | else: 212 | return self.up(x) 213 | 214 | 215 | class _UpsampleBlock(nn.Module): 216 | def __init__(self, 217 | n_channels, scale, 218 | group=1): 219 | super(_UpsampleBlock, self).__init__() 220 | 221 | modules = [] 222 | if scale == 2 or scale == 4 or scale == 8: 223 | for _ in range(int(math.log(scale, 2))): 224 | modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 225 | modules += [nn.PixelShuffle(2)] 226 | elif scale == 3: 227 | modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 228 | modules += [nn.PixelShuffle(3)] 229 | 230 | self.body = nn.Sequential(*modules) 231 | init_weights(self.modules) 232 | 233 | def forward(self, x): 234 | out = self.body(x) 235 | return out -------------------------------------------------------------------------------- /LAM/SaliencyModel/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy import stats 4 | import matplotlib.pyplot as plt 5 | import matplotlib as mpl 6 | from PIL import Image 7 | import cv2 8 | 9 | 10 | def cv2_to_pil(img): 11 | image = Image.fromarray(cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)) 12 | return image 13 | 14 | 15 | def pil_to_cv2(img): 16 | image = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 17 | return image 18 | 19 | 20 | def make_pil_grid(pil_image_list): 21 | sizex, sizey = pil_image_list[0].size 22 | for img in pil_image_list: 23 | assert sizex == img.size[0] and sizey == img.size[1], 'check image size' 24 | 25 | target = Image.new('RGB', (sizex * len(pil_image_list), sizey)) 26 | left = 0 27 | right = sizex 28 | for i in range(len(pil_image_list)): 29 | target.paste(pil_image_list[i], (left, 0, right, sizey)) 30 | left += sizex 31 | right += sizex 32 | return target 33 | 34 | 35 | def blend_input(map, input): 36 | return Image.blend(map, input, 0.4) 37 | 38 | 39 | def count_saliency_pixels(map, threshold=0.95): 40 | sum_threshold = map.reshape(-1).sum() * threshold 41 | cum_sum = -np.cumsum(np.sort(-map.reshape(-1))) 42 | return len(cum_sum[cum_sum < sum_threshold]) 43 | 44 | 45 | def plot_diff_of_attrs_kde(A, B, zoomin=4, blend=0.5): 46 | grad_flat = A.reshape((-1)) 47 | datapoint_y, datapoint_x = np.mgrid[0:A.shape[0]:1, 0:A.shape[1]:1] 48 | Y, X = np.mgrid[0:A.shape[0]:1, 0:A.shape[1]:1] 49 | positions = np.vstack([X.ravel(), Y.ravel()]) 50 | pixels = np.vstack([datapoint_x.ravel(), datapoint_y.ravel()]) 51 | kernel = stats.gaussian_kde(pixels, weights=grad_flat) 52 | Za = np.reshape(kernel(positions).T, A.shape) 53 | Za = Za / Za.max() 54 | 55 | grad_flat = B.reshape((-1)) 56 | datapoint_y, datapoint_x = np.mgrid[0:B.shape[0]:1, 0:B.shape[1]:1] 57 | Y, X = np.mgrid[0:B.shape[0]:1, 0:B.shape[1]:1] 58 | positions = np.vstack([X.ravel(), Y.ravel()]) 59 | pixels = np.vstack([datapoint_x.ravel(), datapoint_y.ravel()]) 60 | kernel = stats.gaussian_kde(pixels, weights=grad_flat) 61 | Zb = np.reshape(kernel(positions).T, B.shape) 62 | Zb = Zb / Zb.max() 63 | 64 | diff = Za - Zb 65 | diff_norm = diff / diff.max() 66 | vis = Zb - blend*diff_norm 67 | 68 | cmap = plt.get_cmap('seismic') 69 | # cmap = plt.get_cmap('Purples') 70 | map_color = (255 * cmap(vis * 0.5 + 0.5)).astype(np.uint8) 71 | # map_color = (255 * cmap(Z)).astype(np.uint8) 72 | Img = Image.fromarray(map_color) 73 | s1, s2 = Img.size 74 | return Img.resize((s1 * zoomin, s2 * zoomin), Image.BICUBIC) 75 | 76 | 77 | def vis_saliency_kde(map, zoomin=4): 78 | grad_flat = map.reshape((-1)) 79 | datapoint_y, datapoint_x = np.mgrid[0:map.shape[0]:1, 0:map.shape[1]:1] 80 | Y, X = np.mgrid[0:map.shape[0]:1, 0:map.shape[1]:1] 81 | positions = np.vstack([X.ravel(), Y.ravel()]) 82 | pixels = np.vstack([datapoint_x.ravel(), datapoint_y.ravel()]) 83 | kernel = stats.gaussian_kde(pixels, weights=grad_flat) 84 | Z = np.reshape(kernel(positions).T, map.shape) 85 | Z = Z / Z.max() 86 | cmap = plt.get_cmap('seismic') 87 | # cmap = plt.get_cmap('Purples') 88 | map_color = (255 * cmap(Z * 0.5 + 0.5)).astype(np.uint8) 89 | # map_color = (255 * cmap(Z)).astype(np.uint8) 90 | Img = Image.fromarray(map_color) 91 | s1, s2 = Img.size 92 | return Img.resize((s1 * zoomin, s2 * zoomin), Image.BICUBIC) 93 | 94 | 95 | def vis_saliency(map, zoomin=4): 96 | """ 97 | :param map: the saliency map, 2D, norm to [0, 1] 98 | :param zoomin: the resize factor, nn upsample 99 | :return: 100 | """ 101 | cmap = plt.get_cmap('seismic') 102 | # cmap = plt.get_cmap('Purples') 103 | map_color = (255 * cmap(map * 0.5 + 0.5)).astype(np.uint8) 104 | # map_color = (255 * cmap(map)).astype(np.uint8) 105 | Img = Image.fromarray(map_color) 106 | s1, s2 = Img.size 107 | Img = Img.resize((s1 * zoomin, s2 * zoomin), Image.NEAREST) 108 | return Img.convert('RGB') 109 | 110 | 111 | def click_select_position(pil_img, window_size=16): 112 | """ 113 | 114 | :param pil_img: 115 | :param window_size: 116 | :return: w, h 117 | """ 118 | cv2_img = pil_to_cv2(pil_img) 119 | position = [-1, -1] 120 | def mouse(event, x, y, flags, param): 121 | """""" 122 | if event == cv2.EVENT_LBUTTONDOWN: 123 | xy = "%d, %d" % (x, y) 124 | position[0] = x 125 | position[1] = y 126 | draw_img = cv2_img.copy() 127 | cv2.rectangle(draw_img, (x, y), (x + window_size, y + window_size), (0,0,255), 2) 128 | cv2.putText(draw_img, xy, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0, (255, 255, 255), thickness = 1) 129 | cv2.imshow("image", draw_img) 130 | 131 | cv2.namedWindow("image") 132 | cv2.imshow("image", cv2_img) 133 | cv2.resizeWindow("image", 800, 600) 134 | cv2.setMouseCallback("image", mouse) 135 | 136 | cv2.waitKey(0) 137 | cv2.destroyAllWindows() 138 | return_img = cv2_img.copy() 139 | cv2.rectangle(return_img, (position[0], position[1]), (position[0] + window_size, position[1] + window_size), (0, 0, 255), 2) 140 | return position[0], position[1], cv2_to_pil(return_img) 141 | 142 | 143 | def prepare_images(hr_path, scale=4): 144 | hr_pil = Image.open(hr_path) 145 | sizex, sizey = hr_pil.size 146 | hr_pil = hr_pil.crop((0, 0, sizex - sizex % scale, sizey - sizey % scale)) 147 | sizex, sizey = hr_pil.size 148 | lr_pil = hr_pil.resize((sizex // scale, sizey // scale), Image.BICUBIC) 149 | return lr_pil, hr_pil 150 | 151 | 152 | def grad_abs_norm(grad): 153 | """ 154 | 155 | :param grad: numpy array 156 | :return: 157 | """ 158 | grad_2d = np.abs(grad.sum(axis=0)) 159 | grad_max = grad_2d.max() 160 | grad_norm = grad_2d / grad_max 161 | return grad_norm 162 | 163 | 164 | def grad_norm(grad): 165 | """ 166 | 167 | :param grad: numpy array 168 | :return: 169 | """ 170 | grad_2d = grad.sum(axis=0) 171 | grad_max = max(grad_2d.max(), abs(grad_2d.min())) 172 | grad_norm = grad_2d / grad_max 173 | return grad_norm 174 | 175 | 176 | def grad_abs_norm_singlechannel(grad): 177 | """ 178 | 179 | :param grad: numpy array 180 | :return: 181 | """ 182 | grad_2d = np.abs(grad) 183 | grad_max = grad_2d.max() 184 | grad_norm = grad_2d / grad_max 185 | return grad_norm 186 | 187 | 188 | def IG_baseline(numpy_image, mode='gaus'): 189 | """ 190 | :param numpy_image: cv2 image 191 | :param mode: 192 | :return: 193 | """ 194 | if mode == 'l1': 195 | raise NotImplementedError() 196 | elif mode == 'gaus': 197 | ablated = cv2.GaussianBlur(numpy_image, (7, 7), 0) 198 | elif mode == 'bif': 199 | ablated = cv2.bilateralFilter(numpy_image, 15, 90, 90) 200 | elif mode == 'mean': 201 | ablated = cv2.medianBlur(numpy_image, 5) 202 | else: 203 | ablated = cv2.GaussianBlur(numpy_image, (7, 7), 0) 204 | return ablated 205 | 206 | 207 | def interpolation(x, x_prime, fold, mode='linear'): 208 | diff = x - x_prime 209 | l = np.linspace(0, 1, fold).reshape((fold, 1, 1, 1)) 210 | interp_list = l * diff + x_prime 211 | return interp_list 212 | 213 | 214 | def isotropic_gaussian_kernel(l, sigma, epsilon=1e-5): 215 | ax = np.arange(-l // 2 + 1., l // 2 + 1.) 216 | xx, yy = np.meshgrid(ax, ax) 217 | kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * (sigma + epsilon) ** 2)) 218 | return kernel / np.sum(kernel) 219 | 220 | 221 | def gini(array): 222 | """Calculate the Gini coefficient of a numpy array.""" 223 | # based on bottom eq: 224 | # http://www.statsdirect.com/help/generatedimages/equations/equation154.svg 225 | # from: 226 | # http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm 227 | # All values are treated equally, arrays must be 1d: 228 | array = array.flatten() 229 | if np.amin(array) < 0: 230 | # Values cannot be negative: 231 | array -= np.amin(array) 232 | # Values cannot be 0: 233 | array += 0.0000001 234 | # Values must be sorted: 235 | array = np.sort(array) 236 | # Index per array element: 237 | index = np.arange(1,array.shape[0]+1) 238 | # Number of array elements: 239 | n = array.shape[0] 240 | # Gini coefficient: 241 | return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array))) 242 | 243 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ##
Nankai University
9 | 10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
26 | 27 | 1. [Requirements](#%EF%B8%8F-requirements) 28 | 2. [Datasets](#-datasets) 29 | 3. [Implementary Details](#-implementary-details) 30 | 4. [Train and Test](#%EF%B8%8F-train-and-test) 31 | 5. [Results and Models](#-results-and-models) 32 | 6. [Acknowledgments](#-acknowledgments) 33 | 7. [Citation](#-citation) 34 | 35 |
36 | 37 |
62 |
70 |
|
|
|
|
|
|
101 | |
|
|
|
|
|
|
102 | |
|
|
|
|
|
|
103 | |
|
|
|
|
|
|
104 | |**Params/FLOPs**| 150K/8G|1518K/114G|840K/47G|43090K/2895G|8712K/495G|
105 |
106 | Results of our MAN-tiny/light/base models. Set5 validation set is used below to show the general performance. The visual results of five testsets are provided in the last column.
107 |
108 | | Methods | Params | FLOPs |PSNR/SSIM (x2)|PSNR/SSIM (x3)|PSNR/SSIM (x4)|Results|
109 | |:---------|:---------:|:--------:|:------:|:------:|:------:|:--------:|
110 | | MAN-tiny | 150K | 8.4G | 37.91/0.9603 | 34.23/0.9258 | 32.07/0.8930 | [x2](https://pan.baidu.com/s/1mYkGvAlz0bSZuCVubkpsmg?pwd=mans)/[x3](https://pan.baidu.com/s/1RP5gGu-QPXTkH1NPH7axag?pwd=mans)/[x4](https://pan.baidu.com/s/1u22su2bT4Pq_idVxAnqWdw?pwd=mans) |
111 | | MAN-light| 840K | 47.1G | 38.18/0.9612 | 34.65/0.9292 | 32.50/0.8988 | [x2](https://pan.baidu.com/s/1AVuPa7bsbb3qMQqMSM-IJQ?pwd=mans)/[x3](https://pan.baidu.com/s/1TRL7-Y23JddVOpEhH0ObEQ?pwd=mans)/[x4](https://pan.baidu.com/s/1T2bPZcjFRxAgMxGWtPv-Lw?pwd=mans) |
112 | | MAN+ | 8712K | 495G | 38.44/0.9623 | 34.97/0.9315 | 32.87/0.9030 | [x2](https://pan.baidu.com/s/1pTb3Fob_7MOxMKIdopI0hQ?pwd=mans)/[x3](https://pan.baidu.com/s/1L3HEtcraU8Y9VY-HpCZdfg?pwd=mans)/[x4](https://pan.baidu.com/s/1FCNqht9zi9HecG3ExRdeWQ?pwd=mans) |
113 |
114 | 💖 Acknowledgments
115 | ---
116 |
117 | We would thank [VAN](https://github.com/Visual-Attention-Network/VAN-Classification) and [BasicSR](https://github.com/XPixelGroup/BasicSR) for their enlightening work!
118 |
119 | 🎓 Citation
120 | ---
121 | ```
122 | @inproceedings{wang2024multi,
123 | title={Multi-scale Attention Network for Single Image Super-Resolution},
124 | author={Wang, Yan and Li, Yusen and Wang, Gang and Liu, Xiaoguang},
125 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
126 | year={2024}
127 | }
128 | ```
129 |
130 | or
131 |
132 | ```
133 | @article{wang2022multi,
134 | title={Multi-scale Attention Network for Single Image Super-Resolution},
135 | author={Wang, Yan and Li, Yusen and Wang, Gang and Liu, Xiaoguang},
136 | journal={arXiv preprint arXiv:2209.14145},
137 | year={2022}
138 | }
139 | ```
140 |
--------------------------------------------------------------------------------
/LAM/ModelZoo/NN/drln.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from ..NN import drln_ops as ops
4 | import torch.nn.functional as F
5 |
6 |
7 | def make_model(args, parent=False):
8 | return DRLN(args)
9 |
10 |
11 | class CALayer(nn.Module):
12 | def __init__(self, channel, reduction=16):
13 | super(CALayer, self).__init__()
14 |
15 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
16 |
17 | self.c1 = ops.BasicBlock(channel, channel // reduction, 3, 1, 3, 3)
18 | self.c2 = ops.BasicBlock(channel, channel // reduction, 3, 1, 5, 5)
19 | self.c3 = ops.BasicBlock(channel, channel // reduction, 3, 1, 7, 7)
20 | self.c4 = ops.BasicBlockSig((channel // reduction) * 3, channel, 3, 1, 1)
21 |
22 | def forward(self, x):
23 | y = self.avg_pool(x)
24 | c1 = self.c1(y)
25 | c2 = self.c2(y)
26 | c3 = self.c3(y)
27 | c_out = torch.cat([c1, c2, c3], dim=1)
28 | y = self.c4(c_out)
29 | return x * y
30 |
31 |
32 | class Block(nn.Module):
33 | def __init__(self, in_channels, out_channels, group=1):
34 | super(Block, self).__init__()
35 |
36 | self.r1 = ops.ResidualBlock(in_channels, out_channels)
37 | self.r2 = ops.ResidualBlock(in_channels * 2, out_channels * 2)
38 | self.r3 = ops.ResidualBlock(in_channels * 4, out_channels * 4)
39 | self.g = ops.BasicBlock(in_channels * 8, out_channels, 1, 1, 0)
40 | self.ca = CALayer(in_channels)
41 |
42 | def forward(self, x):
43 | c0 = x
44 |
45 | r1 = self.r1(c0)
46 | c1 = torch.cat([c0, r1], dim=1)
47 |
48 | r2 = self.r2(c1)
49 | c2 = torch.cat([c1, r2], dim=1)
50 |
51 | r3 = self.r3(c2)
52 | c3 = torch.cat([c2, r3], dim=1)
53 |
54 | g = self.g(c3)
55 | out = self.ca(g)
56 | return out
57 |
58 |
59 | class DRLN(nn.Module):
60 | def __init__(self):
61 | super(DRLN, self).__init__()
62 |
63 | # n_resgroups = args.n_resgroups
64 | # n_resblocks = args.n_resblocks
65 | # n_feats = args.n_feats
66 | # kernel_size = 3
67 | # reduction = args.reduction
68 | # scale = args.scale[0]
69 | # act = nn.ReLU(True)
70 |
71 | self.scale = 4
72 | chs = 64
73 |
74 | self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
75 | self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)
76 |
77 | self.head = nn.Conv2d(3, chs, 3, 1, 1)
78 |
79 | self.b1 = Block(chs, chs)
80 | self.b2 = Block(chs, chs)
81 | self.b3 = Block(chs, chs)
82 | self.b4 = Block(chs, chs)
83 | self.b5 = Block(chs, chs)
84 | self.b6 = Block(chs, chs)
85 | self.b7 = Block(chs, chs)
86 | self.b8 = Block(chs, chs)
87 | self.b9 = Block(chs, chs)
88 | self.b10 = Block(chs, chs)
89 | self.b11 = Block(chs, chs)
90 | self.b12 = Block(chs, chs)
91 | self.b13 = Block(chs, chs)
92 | self.b14 = Block(chs, chs)
93 | self.b15 = Block(chs, chs)
94 | self.b16 = Block(chs, chs)
95 | self.b17 = Block(chs, chs)
96 | self.b18 = Block(chs, chs)
97 | self.b19 = Block(chs, chs)
98 | self.b20 = Block(chs, chs)
99 |
100 | self.c1 = ops.BasicBlock(chs * 2, chs, 3, 1, 1)
101 | self.c2 = ops.BasicBlock(chs * 3, chs, 3, 1, 1)
102 | self.c3 = ops.BasicBlock(chs * 4, chs, 3, 1, 1)
103 | self.c4 = ops.BasicBlock(chs * 2, chs, 3, 1, 1)
104 | self.c5 = ops.BasicBlock(chs * 3, chs, 3, 1, 1)
105 | self.c6 = ops.BasicBlock(chs * 4, chs, 3, 1, 1)
106 | self.c7 = ops.BasicBlock(chs * 2, chs, 3, 1, 1)
107 | self.c8 = ops.BasicBlock(chs * 3, chs, 3, 1, 1)
108 | self.c9 = ops.BasicBlock(chs * 4, chs, 3, 1, 1)
109 | self.c10 = ops.BasicBlock(chs * 2, chs, 3, 1, 1)
110 | self.c11 = ops.BasicBlock(chs * 3, chs, 3, 1, 1)
111 | self.c12 = ops.BasicBlock(chs * 4, chs, 3, 1, 1)
112 | self.c13 = ops.BasicBlock(chs * 2, chs, 3, 1, 1)
113 | self.c14 = ops.BasicBlock(chs * 3, chs, 3, 1, 1)
114 | self.c15 = ops.BasicBlock(chs * 4, chs, 3, 1, 1)
115 | self.c16 = ops.BasicBlock(chs * 5, chs, 3, 1, 1)
116 | self.c17 = ops.BasicBlock(chs * 2, chs, 3, 1, 1)
117 | self.c18 = ops.BasicBlock(chs * 3, chs, 3, 1, 1)
118 | self.c19 = ops.BasicBlock(chs * 4, chs, 3, 1, 1)
119 | self.c20 = ops.BasicBlock(chs * 5, chs, 3, 1, 1)
120 |
121 | self.upsample = ops.UpsampleBlock(chs, self.scale, multi_scale=False)
122 | # self.convert = ops.ConvertBlock(chs, chs, 20)
123 | self.tail = nn.Conv2d(chs, 3, 3, 1, 1)
124 |
125 | def forward(self, x):
126 | x = self.sub_mean(x * 255.)
127 | x = self.head(x)
128 | c0 = o0 = x
129 |
130 | b1 = self.b1(o0)
131 | c1 = torch.cat([c0, b1], dim=1)
132 | o1 = self.c1(c1)
133 |
134 | b2 = self.b2(o1)
135 | c2 = torch.cat([c1, b2], dim=1)
136 | o2 = self.c2(c2)
137 |
138 | b3 = self.b3(o2)
139 | c3 = torch.cat([c2, b3], dim=1)
140 | o3 = self.c3(c3)
141 | a1 = o3 + c0
142 |
143 | b4 = self.b4(a1)
144 | c4 = torch.cat([o3, b4], dim=1)
145 | o4 = self.c4(c4)
146 |
147 | b5 = self.b5(a1)
148 | c5 = torch.cat([c4, b5], dim=1)
149 | o5 = self.c5(c5)
150 |
151 | b6 = self.b6(o5)
152 | c6 = torch.cat([c5, b6], dim=1)
153 | o6 = self.c6(c6)
154 | a2 = o6 + a1
155 |
156 | b7 = self.b7(a2)
157 | c7 = torch.cat([o6, b7], dim=1)
158 | o7 = self.c7(c7)
159 |
160 | b8 = self.b8(o7)
161 | c8 = torch.cat([c7, b8], dim=1)
162 | o8 = self.c8(c8)
163 |
164 | b9 = self.b9(o8)
165 | c9 = torch.cat([c8, b9], dim=1)
166 | o9 = self.c9(c9)
167 | a3 = o9 + a2
168 |
169 | b10 = self.b10(a3)
170 | c10 = torch.cat([o9, b10], dim=1)
171 | o10 = self.c10(c10)
172 |
173 | b11 = self.b11(o10)
174 | c11 = torch.cat([c10, b11], dim=1)
175 | o11 = self.c11(c11)
176 |
177 | b12 = self.b12(o11)
178 | c12 = torch.cat([c11, b12], dim=1)
179 | o12 = self.c12(c12)
180 | a4 = o12 + a3
181 |
182 | b13 = self.b13(a4)
183 | c13 = torch.cat([o12, b13], dim=1)
184 | o13 = self.c13(c13)
185 |
186 | b14 = self.b14(o13)
187 | c14 = torch.cat([c13, b14], dim=1)
188 | o14 = self.c14(c14)
189 |
190 | b15 = self.b15(o14)
191 | c15 = torch.cat([c14, b15], dim=1)
192 | o15 = self.c15(c15)
193 |
194 | b16 = self.b16(o15)
195 | c16 = torch.cat([c15, b16], dim=1)
196 | o16 = self.c16(c16)
197 | a5 = o16 + a4
198 |
199 | b17 = self.b17(a5)
200 | c17 = torch.cat([o16, b17], dim=1)
201 | o17 = self.c17(c17)
202 |
203 | b18 = self.b18(o17)
204 | c18 = torch.cat([c17, b18], dim=1)
205 | o18 = self.c18(c18)
206 |
207 | b19 = self.b19(o18)
208 | c19 = torch.cat([c18, b19], dim=1)
209 | o19 = self.c19(c19)
210 |
211 | b20 = self.b20(o19)
212 | c20 = torch.cat([c19, b20], dim=1)
213 | o20 = self.c20(c20)
214 | a6 = o20 + a5
215 |
216 | # c_out = torch.cat([b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, b17, b18, b19, b20], dim=1)
217 |
218 | # b = self.convert(c_out)
219 | b_out = a6 + x
220 | out = self.upsample(b_out, scale=self.scale)
221 |
222 | out = self.tail(out)
223 | f_out = self.add_mean(out)
224 |
225 | return f_out / 255.
226 |
227 | def load_state_dict(self, state_dict, strict=False):
228 | own_state = self.state_dict()
229 | for name, param in state_dict.items():
230 | if name in own_state:
231 | if isinstance(param, nn.Parameter):
232 | param = param.data
233 | try:
234 | own_state[name].copy_(param)
235 | except Exception:
236 | if name.find('tail') >= 0 or name.find('upsample') >= 0:
237 | print('Replace pre-trained upsampler to new one...')
238 | else:
239 | raise RuntimeError('While copying the parameter named {}, '
240 | 'whose dimensions in the model are {} and '
241 | 'whose dimensions in the checkpoint are {}.'
242 | .format(name, own_state[name].size(), param.size()))
243 | elif strict:
244 | if name.find('tail') == -1:
245 | raise KeyError('unexpected key "{}" in state_dict'
246 | .format(name))
247 |
248 | if strict:
249 | missing = set(own_state.keys()) - set(state_dict.keys())
250 | if len(missing) > 0:
251 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
252 |
253 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright MAN Authors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/LAM/ModelZoo/NN/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from torch.autograd import Variable
9 |
10 |
11 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
12 | return nn.Conv2d(
13 | in_channels, out_channels, kernel_size,
14 | padding=(kernel_size//2), bias=bias)
15 |
16 |
17 | class MeanShift(nn.Conv2d):
18 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
19 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
20 | std = torch.Tensor(rgb_std)
21 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
22 | self.weight.data.div_(std.view(3, 1, 1, 1))
23 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
24 | self.bias.data.div_(std)
25 | self.requires_grad = False
26 |
27 |
28 | class DownBlock(nn.Module):
29 | def __init__(self, opt, scale, nFeat=None, in_channels=None, out_channels=None):
30 | super(DownBlock, self).__init__()
31 | negval = opt.negval
32 |
33 | if nFeat is None:
34 | nFeat = opt.n_feats
35 |
36 | if in_channels is None:
37 | in_channels = opt.n_colors
38 |
39 | if out_channels is None:
40 | out_channels = opt.n_colors
41 |
42 | dual_block = [
43 | nn.Sequential(
44 | nn.Conv2d(in_channels, nFeat, kernel_size=3, stride=2, padding=1, bias=False),
45 | nn.LeakyReLU(negative_slope=negval, inplace=True)
46 | )
47 | ]
48 |
49 | for _ in range(1, int(np.log2(scale))):
50 | dual_block.append(
51 | nn.Sequential(
52 | nn.Conv2d(nFeat, nFeat, kernel_size=3, stride=2, padding=1, bias=False),
53 | nn.LeakyReLU(negative_slope=negval, inplace=True)
54 | )
55 | )
56 |
57 | dual_block.append(nn.Conv2d(nFeat, out_channels, kernel_size=3, stride=1, padding=1, bias=False))
58 |
59 | self.dual_module = nn.Sequential(*dual_block)
60 |
61 | def forward(self, x):
62 | x = self.dual_module(x)
63 | return x
64 |
65 |
66 | class BasicBlock(nn.Sequential):
67 | def __init__(
68 | self, in_channels, out_channels, kernel_size, stride=1, bias=False,
69 | bn=True, act=nn.ReLU(True)):
70 |
71 | m = [nn.Conv2d(
72 | in_channels, out_channels, kernel_size,
73 | padding=(kernel_size//2), stride=stride, bias=bias)
74 | ]
75 | if bn: m.append(nn.BatchNorm2d(out_channels))
76 | if act is not None: m.append(act)
77 | super(BasicBlock, self).__init__(*m)
78 |
79 |
80 | class ResBlock(nn.Module):
81 | def __init__(
82 | self, conv, n_feat, kernel_size,
83 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
84 |
85 | super(ResBlock, self).__init__()
86 | m = []
87 | for i in range(2):
88 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias))
89 | if bn: m.append(nn.BatchNorm2d(n_feat))
90 | if i == 0: m.append(act)
91 |
92 | self.body = nn.Sequential(*m)
93 | self.res_scale = res_scale
94 |
95 | def forward(self, x):
96 | res = self.body(x).mul(self.res_scale)
97 | res += x
98 |
99 | return res
100 |
101 |
102 | class Upsampler(nn.Sequential):
103 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
104 |
105 | m = []
106 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
107 | for _ in range(int(math.log(scale, 2))):
108 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
109 | m.append(nn.PixelShuffle(2))
110 | if bn: m.append(nn.BatchNorm2d(n_feat))
111 | if act: m.append(act())
112 | elif scale == 3:
113 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
114 | m.append(nn.PixelShuffle(3))
115 | if bn: m.append(nn.BatchNorm2d(n_feat))
116 | if act: m.append(act())
117 | else:
118 | raise NotImplementedError
119 |
120 | super(Upsampler, self).__init__(*m)
121 |
122 |
123 | ## add SELayer
124 | class SELayer(nn.Module):
125 | def __init__(self, channel, reduction=16):
126 | super(SELayer, self).__init__()
127 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
128 | self.conv_du = nn.Sequential(
129 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
130 | nn.ReLU(inplace=True),
131 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
132 | nn.Sigmoid()
133 | )
134 |
135 | def forward(self, x):
136 | y = self.avg_pool(x)
137 | y = self.conv_du(y)
138 | return x * y
139 |
140 |
141 | ## add SEResBlock
142 | class SEResBlock(nn.Module):
143 | def __init__(
144 | self, conv, n_feat, kernel_size, reduction,
145 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
146 |
147 | super(SEResBlock, self).__init__()
148 | modules_body = []
149 | for i in range(2):
150 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
151 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
152 | if i == 0: modules_body.append(act)
153 | modules_body.append(SELayer(n_feat, reduction))
154 | self.body = nn.Sequential(*modules_body)
155 | self.res_scale = res_scale
156 |
157 | def forward(self, x):
158 | res = self.body(x)
159 | #res = self.body(x).mul(self.res_scale)
160 | res += x
161 |
162 | return res
163 |
164 |
165 | ## Channel Attention (CA) Layer
166 | class CALayer(nn.Module):
167 | def __init__(self, channel, reduction=16):
168 | super(CALayer, self).__init__()
169 | # global average pooling: feature --> point
170 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
171 | # feature channel downscale and upscale --> channel weight
172 | self.conv_du = nn.Sequential(
173 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
174 | nn.ReLU(inplace=True),
175 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
176 | nn.Sigmoid()
177 | )
178 |
179 | def forward(self, x):
180 | y = self.avg_pool(x)
181 | y = self.conv_du(y)
182 | return x * y
183 |
184 |
185 | ## Residual Channel Attention Block (RCAB)
186 | class RCAB(nn.Module):
187 | def __init__(self, conv, n_feat, kernel_size, reduction=16, bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
188 | super(RCAB, self).__init__()
189 | modules_body = []
190 | for i in range(2):
191 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
192 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
193 | if i == 0: modules_body.append(act)
194 | modules_body.append(CALayer(n_feat, reduction))
195 | self.body = nn.Sequential(*modules_body)
196 | self.res_scale = res_scale
197 |
198 | def forward(self, x):
199 | res = self.body(x)
200 | res += x
201 | return res
202 |
203 |
204 | # add NonLocalBlock2D
205 | # reference: https://github.com/AlexHex7/Non-local_pytorch/blob/master/lib/non_local_simple_version.py
206 | class NonLocalBlock2D(nn.Module):
207 | def __init__(self, in_channels, inter_channels):
208 | super(NonLocalBlock2D, self).__init__()
209 |
210 | self.in_channels = in_channels
211 | self.inter_channels = inter_channels
212 |
213 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1,
214 | padding=0)
215 |
216 | self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1,
217 | padding=0)
218 | # for pytorch 0.3.1
219 | # nn.init.constant(self.W.weight, 0)
220 | # nn.init.constant(self.W.bias, 0)
221 | # for pytorch 0.4.0
222 | nn.init.constant_(self.W.weight, 0)
223 | nn.init.constant_(self.W.bias, 0)
224 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1,
225 | padding=0)
226 |
227 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1,
228 | padding=0)
229 |
230 | def forward(self, x):
231 | batch_size = x.size(0)
232 |
233 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
234 |
235 | g_x = g_x.permute(0, 2, 1)
236 |
237 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
238 |
239 | theta_x = theta_x.permute(0, 2, 1)
240 |
241 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
242 |
243 | f = torch.matmul(theta_x, phi_x)
244 |
245 | f_div_C = F.softmax(f, dim=1)
246 |
247 | y = torch.matmul(f_div_C, g_x)
248 |
249 | y = y.permute(0, 2, 1).contiguous()
250 |
251 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
252 | W_y = self.W(y)
253 | z = W_y + x
254 |
255 | return z
256 |
257 |
258 | ## define trunk branch
259 | class TrunkBranch(nn.Module):
260 | def __init__(
261 | self, conv, n_feat, kernel_size,
262 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
263 | super(TrunkBranch, self).__init__()
264 | modules_body = []
265 | for i in range(2):
266 | modules_body.append(
267 | ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
268 | self.body = nn.Sequential(*modules_body)
269 |
270 | def forward(self, x):
271 | tx = self.body(x)
272 |
273 | return tx
274 |
275 |
276 | ## define mask branch
277 | class MaskBranchDownUp(nn.Module):
278 | def __init__(
279 | self, conv, n_feat, kernel_size,
280 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
281 | super(MaskBranchDownUp, self).__init__()
282 |
283 | MB_RB1 = []
284 | MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
285 |
286 | MB_Down = []
287 | MB_Down.append(nn.Conv2d(n_feat, n_feat, 3, stride=2, padding=1))
288 |
289 | MB_RB2 = []
290 | for i in range(2):
291 | MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
292 |
293 | MB_Up = []
294 | MB_Up.append(nn.ConvTranspose2d(n_feat, n_feat, 6, stride=2, padding=2))
295 |
296 | MB_RB3 = []
297 | MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
298 |
299 | MB_1x1conv = []
300 | MB_1x1conv.append(nn.Conv2d(n_feat, n_feat, 1, padding=0, bias=True))
301 |
302 | MB_sigmoid = []
303 | MB_sigmoid.append(nn.Sigmoid())
304 |
305 | self.MB_RB1 = nn.Sequential(*MB_RB1)
306 | self.MB_Down = nn.Sequential(*MB_Down)
307 | self.MB_RB2 = nn.Sequential(*MB_RB2)
308 | self.MB_Up = nn.Sequential(*MB_Up)
309 | self.MB_RB3 = nn.Sequential(*MB_RB3)
310 | self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
311 | self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
312 |
313 | def forward(self, x):
314 | x_RB1 = self.MB_RB1(x)
315 | x_Down = self.MB_Down(x_RB1)
316 | x_RB2 = self.MB_RB2(x_Down)
317 | x_Up = self.MB_Up(x_RB2)
318 | x_preRB3 = x_RB1 + x_Up
319 | x_RB3 = self.MB_RB3(x_preRB3)
320 | x_1x1 = self.MB_1x1conv(x_RB3)
321 | mx = self.MB_sigmoid(x_1x1)
322 |
323 | return mx
324 |
325 |
326 | ## define nonlocal mask branch
327 | class NLMaskBranchDownUp(nn.Module):
328 | def __init__(
329 | self, conv, n_feat, kernel_size,
330 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
331 | super(NLMaskBranchDownUp, self).__init__()
332 |
333 | MB_RB1 = []
334 | MB_RB1.append(NonLocalBlock2D(n_feat, n_feat // 2))
335 | MB_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
336 |
337 | MB_Down = []
338 | MB_Down.append(nn.Conv2d(n_feat, n_feat, 3, stride=2, padding=1))
339 |
340 | MB_RB2 = []
341 | for i in range(2):
342 | MB_RB2.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
343 |
344 | MB_Up = []
345 | MB_Up.append(nn.ConvTranspose2d(n_feat, n_feat, 6, stride=2, padding=2))
346 |
347 | MB_RB3 = []
348 | MB_RB3.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
349 |
350 | MB_1x1conv = []
351 | MB_1x1conv.append(nn.Conv2d(n_feat, n_feat, 1, padding=0, bias=True))
352 |
353 | MB_sigmoid = []
354 | MB_sigmoid.append(nn.Sigmoid())
355 |
356 | self.MB_RB1 = nn.Sequential(*MB_RB1)
357 | self.MB_Down = nn.Sequential(*MB_Down)
358 | self.MB_RB2 = nn.Sequential(*MB_RB2)
359 | self.MB_Up = nn.Sequential(*MB_Up)
360 | self.MB_RB3 = nn.Sequential(*MB_RB3)
361 | self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
362 | self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
363 |
364 | def forward(self, x):
365 | x_RB1 = self.MB_RB1(x)
366 | x_Down = self.MB_Down(x_RB1)
367 | x_RB2 = self.MB_RB2(x_Down)
368 | x_Up = self.MB_Up(x_RB2)
369 | x_preRB3 = x_RB1 + x_Up
370 | x_RB3 = self.MB_RB3(x_preRB3)
371 | x_1x1 = self.MB_1x1conv(x_RB3)
372 | mx = self.MB_sigmoid(x_1x1)
373 |
374 | return mx
375 |
376 |
377 | ## define residual attention module
378 | class ResAttModuleDownUpPlus(nn.Module):
379 | def __init__(
380 | self, conv, n_feat, kernel_size,
381 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
382 | super(ResAttModuleDownUpPlus, self).__init__()
383 | RA_RB1 = []
384 | RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
385 | RA_TB = []
386 | RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
387 | RA_MB = []
388 | RA_MB.append(MaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
389 | RA_tail = []
390 | for i in range(2):
391 | RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
392 |
393 | self.RA_RB1 = nn.Sequential(*RA_RB1)
394 | self.RA_TB = nn.Sequential(*RA_TB)
395 | self.RA_MB = nn.Sequential(*RA_MB)
396 | self.RA_tail = nn.Sequential(*RA_tail)
397 |
398 | def forward(self, input):
399 | RA_RB1_x = self.RA_RB1(input)
400 | tx = self.RA_TB(RA_RB1_x)
401 | mx = self.RA_MB(RA_RB1_x)
402 | txmx = tx * mx
403 | hx = txmx + RA_RB1_x
404 | hx = self.RA_tail(hx)
405 |
406 | return hx
407 |
408 |
409 | ## define nonlocal residual attention module
410 | class NLResAttModuleDownUpPlus(nn.Module):
411 | def __init__(
412 | self, conv, n_feat, kernel_size,
413 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
414 | super(NLResAttModuleDownUpPlus, self).__init__()
415 | RA_RB1 = []
416 | RA_RB1.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
417 | RA_TB = []
418 | RA_TB.append(TrunkBranch(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
419 | RA_MB = []
420 | RA_MB.append(NLMaskBranchDownUp(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
421 | RA_tail = []
422 | for i in range(2):
423 | RA_tail.append(ResBlock(conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1))
424 |
425 | self.RA_RB1 = nn.Sequential(*RA_RB1)
426 | self.RA_TB = nn.Sequential(*RA_TB)
427 | self.RA_MB = nn.Sequential(*RA_MB)
428 | self.RA_tail = nn.Sequential(*RA_tail)
429 |
430 | def forward(self, input):
431 | RA_RB1_x = self.RA_RB1(input)
432 | tx = self.RA_TB(RA_RB1_x)
433 | mx = self.RA_MB(RA_RB1_x)
434 | txmx = tx * mx
435 | hx = txmx + RA_RB1_x
436 | hx = self.RA_tail(hx)
437 |
438 | return hx
--------------------------------------------------------------------------------
/LAM/ModelZoo/NN/man.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from basicsr.utils.registry import ARCH_REGISTRY
9 |
10 | #LKA from VAN (https://github.com/Visual-Attention-Network)
11 | class LKA(nn.Module):
12 | def __init__(self, dim):
13 | super().__init__()
14 | self.conv0 = nn.Conv2d(dim, dim, 7, padding=7//2, groups=dim)
15 | self.conv_spatial = nn.Conv2d(dim, dim, 9, stride=1, padding=((9//2)*4), groups=dim, dilation=4)
16 | self.conv1 = nn.Conv2d(dim, dim, 1)
17 |
18 | def forward(self, x):
19 | u = x.clone()
20 | attn = self.conv0(x)
21 | attn = self.conv_spatial(attn)
22 | attn = self.conv1(attn)
23 |
24 | return u * attn
25 |
26 | class Attention(nn.Module):
27 | def __init__(self, n_feats):
28 | super().__init__()
29 |
30 | self.norm = LayerNorm(n_feats, data_format='channels_first')
31 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
32 |
33 | self.proj_1 = nn.Conv2d(n_feats, n_feats, 1)
34 | self.spatial_gating_unit = LKA(n_feats)
35 | self.proj_2 = nn.Conv2d(n_feats, n_feats, 1)
36 |
37 | def forward(self, x):
38 | shorcut = x.clone()
39 | x = self.proj_1(self.norm(x))
40 | x = self.spatial_gating_unit(x)
41 | x = self.proj_2(x)
42 | x = x*self.scale + shorcut
43 | return x
44 | #----------------------------------------------------------------------------------------------------------------
45 |
46 | class MLP(nn.Module):
47 | def __init__(self, n_feats):
48 | super().__init__()
49 |
50 | self.norm = LayerNorm(n_feats, data_format='channels_first')
51 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
52 |
53 | i_feats = 2*n_feats
54 |
55 | self.fc1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
56 | self.act = nn.GELU()
57 | self.fc2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0)
58 |
59 | def forward(self, x):
60 | shortcut = x.clone()
61 | x = self.norm(x)
62 | x = self.fc1(x)
63 | x = self.act(x)
64 | x = self.fc2(x)
65 |
66 | return x*self.scale + shortcut
67 |
68 |
69 | class CFF(nn.Module):
70 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'):
71 | super().__init__()
72 | i_feats =n_feats*2
73 |
74 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
75 | self.DWConv1 = nn.Sequential(
76 | nn.Conv2d(i_feats, i_feats, 7, 1, 7//2, groups= n_feats),
77 | nn.GELU())
78 | self.Conv2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0)
79 |
80 | self.norm = LayerNorm(n_feats, data_format='channels_first')
81 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
82 |
83 | def forward(self, x):
84 | shortcut = x.clone()
85 |
86 | #Ghost Expand
87 | x = self.Conv1(self.norm(x))
88 | x = self.DWConv1(x)
89 | x = self.Conv2(x)
90 |
91 | return x*self.scale + shortcut
92 |
93 |
94 | class SimpleGate(nn.Module):
95 | def __init__(self, n_feats):
96 | super().__init__()
97 | i_feats =n_feats*2
98 |
99 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
100 | #self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats)
101 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
102 |
103 | self.norm = LayerNorm(n_feats, data_format='channels_first')
104 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
105 |
106 | def forward(self, x):
107 | shortcut = x.clone()
108 |
109 | #Ghost Expand
110 | x = self.Conv1(self.norm(x))
111 | a, x = torch.chunk(x, 2, dim=1)
112 | x = x*a #self.DWConv1(a)
113 | x = self.Conv2(x)
114 |
115 | return x*self.scale + shortcut
116 | #-----------------------------------------------------------------------------------------------------------------
117 | #RCAN-style
118 | class RCBv6(nn.Module):
119 | def __init__(
120 | self, n_feats, k, lk=7, res_scale=1.0, style ='X', act = nn.SiLU(), deploy=False):
121 | super().__init__()
122 | self.LKA = nn.Sequential(
123 | nn.Conv2d(n_feats, n_feats, 5, 1, lk//2, groups= n_feats),
124 | nn.Conv2d(n_feats, n_feats, 7, stride=1, padding=9, groups=n_feats, dilation=3),
125 | nn.Conv2d(n_feats, n_feats, 1, 1, 0),
126 | nn.Sigmoid())
127 |
128 | #self.LFE2 = LFEv3(n_feats, attn ='CA')
129 |
130 | self.LFE = nn.Sequential(
131 | nn.Conv2d(n_feats,n_feats,3,1,1),
132 | nn.GELU(),
133 | nn.Conv2d(n_feats,n_feats,3,1,1))
134 |
135 |
136 | def forward(self, x, pre_attn=None, RAA=None):
137 | shortcut = x.clone()
138 | x = self.LFE(x)
139 |
140 | x = self.LKA(x)*x
141 |
142 | return x + shortcut
143 |
144 | #-----------------------------------------------------------------------------------------------------------------
145 |
146 |
147 | class MLKA_Ablation(nn.Module):
148 | def __init__(self, n_feats, k=2, squeeze_factor=15):
149 | super().__init__()
150 | i_feats = 2*n_feats
151 |
152 | self.n_feats= n_feats
153 | self.i_feats = i_feats
154 |
155 | self.norm = LayerNorm(n_feats, data_format='channels_first')
156 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
157 |
158 | k = 2
159 |
160 | #Multiscale Large Kernel Attention
161 | self.LKA7 = nn.Sequential(
162 | nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k),
163 | nn.Conv2d(n_feats//k, n_feats//k, 9, stride=1, padding=(9//2)*4, groups=n_feats//k, dilation=4),
164 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))
165 | self.LKA5 = nn.Sequential(
166 | nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k),
167 | nn.Conv2d(n_feats//k, n_feats//k, 7, stride=1, padding=(7//2)*3, groups=n_feats//k, dilation=3),
168 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))
169 | '''self.LKA3 = nn.Sequential(
170 | nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k),
171 | nn.Conv2d(n_feats//k, n_feats//k, 5, stride=1, padding=(5//2)*2, groups=n_feats//k, dilation=2),
172 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))'''
173 |
174 | #self.X3 = nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k)
175 | self.X5 = nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k)
176 | self.X7 = nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k)
177 |
178 | self.proj_first = nn.Sequential(
179 | nn.Conv2d(n_feats, i_feats, 1, 1, 0))
180 |
181 | self.proj_last = nn.Sequential(
182 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
183 |
184 |
185 | def forward(self, x, pre_attn=None, RAA=None):
186 | shortcut = x.clone()
187 |
188 | x = self.norm(x)
189 |
190 | x = self.proj_first(x)
191 |
192 | a, x = torch.chunk(x, 2, dim=1)
193 |
194 | #u_1, u_2, u_3= torch.chunk(u, 3, dim=1)
195 | a_1, a_2 = torch.chunk(a, 2, dim=1)
196 |
197 | a = torch.cat([self.LKA7(a_1)*self.X7(a_1), self.LKA5(a_2)*self.X5(a_2)], dim=1)
198 |
199 | x = self.proj_last(x*a)*self.scale + shortcut
200 |
201 | return x
202 | #-----------------------------------------------------------------------------------------------------------------
203 |
204 |
205 | class LayerNorm(nn.Module):
206 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
207 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
208 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
209 | with shape (batch_size, channels, height, width).
210 | """
211 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
212 | super().__init__()
213 | self.weight = nn.Parameter(torch.ones(normalized_shape))
214 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
215 | self.eps = eps
216 | self.data_format = data_format
217 | if self.data_format not in ["channels_last", "channels_first"]:
218 | raise NotImplementedError
219 | self.normalized_shape = (normalized_shape, )
220 |
221 | def forward(self, x):
222 | if self.data_format == "channels_last":
223 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
224 | elif self.data_format == "channels_first":
225 | u = x.mean(1, keepdim=True)
226 | s = (x - u).pow(2).mean(1, keepdim=True)
227 | x = (x - u) / torch.sqrt(s + self.eps)
228 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
229 | return x
230 |
231 |
232 |
233 | class SGAB(nn.Module):
234 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'):
235 | super().__init__()
236 | i_feats =n_feats*2
237 |
238 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
239 | self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats)
240 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
241 |
242 | self.norm = LayerNorm(n_feats, data_format='channels_first')
243 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
244 |
245 | def forward(self, x):
246 | shortcut = x.clone()
247 |
248 | #Ghost Expand
249 | x = self.Conv1(self.norm(x))
250 | a, x = torch.chunk(x, 2, dim=1)
251 | x = x*self.DWConv1(a)
252 | x = self.Conv2(x)
253 |
254 | return x*self.scale + shortcut
255 |
256 | class GroupGLKA(nn.Module):
257 | def __init__(self, n_feats, k=2, squeeze_factor=15):
258 | super().__init__()
259 | i_feats = 2*n_feats
260 |
261 | self.n_feats= n_feats
262 | self.i_feats = i_feats
263 |
264 | self.norm = LayerNorm(n_feats, data_format='channels_first')
265 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
266 |
267 | #Multiscale Large Kernel Attention
268 | self.LKA7 = nn.Sequential(
269 | nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3),
270 | nn.Conv2d(n_feats//3, n_feats//3, 9, stride=1, padding=(9//2)*4, groups=n_feats//3, dilation=4),
271 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0))
272 | self.LKA5 = nn.Sequential(
273 | nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3),
274 | nn.Conv2d(n_feats//3, n_feats//3, 7, stride=1, padding=(7//2)*3, groups=n_feats//3, dilation=3),
275 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0))
276 | self.LKA3 = nn.Sequential(
277 | nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3),
278 | nn.Conv2d(n_feats//3, n_feats//3, 5, stride=1, padding=(5//2)*2, groups=n_feats//3, dilation=2),
279 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0))
280 |
281 | self.X3 = nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3)
282 | self.X5 = nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3)
283 | self.X7 = nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3)
284 |
285 | self.proj_first = nn.Sequential(
286 | nn.Conv2d(n_feats, i_feats, 1, 1, 0))
287 |
288 | self.proj_last = nn.Sequential(
289 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
290 |
291 |
292 | def forward(self, x, pre_attn=None, RAA=None):
293 | shortcut = x.clone()
294 |
295 | x = self.norm(x)
296 |
297 | x = self.proj_first(x)
298 |
299 | a, x = torch.chunk(x, 2, dim=1)
300 |
301 | a_1, a_2, a_3= torch.chunk(a, 3, dim=1)
302 |
303 | a = torch.cat([self.LKA3(a_1)*self.X3(a_1), self.LKA5(a_2)*self.X5(a_2), self.LKA7(a_3)*self.X7(a_3)], dim=1)
304 |
305 | x = self.proj_last(x*a)*self.scale + shortcut
306 |
307 | return x
308 |
309 |
310 | # MAB
311 | class MAB(nn.Module):
312 | def __init__(
313 | self, n_feats):
314 | super().__init__()
315 |
316 | self.LKA = GroupGLKA(n_feats)
317 |
318 | self.LFE = SGAB(n_feats)
319 |
320 | def forward(self, x, pre_attn=None, RAA=None):
321 | #large kernel attention
322 | x = self.LKA(x)
323 |
324 | #local feature extraction
325 | x = self.LFE(x)
326 |
327 | return x
328 |
329 | class LKAT(nn.Module):
330 | def __init__(self, n_feats):
331 | super().__init__()
332 |
333 | #self.norm = LayerNorm(n_feats, data_format='channels_first')
334 | #self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
335 |
336 | self.conv0 = nn.Sequential(
337 | nn.Conv2d(n_feats, n_feats, 1, 1, 0),
338 | nn.GELU())
339 |
340 | self.att = nn.Sequential(
341 | nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats),
342 | nn.Conv2d(n_feats, n_feats, 9, stride=1, padding=(9//2)*3, groups=n_feats, dilation=3),
343 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
344 |
345 | self.conv1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
346 |
347 | def forward(self, x):
348 | x = self.conv0(x)
349 | x = x*self.att(x)
350 | x = self.conv1(x)
351 | return x
352 |
353 | class ResGroup(nn.Module):
354 | def __init__(self, n_resblocks, n_feats, res_scale=1.0):
355 | super(ResGroup, self).__init__()
356 | self.body = nn.ModuleList([
357 | MAB(n_feats) \
358 | for _ in range(n_resblocks)])
359 |
360 | self.body_t = LKAT(n_feats)
361 |
362 | def forward(self, x):
363 | res = x.clone()
364 |
365 | for i, block in enumerate(self.body):
366 | res = block(res)
367 |
368 | x = self.body_t(res) + x
369 |
370 | return x
371 |
372 | class MeanShift(nn.Conv2d):
373 | def __init__(
374 | self, rgb_range,
375 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
376 |
377 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
378 | std = torch.Tensor(rgb_std)
379 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
380 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
381 | for p in self.parameters():
382 | p.requires_grad = False
383 |
384 |
385 | class MAN(nn.Module):
386 | def __init__(self, n_resblocks=36, n_resgroups=1, n_colors=3, n_feats=180, scale=2, res_scale = 1.0):
387 | super(MAN, self).__init__()
388 |
389 | #res_scale = res_scale
390 | self.n_resgroups = n_resgroups
391 |
392 | self.sub_mean = MeanShift(1.0)
393 | self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
394 |
395 | # define body module
396 | self.body = nn.ModuleList([
397 | ResGroup(
398 | n_resblocks, n_feats, res_scale=res_scale)
399 | for i in range(n_resgroups)])
400 |
401 | if self.n_resgroups > 1:
402 | self.body_t = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
403 |
404 | # define tail module
405 | self.tail = nn.Sequential(
406 | nn.Conv2d(n_feats, n_colors*(scale**2), 3, 1, 1),
407 | nn.PixelShuffle(scale)
408 | )
409 | self.add_mean = MeanShift(1.0, sign=1)
410 |
411 | def forward(self, x):
412 | x = self.sub_mean(x)
413 | x = self.head(x)
414 | res = x
415 | for i in self.body:
416 | res = i(res)
417 | if self.n_resgroups>1:
418 | res = self.body_t(res) + x
419 | x = self.tail(res)
420 | x = self.add_mean(x)
421 | return x
422 |
423 | def visual_feature(self, x):
424 | fea = []
425 | x = self.head(x)
426 | res = x
427 |
428 | for i in self.body:
429 | temp = res
430 | res = i(res)
431 | fea.append(res)
432 |
433 | res = self.body_t(res) + x
434 |
435 | x = self.tail(res)
436 | return x, fea
437 |
438 | def load_state_dict(self, state_dict, strict=False):
439 | own_state = self.state_dict()
440 | for name, param in state_dict.items():
441 | if name in own_state:
442 | if isinstance(param, nn.Parameter):
443 | param = param.data
444 | try:
445 | own_state[name].copy_(param)
446 | except Exception:
447 | if name.find('tail') >= 0:
448 | print('Replace pre-trained upsampler to new one...')
449 | else:
450 | raise RuntimeError('While copying the parameter named {}, '
451 | 'whose dimensions in the model are {} and '
452 | 'whose dimensions in the checkpoint are {}.'
453 | .format(name, own_state[name].size(), param.size()))
454 | elif strict:
455 | if name.find('tail') == -1:
456 | raise KeyError('unexpected key "{}" in state_dict'
457 | .format(name))
458 |
459 | if strict:
460 | missing = set(own_state.keys()) - set(state_dict.keys())
461 | if len(missing) > 0:
462 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
--------------------------------------------------------------------------------
/archs/MAN_arch.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from basicsr.utils.registry import ARCH_REGISTRY
9 |
10 | #LKA from VAN (https://github.com/Visual-Attention-Network)
11 | class LKA(nn.Module):
12 | def __init__(self, dim):
13 | super().__init__()
14 | self.conv0 = nn.Conv2d(dim, dim, 7, padding=7//2, groups=dim)
15 | self.conv_spatial = nn.Conv2d(dim, dim, 9, stride=1, padding=((9//2)*4), groups=dim, dilation=4)
16 | self.conv1 = nn.Conv2d(dim, dim, 1)
17 |
18 | def forward(self, x):
19 | u = x.clone()
20 | attn = self.conv0(x)
21 | attn = self.conv_spatial(attn)
22 | attn = self.conv1(attn)
23 |
24 | return u * attn
25 |
26 | class Attention(nn.Module):
27 | def __init__(self, n_feats):
28 | super().__init__()
29 |
30 | self.norm = LayerNorm(n_feats, data_format='channels_first')
31 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
32 |
33 | self.proj_1 = nn.Conv2d(n_feats, n_feats, 1)
34 | self.spatial_gating_unit = LKA(n_feats)
35 | self.proj_2 = nn.Conv2d(n_feats, n_feats, 1)
36 |
37 | def forward(self, x):
38 | shorcut = x.clone()
39 | x = self.proj_1(self.norm(x))
40 | x = self.spatial_gating_unit(x)
41 | x = self.proj_2(x)
42 | x = x*self.scale + shorcut
43 | return x
44 | #----------------------------------------------------------------------------------------------------------------
45 |
46 | class MLP(nn.Module):
47 | def __init__(self, n_feats):
48 | super().__init__()
49 |
50 | self.norm = LayerNorm(n_feats, data_format='channels_first')
51 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
52 |
53 | i_feats = 2*n_feats
54 |
55 | self.fc1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
56 | self.act = nn.GELU()
57 | self.fc2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0)
58 |
59 | def forward(self, x):
60 | shortcut = x.clone()
61 | x = self.norm(x)
62 | x = self.fc1(x)
63 | x = self.act(x)
64 | x = self.fc2(x)
65 |
66 | return x*self.scale + shortcut
67 |
68 |
69 | class CFF(nn.Module):
70 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'):
71 | super().__init__()
72 | i_feats =n_feats*2
73 |
74 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
75 | self.DWConv1 = nn.Sequential(
76 | nn.Conv2d(i_feats, i_feats, 7, 1, 7//2, groups= n_feats),
77 | nn.GELU())
78 | self.Conv2 = nn.Conv2d(i_feats, n_feats, 1, 1, 0)
79 |
80 | self.norm = LayerNorm(n_feats, data_format='channels_first')
81 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
82 |
83 | def forward(self, x):
84 | shortcut = x.clone()
85 |
86 | #Ghost Expand
87 | x = self.Conv1(self.norm(x))
88 | x = self.DWConv1(x)
89 | x = self.Conv2(x)
90 |
91 | return x*self.scale + shortcut
92 |
93 |
94 | class SimpleGate(nn.Module):
95 | def __init__(self, n_feats):
96 | super().__init__()
97 | i_feats =n_feats*2
98 |
99 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
100 | #self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats)
101 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
102 |
103 | self.norm = LayerNorm(n_feats, data_format='channels_first')
104 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
105 |
106 | def forward(self, x):
107 | shortcut = x.clone()
108 |
109 | #Ghost Expand
110 | x = self.Conv1(self.norm(x))
111 | a, x = torch.chunk(x, 2, dim=1)
112 | x = x*a #self.DWConv1(a)
113 | x = self.Conv2(x)
114 |
115 | return x*self.scale + shortcut
116 | #-----------------------------------------------------------------------------------------------------------------
117 | #RCAN-style
118 | class RCBv6(nn.Module):
119 | def __init__(
120 | self, n_feats, k, lk=7, res_scale=1.0, style ='X', act = nn.SiLU(), deploy=False):
121 | super().__init__()
122 | self.LKA = nn.Sequential(
123 | nn.Conv2d(n_feats, n_feats, 5, 1, lk//2, groups= n_feats),
124 | nn.Conv2d(n_feats, n_feats, 7, stride=1, padding=9, groups=n_feats, dilation=3),
125 | nn.Conv2d(n_feats, n_feats, 1, 1, 0),
126 | nn.Sigmoid())
127 |
128 | #self.LFE2 = LFEv3(n_feats, attn ='CA')
129 |
130 | self.LFE = nn.Sequential(
131 | nn.Conv2d(n_feats,n_feats,3,1,1),
132 | nn.GELU(),
133 | nn.Conv2d(n_feats,n_feats,3,1,1))
134 |
135 |
136 | def forward(self, x, pre_attn=None, RAA=None):
137 | shortcut = x.clone()
138 | x = self.LFE(x)
139 |
140 | x = self.LKA(x)*x
141 |
142 | return x + shortcut
143 |
144 | #-----------------------------------------------------------------------------------------------------------------
145 |
146 |
147 | class MLKA_Ablation(nn.Module):
148 | def __init__(self, n_feats, k=2, squeeze_factor=15):
149 | super().__init__()
150 | i_feats = 2*n_feats
151 |
152 | self.n_feats= n_feats
153 | self.i_feats = i_feats
154 |
155 | self.norm = LayerNorm(n_feats, data_format='channels_first')
156 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
157 |
158 | k = 2
159 |
160 | #Multiscale Large Kernel Attention
161 | self.LKA7 = nn.Sequential(
162 | nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k),
163 | nn.Conv2d(n_feats//k, n_feats//k, 9, stride=1, padding=(9//2)*4, groups=n_feats//k, dilation=4),
164 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))
165 | self.LKA5 = nn.Sequential(
166 | nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k),
167 | nn.Conv2d(n_feats//k, n_feats//k, 7, stride=1, padding=(7//2)*3, groups=n_feats//k, dilation=3),
168 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))
169 | '''self.LKA3 = nn.Sequential(
170 | nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k),
171 | nn.Conv2d(n_feats//k, n_feats//k, 5, stride=1, padding=(5//2)*2, groups=n_feats//k, dilation=2),
172 | nn.Conv2d(n_feats//k, n_feats//k, 1, 1, 0))'''
173 |
174 | #self.X3 = nn.Conv2d(n_feats//k, n_feats//k, 3, 1, 1, groups= n_feats//k)
175 | self.X5 = nn.Conv2d(n_feats//k, n_feats//k, 5, 1, 5//2, groups= n_feats//k)
176 | self.X7 = nn.Conv2d(n_feats//k, n_feats//k, 7, 1, 7//2, groups= n_feats//k)
177 |
178 | self.proj_first = nn.Sequential(
179 | nn.Conv2d(n_feats, i_feats, 1, 1, 0))
180 |
181 | self.proj_last = nn.Sequential(
182 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
183 |
184 |
185 | def forward(self, x, pre_attn=None, RAA=None):
186 | shortcut = x.clone()
187 |
188 | x = self.norm(x)
189 |
190 | x = self.proj_first(x)
191 |
192 | a, x = torch.chunk(x, 2, dim=1)
193 |
194 | #u_1, u_2, u_3= torch.chunk(u, 3, dim=1)
195 | a_1, a_2 = torch.chunk(a, 2, dim=1)
196 |
197 | a = torch.cat([self.LKA7(a_1)*self.X7(a_1), self.LKA5(a_2)*self.X5(a_2)], dim=1)
198 |
199 | x = self.proj_last(x*a)*self.scale + shortcut
200 |
201 | return x
202 | #-----------------------------------------------------------------------------------------------------------------
203 |
204 |
205 | class LayerNorm(nn.Module):
206 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
207 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
208 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
209 | with shape (batch_size, channels, height, width).
210 | """
211 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
212 | super().__init__()
213 | self.weight = nn.Parameter(torch.ones(normalized_shape))
214 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
215 | self.eps = eps
216 | self.data_format = data_format
217 | if self.data_format not in ["channels_last", "channels_first"]:
218 | raise NotImplementedError
219 | self.normalized_shape = (normalized_shape, )
220 |
221 | def forward(self, x):
222 | if self.data_format == "channels_last":
223 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
224 | elif self.data_format == "channels_first":
225 | u = x.mean(1, keepdim=True)
226 | s = (x - u).pow(2).mean(1, keepdim=True)
227 | x = (x - u) / torch.sqrt(s + self.eps)
228 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
229 | return x
230 |
231 |
232 |
233 | class SGAB(nn.Module):
234 | def __init__(self, n_feats, drop=0.0, k=2, squeeze_factor= 15, attn ='GLKA'):
235 | super().__init__()
236 | i_feats =n_feats*2
237 |
238 | self.Conv1 = nn.Conv2d(n_feats, i_feats, 1, 1, 0)
239 | self.DWConv1 = nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats)
240 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
241 |
242 | self.norm = LayerNorm(n_feats, data_format='channels_first')
243 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
244 |
245 | def forward(self, x):
246 | shortcut = x.clone()
247 |
248 | #Ghost Expand
249 | x = self.Conv1(self.norm(x))
250 | a, x = torch.chunk(x, 2, dim=1)
251 | x = x*self.DWConv1(a)
252 | x = self.Conv2(x)
253 |
254 | return x*self.scale + shortcut
255 |
256 | class GroupGLKA(nn.Module):
257 | def __init__(self, n_feats, k=2, squeeze_factor=15):
258 | super().__init__()
259 | i_feats = 2*n_feats
260 |
261 | self.n_feats= n_feats
262 | self.i_feats = i_feats
263 |
264 | self.norm = LayerNorm(n_feats, data_format='channels_first')
265 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
266 |
267 | #Multiscale Large Kernel Attention
268 | self.LKA7 = nn.Sequential(
269 | nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3),
270 | nn.Conv2d(n_feats//3, n_feats//3, 9, stride=1, padding=(9//2)*4, groups=n_feats//3, dilation=4),
271 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0))
272 | self.LKA5 = nn.Sequential(
273 | nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3),
274 | nn.Conv2d(n_feats//3, n_feats//3, 7, stride=1, padding=(7//2)*3, groups=n_feats//3, dilation=3),
275 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0))
276 | self.LKA3 = nn.Sequential(
277 | nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3),
278 | nn.Conv2d(n_feats//3, n_feats//3, 5, stride=1, padding=(5//2)*2, groups=n_feats//3, dilation=2),
279 | nn.Conv2d(n_feats//3, n_feats//3, 1, 1, 0))
280 |
281 | self.X3 = nn.Conv2d(n_feats//3, n_feats//3, 3, 1, 1, groups= n_feats//3)
282 | self.X5 = nn.Conv2d(n_feats//3, n_feats//3, 5, 1, 5//2, groups= n_feats//3)
283 | self.X7 = nn.Conv2d(n_feats//3, n_feats//3, 7, 1, 7//2, groups= n_feats//3)
284 |
285 | self.proj_first = nn.Sequential(
286 | nn.Conv2d(n_feats, i_feats, 1, 1, 0))
287 |
288 | self.proj_last = nn.Sequential(
289 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
290 |
291 |
292 | def forward(self, x, pre_attn=None, RAA=None):
293 | shortcut = x.clone()
294 |
295 | x = self.norm(x)
296 |
297 | x = self.proj_first(x)
298 |
299 | a, x = torch.chunk(x, 2, dim=1)
300 |
301 | a_1, a_2, a_3= torch.chunk(a, 3, dim=1)
302 |
303 | a = torch.cat([self.LKA3(a_1)*self.X3(a_1), self.LKA5(a_2)*self.X5(a_2), self.LKA7(a_3)*self.X7(a_3)], dim=1)
304 |
305 | x = self.proj_last(x*a)*self.scale + shortcut
306 |
307 | return x
308 |
309 |
310 | # MAB
311 | class MAB(nn.Module):
312 | def __init__(
313 | self, n_feats):
314 | super().__init__()
315 |
316 | self.LKA = GroupGLKA(n_feats)
317 |
318 | self.LFE = SGAB(n_feats)
319 |
320 | def forward(self, x, pre_attn=None, RAA=None):
321 | #large kernel attention
322 | x = self.LKA(x)
323 |
324 | #local feature extraction
325 | x = self.LFE(x)
326 |
327 | return x
328 |
329 | class LKAT(nn.Module):
330 | def __init__(self, n_feats):
331 | super().__init__()
332 |
333 | #self.norm = LayerNorm(n_feats, data_format='channels_first')
334 | #self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True)
335 |
336 | self.conv0 = nn.Sequential(
337 | nn.Conv2d(n_feats, n_feats, 1, 1, 0),
338 | nn.GELU())
339 |
340 | self.att = nn.Sequential(
341 | nn.Conv2d(n_feats, n_feats, 7, 1, 7//2, groups= n_feats),
342 | nn.Conv2d(n_feats, n_feats, 9, stride=1, padding=(9//2)*3, groups=n_feats, dilation=3),
343 | nn.Conv2d(n_feats, n_feats, 1, 1, 0))
344 |
345 | self.conv1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
346 |
347 | def forward(self, x):
348 | x = self.conv0(x)
349 | x = x*self.att(x)
350 | x = self.conv1(x)
351 | return x
352 |
353 | class ResGroup(nn.Module):
354 | def __init__(self, n_resblocks, n_feats, res_scale=1.0):
355 | super(ResGroup, self).__init__()
356 | self.body = nn.ModuleList([
357 | MAB(n_feats) \
358 | for _ in range(n_resblocks)])
359 |
360 | self.body_t = LKAT(n_feats)
361 |
362 | def forward(self, x):
363 | res = x.clone()
364 |
365 | for i, block in enumerate(self.body):
366 | res = block(res)
367 |
368 | x = self.body_t(res) + x
369 |
370 | return x
371 |
372 | class MeanShift(nn.Conv2d):
373 | def __init__(
374 | self, rgb_range,
375 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
376 |
377 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
378 | std = torch.Tensor(rgb_std)
379 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
380 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
381 | for p in self.parameters():
382 | p.requires_grad = False
383 |
384 |
385 | @ARCH_REGISTRY.register()
386 | class MAN(nn.Module):
387 | def __init__(self, n_resblocks=36, n_resgroups=1, n_colors=3, n_feats=180, scale=2, res_scale = 1.0):
388 | super(MAN, self).__init__()
389 |
390 | #res_scale = res_scale
391 | self.n_resgroups = n_resgroups
392 |
393 | self.sub_mean = MeanShift(1.0)
394 | self.head = nn.Conv2d(n_colors, n_feats, 3, 1, 1)
395 |
396 | # define body module
397 | self.body = nn.ModuleList([
398 | ResGroup(
399 | n_resblocks, n_feats, res_scale=res_scale)
400 | for i in range(n_resgroups)])
401 |
402 | if self.n_resgroups > 1:
403 | self.body_t = nn.Conv2d(n_feats, n_feats, 3, 1, 1)
404 |
405 | # define tail module
406 | self.tail = nn.Sequential(
407 | nn.Conv2d(n_feats, n_colors*(scale**2), 3, 1, 1),
408 | nn.PixelShuffle(scale)
409 | )
410 | self.add_mean = MeanShift(1.0, sign=1)
411 |
412 | def forward(self, x):
413 | x = self.sub_mean(x)
414 | x = self.head(x)
415 | res = x
416 | for i in self.body:
417 | res = i(res)
418 | if self.n_resgroups>1:
419 | res = self.body_t(res) + x
420 | x = self.tail(res)
421 | x = self.add_mean(x)
422 | return x
423 |
424 | def visual_feature(self, x):
425 | fea = []
426 | x = self.head(x)
427 | res = x
428 |
429 | for i in self.body:
430 | temp = res
431 | res = i(res)
432 | fea.append(res)
433 |
434 | res = self.body_t(res) + x
435 |
436 | x = self.tail(res)
437 | return x, fea
438 |
439 | def load_state_dict(self, state_dict, strict=False):
440 | own_state = self.state_dict()
441 | for name, param in state_dict.items():
442 | if name in own_state:
443 | if isinstance(param, nn.Parameter):
444 | param = param.data
445 | try:
446 | own_state[name].copy_(param)
447 | except Exception:
448 | if name.find('tail') >= 0:
449 | print('Replace pre-trained upsampler to new one...')
450 | else:
451 | raise RuntimeError('While copying the parameter named {}, '
452 | 'whose dimensions in the model are {} and '
453 | 'whose dimensions in the checkpoint are {}.'
454 | .format(name, own_state[name].size(), param.size()))
455 | elif strict:
456 | if name.find('tail') == -1:
457 | raise KeyError('unexpected key "{}" in state_dict'
458 | .format(name))
459 |
460 | if strict:
461 | missing = set(own_state.keys()) - set(state_dict.keys())
462 | if len(missing) > 0:
463 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
464 |
--------------------------------------------------------------------------------
/LAM/ModelZoo/NN/san.py:
--------------------------------------------------------------------------------
1 | from ..NN import common
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from ..NN.MPNCOV.python import MPNCOV
6 |
7 |
8 | #
9 | def make_model(args, parent=False):
10 | return SAN(args)
11 |
12 |
13 | ## non_local module
14 | class _NonLocalBlockND(nn.Module):
15 | def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian',
16 | sub_sample=True, bn_layer=True):
17 | super(_NonLocalBlockND, self).__init__()
18 | assert dimension in [1, 2, 3]
19 | assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation']
20 |
21 | # print('Dimension: %d, mode: %s' % (dimension, mode))
22 |
23 | self.mode = mode
24 | self.dimension = dimension
25 | self.sub_sample = sub_sample
26 |
27 | self.in_channels = in_channels
28 | self.inter_channels = inter_channels
29 |
30 | if self.inter_channels is None:
31 | self.inter_channels = in_channels // 2
32 | if self.inter_channels == 0:
33 | self.inter_channels = 1
34 |
35 | if dimension == 3:
36 | conv_nd = nn.Conv3d
37 | max_pool = nn.MaxPool3d
38 | bn = nn.BatchNorm3d
39 | elif dimension == 2:
40 | conv_nd = nn.Conv2d
41 | max_pool = nn.MaxPool2d
42 | sub_sample = nn.Upsample
43 | bn = nn.BatchNorm2d
44 | else:
45 | conv_nd = nn.Conv1d
46 | max_pool = nn.MaxPool1d
47 | bn = nn.BatchNorm1d
48 |
49 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
50 | kernel_size=1, stride=1, padding=0)
51 |
52 | if bn_layer:
53 | self.W = nn.Sequential(
54 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
55 | kernel_size=1, stride=1, padding=0),
56 | bn(self.in_channels)
57 | )
58 | nn.init.constant_(self.W[1].weight, 0)
59 | nn.init.constant_(self.W[1].bias, 0)
60 | else:
61 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
62 | kernel_size=1, stride=1, padding=0)
63 | nn.init.constant_(self.W.weight, 0)
64 | nn.init.constant_(self.W.bias, 0)
65 |
66 | self.theta = None
67 | self.phi = None
68 | self.concat_project = None
69 | # self.fc = nn.Linear(64,2304,bias=True)
70 | # self.sub_bilinear = nn.Upsample(size=(48,48),mode='bilinear')
71 | # self.sub_maxpool = nn.AdaptiveMaxPool2d(output_size=(48,48))
72 | if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
73 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
74 | kernel_size=1, stride=1, padding=0)
75 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
76 | kernel_size=1, stride=1, padding=0)
77 |
78 | if mode == 'embedded_gaussian':
79 | self.operation_function = self._embedded_gaussian
80 | elif mode == 'dot_product':
81 | self.operation_function = self._dot_product
82 | elif mode == 'concatenation':
83 | self.operation_function = self._concatenation
84 | self.concat_project = nn.Sequential(
85 | nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
86 | nn.ReLU()
87 | )
88 | elif mode == 'gaussian':
89 | self.operation_function = self._gaussian
90 |
91 | if sub_sample:
92 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2))
93 | if self.phi is None:
94 | self.phi = max_pool(kernel_size=2)
95 | else:
96 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2))
97 |
98 | def forward(self, x):
99 | '''
100 | :param x: (b, c, t, h, w)
101 | :return:
102 | '''
103 |
104 | output = self.operation_function(x)
105 | return output
106 |
107 | def _embedded_gaussian(self, x):
108 | batch_size, C, H, W = x.shape
109 |
110 | # x_sub = self.sub_bilinear(x) # bilinear downsample
111 | # x_sub = self.sub_maxpool(x) # maxpool downsample
112 |
113 | ##
114 | # g_x = x.view(batch_size, self.inter_channels, -1)
115 | # g_x = g_x.permute(0, 2, 1)
116 | #
117 | # # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
118 | # # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
119 | # # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
120 | # theta_x = x.view(batch_size, self.inter_channels, -1)
121 | # theta_x = theta_x.permute(0, 2, 1)
122 | # fc = self.fc(theta_x)
123 | # # phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
124 | # # f = torch.matmul(theta_x, phi_x)
125 | # # return f
126 | # # f_div_C = F.softmax(fc, dim=-1)
127 | # return fc
128 |
129 | ##
130 | # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c)
131 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
132 | g_x = g_x.permute(0, 2, 1)
133 |
134 | # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c)
135 | # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw)
136 | # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw)
137 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
138 | theta_x = theta_x.permute(0, 2, 1)
139 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
140 | f = torch.matmul(theta_x, phi_x)
141 | # return f
142 | f_div_C = F.softmax(f, dim=-1)
143 | # return f_div_C
144 | # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w)
145 | y = torch.matmul(f_div_C, g_x)
146 | y = y.permute(0, 2, 1).contiguous()
147 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
148 | W_y = self.W(y)
149 | z = W_y + x
150 |
151 | return z
152 |
153 | def _gaussian(self, x):
154 | batch_size = x.size(0)
155 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
156 | g_x = g_x.permute(0, 2, 1)
157 |
158 | theta_x = x.view(batch_size, self.in_channels, -1)
159 | theta_x = theta_x.permute(0, 2, 1)
160 |
161 | if self.sub_sample:
162 | phi_x = self.phi(x).view(batch_size, self.in_channels, -1)
163 | else:
164 | phi_x = x.view(batch_size, self.in_channels, -1)
165 |
166 | f = torch.matmul(theta_x, phi_x)
167 | f_div_C = F.softmax(f, dim=-1)
168 |
169 | y = torch.matmul(f_div_C, g_x)
170 | y = y.permute(0, 2, 1).contiguous()
171 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
172 | W_y = self.W(y)
173 | z = W_y + x
174 |
175 | return z
176 |
177 | def _dot_product(self, x):
178 | batch_size = x.size(0)
179 |
180 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
181 | g_x = g_x.permute(0, 2, 1)
182 |
183 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
184 | theta_x = theta_x.permute(0, 2, 1)
185 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
186 | f = torch.matmul(theta_x, phi_x)
187 | N = f.size(-1)
188 | f_div_C = f / N
189 |
190 | y = torch.matmul(f_div_C, g_x)
191 | y = y.permute(0, 2, 1).contiguous()
192 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
193 | W_y = self.W(y)
194 | z = W_y + x
195 |
196 | return z
197 |
198 | def _concatenation(self, x):
199 | batch_size = x.size(0)
200 |
201 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
202 | g_x = g_x.permute(0, 2, 1)
203 |
204 | # (b, c, N, 1)
205 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1, 1)
206 | # (b, c, 1, N)
207 | phi_x = self.phi(x).view(batch_size, self.inter_channels, 1, -1)
208 |
209 | h = theta_x.size(2)
210 | w = phi_x.size(3)
211 | theta_x = theta_x.repeat(1, 1, 1, w)
212 | phi_x = phi_x.repeat(1, 1, h, 1)
213 |
214 | concat_feature = torch.cat([theta_x, phi_x], dim=1)
215 | f = self.concat_project(concat_feature)
216 | b, _, h, w = f.size()
217 | f = f.view(b, h, w)
218 |
219 | N = f.size(-1)
220 | f_div_C = f / N
221 |
222 | y = torch.matmul(f_div_C, g_x)
223 | y = y.permute(0, 2, 1).contiguous()
224 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
225 | W_y = self.W(y)
226 | z = W_y + x
227 |
228 | return z
229 |
230 |
231 | class NONLocalBlock1D(_NonLocalBlockND):
232 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True):
233 | super(NONLocalBlock1D, self).__init__(in_channels,
234 | inter_channels=inter_channels,
235 | dimension=1, mode=mode,
236 | sub_sample=sub_sample,
237 | bn_layer=bn_layer)
238 |
239 |
240 | class NONLocalBlock2D(_NonLocalBlockND):
241 | def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample=True, bn_layer=True):
242 | super(NONLocalBlock2D, self).__init__(in_channels,
243 | inter_channels=inter_channels,
244 | dimension=2, mode=mode,
245 | sub_sample=sub_sample,
246 | bn_layer=bn_layer)
247 |
248 |
249 | ## Channel Attention (CA) Layer
250 | class CALayer(nn.Module):
251 | def __init__(self, channel, reduction=8):
252 | super(CALayer, self).__init__()
253 | # global average pooling: feature --> point
254 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
255 | self.max_pool = nn.AdaptiveMaxPool2d(1)
256 | # feature channel downscale and upscale --> channel weight
257 | self.conv_du = nn.Sequential(
258 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
259 | nn.ReLU(inplace=True),
260 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
261 | # nn.Sigmoid()
262 | # nn.BatchNorm2d(channel)
263 | )
264 |
265 | def forward(self, x):
266 | _, _, h, w = x.shape
267 | y_ave = self.avg_pool(x)
268 | # y_max = self.max_pool(x)
269 | y_ave = self.conv_du(y_ave)
270 | # y_max = self.conv_du(y_max)
271 | # y = y_ave + y_max
272 | # expand y to C*H*W
273 | # expand_y = y.expand(-1,-1,h,w)
274 | return y_ave
275 |
276 |
277 | ## second-order Channel attention (SOCA)
278 | class SOCA(nn.Module):
279 | def __init__(self, channel, reduction=8):
280 | super(SOCA, self).__init__()
281 | # global average pooling: feature --> point
282 | # self.avg_pool = nn.AdaptiveAvgPool2d(1)
283 | # self.max_pool = nn.AdaptiveMaxPool2d(1)
284 | self.max_pool = nn.MaxPool2d(kernel_size=2)
285 |
286 | # feature channel downscale and upscale --> channel weight
287 | self.conv_du = nn.Sequential(
288 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
289 | nn.ReLU(inplace=True),
290 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
291 | nn.Sigmoid()
292 | # nn.BatchNorm2d(channel)
293 | )
294 |
295 | def forward(self, x):
296 | batch_size, C, h, w = x.shape # x: NxCxHxW
297 | N = int(h * w)
298 | min_h = min(h, w)
299 | h1 = 1000
300 | w1 = 1000
301 | if h < h1 and w < w1:
302 | x_sub = x
303 | elif h < h1 and w > w1:
304 | # H = (h - h1) // 2
305 | W = (w - w1) // 2
306 | x_sub = x[:, :, :, W:(W + w1)]
307 | elif w < w1 and h > h1:
308 | H = (h - h1) // 2
309 | # W = (w - w1) // 2
310 | x_sub = x[:, :, H:H + h1, :]
311 | else:
312 | H = (h - h1) // 2
313 | W = (w - w1) // 2
314 | x_sub = x[:, :, H:(H + h1), W:(W + w1)]
315 | # subsample
316 | # subsample_scale = 2
317 | # subsample = nn.Upsample(size=(h // subsample_scale, w // subsample_scale), mode='nearest')
318 | # x_sub = subsample(x)
319 | # max_pool = nn.MaxPool2d(kernel_size=2)
320 | # max_pool = nn.AvgPool2d(kernel_size=2)
321 | # x_sub = self.max_pool(x)
322 | ##
323 | ## MPN-COV
324 | cov_mat = MPNCOV.CovpoolLayer(x_sub) # Global Covariance pooling layer
325 | cov_mat_sqrt = MPNCOV.SqrtmLayer(cov_mat,
326 | 5) # Matrix square root layer( including pre-norm,Newton-Schulz iter. and post-com. with 5 iteration)
327 | ##
328 | cov_mat_sum = torch.mean(cov_mat_sqrt, 1)
329 | cov_mat_sum = cov_mat_sum.view(batch_size, C, 1, 1)
330 | # y_ave = self.avg_pool(x)
331 | # y_max = self.max_pool(x)
332 | y_cov = self.conv_du(cov_mat_sum)
333 | # y_max = self.conv_du(y_max)
334 | # y = y_ave + y_max
335 | # expand y to C*H*W
336 | # expand_y = y.expand(-1,-1,h,w)
337 | return y_cov * x
338 |
339 |
340 | ## self-attention+ channel attention module
341 | class Nonlocal_CA(nn.Module):
342 | def __init__(self, in_feat=64, inter_feat=32, reduction=8, sub_sample=False, bn_layer=True):
343 | super(Nonlocal_CA, self).__init__()
344 | # second-order channel attention
345 | self.soca = SOCA(in_feat, reduction=reduction)
346 | # nonlocal module
347 | self.non_local = _NonLocalBlockND(in_channels=in_feat, inter_channels=inter_feat, dimension=2, sub_sample=sub_sample, bn_layer=bn_layer)
348 |
349 | self.sigmoid = nn.Sigmoid()
350 |
351 | def forward(self, x):
352 | ## divide feature map into 4 part
353 | batch_size, C, H, W = x.shape
354 | H1 = int(H / 2)
355 | W1 = int(W / 2)
356 | nonlocal_feat = torch.zeros_like(x)
357 |
358 | feat_sub_lu = x[:, :, :H1, :W1]
359 | feat_sub_ld = x[:, :, H1:, :W1]
360 | feat_sub_ru = x[:, :, :H1, W1:]
361 | feat_sub_rd = x[:, :, H1:, W1:]
362 |
363 | nonlocal_lu = self.non_local(feat_sub_lu)
364 | nonlocal_ld = self.non_local(feat_sub_ld)
365 | nonlocal_ru = self.non_local(feat_sub_ru)
366 | nonlocal_rd = self.non_local(feat_sub_rd)
367 | nonlocal_feat[:, :, :H1, :W1] = nonlocal_lu
368 | nonlocal_feat[:, :, H1:, :W1] = nonlocal_ld
369 | nonlocal_feat[:, :, :H1, W1:] = nonlocal_ru
370 | nonlocal_feat[:, :, H1:, W1:] = nonlocal_rd
371 |
372 | return nonlocal_feat
373 |
374 |
375 | ## Residual Block (RB)
376 | class RB(nn.Module):
377 | def __init__(self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(inplace=True),
378 | res_scale=1, dilation=2):
379 | super(RB, self).__init__()
380 | modules_body = []
381 |
382 | # self.gamma1 = nn.Parameter(torch.ones(1))
383 | self.gamma1 = 1.0
384 | # self.salayer = SALayer(n_feat, reduction=reduction, dilation=dilation)
385 | # self.salayer = SALayer2(n_feat, reduction=reduction, dilation=dilation)
386 |
387 | self.conv_first = nn.Sequential(conv(n_feat, n_feat, kernel_size, bias=bias),
388 | act,
389 | conv(n_feat, n_feat, kernel_size, bias=bias)
390 | )
391 |
392 | self.res_scale = res_scale
393 |
394 | def forward(self, x):
395 | y = self.conv_first(x)
396 | y = y + x
397 |
398 | return y
399 |
400 |
401 | ## Local-source Residual Attention Group (LSRARG)
402 | class LSRAG(nn.Module):
403 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
404 | super(LSRAG, self).__init__()
405 | ##
406 | self.rcab = nn.ModuleList([RB(conv, n_feat, kernel_size, reduction, \
407 | bias=True, bn=False, act=nn.ReLU(inplace=True), res_scale=1) for _ in
408 | range(n_resblocks)])
409 | self.soca = (SOCA(n_feat, reduction=reduction))
410 | self.conv_last = (conv(n_feat, n_feat, kernel_size))
411 | self.n_resblocks = n_resblocks
412 | ##
413 | # modules_body = []
414 | self.gamma = nn.Parameter(torch.zeros(1))
415 | # self.gamma = 0.2
416 | # for i in range(n_resblocks):
417 | # modules_body.append(RCAB(conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(inplace=True), res_scale=1))
418 | # modules_body.append(SOCA(n_feat,reduction=reduction))
419 | # # modules_body.append(Nonlocal_CA(in_feat=n_feat, inter_feat=n_feat//8, reduction =reduction, sub_sample=False, bn_layer=False))
420 | # modules_body.append(conv(n_feat, n_feat, kernel_size))
421 | # self.body = nn.Sequential(*modules_body)
422 | ##
423 |
424 | def make_layer(self, block, num_of_layer):
425 | layers = []
426 | for _ in range(num_of_layer):
427 | layers.append(block)
428 | return nn.ModuleList(layers)
429 | # return nn.Sequential(*layers)
430 |
431 | def forward(self, x):
432 | residual = x
433 | # batch_size,C,H,W = x.shape
434 | # y_pre = self.body(x)
435 | # y_pre = y_pre + x
436 | # return y_pre
437 |
438 | ## share-source skip connection
439 |
440 | for i, l in enumerate(self.rcab):
441 | # x = l(x) + self.gamma*residual
442 | x = l(x)
443 | x = self.soca(x)
444 | x = self.conv_last(x)
445 |
446 | x = x + residual
447 |
448 | return x
449 | ##
450 |
451 |
452 | ## Second-order Channel Attention Network (SAN)
453 | class SAN(nn.Module):
454 | def __init__(self, factor=4, num_channels=3, conv=common.default_conv):
455 | super(SAN, self).__init__()
456 | n_resgroups = 20
457 | n_resblocks = 10
458 | n_feats = 64
459 | kernel_size = 3
460 | reduction = 16
461 | scale = factor
462 | act = nn.ReLU(inplace=True)
463 |
464 | # RGB mean for DIV2K
465 | rgb_mean = (0.4488, 0.4371, 0.4040)
466 | rgb_std = (1.0, 1.0, 1.0)
467 | self.sub_mean = common.MeanShift(1.0, rgb_mean, rgb_std)
468 | # self.soca= SOCA(n_feats, reduction=reduction)
469 |
470 | # define head module
471 | modules_head = [conv(num_channels, n_feats, kernel_size)]
472 |
473 | # define body module
474 | ## share-source skip connection
475 |
476 | ##
477 | self.gamma = nn.Parameter(torch.zeros(1))
478 | # self.gamma = 0.2
479 | self.n_resgroups = n_resgroups
480 | self.RG = nn.ModuleList([LSRAG(conv, n_feats, kernel_size, reduction, \
481 | act=act, res_scale=1.0, n_resblocks=n_resblocks) for _ in
482 | range(n_resgroups)])
483 | self.conv_last = conv(n_feats, n_feats, kernel_size)
484 |
485 | # modules_body = [
486 | # ResidualGroup(
487 | # conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
488 | # for _ in range(n_resgroups)]
489 | # modules_body.append(conv(n_feats, n_feats, kernel_size))
490 |
491 | # define tail module
492 | modules_tail = [
493 | common.Upsampler(conv, scale, n_feats, act=False),
494 | conv(n_feats, num_channels, kernel_size)]
495 |
496 | self.add_mean = common.MeanShift(1.0, rgb_mean, rgb_std, 1)
497 | self.non_local = Nonlocal_CA(in_feat=n_feats, inter_feat=n_feats // 8, reduction=8, sub_sample=False,
498 | bn_layer=False)
499 |
500 | self.head = nn.Sequential(*modules_head)
501 | # self.body = nn.Sequential(*modules_body)
502 | self.tail = nn.Sequential(*modules_tail)
503 |
504 | def make_layer(self, block, num_of_layer):
505 | layers = []
506 | for _ in range(num_of_layer):
507 | layers.append(block)
508 |
509 | return nn.ModuleList(layers)
510 | # return nn.Sequential(*layers)
511 |
512 | def forward(self, x):
513 | x = self.sub_mean(x * 255.)
514 | x = self.head(x)
515 |
516 | ## add nonlocal
517 | xx = self.non_local(x)
518 |
519 | # share-source skip connection
520 | residual = xx
521 |
522 | # res = self.RG(xx)
523 | # res = res + xx
524 | ## share-source residual gruop
525 | for i, l in enumerate(self.RG):
526 | xx = l(xx) + self.gamma * residual
527 | # xx = self.gamma*xx + residual
528 | # body part
529 | # res = self.body(xx)
530 | ##
531 | ## add nonlocal
532 | res = self.non_local(xx)
533 | ##
534 | # res = self.soca(res)
535 | # res += x
536 | res = res + x
537 |
538 | x = self.tail(res)
539 | x = self.add_mean(x)
540 |
541 | return x / 255.
542 |
543 | def load_state_dict(self, state_dict, strict=False):
544 | own_state = self.state_dict()
545 | for name, param in state_dict.items():
546 | if name in own_state:
547 | if isinstance(param, nn.Parameter):
548 | param = param.data
549 | try:
550 | own_state[name].copy_(param)
551 | except Exception:
552 | if name.find('tail') >= 0:
553 | print('Replace pre-trained upsampler to new one...')
554 | else:
555 | raise RuntimeError('While copying the parameter named {}, '
556 | 'whose dimensions in the model are {} and '
557 | 'whose dimensions in the checkpoint are {}.'
558 | .format(name, own_state[name].size(), param.size()))
559 | elif strict:
560 | if name.find('tail') == -1:
561 | raise KeyError('unexpected key "{}" in state_dict'
562 | .format(name))
563 |
564 | if strict:
565 | missing = set(own_state.keys()) - set(state_dict.keys())
566 | if len(missing) > 0:
567 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
--------------------------------------------------------------------------------