├── .idea ├── .gitignore ├── DAN-Basd-on-Openmmlab.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── configs └── restorers │ └── dan │ └── dan_setting2.py ├── mmedit └── models │ ├── backbones │ └── sr_backbones │ │ └── dan_net.py │ ├── common │ └── dan_preprocess.py │ └── restorers │ └── dan.py └── tools └── data └── super-resolution └── dan_datasets ├── pca_matrix ├── pca_aniso_matrix_x2.pth └── pca_aniso_matrix_x4.pth ├── preprocess_dan_datasets.py └── preprocess_div2k_dataset.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/DAN-Basd-on-Openmmlab.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 35 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAN-Basd-on-Openmmlab 2 | DAN: [Unfolding the Alternating Optimization for Blind Super Resolution](https://arxiv.org/abs/2010.02631) 3 | 4 | We reproduce DAN via [mmediting](https://github.com/open-mmlab/mmediting) based on [open-sourced code](https://github.com/greatlog/DAN). 5 | 6 | ## Requirements 7 | 8 | - PyTorch >= 1.3 9 | - mmediting >= 0.9 10 | 11 | ## DataSets 12 | We use [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) as our training datasets. 13 | For evaluation of Setting 2, we use [DIV2KRK](http://www.wisdom.weizmann.ac.il/~vision/kernelgan/DIV2KRK_public.zip) datasets, 14 | 15 | ## Usages 16 | How to run this repo: copy the file to the mmediting workspace and run the program directly based on the commands in mmediting 17 | 18 | 1. Copy files to MMEditing workspace. 19 | ```shell 20 | cd DAN-Basd-on-Openmmlab/ 21 | mv ./mmedit/models/restorers/dan.py ${mmediting_workspace}/mmedit/models/restorers/ 22 | mv ./mmedit/models/backbones/sr_backbones/dan_net.py ${mmediting_workspace}/mmedit/models/backbones/sr_backbones/ 23 | mv ./mmedit/models/common/DANpreprocess.py ${mmediting_workspace}/mmedit/models/common 24 | mv ./configs/restorers/dan ${mmediting_workspace}/configs/restorers/ 25 | mv ./tools/data/super-resolution/dan_datasets ${mmediting_workspace}/tools/data/super-resolution/ 26 | ``` 27 | 2. Modify the configuration file as follows: 28 | 29 | ```python 30 | pca_matrix_path='${mmediting_workspace}/tools/data/super-resolution/div2k/pca_matrix/pca_aniso_matrix_x4.pth' # your pca_matrix path 31 | # Training 32 | gt_folder='${dataset_workspace}/dataset/DF2K_train_HR_sub' # your train data path 33 | # Testing 34 | lq_folder='${dataset_workspace}/dataset/DIV2KRK/lr_x4' # your test data LR path 35 | gt_folder='${dataset_workspace}/dataset/DIV2KRK/gt' # your test data HR path 36 | ``` 37 | 3. Add script to init file, as follows: 38 | 39 | - modify the `mmedit/models/backbones/sr_backbones/__init__.py`: 40 | ```python 41 | from .dan_net import DAN 42 | # add DAN into __all__ list. 43 | ``` 44 | - modify the `mmedit/models/commons/__init__.py`: 45 | ```python 46 | from .dan_preprocess import SRMDPreprocessing 47 | # add SRMDreprocessing into __all__ list. 48 | ``` 49 | - modify the `mmedit/models/restorers/__init__.py`: 50 | ```python 51 | from .dan import DAN 52 | # add DAN into __all__ list. 53 | ``` 54 | 55 | 4. Training/Test 56 | 57 | Before using it, please download and process the dataset and set the path in the configuration file. 58 | 59 | - Train 60 | 61 | ```shell 62 | # Single GPU 63 | python tools/train.py configs/restorers/dan/dan_setting2.py --work_dir ${YOUR_WORK_DIR} 64 | 65 | # Multiple GPUs 66 | ./tools/dist_train.sh configs/restorers/dan/dan_setting2.py ${GPU_NUM} --work_dir ${YOUR_WORK_DIR} 67 | ``` 68 | 69 | - Test 70 | ```shell 71 | # Single GPU 72 | python tools/test.py configs/restorers/dan/dan_setting2.py ${CHECKPOINT_FILE} [--metrics ${METRICS}] [--out ${RESULT_FILE}] 73 | 74 | # Multiple GPUs 75 | ./tools/dist_test.sh configs/restorers/dan/dan_setting2.py ${CHECKPOINT_FILE} ${GPU_NUM} [--metrics ${METRICS}] [--out ${RESULT_FILE}] 76 | ``` 77 | 78 | ## Result 79 | 80 | ### DIV2KRK 81 | The passwds of download links are all 'ta2o' 82 | 83 | | Method | scale | Datasets | PSNR | SSIM | Download | 84 | | :-----: | :----: | :----: | :----: | :----: | :----:| 85 | | DAN-RGB (paper) | x4 | DIV2KRK | 26.09 | 0.7312 | - | 86 | | DAN-Y (paper) | x4 | DIV2KRK | 27.55 | 0.7582 | - | 87 | | DAN-RGB (Ours) | x4 | DIV2KRK | 27.41 | 0.7666 | [model](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) / [test_pkl](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) | 88 | | DAN-Y (Ours) | x4 | DIV2KRK | 28.88 | 0.7915 | [model](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) / [test_pkl](https://pan.baidu.com/s/1T_BOVR7Ui-NLUIKr6R20-w) | 89 | 90 | ------------ 91 | -------------------------------------------------------------------------------- /configs/restorers/dan/dan_setting2.py: -------------------------------------------------------------------------------- 1 | exp_name = 'dan_setting2_lr6e25-6_320' 2 | 3 | scale = 4 4 | # model settings 5 | model = dict( 6 | type='DAN', 7 | generator=dict( 8 | type='DAN', 9 | nf=64, 10 | nb=40, 11 | input_para=10, 12 | loop=4, 13 | kernel_size=21, 14 | # Your pca matrix path 15 | pca_matrix_path='pca_aniso_matrix_x4.pth'), 16 | pixel_loss=dict(type='MSELoss', loss_weight=1.0, reduction='mean')) 17 | # model training and testing settings 18 | 19 | train_cfg = dict( 20 | # Your pca matrix path 21 | pca_matrix_path='pca_aniso_matrix_x4.pth', 22 | scale=scale, 23 | degradation=dict( 24 | random_kernel=True, 25 | ksize=21, 26 | code_length=10, 27 | sig_min=0.6, 28 | sig_max=5.0, 29 | rate_iso=0, 30 | random_disturb=True)) 31 | test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=scale) 32 | 33 | # dataset settings 34 | train_dataset_type = 'SRFolderGTDataset' 35 | val_dataset_type = 'SRFolderDataset' 36 | train_pipeline = [ 37 | dict( 38 | type='LoadImageFromFile', 39 | io_backend='disk', 40 | key='gt', 41 | flag='unchanged'), 42 | dict(type='RescaleToZeroOne', keys=['gt']), 43 | dict( 44 | type='Normalize', 45 | keys=['gt'], 46 | mean=[0, 0, 0], 47 | std=[1, 1, 1], 48 | to_rgb=True), 49 | dict(type='Crop', keys=['gt'], crop_size=(256, 256)), 50 | dict(type='Flip', keys=['gt'], flip_ratio=0.5, direction='horizontal'), 51 | dict(type='Flip', keys=['gt'], flip_ratio=0.5, direction='vertical'), 52 | dict(type='RandomTransposeHW', keys=['gt'], transpose_ratio=0.5), 53 | dict(type='Collect', keys=['gt'], meta_keys=['gt_path']), 54 | dict(type='ImageToTensor', keys=['gt']) 55 | ] 56 | test_pipeline = [ 57 | dict(type='LoadImageFromFile', io_backend='disk', key='lq', flag='color'), 58 | dict(type='LoadImageFromFile', io_backend='disk', key='gt', flag='color'), 59 | dict(type='RescaleToZeroOne', keys=['lq', 'gt']), 60 | dict( 61 | type='Normalize', 62 | keys=['lq', 'gt'], 63 | mean=[0, 0, 0], 64 | std=[1, 1, 1], 65 | to_rgb=True), 66 | dict(type='Collect', keys=['lq', 'gt'], meta_keys=['lq_path', 'gt_path']), 67 | dict(type='ImageToTensor', keys=['lq', 'gt']) 68 | ] 69 | 70 | data = dict( 71 | workers_per_gpu=4, 72 | train_dataloader=dict(samples_per_gpu=8, drop_last=True), 73 | val_dataloader=dict(samples_per_gpu=1), 74 | test_dataloader=dict(samples_per_gpu=1), 75 | train=dict( 76 | type='RepeatDataset', 77 | times=1000, 78 | dataset=dict( 79 | type=train_dataset_type, 80 | gt_folder='dataset/DF2K_train_HR_sub', # Your training data path 81 | pipeline=train_pipeline, 82 | scale=scale)), 83 | val=dict( 84 | type=val_dataset_type, 85 | lq_folder='dataset/DIV2KRK/lr_x4', # Your validation data LR path 86 | gt_folder='dataset/DIV2KRK/gt', # Your validation data HR path 87 | pipeline=test_pipeline, 88 | scale=scale, 89 | filename_tmpl='{}'), 90 | test=dict( 91 | type=val_dataset_type, 92 | lq_folder='dataset/DIV2KRK/lr_x4', # Your validation data LR path 93 | gt_folder='dataset/DIV2KRK/gt', # Your validation data HR path 94 | pipeline=test_pipeline, 95 | scale=scale, 96 | filename_tmpl='{}')) 97 | 98 | # optimizer 99 | optimizers = dict(generator=dict(type='Adam', lr=6.25e-6, betas=(0.9, 0.999))) 100 | 101 | # learning policy 102 | total_iters = 400000 103 | lr_config = dict( 104 | policy='Step', by_epoch=False, step=[100000, 200000, 300000], gamma=0.5) 105 | 106 | checkpoint_config = dict(interval=1000, save_optimizer=True, by_epoch=False) 107 | evaluation = dict(interval=1000, save_image=False, gpu_collect=True) 108 | log_config = dict( 109 | interval=100, 110 | hooks=[ 111 | dict(type='TextLoggerHook', by_epoch=False), 112 | dict(type='TensorboardLoggerHook'), 113 | # dict(type='PaviLoggerHook', init_kwargs=dict(project='mmedit-sr')) 114 | ]) 115 | visual_config = None 116 | 117 | # runtime settings 118 | dist_params = dict(backend='nccl') 119 | log_level = 'INFO' 120 | work_dir = f'./work_dirs/{exp_name}' 121 | load_from = 'mmediting/danx4_l1_256/iter_40000.pth' 122 | resume_from = None 123 | workflow = [('train', 1)] 124 | -------------------------------------------------------------------------------- /mmedit/models/backbones/sr_backbones/dan_net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from mmcv.runner import load_checkpoint 7 | from mmedit.models.registry import BACKBONES 8 | from mmedit.utils import get_root_logger 9 | 10 | 11 | class CALayer(nn.Module): 12 | def __init__(self, nf, reduction=16): 13 | super(CALayer, self).__init__() 14 | self.body = nn.Sequential( 15 | nn.Conv2d(nf, nf // reduction, 1, 1, 0), 16 | nn.LeakyReLU(0.2), 17 | nn.Conv2d(nf // reduction, nf, 1, 1, 0), 18 | nn.Sigmoid(), 19 | ) 20 | self.avg = nn.AdaptiveAvgPool2d(1) 21 | 22 | def forward(self, x): 23 | y = self.avg(x) 24 | y = self.body(y) 25 | return torch.mul(x, y) 26 | 27 | 28 | class CRB_Layer(nn.Module): 29 | def __init__(self, nf1, nf2): 30 | super(CRB_Layer, self).__init__() 31 | 32 | body = [ 33 | nn.Conv2d(nf1 + nf2, nf1 + nf2, 3, 1, 1), 34 | nn.LeakyReLU(0.2, True), 35 | nn.Conv2d(nf1 + nf2, nf1, 3, 1, 1), 36 | CALayer(nf1), 37 | ] 38 | 39 | self.body = nn.Sequential(*body) 40 | 41 | def forward(self, x): 42 | f1, f2 = x 43 | f1 = self.body(torch.cat(x, 1)) + f1 44 | return [f1, f2] 45 | 46 | 47 | class Estimator(nn.Module): 48 | def __init__(self, in_nc=3, nf=64, num_blocks=5, scale=4, kernel_size=4): 49 | super(Estimator, self).__init__() 50 | 51 | self.ksize = kernel_size 52 | 53 | self.head_LR = nn.Conv2d(in_nc, nf // 2, 1, 1, 0) 54 | self.head_HR = nn.Conv2d(in_nc, nf // 2, 9, scale, 4) 55 | 56 | body = [CRB_Layer(nf // 2, nf // 2) for _ in range(num_blocks)] 57 | self.body = nn.Sequential(*body) 58 | 59 | self.out = nn.Conv2d(nf // 2, 10, 3, 1, 1) 60 | self.globalPooling = nn.AdaptiveAvgPool2d((1, 1)) 61 | 62 | def forward(self, GT, LR): 63 | 64 | lrf = self.head_LR(LR) 65 | hrf = self.head_HR(GT) 66 | 67 | f = [lrf, hrf] 68 | f, _ = self.body(f) 69 | f = self.out(f) 70 | f = self.globalPooling(f) 71 | f = f.view(f.size()[:2]) 72 | 73 | return f 74 | 75 | 76 | class Restorer(nn.Module): 77 | def __init__( 78 | self, in_nc=3, out_nc=3, nf=64, nb=8, scale=4, input_para=10, min=0.0, max=1.0 79 | ): 80 | super(Restorer, self).__init__() 81 | self.min = min 82 | self.max = max 83 | self.para = input_para 84 | self.num_blocks = nb 85 | 86 | self.head = nn.Conv2d(in_nc, nf, 3, stride=1, padding=1) 87 | 88 | body = [CRB_Layer(nf, input_para) for _ in range(nb)] 89 | self.body = nn.Sequential(*body) 90 | 91 | self.fusion = nn.Conv2d(nf, nf, 3, 1, 1) 92 | 93 | if scale == 4: # x4 94 | self.upscale = nn.Sequential( 95 | nn.Conv2d( 96 | in_channels=nf, 97 | out_channels=nf * scale, 98 | kernel_size=3, 99 | stride=1, 100 | padding=1, 101 | bias=True, 102 | ), 103 | nn.PixelShuffle(scale // 2), 104 | nn.Conv2d( 105 | in_channels=nf, 106 | out_channels=nf * scale, 107 | kernel_size=3, 108 | stride=1, 109 | padding=1, 110 | bias=True, 111 | ), 112 | nn.PixelShuffle(scale // 2), 113 | nn.Conv2d(nf, 3, 3, 1, 1), 114 | ) 115 | else: # x2, x3 116 | self.upscale = nn.Sequential( 117 | nn.Conv2d( 118 | in_channels=nf, 119 | out_channels=nf * scale ** 2, 120 | kernel_size=3, 121 | stride=1, 122 | padding=1, 123 | bias=True, 124 | ), 125 | nn.PixelShuffle(scale), 126 | nn.Conv2d(nf, 3, 3, 1, 1), 127 | ) 128 | 129 | def forward(self, input, ker_code): 130 | B, C, H, W = input.size() # I_LR batch 131 | B_h, C_h = ker_code.size() # Batch, Len=10 132 | ker_code_exp = ker_code.view((B_h, C_h, 1, 1)).expand( 133 | (B_h, C_h, H, W) 134 | ) # kernel_map stretch 135 | 136 | f = self.head(input) 137 | inputs = [f, ker_code_exp] 138 | f, _ = self.body(inputs) 139 | f = self.fusion(f) 140 | out = self.upscale(f) 141 | 142 | return out # torch.clamp(out, min=self.min, max=self.max) 143 | 144 | 145 | @BACKBONES.register_module() 146 | class DAN(nn.Module): 147 | def __init__( 148 | self, 149 | nf=64, 150 | nb=16, 151 | upscale=4, 152 | input_para=10, 153 | kernel_size=21, 154 | loop=8, 155 | pca_matrix_path=None, 156 | ): 157 | super(DAN, self).__init__() 158 | 159 | self.ksize = kernel_size 160 | self.loop = loop 161 | self.scale = upscale 162 | 163 | self.Restorer = Restorer(nf=nf, nb=nb, scale=self.scale, input_para=input_para) 164 | self.Estimator = Estimator(kernel_size=kernel_size, scale=self.scale) 165 | 166 | self.register_buffer("encoder", torch.load(pca_matrix_path)[None]) 167 | 168 | kernel = torch.zeros(1, self.ksize, self.ksize) 169 | kernel[:, self.ksize // 2, self.ksize // 2] = 1 170 | 171 | self.register_buffer("init_kernel", kernel) 172 | init_ker_map = self.init_kernel.view(1, 1, self.ksize ** 2).matmul( 173 | self.encoder 174 | )[:, 0] 175 | self.register_buffer("init_ker_map", init_ker_map) 176 | 177 | def forward(self, lr): 178 | 179 | srs = [] 180 | ker_maps = [] 181 | 182 | B, C, H, W = lr.shape 183 | ker_map = self.init_ker_map.repeat([B, 1]) 184 | 185 | for i in range(self.loop): 186 | 187 | sr = self.Restorer(lr, ker_map.detach()) 188 | ker_map = self.Estimator(sr.detach(), lr) 189 | 190 | srs.append(sr) 191 | ker_maps.append(ker_map) 192 | return [srs, ker_maps] 193 | 194 | def init_weights(self, pretrained=None, strict=True): 195 | """Init weights for models. 196 | Args: 197 | pretrained (str, optional): Path for pretrained weights. If given 198 | None, pretrained weights will not be loaded. Defaults to None. 199 | strict (boo, optional): Whether strictly load the pretrained model. 200 | Defaults to True. 201 | """ 202 | if isinstance(pretrained, str): 203 | logger = get_root_logger() 204 | load_checkpoint(self, pretrained, strict=strict, logger=logger) 205 | elif pretrained is None: 206 | pass # use default initialization 207 | else: 208 | raise TypeError('"pretrained" must be a str or None. ' 209 | f'But received {type(pretrained)}.') 210 | -------------------------------------------------------------------------------- /mmedit/models/common/dan_preprocess.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy.io import loadmat 8 | from torch.autograd import Variable 9 | from torchvision.utils import make_grid 10 | 11 | 12 | def img2tensor(img): 13 | """ 14 | # BGR to RGB, HWC to CHW, numpy to tensor 15 | Input: img(H, W, C), [0,255], np.uint8 (default) 16 | Output: 3D(C,H,W), RGB order, float tensor 17 | """ 18 | img = img.astype(np.float32) / 255.0 19 | img = img[:, :, [2, 1, 0]] 20 | img = torch.from_numpy(np.ascontiguousarray(np.transpose(img, 21 | (2, 0, 22 | 1)))).float() 23 | return img 24 | 25 | 26 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 27 | """ 28 | Converts a torch Tensor into an image Numpy array 29 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 30 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 31 | """ 32 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 33 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0] 34 | ) # to range [0,1] 35 | n_dim = tensor.dim() 36 | if n_dim == 4: 37 | n_img = len(tensor) 38 | img_np = make_grid( 39 | tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 40 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 41 | elif n_dim == 3: 42 | img_np = tensor.numpy() 43 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 44 | elif n_dim == 2: 45 | img_np = tensor.numpy() 46 | else: 47 | raise TypeError('Only support 4D, 3D and 2D tensor.') 48 | if out_type == np.uint8: 49 | img_np = (img_np * 255.0).round() 50 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 51 | return img_np.astype(out_type) 52 | 53 | 54 | def cubic(x): 55 | absx = torch.abs(x) 56 | absx2 = absx**2 57 | absx3 = absx**3 58 | 59 | weight = (1.5 * absx3 - 2.5 * absx2 + 1) * ( 60 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 61 | 2) * (((absx > 1) * 62 | (absx <= 2)).type_as(absx)) 63 | return weight 64 | 65 | 66 | def calculate_weights_indices(in_length, out_length, scale, kernel, 67 | kernel_width, antialiasing): 68 | if (scale < 1) and (antialiasing): 69 | # Use a modified kernel to simultaneously interpolate 70 | kernel_width = kernel_width / scale 71 | 72 | # Output-space coordinates 73 | x = torch.linspace(1, out_length, out_length) 74 | 75 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 76 | # in output space maps to 0.5 in input space, and 0.5+scale in output 77 | # space maps to 1.5 in input space. 78 | u = x / scale + 0.5 * (1 - 1 / scale) 79 | 80 | # What is the left-most pixel that can be involved in the computation? 81 | left = torch.floor(u - kernel_width / 2) 82 | 83 | # What is the maximum number of pixels that can be involved in the 84 | # computation? Note: it's OK to use an extra pixel here; if the 85 | # corresponding weights are all zero, it will be eliminated at the end 86 | # of this function. 87 | P = math.ceil(kernel_width) + 2 88 | 89 | # The indices of the input pixels involved in computing the k-th output 90 | # pixel are in row k of the indices matrix. 91 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( 92 | 0, P - 1, P).view(1, P).expand(out_length, P) 93 | 94 | # The weights used to compute the k-th output pixel are in row k of the 95 | # weights matrix. 96 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices 97 | # apply cubic kernel 98 | if (scale < 1) and (antialiasing): 99 | weights = scale * cubic(distance_to_center * scale) 100 | else: 101 | weights = cubic(distance_to_center) 102 | # Normalize the weights matrix so that each row sums to 1. 103 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 104 | weights = weights / weights_sum.expand(out_length, P) 105 | 106 | # If a column in weights is all zero, get rid of it. 107 | weights_zero_tmp = torch.sum((weights == 0), 0) 108 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 109 | indices = indices.narrow(1, 1, P - 2) 110 | weights = weights.narrow(1, 1, P - 2) 111 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 112 | indices = indices.narrow(1, 0, P - 2) 113 | weights = weights.narrow(1, 0, P - 2) 114 | weights = weights.contiguous() 115 | indices = indices.contiguous() 116 | sym_len_s = -indices.min() + 1 117 | sym_len_e = indices.max() - in_length 118 | indices = indices + sym_len_s - 1 119 | return weights, indices, int(sym_len_s), int(sym_len_e) 120 | 121 | 122 | def imresize(img, scale, antialiasing=True): 123 | # Now the scale should be the same for H and W 124 | # input: img: CHW RGB [0,1] 125 | # output: CHW RGB [0,1] w/o round 126 | is_numpy = False 127 | if isinstance(img, np.ndarray): 128 | img = torch.from_numpy(img.transpose(2, 0, 1)) 129 | is_numpy = True 130 | device = img.device 131 | 132 | is_batch = True 133 | if len(img.shape) == 3: # C, H, W 134 | img = img[None] 135 | is_batch = False 136 | 137 | B, in_C, in_H, in_W = img.size() 138 | img = img.view(-1, in_H, in_W) 139 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 140 | kernel_width = 4 141 | kernel = 'cubic' 142 | 143 | # Return the desired dimension order for performing the resize. The 144 | # strategy is to perform the resize first along the dimension with the 145 | # smallest scale factor. 146 | # Now we do not support this. 147 | 148 | # get weights and indices 149 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 150 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 151 | weights_H, indices_H = weights_H.to(device), indices_H.to(device) 152 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 153 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 154 | weights_W, indices_W = weights_W.to(device), indices_W.to(device) 155 | # process H dimension 156 | # symmetric copying 157 | img_aug = torch.FloatTensor(B * in_C, in_H + sym_len_Hs + sym_len_He, 158 | in_W).to(device) 159 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) 160 | 161 | sym_patch = img[:, :sym_len_Hs, :] 162 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long().to(device) 163 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 164 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) 165 | 166 | sym_patch = img[:, -sym_len_He:, :] 167 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long().to(device) 168 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 169 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 170 | 171 | out_1 = torch.FloatTensor(B * in_C, out_H, in_W).to(device) 172 | kernel_width = weights_H.size(1) 173 | for i in range(out_H): 174 | idx = int(indices_H[i][0]) 175 | out_1[:, i, :] = (img_aug[:, idx:idx + kernel_width, :].transpose( 176 | 1, 2).matmul(weights_H[i][None, :, None].repeat(B * in_C, 1, 177 | 1))).squeeze() 178 | 179 | # process W dimension 180 | # symmetric copying 181 | out_1_aug = torch.FloatTensor(B * in_C, out_H, 182 | in_W + sym_len_Ws + sym_len_We).to(device) 183 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) 184 | 185 | sym_patch = out_1[:, :, :sym_len_Ws] 186 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long().to(device) 187 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 188 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) 189 | 190 | sym_patch = out_1[:, :, -sym_len_We:] 191 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long().to(device) 192 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 193 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 194 | 195 | out_2 = torch.FloatTensor(B * in_C, out_H, out_W).to(device) 196 | kernel_width = weights_W.size(1) 197 | for i in range(out_W): 198 | idx = int(indices_W[i][0]) 199 | out_2[:, :, i] = (out_1_aug[:, :, idx:idx + kernel_width].matmul( 200 | weights_W[i][None, :, None].repeat(B * in_C, 1, 1))).squeeze() 201 | 202 | out_2 = out_2.contiguous().view(B, in_C, out_H, out_W) 203 | if not is_batch: 204 | out_2 = out_2[0] 205 | return out_2.cpu().numpy().transpose(1, 2, 0) if is_numpy else out_2 206 | 207 | 208 | def DUF_downsample(x, scale=4): 209 | """Downsamping with Gaussian kernel used in the DUF official code 210 | Args: 211 | x (Tensor, [B, T, C, H, W]): frames to be downsampled. 212 | scale (int): downsampling factor: 2 | 3 | 4. 213 | """ 214 | 215 | assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale) 216 | 217 | def gkern(kernlen=13, nsig=1.6): 218 | import scipy.ndimage.filters as fi 219 | 220 | inp = np.zeros((kernlen, kernlen)) 221 | # set element at the middle to one, a dirac delta 222 | inp[kernlen // 2, kernlen // 2] = 1 223 | # gaussian-smooth the dirac, resulting in a gaussian filter mask 224 | return fi.gaussian_filter(inp, nsig) 225 | 226 | B, T, C, H, W = x.size() 227 | x = x.view(-1, 1, H, W) 228 | # 6 is the pad of the gaussian filter 229 | pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 230 | r_h, r_w = 0, 0 231 | if scale == 3: 232 | r_h = 3 - (H % 3) 233 | r_w = 3 - (W % 3) 234 | x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect') 235 | 236 | gaussian_filter = ( 237 | torch.from_numpy(gkern(13, 0.4 * 238 | scale)).type_as(x).unsqueeze(0).unsqueeze(0)) 239 | x = F.conv2d(x, gaussian_filter, stride=scale) 240 | x = x[:, :, 2:-2, 2:-2] 241 | x = x.view(B, T, C, x.size(2), x.size(3)) 242 | return x 243 | 244 | 245 | def PCA(data, k=2): 246 | X = torch.from_numpy(data) 247 | X_mean = torch.mean(X, 0) 248 | X = X - X_mean.expand_as(X) 249 | U, S, V = torch.svd(torch.t(X)) 250 | return U[:, :k] # PCA matrix 251 | 252 | 253 | def random_batch_kernel( 254 | batch, 255 | kernel_size=21, 256 | sig_min=0.2, 257 | sig_max=4.0, 258 | rate_iso=1.0, 259 | tensor=True, 260 | random_disturb=False, 261 | ): 262 | if rate_iso == 1: 263 | 264 | sigma = np.random.uniform(sig_min, sig_max, (batch, 1, 1)) 265 | ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0) 266 | xx, yy = np.meshgrid(ax, ax) 267 | xx = xx[None].repeat(batch, 0) 268 | yy = yy[None].repeat(batch, 0) 269 | kernel = np.exp(-(xx**2 + yy**2) / (2.0 * sigma**2)) 270 | kernel = kernel / np.sum(kernel, (1, 2), keepdims=True) 271 | return torch.FloatTensor(kernel) if tensor else kernel 272 | 273 | else: 274 | 275 | sigma_x = np.random.uniform(sig_min, sig_max, (batch, 1, 1)) 276 | sigma_y = np.random.uniform(sig_min, sig_max, (batch, 1, 1)) 277 | 278 | D = np.zeros((batch, 2, 2)) 279 | D[:, 0, 0] = sigma_x.squeeze()**2 280 | D[:, 1, 1] = sigma_y.squeeze()**2 281 | 282 | radians = np.random.uniform(-np.pi, np.pi, (batch)) 283 | mask_iso = np.random.uniform(0, 1, (batch)) < rate_iso 284 | radians[mask_iso] = 0 285 | sigma_y[mask_iso] = sigma_x[mask_iso] 286 | 287 | U = np.zeros((batch, 2, 2)) 288 | U[:, 0, 0] = np.cos(radians) 289 | U[:, 0, 1] = -np.sin(radians) 290 | U[:, 1, 0] = np.sin(radians) 291 | U[:, 1, 1] = np.cos(radians) 292 | sigma = np.matmul(U, np.matmul(D, U.transpose(0, 2, 1))) 293 | ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0) 294 | xx, yy = np.meshgrid(ax, ax) 295 | xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), 296 | yy.reshape(kernel_size * kernel_size, 297 | 1))).reshape(kernel_size, kernel_size, 2) 298 | xy = xy[None].repeat(batch, 0) 299 | inverse_sigma = np.linalg.inv(sigma)[:, None, None] 300 | kernel = np.exp(-0.5 * np.matmul( 301 | np.matmul(xy[:, :, :, None], inverse_sigma), xy[:, :, :, :, None])) 302 | kernel = kernel.reshape(batch, kernel_size, kernel_size) 303 | if random_disturb: 304 | kernel = kernel + np.random.uniform( 305 | 0, 0.25, (batch, kernel_size, kernel_size)) * kernel 306 | kernel = kernel / np.sum(kernel, (1, 2), keepdims=True) 307 | 308 | return torch.FloatTensor(kernel) if tensor else kernel 309 | 310 | 311 | def stable_batch_kernel(batch, kernel_size=21, sig=2.6, tensor=True): 312 | sigma = sig 313 | ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0) 314 | xx, yy = np.meshgrid(ax, ax) 315 | xx = xx[None].repeat(batch, 0) 316 | yy = yy[None].repeat(batch, 0) 317 | kernel = np.exp(-(xx**2 + yy**2) / (2.0 * sigma**2)) 318 | kernel = kernel / np.sum(kernel, (1, 2), keepdims=True) 319 | return torch.FloatTensor(kernel) if tensor else kernel 320 | 321 | 322 | def b_Bicubic(variable, scale): 323 | B, C, H, W = variable.size() 324 | # H_new = int(H / scale) 325 | # W_new = int(W / scale) 326 | tensor_v = variable.view((B, C, H, W)) 327 | re_tensor = imresize(tensor_v, 1 / scale) 328 | return re_tensor 329 | 330 | 331 | def random_batch_noise(batch, high, rate_cln=1.0): 332 | noise_level = np.random.uniform(size=(batch, 1)) * high 333 | noise_mask = np.random.uniform(size=(batch, 1)) 334 | noise_mask[noise_mask < rate_cln] = 0 335 | noise_mask[noise_mask >= rate_cln] = 1 336 | return noise_level * noise_mask 337 | 338 | 339 | def b_GaussianNoising(tensor, 340 | sigma, 341 | mean=0.0, 342 | noise_size=None, 343 | min=0.0, 344 | max=1.0): 345 | if noise_size is None: 346 | size = tensor.size() 347 | else: 348 | size = noise_size 349 | noise = torch.mul( 350 | torch.FloatTensor(np.random.normal(loc=mean, scale=1.0, size=size)), 351 | sigma.view(sigma.size() + (1, 1)), 352 | ).to(tensor.device) 353 | return torch.clamp(noise + tensor, min=min, max=max) 354 | 355 | 356 | class BatchSRKernel(object): 357 | 358 | def __init__( 359 | self, 360 | kernel_size=21, 361 | sig=2.6, 362 | sig_min=0.2, 363 | sig_max=4.0, 364 | rate_iso=1.0, 365 | random_disturb=False, 366 | ): 367 | self.kernel_size = kernel_size 368 | self.sig = sig 369 | self.sig_min = sig_min 370 | self.sig_max = sig_max 371 | self.rate = rate_iso 372 | self.random_disturb = random_disturb 373 | 374 | def __call__(self, random, batch, tensor=False): 375 | if random: # random kernel 376 | return random_batch_kernel( 377 | batch, 378 | kernel_size=self.kernel_size, 379 | sig_min=self.sig_min, 380 | sig_max=self.sig_max, 381 | rate_iso=self.rate, 382 | tensor=tensor, 383 | random_disturb=self.random_disturb, 384 | ) 385 | else: # stable kernel 386 | return stable_batch_kernel( 387 | batch, 388 | kernel_size=self.kernel_size, 389 | sig=self.sig, 390 | tensor=tensor) 391 | 392 | 393 | class BatchBlurKernel(object): 394 | 395 | def __init__(self, kernels_path): 396 | kernels = loadmat(kernels_path)['kernels'] 397 | self.num_kernels = kernels.shape[0] 398 | self.kernels = kernels 399 | 400 | def __call__(self, random, batch, tensor=False): 401 | index = np.random.randint(0, self.num_kernels, batch) 402 | kernels = self.kernels[index] 403 | return torch.FloatTensor(kernels).contiguous() if tensor else kernels 404 | 405 | 406 | class PCAEncoder(nn.Module): 407 | 408 | def __init__(self, weight): 409 | super().__init__() 410 | self.register_buffer('weight', weight) 411 | self.size = self.weight.size() 412 | 413 | def forward(self, batch_kernel): 414 | B, H, W = batch_kernel.size() # [B, l, l] 415 | return torch.bmm( 416 | batch_kernel.view((B, 1, H * W)), 417 | self.weight.expand((B, ) + self.size)).view((B, -1)) 418 | 419 | 420 | class BatchBlur(object): 421 | 422 | def __init__(self, kernel_size=15): 423 | self.kernel_size = kernel_size 424 | if kernel_size % 2 == 1: 425 | self.pad = (kernel_size // 2, kernel_size // 2, kernel_size // 2, 426 | kernel_size // 2) 427 | else: 428 | self.pad = (kernel_size // 2, kernel_size // 2 - 1, 429 | kernel_size // 2, kernel_size // 2 - 1) 430 | # self.pad = nn.ZeroPad2d(l // 2) 431 | 432 | def __call__(self, input, kernel): 433 | B, C, H, W = input.size() 434 | pad = F.pad(input, self.pad, mode='reflect') 435 | H_p, W_p = pad.size()[-2:] 436 | 437 | if len(kernel.size()) == 2: 438 | input_CBHW = pad.view((C * B, 1, H_p, W_p)) 439 | kernel_var = kernel.contiguous().view( 440 | (1, 1, self.kernel_size, self.kernel_size)) 441 | return F.conv2d( 442 | input_CBHW, kernel_var, padding=0).view((B, C, H, W)) 443 | else: 444 | input_CBHW = pad.view((1, C * B, H_p, W_p)) 445 | kernel_var = ( 446 | kernel.contiguous().view( 447 | (B, 1, self.kernel_size, 448 | self.kernel_size)).repeat(1, C, 1, 1).view( 449 | (B * C, 1, self.kernel_size, self.kernel_size))) 450 | return F.conv2d( 451 | input_CBHW, kernel_var, groups=B * C).view((B, C, H, W)) 452 | 453 | 454 | class SRMDPreprocessing(object): 455 | 456 | def __init__(self, 457 | scale, 458 | pca_matrix, 459 | ksize=21, 460 | code_length=10, 461 | random_kernel=True, 462 | noise=False, 463 | cuda=False, 464 | random_disturb=False, 465 | sig=0, 466 | sig_min=0, 467 | sig_max=0, 468 | rate_iso=1.0, 469 | rate_cln=1, 470 | noise_high=0, 471 | stored_kernel=False, 472 | pre_kernel_path=None): 473 | self.encoder = PCAEncoder(pca_matrix).cuda() if cuda else PCAEncoder( 474 | pca_matrix) 475 | 476 | self.kernel_gen = ( 477 | BatchSRKernel( 478 | kernel_size=ksize, 479 | sig=sig, 480 | sig_min=sig_min, 481 | sig_max=sig_max, 482 | rate_iso=rate_iso, 483 | random_disturb=random_disturb, 484 | ) if not stored_kernel else BatchBlurKernel(pre_kernel_path)) 485 | 486 | self.blur = BatchBlur(kernel_size=ksize) 487 | self.para_in = code_length 488 | self.kernel_size = ksize 489 | self.noise = noise 490 | self.scale = scale 491 | self.cuda = cuda 492 | self.rate_cln = rate_cln 493 | self.noise_high = noise_high 494 | self.random = random_kernel 495 | 496 | def __call__(self, hr_tensor, kernel=False): 497 | # hr_tensor is tensor, not cuda tensor 498 | 499 | hr_var = Variable(hr_tensor).cuda() if self.cuda else Variable( 500 | hr_tensor) 501 | device = hr_var.device 502 | B, C, H, W = hr_var.size() 503 | 504 | b_kernels = Variable(self.kernel_gen(self.random, B, 505 | tensor=True)).to(device) 506 | hr_blured_var = self.blur(hr_var, b_kernels) 507 | 508 | # B x self.para_input 509 | kernel_code = self.encoder(b_kernels) 510 | 511 | # Down sample 512 | if self.scale != 1: 513 | lr_blured_t = b_Bicubic(hr_blured_var, self.scale) 514 | else: 515 | lr_blured_t = hr_blured_var 516 | 517 | # Noisy 518 | if self.noise: 519 | Noise_level = torch.FloatTensor( 520 | random_batch_noise(B, self.noise_high, self.rate_cln)) 521 | lr_noised_t = b_GaussianNoising(lr_blured_t, self.noise_high) 522 | else: 523 | Noise_level = torch.zeros((B, 1)) 524 | lr_noised_t = lr_blured_t 525 | 526 | Noise_level = Variable(Noise_level).cuda() 527 | re_code = ( 528 | torch.cat([kernel_code, Noise_level * 529 | 10], dim=1) if self.noise else kernel_code) 530 | lr_re = Variable(lr_noised_t).to(device) 531 | 532 | return (lr_re, re_code, b_kernels) if kernel else (lr_re, re_code) 533 | -------------------------------------------------------------------------------- /mmedit/models/restorers/dan.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os.path as osp 3 | 4 | import mmcv 5 | import torch 6 | from mmcv.runner import auto_fp16 7 | 8 | from mmedit.core import psnr, ssim, tensor2img 9 | from .basic_restorer import BasicRestorer 10 | from ..builder import build_backbone, build_loss 11 | from ..registry import MODELS 12 | from ..common import SRMDPreprocessing 13 | 14 | 15 | @MODELS.register_module() 16 | class DAN(BasicRestorer): 17 | """Basic model for image restoration. 18 | 19 | It must contain a generator that takes an image as inputs and outputs a 20 | restored image. It also has a pixel-wise loss for training. 21 | 22 | The subclasses should overwrite the function `forward_train`, 23 | `forward_test` and `train_step`. 24 | 25 | Args: 26 | generator (dict): Config for the generator structure. 27 | pixel_loss (dict): Config for pixel-wise loss. 28 | train_cfg (dict): Config for training. Default: None. 29 | test_cfg (dict): Config for testing. Default: None. 30 | pretrained (str): Path for pretrained model. Default: None. 31 | """ 32 | allowed_metrics = {'PSNR': psnr, 'SSIM': ssim} 33 | 34 | def __init__(self, 35 | generator, 36 | pixel_loss, 37 | train_cfg=None, 38 | test_cfg=None, 39 | pretrained=None): 40 | super(BasicRestorer, self).__init__() 41 | 42 | self.train_cfg = train_cfg 43 | self.test_cfg = test_cfg 44 | 45 | # support fp16 46 | self.fp16_enabled = False 47 | 48 | # load PCA matrix of enough kernel 49 | if train_cfg: 50 | pca_matrix = torch.load(train_cfg['pca_matrix_path'], map_location=lambda storage, loc: storage) 51 | print('PCA matrix shape:{}'.format(pca_matrix.shape)) 52 | self.prepro = SRMDPreprocessing(train_cfg['scale'], pca_matrix=pca_matrix, cuda=True, **train_cfg['degradation']) 53 | 54 | # generator 55 | self.generator = build_backbone(generator) 56 | # for parameters in self.generator.parameters(): 57 | # pass 58 | # print(parameters) 59 | self.init_weights(pretrained) 60 | 61 | # loss 62 | self.pixel_loss = build_loss(pixel_loss) 63 | 64 | def init_weights(self, pretrained=None, strict=True): 65 | """Init weights for models. 66 | 67 | Args: 68 | pretrained (str, optional): Path for pretrained weights. If given 69 | None, pretrained weights will not be loaded. Defaults to None. 70 | """ 71 | self.generator.init_weights(pretrained, strict) 72 | 73 | @auto_fp16(apply_to=('lq', )) 74 | def forward(self, lq, gt=None, ker_map=None, test_mode=False, **kwargs): 75 | """Forward function. 76 | 77 | Args: 78 | lq (Tensor): Input lq images. 79 | gt (Tensor): Ground-truth image. Default: None. 80 | test_mode (bool): Whether in test mode or not. Default: False. 81 | kwargs (dict): Other arguments. 82 | """ 83 | 84 | if test_mode: 85 | return self.forward_test(lq, gt, **kwargs) 86 | 87 | return self.forward_train(lq, gt, ker_map) 88 | 89 | def forward_train(self, lq, gt, ker_map): 90 | """Training forward function. 91 | 92 | Args: 93 | lq (Tensor): LQ Tensor with shape (n, c, h, w). 94 | gt (Tensor): GT Tensor with shape (n, c, h, w). 95 | ker_map (Tensor) : ker map Tensor 96 | Returns: 97 | Tensor: Output tensor. 98 | """ 99 | losses = dict() 100 | srs, fake_kers = self.generator(lq) 101 | output = srs[-1] 102 | out_ker = fake_kers[-1] 103 | loss_pix = 0 104 | loss_pix += self.pixel_loss(out_ker, ker_map) 105 | loss_pix += self.pixel_loss(output, gt) 106 | 107 | losses['loss_pix'] = loss_pix 108 | outputs = dict( 109 | losses=losses, 110 | num_samples=len(gt.data), 111 | results=dict(lq=lq.cpu(), gt=gt.cpu(), output=output.cpu())) 112 | return outputs 113 | 114 | def evaluate(self, output, gt): 115 | """Evaluation function. 116 | 117 | Args: 118 | output (Tensor): Model output with shape (n, c, h, w). 119 | gt (Tensor): GT Tensor with shape (n, c, h, w). 120 | 121 | Returns: 122 | dict: Evaluation results. 123 | """ 124 | crop_border = self.test_cfg.crop_border 125 | 126 | output = tensor2img(output) 127 | gt = tensor2img(gt) 128 | 129 | eval_result = dict() 130 | for metric in self.test_cfg.metrics: 131 | eval_result[metric] = self.allowed_metrics[metric](output, gt, 132 | crop_border) 133 | return eval_result 134 | 135 | def forward_test(self, 136 | lq, 137 | gt=None, 138 | meta=None, 139 | save_image=False, 140 | save_path=None, 141 | iteration=None): 142 | """Testing forward function. 143 | 144 | Args: 145 | lq (Tensor): LQ Tensor with shape (n, c, h, w). 146 | gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None. 147 | save_image (bool): Whether to save image. Default: False. 148 | save_path (str): Path to save image. Default: None. 149 | iteration (int): Iteration for the saving image name. 150 | Default: None. 151 | 152 | Returns: 153 | dict: Output results. 154 | """ 155 | srs, _ = self.generator(lq) 156 | output = srs[-1] 157 | if self.test_cfg is not None and self.test_cfg.get('metrics', None): 158 | assert gt is not None, ( 159 | 'evaluation with metrics must have gt images.') 160 | results = dict(eval_result=self.evaluate(output, gt)) 161 | else: 162 | results = dict(lq=lq.cpu(), output=output.cpu()) 163 | if gt is not None: 164 | results['gt'] = gt.cpu() 165 | 166 | # save image 167 | if save_image: 168 | lq_path = meta[0]['lq_path'] 169 | folder_name = osp.splitext(osp.basename(lq_path))[0] 170 | if isinstance(iteration, numbers.Number): 171 | save_path = osp.join(save_path, folder_name, 172 | f'{folder_name}-{iteration + 1:06d}.png') 173 | elif iteration is None: 174 | save_path = osp.join(save_path, f'{folder_name}.png') 175 | else: 176 | raise ValueError('iteration should be number or None, ' 177 | f'but got {type(iteration)}') 178 | mmcv.imwrite(tensor2img(output), save_path) 179 | 180 | return results 181 | 182 | def forward_dummy(self, GT_img): 183 | """Forward of networks. 184 | 185 | Args: 186 | gt (Tensor): GT image. 187 | 188 | Returns: 189 | out (Tensor): Predicted super-resolution results (n, 3, 4h, 4w). 190 | """ 191 | LR_img, ker_map = self.prepro(GT_img) 192 | LR_img = (LR_img * 255).round() / 255 193 | GT_img = GT_img.cuda() 194 | LR_img = LR_img.cuda() 195 | ker_map = ker_map.cuda() 196 | return GT_img, LR_img, ker_map 197 | 198 | def train_step(self, data_batch, optimizer): 199 | """Train step. 200 | 201 | Args: 202 | data_batch (dict): A batch of data. 203 | optimizer (obj): Optimizer. 204 | 205 | Returns: 206 | dict: Returned output. 207 | """ 208 | gt = data_batch['gt'] 209 | GT_img, LR_img, ker_map = self.forward_dummy(gt) 210 | outputs = self(LR_img, GT_img, ker_map, test_mode=False) 211 | loss, log_vars = self.parse_losses(outputs.pop('losses')) 212 | 213 | # optimize 214 | optimizer['generator'].zero_grad() 215 | loss.backward() 216 | optimizer['generator'].step() 217 | 218 | outputs.update({'log_vars': log_vars}) 219 | return outputs 220 | 221 | def val_step(self, data_batch, **kwargs): 222 | """Validation step. 223 | 224 | Args: 225 | data_batch (dict): A batch of data. 226 | kwargs (dict): Other arguments for ``val_step``. 227 | 228 | Returns: 229 | dict: Returned output. 230 | """ 231 | output = self.forward_test(**data_batch, **kwargs) 232 | return output 233 | -------------------------------------------------------------------------------- /tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexZou14/DAN-Basd-on-Openmmlab/fc0d273edf7d874a641e47cedde8d5ca54abcad7/tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x2.pth -------------------------------------------------------------------------------- /tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlexZou14/DAN-Basd-on-Openmmlab/fc0d273edf7d874a641e47cedde8d5ca54abcad7/tools/data/super-resolution/dan_datasets/pca_matrix/pca_aniso_matrix_x4.pth -------------------------------------------------------------------------------- /tools/data/super-resolution/dan_datasets/preprocess_dan_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from mmedit.models.common.DANpreprocess import tensor2img, imresize, SRMDPreprocessing, img2tensor 7 | import random 8 | 9 | def set_random_seed(seed): 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | 15 | def generate_mod_LR_bic(): 16 | # set parameters 17 | up_scale = 2 18 | mod_scale = 4 19 | # set data dir 20 | sourcedir = "/media/lab216/dbf4a469-6c52-4371-95c4-c24c882bc23b/data/benchmark/Set5/HR" 21 | savedir = "/media/lab216/dbf4a469-6c52-4371-95c4-c24c882bc23b/data/benchmark/Set5/generate/" 22 | 23 | # load PCA matrix of enough kernel 24 | print("load PCA matrix") 25 | pca_matrix = torch.load( 26 | "/home/lab216/Desktop/mmediting/tools/data/super-resolution/div2k/pca_matrix/pca_matrix_x2.pth", map_location=lambda storage, loc: storage 27 | ) 28 | print("PCA matrix shape: {}".format(pca_matrix.shape)) 29 | 30 | degradation_setting = { 31 | "random_kernel": False, 32 | "code_length": 10, 33 | "ksize": 21, 34 | "pca_matrix": pca_matrix, 35 | "scale": up_scale, 36 | "cuda": True, 37 | "rate_iso": 1.0 38 | } 39 | 40 | # set random seed 41 | set_random_seed(0) 42 | 43 | saveHRpath = os.path.join(savedir, "HR", "x" + str(mod_scale)) 44 | saveLRpath = os.path.join(savedir, "LR", "x" + str(up_scale)) 45 | saveBicpath = os.path.join(savedir, "Bic", "x" + str(up_scale)) 46 | saveLRblurpath = os.path.join(savedir, "LRblur", "x" + str(up_scale)) 47 | 48 | if not os.path.isdir(sourcedir): 49 | print("Error: No source data found") 50 | exit(0) 51 | if not os.path.isdir(savedir): 52 | os.mkdir(savedir) 53 | 54 | if not os.path.isdir(os.path.join(savedir, "HR")): 55 | os.mkdir(os.path.join(savedir, "HR")) 56 | if not os.path.isdir(os.path.join(savedir, "LR")): 57 | os.mkdir(os.path.join(savedir, "LR")) 58 | if not os.path.isdir(os.path.join(savedir, "Bic")): 59 | os.mkdir(os.path.join(savedir, "Bic")) 60 | if not os.path.isdir(os.path.join(savedir, "LRblur")): 61 | os.mkdir(os.path.join(savedir, "LRblur")) 62 | 63 | if not os.path.isdir(saveHRpath): 64 | os.mkdir(saveHRpath) 65 | else: 66 | print("It will cover " + str(saveHRpath)) 67 | 68 | if not os.path.isdir(saveLRpath): 69 | os.mkdir(saveLRpath) 70 | else: 71 | print("It will cover " + str(saveLRpath)) 72 | 73 | if not os.path.isdir(saveBicpath): 74 | os.mkdir(saveBicpath) 75 | else: 76 | print("It will cover " + str(saveBicpath)) 77 | 78 | if not os.path.isdir(saveLRblurpath): 79 | os.mkdir(saveLRblurpath) 80 | else: 81 | print("It will cover " + str(saveLRblurpath)) 82 | 83 | filepaths = sorted([f for f in os.listdir(sourcedir) if f.endswith(".png")]) 84 | print(filepaths) 85 | num_files = len(filepaths) 86 | 87 | # kernel_map_tensor = torch.zeros((num_files, 1, 10)) # each kernel map: 1*10 88 | 89 | # prepare data with augementation 90 | 91 | for i in range(num_files): 92 | filename = filepaths[i] 93 | print("No.{} -- Processing {}".format(i, filename)) 94 | # read image 95 | image = cv2.imread(os.path.join(sourcedir, filename)) 96 | 97 | width = int(np.floor(image.shape[1] / mod_scale)) 98 | height = int(np.floor(image.shape[0] / mod_scale)) 99 | # modcrop 100 | if len(image.shape) == 3: 101 | image_HR = image[0: mod_scale * height, 0: mod_scale * width, :] 102 | else: 103 | image_HR = image[0: mod_scale * height, 0: mod_scale * width] 104 | # LR_blur, by random gaussian kernel 105 | img_HR = img2tensor(image_HR) 106 | C, H, W = img_HR.size() 107 | 108 | for sig in np.linspace(1.8, 3.2, 8): 109 | prepro = SRMDPreprocessing(sig=sig, **degradation_setting) 110 | 111 | LR_img, ker_map = prepro(img_HR.view(1, C, H, W)) 112 | image_LR_blur = tensor2img(LR_img) 113 | cv2.imwrite(os.path.join(saveLRblurpath, 'sig{}_{}'.format(sig, filename)), image_LR_blur) 114 | cv2.imwrite(os.path.join(saveHRpath, 'sig{}_{}'.format(sig, filename)), image_HR) 115 | # LR 116 | image_LR = imresize(image_HR, 1 / up_scale, True) 117 | # bic 118 | image_Bic = imresize(image_LR, up_scale, True) 119 | 120 | # cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) 121 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) 122 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) 123 | 124 | # kernel_map_tensor[i] = ker_map 125 | # save dataset corresponding kernel maps 126 | # torch.save(kernel_map_tensor, './Set5_sig2.6_kermap.pth') 127 | print("Image Blurring & Down smaple Done: X" + str(up_scale)) 128 | 129 | 130 | if __name__ == "__main__": 131 | generate_mod_LR_bic() 132 | -------------------------------------------------------------------------------- /tools/data/super-resolution/dan_datasets/preprocess_div2k_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import re 5 | import sys 6 | from multiprocessing import Pool 7 | 8 | import cv2 9 | import lmdb 10 | import mmcv 11 | import numpy as np 12 | 13 | 14 | def main_extract_subimages(args): 15 | """A multi-thread tool to crop large images to sub-images for faster IO. 16 | 17 | It is used for DIV2K dataset. 18 | 19 | opt (dict): Configuration dict. It contains: 20 | n_thread (int): Thread number. 21 | compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. 22 | A higher value means a smaller size and longer compression time. 23 | Use 0 for faster CPU decompression. Default: 3, same in cv2. 24 | 25 | input_folder (str): Path to the input folder. 26 | save_folder (str): Path to save folder. 27 | crop_size (int): Crop size. 28 | step (int): Step for overlapped sliding window. 29 | thresh_size (int): Threshold size. Patches whose size is lower 30 | than thresh_size will be dropped. 31 | 32 | Usage: 33 | For each folder, run this script. 34 | Typically, there are four folders to be processed for DIV2K dataset. 35 | DIV2K_train_HR 36 | DIV2K_train_LR_bicubic/X2 37 | DIV2K_train_LR_bicubic/X3 38 | DIV2K_train_LR_bicubic/X4 39 | After process, each sub_folder should have the same number of 40 | subimages. 41 | Remember to modify opt configurations according to your settings. 42 | """ 43 | 44 | opt = {} 45 | opt['n_thread'] = args.n_thread 46 | opt['compression_level'] = args.compression_level 47 | 48 | # HR images 49 | opt['input_folder'] = osp.join(args.data_root, 'DF2K_train_HR') 50 | opt['save_folder'] = osp.join(args.data_root, 'DF2K_train_HR_sub') 51 | opt['crop_size'] = args.crop_size 52 | opt['step'] = args.step 53 | opt['thresh_size'] = args.thresh_size 54 | extract_subimages(opt) 55 | 56 | 57 | def extract_subimages(opt): 58 | """Crop images to subimages. 59 | 60 | Args: 61 | opt (dict): Configuration dict. It contains: 62 | input_folder (str): Path to the input folder. 63 | save_folder (str): Path to save folder. 64 | n_thread (int): Thread number. 65 | """ 66 | input_folder = opt['input_folder'] 67 | save_folder = opt['save_folder'] 68 | if not osp.exists(save_folder): 69 | os.makedirs(save_folder) 70 | print(f'mkdir {save_folder} ...') 71 | else: 72 | print(f'Folder {save_folder} already exists. Exit.') 73 | sys.exit(1) 74 | 75 | img_list = list(mmcv.scandir(input_folder)) 76 | img_list = [osp.join(input_folder, v) for v in img_list] 77 | 78 | prog_bar = mmcv.ProgressBar(len(img_list)) 79 | pool = Pool(opt['n_thread']) 80 | for path in img_list: 81 | pool.apply_async( 82 | worker, args=(path, opt), callback=lambda arg: prog_bar.update()) 83 | pool.close() 84 | pool.join() 85 | print('All processes done.') 86 | 87 | 88 | def worker(path, opt): 89 | """Worker for each process. 90 | 91 | Args: 92 | path (str): Image path. 93 | opt (dict): Configuration dict. It contains: 94 | crop_size (int): Crop size. 95 | step (int): Step for overlapped sliding window. 96 | thresh_size (int): Threshold size. Patches whose size is smaller 97 | than thresh_size will be dropped. 98 | save_folder (str): Path to save folder. 99 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 100 | 101 | Returns: 102 | process_info (str): Process information displayed in progress bar. 103 | """ 104 | crop_size = opt['crop_size'] 105 | step = opt['step'] 106 | thresh_size = opt['thresh_size'] 107 | img_name, extension = osp.splitext(osp.basename(path)) 108 | 109 | # remove the x2, x3, x4 and x8 in the filename for DIV2K 110 | img_name = re.sub('x[2348]', '', img_name) 111 | 112 | img = mmcv.imread(path, flag='unchanged') 113 | 114 | if img.ndim == 2 or img.ndim == 3: 115 | h, w = img.shape[:2] 116 | else: 117 | raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') 118 | 119 | h_space = np.arange(0, h - crop_size + 1, step) 120 | if h - (h_space[-1] + crop_size) > thresh_size: 121 | h_space = np.append(h_space, h - crop_size) 122 | w_space = np.arange(0, w - crop_size + 1, step) 123 | if w - (w_space[-1] + crop_size) > thresh_size: 124 | w_space = np.append(w_space, w - crop_size) 125 | 126 | index = 0 127 | for x in h_space: 128 | for y in w_space: 129 | index += 1 130 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...] 131 | cv2.imwrite( 132 | osp.join(opt['save_folder'], 133 | f'{img_name}_s{index:03d}{extension}'), cropped_img, 134 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 135 | process_info = f'Processing {img_name} ...' 136 | return process_info 137 | 138 | 139 | def make_lmdb_for_div2k(data_root): 140 | """Create lmdb files for DIV2K dataset. 141 | 142 | Args: 143 | data_root (str): Data root path. 144 | 145 | Usage: 146 | Typically, there are four folders to be processed for DIV2K dataset. 147 | DIV2K_train_HR_sub 148 | DIV2K_train_LR_bicubic/X2_sub 149 | DIV2K_train_LR_bicubic/X3_sub 150 | DIV2K_train_LR_bicubic/X4_sub 151 | Remember to modify opt configurations according to your settings. 152 | """ 153 | 154 | folder_paths = [ 155 | osp.join(data_root, 'DF2K_train_HR_sub') 156 | ] 157 | lmdb_paths = [ 158 | osp.join(data_root, 'DF2K_train_HR_sub.lmdb') 159 | ] 160 | 161 | for folder_path, lmdb_path in zip(folder_paths, lmdb_paths): 162 | img_path_list, keys = prepare_keys_div2k(folder_path) 163 | make_lmdb(folder_path, lmdb_path, img_path_list, keys) 164 | 165 | 166 | def prepare_keys_div2k(folder_path): 167 | """Prepare image path list and keys for DIV2K dataset. 168 | 169 | Args: 170 | folder_path (str): Folder path. 171 | 172 | Returns: 173 | list[str]: Image path list. 174 | list[str]: Key list. 175 | """ 176 | print('Reading image path list ...') 177 | img_path_list = sorted( 178 | list(mmcv.scandir(folder_path, suffix='png', recursive=False))) 179 | keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] 180 | 181 | return img_path_list, keys 182 | 183 | 184 | def make_lmdb(data_path, 185 | lmdb_path, 186 | img_path_list, 187 | keys, 188 | batch=5000, 189 | compress_level=1, 190 | multiprocessing_read=False, 191 | n_thread=40): 192 | """Make lmdb. 193 | 194 | Contents of lmdb. The file structure is: 195 | example.lmdb 196 | ├── data.mdb 197 | ├── lock.mdb 198 | ├── meta_info.txt 199 | 200 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 201 | https://lmdb.readthedocs.io/en/release/ for more details. 202 | 203 | The meta_info.txt is a specified txt file to record the meta information 204 | of our datasets. It will be automatically created when preparing 205 | datasets by our provided dataset tools. 206 | Each line in the txt file records 1)image name (with extension), 207 | 2)image shape, and 3)compression level, separated by a white space. 208 | 209 | For example, the meta information could be: 210 | `000_00000000.png (720,1280,3) 1`, which means: 211 | 1) image name (with extension): 000_00000000.png; 212 | 2) image shape: (720,1280,3); 213 | 3) compression level: 1 214 | 215 | We use the image name without extension as the lmdb key. 216 | 217 | If `multiprocessing_read` is True, it will read all the images to memory 218 | using multiprocessing. Thus, your server needs to have enough memory. 219 | 220 | Args: 221 | data_path (str): Data path for reading images. 222 | lmdb_path (str): Lmdb save path. 223 | img_path_list (str): Image path list. 224 | keys (str): Used for lmdb keys. 225 | batch (int): After processing batch images, lmdb commits. 226 | Default: 5000. 227 | compress_level (int): Compress level when encoding images. Default: 1. 228 | multiprocessing_read (bool): Whether use multiprocessing to read all 229 | the images to memory. Default: False. 230 | n_thread (int): For multiprocessing. 231 | """ 232 | assert len(img_path_list) == len(keys), ( 233 | 'img_path_list and keys should have the same length, ' 234 | f'but got {len(img_path_list)} and {len(keys)}') 235 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 236 | print(f'Total images: {len(img_path_list)}') 237 | if not lmdb_path.endswith('.lmdb'): 238 | raise ValueError("lmdb_path must end with '.lmdb'.") 239 | if osp.exists(lmdb_path): 240 | print(f'Folder {lmdb_path} already exists. Exit.') 241 | sys.exit(1) 242 | 243 | if multiprocessing_read: 244 | # read all the images to memory (multiprocessing) 245 | dataset = {} # use dict to keep the order for multiprocessing 246 | shapes = {} 247 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 248 | prog_bar = mmcv.ProgressBar(len(img_path_list)) 249 | 250 | def callback(arg): 251 | """get the image data and update prog_bar.""" 252 | key, dataset[key], shapes[key] = arg 253 | prog_bar.update() 254 | 255 | pool = Pool(n_thread) 256 | for path, key in zip(img_path_list, keys): 257 | pool.apply_async( 258 | read_img_worker, 259 | args=(osp.join(data_path, path), key, compress_level), 260 | callback=callback) 261 | pool.close() 262 | pool.join() 263 | print(f'Finish reading {len(img_path_list)} images.') 264 | 265 | # create lmdb environment 266 | # obtain data size for one image 267 | img = mmcv.imread(osp.join(data_path, img_path_list[0]), flag='unchanged') 268 | _, img_byte = cv2.imencode('.png', img, 269 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 270 | data_size_per_img = img_byte.nbytes 271 | print('Data size per image is: ', data_size_per_img) 272 | data_size = data_size_per_img * len(img_path_list) 273 | env = lmdb.open(lmdb_path, map_size=data_size * 10) 274 | 275 | # write data to lmdb 276 | prog_bar = mmcv.ProgressBar(len(img_path_list)) 277 | txn = env.begin(write=True) 278 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 279 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 280 | prog_bar.update() 281 | key_byte = key.encode('ascii') 282 | if multiprocessing_read: 283 | img_byte = dataset[key] 284 | h, w, c = shapes[key] 285 | else: 286 | _, img_byte, img_shape = read_img_worker( 287 | osp.join(data_path, path), key, compress_level) 288 | h, w, c = img_shape 289 | 290 | txn.put(key_byte, img_byte) 291 | # write meta information 292 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 293 | if idx % batch == 0: 294 | txn.commit() 295 | txn = env.begin(write=True) 296 | txn.commit() 297 | env.close() 298 | txt_file.close() 299 | print('\nFinish writing lmdb.') 300 | 301 | 302 | def read_img_worker(path, key, compress_level): 303 | """Read image worker 304 | 305 | Args: 306 | path (str): Image path. 307 | key (str): Image key. 308 | compress_level (int): Compress level when encoding images. 309 | 310 | Returns: 311 | str: Image key. 312 | byte: Image byte. 313 | tuple[int]: Image shape. 314 | """ 315 | img = mmcv.imread(path, flag='unchanged') 316 | if img.ndim == 2: 317 | h, w = img.shape 318 | c = 1 319 | else: 320 | h, w, c = img.shape 321 | _, img_byte = cv2.imencode('.png', img, 322 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 323 | return (key, img_byte, (h, w, c)) 324 | 325 | 326 | def parse_args(): 327 | parser = argparse.ArgumentParser( 328 | description='Prepare DIV2K dataset', 329 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 330 | parser.add_argument('--data-root', help='dataset root') 331 | parser.add_argument( 332 | '--crop-size', 333 | nargs='?', 334 | default=256, 335 | help='cropped size for HR images') 336 | parser.add_argument( 337 | '--step', nargs='?', default=128, help='step size for HR images') 338 | parser.add_argument( 339 | '--thresh-size', 340 | nargs='?', 341 | default=0, 342 | help='threshold size for HR images') 343 | parser.add_argument( 344 | '--compression-level', 345 | nargs='?', 346 | default=3, 347 | help='compression level when save png images') 348 | parser.add_argument( 349 | '--n-thread', 350 | nargs='?', 351 | default=4, 352 | help='thread number when using multiprocessing') 353 | parser.add_argument( 354 | '--make-lmdb', 355 | action='store_true', 356 | help='whether to prepare lmdb files') 357 | args = parser.parse_args() 358 | return args 359 | 360 | 361 | if __name__ == '__main__': 362 | args = parse_args() 363 | # # extract subimages 364 | main_extract_subimages(args) 365 | # # prepare lmdb files if necessary 366 | if args.make_lmdb: 367 | make_lmdb_for_div2k(args.data_root) 368 | --------------------------------------------------------------------------------