├── DiffMSR_Main ├── __init__.py ├── archs │ ├── CATL.py │ ├── S1_arch.py │ ├── S2_arch.py │ ├── __init__.py │ ├── arch_util.py │ ├── attention.py │ ├── common.py │ └── srvgg_arch.py ├── data │ ├── __init__.py │ ├── data_util.py │ ├── diffmsr_paired_dataset.py │ └── transforms.py ├── losses │ ├── __init__.py │ └── my_loss.py ├── models │ ├── DiffMSR_S1_model.py │ ├── DiffMSR_S2_model.py │ ├── __init__.py │ └── lr_scheduler.py ├── test.py ├── train.py ├── train_pipeline.py ├── utils.py ├── utils │ ├── __init__.py │ ├── bundle_submissions.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ └── options.py └── version.py ├── README.md ├── basicsr ├── __init__.py ├── archs │ ├── __init__.py │ ├── arch_util.py │ ├── basicvsr_arch.py │ ├── basicvsrpp_arch.py │ ├── dfdnet_arch.py │ ├── dfdnet_util.py │ ├── discriminator_arch.py │ ├── duf_arch.py │ ├── ecbsr_arch.py │ ├── edsr_arch.py │ ├── edvr_arch.py │ ├── hifacegan_arch.py │ ├── hifacegan_util.py │ ├── inception.py │ ├── rcan_arch.py │ ├── ridnet_arch.py │ ├── rrdbnet_arch.py │ ├── spynet_arch.py │ ├── srresnet_arch.py │ ├── srvgg_arch.py │ ├── stylegan2_arch.py │ ├── swinir_arch.py │ ├── tof_arch.py │ └── vgg_arch.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── degradations.py │ ├── ffhq_dataset.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── realesrgan_dataset.py │ ├── realesrgan_paired_dataset.py │ ├── reds_dataset.py │ ├── single_image_dataset.py │ ├── transforms.py │ ├── video_test_dataset.py │ └── vimeo90k_dataset.py ├── losses │ ├── __init__.py │ ├── basic_loss.py │ ├── gan_loss.py │ └── loss_util.py ├── metrics │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── niqe.py │ ├── niqe_pris_params.npz │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── edvr_model.py │ ├── esrgan_model.py │ ├── hifacegan_model.py │ ├── lr_scheduler.py │ ├── realesrgan_model.py │ ├── realesrnet_model.py │ ├── sr_model.py │ ├── srgan_model.py │ ├── stylegan2_model.py │ ├── swinir_model.py │ ├── video_base_model.py │ ├── video_gan_model.py │ ├── video_recurrent_gan_model.py │ └── video_recurrent_model.py ├── ops │ ├── __init__.py │ ├── dcn │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── test.py ├── train.py ├── utils │ ├── __init__.py │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_process_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── plot_util.py │ └── registry.py └── version.py ├── complex_data_demo ├── dc_mask │ ├── get_dc_mask.m │ ├── lr_2x.mat │ └── lr_4x.mat ├── fastMRI │ ├── train │ │ ├── data_02_01.mat │ │ ├── data_02_02.mat │ │ ├── data_02_03.mat │ │ ├── data_02_04.mat │ │ ├── data_02_05.mat │ │ ├── data_02_06.mat │ │ ├── data_02_07.mat │ │ └── data_02_08.mat │ └── valid │ │ ├── data_02_09.mat │ │ └── data_02_10.mat ├── fastMRI_process.m ├── fft2c.m └── ifft2c.m ├── fig └── model.png ├── ldm ├── __pycache__ │ ├── ddim.cpython-37.pyc │ ├── ddpm.cpython-37.pyc │ ├── util.cpython-37.pyc │ └── util2.cpython-37.pyc ├── classifier.py ├── ddim.py ├── ddpm.py ├── lr_scheduler.py ├── util.py └── util2.py ├── metrics ├── DISTS │ ├── DISTS_pytorch │ │ ├── DISTS_pt.py │ │ └── weights.pt │ ├── DISTS_tensorflow │ │ └── DISTS_tf.py │ ├── LICENSE │ ├── requirements.txt │ └── weights │ │ └── alpha_beta.mat ├── LPIPS.py ├── PSNR.py ├── PieAPP │ ├── LICENSE.txt │ ├── PieAPP_PT.py │ ├── README.md │ ├── dataset │ │ ├── TERMS_OF_USE.pdf │ │ └── dataset_README.md │ ├── imgs │ │ └── images.md │ ├── model │ │ ├── PieAPPv0pt1_PT.py │ │ ├── PieAPPv0pt1_TF.py │ │ └── __init__.py │ ├── run.sh │ ├── scripts │ │ ├── download_PieAPPv0.1_PT_weights.sh │ │ ├── download_PieAPPv0.1_TF_weights.sh │ │ └── download_scripts.md │ ├── test_PieAPP_TF.py │ ├── test_PieAPP_TF_folder.py │ └── utils │ │ ├── __init__.py │ │ ├── image_utils.py │ │ └── model_utils.py └── dists.py ├── options ├── test.yml ├── train_DiffMSR_S1_x4.yml └── train_DiffMSR_S2_x4.yml ├── requirements.txt ├── test.sh ├── train_S1.sh └── train_S2.sh /DiffMSR_Main/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .losses import * 3 | from .archs import * 4 | from .data import * 5 | from .models import * 6 | from .utils import * 7 | from .version import * 8 | -------------------------------------------------------------------------------- /DiffMSR_Main/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import arch modules for registry 6 | # scan all the files that end with '_arch.py' under the archs folder 7 | arch_folder = osp.dirname(osp.abspath(__file__)) 8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 9 | # import all the arch modules 10 | _arch_modules = [importlib.import_module(f'DiffMSR_Main.archs.{file_name}') for file_name in arch_filenames] 11 | -------------------------------------------------------------------------------- /DiffMSR_Main/archs/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch as th 4 | class QKVAttentionLegacy(nn.Module): 5 | """ 6 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 7 | """ 8 | 9 | def __init__(self, n_heads): 10 | super().__init__() 11 | self.n_heads = n_heads 12 | self.scale=math.sqrt(10) 13 | 14 | def forward(self, qkv): 15 | """ 16 | Apply QKV attention. 17 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 18 | :return: an [N x (H * C) x T] tensor after attention. 19 | """ 20 | bs, width, length = qkv.shape 21 | assert width % (3 * self.n_heads) == 0 22 | ch = width // (3 * self.n_heads) 23 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 24 | #scale = 1 / math.sqrt(math.sqrt(ch)) 25 | weight = th.einsum( 26 | "bct,bcs->bts", q * self.scale, k * self.scale 27 | ) # More stable with f16 than dividing afterwards 28 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 29 | a = th.einsum("bts,bcs->bct", weight, v) 30 | return a.reshape(bs, -1, length) 31 | 32 | class QKVAttention(nn.Module): 33 | """ 34 | A module which performs QKV attention and splits in a different order. 35 | """ 36 | 37 | def __init__(self, n_heads): 38 | super().__init__() 39 | self.n_heads = n_heads 40 | self.scale=math.sqrt(10) 41 | 42 | def forward(self, qkv): 43 | """ 44 | Apply QKV attention. 45 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 46 | :return: an [N x (H * C) x T] tensor after attention. 47 | """ 48 | bs, width, length = qkv.shape 49 | assert width % (3 * self.n_heads) == 0 50 | ch = width // (3 * self.n_heads) 51 | q, k, v = qkv.chunk(3, dim=1) 52 | #scale = 1 / math.sqrt(math.sqrt(ch)) 53 | weight = th.einsum( 54 | "bct,bcs->bts", 55 | (q * self.scale).view(bs * self.n_heads, ch, length), 56 | (k * self.scale).view(bs * self.n_heads, ch, length), 57 | ) # More stable with f16 than dividing afterwards 58 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 59 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 60 | return a.reshape(bs, -1, length) 61 | 62 | class AttentionBlock(nn.Module): 63 | def __init__( 64 | self, 65 | channels, 66 | num_heads=1, 67 | num_head_channels=-1, 68 | use_new_attention_order=False, 69 | ): 70 | super().__init__() 71 | self.channels = channels 72 | if num_head_channels == -1: 73 | self.num_heads = num_heads 74 | else: 75 | assert ( 76 | channels % num_head_channels == 0 77 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 78 | self.num_heads = channels // num_head_channels 79 | self.qkv = nn.Conv2d(channels,channels*3,3,padding=1) 80 | if use_new_attention_order: 81 | # split qkv before split heads 82 | self.attention = QKVAttention(self.num_heads) 83 | else: 84 | # split heads before split qkv 85 | self.attention = QKVAttentionLegacy(self.num_heads) 86 | 87 | 88 | def forward(self, x): 89 | res = self.qkv(x[0]) 90 | b, c, *spatial = res.shape 91 | res = res.reshape(b, c, -1) 92 | h = self.attention(res) 93 | b, c, *spatial = x[0].shape 94 | h= h.reshape(b, c, *spatial) 95 | return [x[0] + h,x[1]] -------------------------------------------------------------------------------- /DiffMSR_Main/archs/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 9 | 10 | class ResBlock(nn.Module): 11 | def __init__( 12 | self, conv, n_feats, kernel_size, 13 | bias=True, bn=False, act=nn.LeakyReLU(0.1, inplace=True), res_scale=1): 14 | 15 | super(ResBlock, self).__init__() 16 | m = [] 17 | for i in range(2): 18 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 19 | if bn: 20 | m.append(nn.BatchNorm2d(n_feats)) 21 | if i == 0: 22 | m.append(act) 23 | 24 | self.body = nn.Sequential(*m) 25 | # self.res_scale = res_scale 26 | 27 | def forward(self, x): 28 | res = self.body(x) 29 | res += x 30 | 31 | return res 32 | 33 | class MeanShift(nn.Conv2d): 34 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 35 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 36 | std = torch.Tensor(rgb_std) 37 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 38 | self.weight.data.div_(std.view(3, 1, 1, 1)) 39 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 40 | self.bias.data.div_(std) 41 | self.weight.requires_grad = False 42 | self.bias.requires_grad = False 43 | 44 | 45 | class Upsampler(nn.Sequential): 46 | def __init__(self, conv, scale, n_feat, act=False, bias=True): 47 | m = [] 48 | if (int(scale) & (int(scale) - 1)) == 0: # Is scale = 2^n? 49 | for _ in range(int(math.log(scale, 2))): 50 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 51 | m.append(nn.PixelShuffle(2)) 52 | if act: m.append(act()) 53 | elif scale == 3: 54 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 55 | m.append(nn.PixelShuffle(3)) 56 | if act: m.append(act()) 57 | else: 58 | raise NotImplementedError 59 | 60 | super(Upsampler, self).__init__(*m) -------------------------------------------------------------------------------- /DiffMSR_Main/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import ARCH_REGISTRY 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | @ARCH_REGISTRY.register() 7 | class SRVGGNetCompact(nn.Module): 8 | """A compact VGG-style network structure for super-resolution. 9 | 10 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 11 | conducted on the HR feature space. 12 | 13 | Args: 14 | num_in_ch (int): Channel number of inputs. Default: 3. 15 | num_out_ch (int): Channel number of outputs. Default: 3. 16 | num_feat (int): Channel number of intermediate features. Default: 64. 17 | num_conv (int): Number of convolution layers in the body network. Default: 16. 18 | upscale (int): Upsampling factor. Default: 4. 19 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 20 | """ 21 | 22 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 23 | super(SRVGGNetCompact, self).__init__() 24 | self.num_in_ch = num_in_ch 25 | self.num_out_ch = num_out_ch 26 | self.num_feat = num_feat 27 | self.num_conv = num_conv 28 | self.upscale = upscale 29 | self.act_type = act_type 30 | 31 | self.body = nn.ModuleList() 32 | # the first conv 33 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 34 | # the first activation 35 | if act_type == 'relu': 36 | activation = nn.ReLU(inplace=True) 37 | elif act_type == 'prelu': 38 | activation = nn.PReLU(num_parameters=num_feat) 39 | elif act_type == 'leakyrelu': 40 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 41 | self.body.append(activation) 42 | 43 | # the body structure 44 | for _ in range(num_conv): 45 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 46 | # activation 47 | if act_type == 'relu': 48 | activation = nn.ReLU(inplace=True) 49 | elif act_type == 'prelu': 50 | activation = nn.PReLU(num_parameters=num_feat) 51 | elif act_type == 'leakyrelu': 52 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 53 | self.body.append(activation) 54 | 55 | # the last conv 56 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 57 | # upsample 58 | self.upsampler = nn.PixelShuffle(upscale) 59 | 60 | def forward(self, x): 61 | out = x 62 | for i in range(0, len(self.body)): 63 | out = self.body[i](out) 64 | 65 | out = self.upsampler(out) 66 | # add the nearest upsampled image, so that the network learns the residual 67 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 68 | out += base 69 | return out 70 | -------------------------------------------------------------------------------- /DiffMSR_Main/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import dataset modules for registry 6 | # scan all the files that end with '_dataset.py' under the data folder 7 | data_folder = osp.dirname(osp.abspath(__file__)) 8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 9 | # import all the dataset modules 10 | _dataset_modules = [importlib.import_module(f'DiffMSR_Main.data.{file_name}') for file_name in dataset_filenames] 11 | -------------------------------------------------------------------------------- /DiffMSR_Main/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import arch modules for registry 6 | # scan all the files that end with '_arch.py' under the archs folder 7 | arch_folder = osp.dirname(osp.abspath(__file__)) 8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_loss.py')] 9 | # import all the arch modules 10 | _arch_modules = [importlib.import_module(f'DiffMSR_Main.losses.{file_name}') for file_name in arch_filenames] 11 | -------------------------------------------------------------------------------- /DiffMSR_Main/losses/my_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | 6 | 7 | @LOSS_REGISTRY.register() 8 | class KDLoss(nn.Module): 9 | """ 10 | Args: 11 | loss_weight (float): Loss weight for KD loss. Default: 1.0. 12 | """ 13 | 14 | def __init__(self, loss_weight=1.0, temperature = 0.15): 15 | super(KDLoss, self).__init__() 16 | 17 | self.loss_weight = loss_weight 18 | self.temperature = temperature 19 | 20 | def forward(self, S1_fea, S2_fea): 21 | """ 22 | Args: 23 | S1_fea (List): contain shape (N, L) vector. 24 | S2_fea (List): contain shape (N, L) vector. 25 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. 26 | """ 27 | loss_KD_dis = 0 28 | loss_KD_abs = 0 29 | for i in range(len(S1_fea)): 30 | S2_distance = F.log_softmax(S2_fea[i] / self.temperature, dim=1) 31 | S1_distance = F.softmax(S1_fea[i].detach()/ self.temperature, dim=1) 32 | loss_KD_dis += F.kl_div( 33 | S2_distance, S1_distance, reduction='batchmean') 34 | loss_KD_abs += nn.L1Loss()(S2_fea[i], S1_fea[i].detach()) 35 | return self.loss_weight * loss_KD_dis, self.loss_weight * loss_KD_abs -------------------------------------------------------------------------------- /DiffMSR_Main/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import model modules for registry 6 | # scan all the files that end with '_model.py' under the model folder 7 | model_folder = osp.dirname(osp.abspath(__file__)) 8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 9 | # import all the model modules 10 | _model_modules = [importlib.import_module(f'DiffMSR_Main.models.{file_name}') for file_name in model_filenames] 11 | -------------------------------------------------------------------------------- /DiffMSR_Main/test.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os 3 | import os.path as osp 4 | import sys 5 | 6 | sys.path.append('/mnt/e/CVPR2024/DiffMSR/') 7 | os.environ['RANK'] = str(0) 8 | import os.path as osp 9 | from basicsr.test import test_pipeline 10 | 11 | import DiffMSR_Main.archs 12 | import DiffMSR_Main.data 13 | import DiffMSR_Main.models 14 | 15 | if __name__ == '__main__': 16 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 17 | test_pipeline(root_path) 18 | -------------------------------------------------------------------------------- /DiffMSR_Main/train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os 3 | import os.path as osp 4 | import sys 5 | # 6 | sys.path.append('/mnt/e/CVPR2024/DiffMSR/') 7 | os.environ['RANK'] = str(0) 8 | from DiffMSR_Main.train_pipeline import train_pipeline 9 | 10 | # import DiffMSR_Main.archs 11 | # import DiffMSR_Main.data 12 | # import DiffMSR_Main.models 13 | # import DiffMSR_Main.losses 14 | import warnings 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | if __name__ == '__main__': 19 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 20 | train_pipeline(root_path) 21 | -------------------------------------------------------------------------------- /DiffMSR_Main/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'padding_DP', 34 | 'imfrombytesDP', 35 | 'create_lmdb_for_reds', 36 | 'create_lmdb_for_gopro', 37 | 'create_lmdb_for_rain13k', 38 | ] 39 | -------------------------------------------------------------------------------- /DiffMSR_Main/utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | 17 | submission_folder Folder where denoised images reside 18 | 19 | Output is written to /bundled/. Please submit 20 | the content of this folder. 21 | ''' 22 | 23 | out_folder = os.path.join(submission_folder, session) 24 | # out_folder = os.path.join(submission_folder, "bundled/") 25 | try: 26 | os.mkdir(out_folder) 27 | except:pass 28 | 29 | israw = True 30 | eval_version="1.0" 31 | 32 | for i in range(50): 33 | Idenoised = np.zeros((20,), dtype=np.object) 34 | for bb in range(20): 35 | filename = '%04d_%02d.mat'%(i+1,bb+1) 36 | s = sio.loadmat(os.path.join(submission_folder,filename)) 37 | Idenoised_crop = s["Idenoised_crop"] 38 | Idenoised[bb] = Idenoised_crop 39 | filename = '%04d.mat'%(i+1) 40 | sio.savemat(os.path.join(out_folder, filename), 41 | {"Idenoised": Idenoised, 42 | "israw": israw, 43 | "eval_version": eval_version}, 44 | ) 45 | 46 | def bundle_submissions_srgb(submission_folder,session): 47 | ''' 48 | Bundles submission data for sRGB denoising 49 | 50 | submission_folder Folder where denoised images reside 51 | 52 | Output is written to /bundled/. Please submit 53 | the content of this folder. 54 | ''' 55 | out_folder = os.path.join(submission_folder, session) 56 | # out_folder = os.path.join(submission_folder, "bundled/") 57 | try: 58 | os.mkdir(out_folder) 59 | except:pass 60 | israw = False 61 | eval_version="1.0" 62 | 63 | for i in range(50): 64 | Idenoised = np.zeros((20,), dtype=np.object) 65 | for bb in range(20): 66 | filename = '%04d_%02d.mat'%(i+1,bb+1) 67 | s = sio.loadmat(os.path.join(submission_folder,filename)) 68 | Idenoised_crop = s["Idenoised_crop"] 69 | Idenoised[bb] = Idenoised_crop 70 | filename = '%04d.mat'%(i+1) 71 | sio.savemat(os.path.join(out_folder, filename), 72 | {"Idenoised": Idenoised, 73 | "israw": israw, 74 | "eval_version": eval_version}, 75 | ) 76 | 77 | 78 | 79 | def bundle_submissions_srgb_v1(submission_folder,session): 80 | ''' 81 | Bundles submission data for sRGB denoising 82 | 83 | submission_folder Folder where denoised images reside 84 | 85 | Output is written to /bundled/. Please submit 86 | the content of this folder. 87 | ''' 88 | out_folder = os.path.join(submission_folder, session) 89 | # out_folder = os.path.join(submission_folder, "bundled/") 90 | try: 91 | os.mkdir(out_folder) 92 | except:pass 93 | israw = False 94 | eval_version="1.0" 95 | 96 | for i in range(50): 97 | Idenoised = np.zeros((20,), dtype=np.object) 98 | for bb in range(20): 99 | filename = '%04d_%d.mat'%(i+1,bb+1) 100 | s = sio.loadmat(os.path.join(submission_folder,filename)) 101 | Idenoised_crop = s["Idenoised_crop"] 102 | Idenoised[bb] = Idenoised_crop 103 | filename = '%04d.mat'%(i+1) 104 | sio.savemat(os.path.join(out_folder, filename), 105 | {"Idenoised": Idenoised, 106 | "israw": israw, 107 | "eval_version": eval_version}, 108 | ) -------------------------------------------------------------------------------- /DiffMSR_Main/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from DiffMSR_Main.utils import scandir 5 | from DiffMSR_Main.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /DiffMSR_Main/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /DiffMSR_Main/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /DiffMSR_Main/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | for phase, dataset in opt['datasets'].items(): 49 | # for several datasets, e.g., test_1, test_2 50 | phase = phase.split('_')[0] 51 | dataset['phase'] = phase 52 | if 'scale' in opt: 53 | dataset['scale'] = opt['scale'] 54 | if dataset.get('dataroot_gt') is not None: 55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 56 | if dataset.get('dataroot_lq') is not None: 57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 58 | 59 | # paths 60 | for key, val in opt['path'].items(): 61 | if (val is not None) and ('resume_state' in key 62 | or 'pretrain_network' in key): 63 | opt['path'][key] = osp.expanduser(val) 64 | opt['path']['root'] = osp.abspath( 65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 66 | if is_train: 67 | experiments_root = osp.join(opt['path']['root'], 'experiments', 68 | opt['name']) 69 | opt['path']['experiments_root'] = experiments_root 70 | opt['path']['models'] = osp.join(experiments_root, 'models') 71 | opt['path']['training_states'] = osp.join(experiments_root, 72 | 'training_states') 73 | opt['path']['log'] = experiments_root 74 | opt['path']['visualization'] = osp.join(experiments_root, 75 | 'visualization') 76 | 77 | # change some options for debug mode 78 | if 'debug' in opt['name']: 79 | if 'val' in opt: 80 | opt['val']['val_freq'] = 8 81 | opt['logger']['print_freq'] = 1 82 | opt['logger']['save_checkpoint_freq'] = 8 83 | else: # test 84 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /DiffMSR_Main/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Mon Sep 18 21:12:57 2023 3 | __version__ = '0.2.5.0' 4 | __gitsha__ = 'unknown' 5 | version_info = (0, 2, 5, 0) 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rethinking Diffusion Model for Multi-Contrast MRI Super-Resolution 2 | 3 | Guangyuan Li, Chen Rao, Juncheng Mo, Zhanjie Zhang, Wei Xing, Lei Zhao, "Rethinking Diffusion Model for Multi-Contrast MRI Super-Resolution", CVPR2024 4 | 5 | >**Abstract:** Recently, diffusion models (DM) have been introduced in magnetic resonance imaging (MRI) super-resolution (SR) reconstruction, exhibiting impressive performance, particularly with regard to detailed reconstruction. However, the current DM-based SR reconstruction methods still face the following issues: (1) They require a large number of iterations to reconstruct the final image, which is inefficient and consumes a significant amount of computational resources. (2) The results reconstructed by these methods are often misaligned with the real high-resolution images, leading to remarkable distortion in the reconstructed MR images. To address the aforementioned issues, we propose an efficient diffusion model for multi-contrast MRI SR, named as McDiff. Specifically, we apply DM in a highly compact low-dimensional latent space to generate prior knowledge with high-frequency detail information. The highly compact latent space ensures that DM requires only a few simple iterations to produce accurate prior knowledge. In addition, we design the Prior-Guide Large Window Transformer (PLWformer) as the decoder for DM, which can extend the receptive field while fully utilizing the prior knowledge generated by DM to ensure that the reconstructed MR image remains undistorted. Extensive experiments on public and clinical datasets demonstrate that our McDiff outperforms state-of-the-art methods. 6 | 7 | >

8 | > 9 | >

10 | 11 | ### To Run Our Code 12 | - Train the model 13 | ```bash 14 | bash train_S1.sh 15 | ``` 16 | ```bash 17 | bash train_S2.sh 18 | ``` 19 | 20 | - Test the model 21 | ```bash 22 | bash test.sh 23 | ``` 24 | 25 | ## Acknowledgements 26 | This code is built on [BasicSR](https://github.com/XPixelGroup/BasicSR), [DiffIR](https://github.com/Zj-BinXia/DiffIR), and [SRFormer](https://github.com/HVision-NKU/SRFormer). 27 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /basicsr/archs/dfdnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | from torch.nn.utils.spectral_norm import spectral_norm 6 | 7 | 8 | class BlurFunctionBackward(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, grad_output, kernel, kernel_flip): 12 | ctx.save_for_backward(kernel, kernel_flip) 13 | grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]) 14 | return grad_input 15 | 16 | @staticmethod 17 | def backward(ctx, gradgrad_output): 18 | kernel, _ = ctx.saved_tensors 19 | grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]) 20 | return grad_input, None, None 21 | 22 | 23 | class BlurFunction(Function): 24 | 25 | @staticmethod 26 | def forward(ctx, x, kernel, kernel_flip): 27 | ctx.save_for_backward(kernel, kernel_flip) 28 | output = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) 29 | return output 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | kernel, kernel_flip = ctx.saved_tensors 34 | grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip) 35 | return grad_input, None, None 36 | 37 | 38 | blur = BlurFunction.apply 39 | 40 | 41 | class Blur(nn.Module): 42 | 43 | def __init__(self, channel): 44 | super().__init__() 45 | kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32) 46 | kernel = kernel.view(1, 1, 3, 3) 47 | kernel = kernel / kernel.sum() 48 | kernel_flip = torch.flip(kernel, [2, 3]) 49 | 50 | self.kernel = kernel.repeat(channel, 1, 1, 1) 51 | self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1) 52 | 53 | def forward(self, x): 54 | return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x)) 55 | 56 | 57 | def calc_mean_std(feat, eps=1e-5): 58 | """Calculate mean and std for adaptive_instance_normalization. 59 | 60 | Args: 61 | feat (Tensor): 4D tensor. 62 | eps (float): A small value added to the variance to avoid 63 | divide-by-zero. Default: 1e-5. 64 | """ 65 | size = feat.size() 66 | assert len(size) == 4, 'The input feature should be 4D tensor.' 67 | n, c = size[:2] 68 | feat_var = feat.view(n, c, -1).var(dim=2) + eps 69 | feat_std = feat_var.sqrt().view(n, c, 1, 1) 70 | feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1) 71 | return feat_mean, feat_std 72 | 73 | 74 | def adaptive_instance_normalization(content_feat, style_feat): 75 | """Adaptive instance normalization. 76 | 77 | Adjust the reference features to have the similar color and illuminations 78 | as those in the degradate features. 79 | 80 | Args: 81 | content_feat (Tensor): The reference feature. 82 | style_feat (Tensor): The degradate features. 83 | """ 84 | size = content_feat.size() 85 | style_mean, style_std = calc_mean_std(style_feat) 86 | content_mean, content_std = calc_mean_std(content_feat) 87 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 88 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 89 | 90 | 91 | def AttentionBlock(in_channel): 92 | return nn.Sequential( 93 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True), 94 | spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1))) 95 | 96 | 97 | def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 98 | """Conv block used in MSDilationBlock.""" 99 | 100 | return nn.Sequential( 101 | spectral_norm( 102 | nn.Conv2d( 103 | in_channels, 104 | out_channels, 105 | kernel_size=kernel_size, 106 | stride=stride, 107 | dilation=dilation, 108 | padding=((kernel_size - 1) // 2) * dilation, 109 | bias=bias)), 110 | nn.LeakyReLU(0.2), 111 | spectral_norm( 112 | nn.Conv2d( 113 | out_channels, 114 | out_channels, 115 | kernel_size=kernel_size, 116 | stride=stride, 117 | dilation=dilation, 118 | padding=((kernel_size - 1) // 2) * dilation, 119 | bias=bias)), 120 | ) 121 | 122 | 123 | class MSDilationBlock(nn.Module): 124 | """Multi-scale dilation block.""" 125 | 126 | def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True): 127 | super(MSDilationBlock, self).__init__() 128 | 129 | self.conv_blocks = nn.ModuleList() 130 | for i in range(4): 131 | self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias)) 132 | self.conv_fusion = spectral_norm( 133 | nn.Conv2d( 134 | in_channels * 4, 135 | in_channels, 136 | kernel_size=kernel_size, 137 | stride=1, 138 | padding=(kernel_size - 1) // 2, 139 | bias=bias)) 140 | 141 | def forward(self, x): 142 | out = [] 143 | for i in range(4): 144 | out.append(self.conv_blocks[i](x)) 145 | out = torch.cat(out, 1) 146 | out = self.conv_fusion(out) + x 147 | return out 148 | 149 | 150 | class UpResBlock(nn.Module): 151 | 152 | def __init__(self, in_channel): 153 | super(UpResBlock, self).__init__() 154 | self.body = nn.Sequential( 155 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 156 | nn.LeakyReLU(0.2, True), 157 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 158 | ) 159 | 160 | def forward(self, x): 161 | out = x + self.body(x) 162 | return out 163 | -------------------------------------------------------------------------------- /basicsr/archs/edsr_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class EDSR(nn.Module): 10 | """EDSR network structure. 11 | 12 | Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution. 13 | Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch 14 | 15 | Args: 16 | num_in_ch (int): Channel number of inputs. 17 | num_out_ch (int): Channel number of outputs. 18 | num_feat (int): Channel number of intermediate features. 19 | Default: 64. 20 | num_block (int): Block number in the trunk network. Default: 16. 21 | upscale (int): Upsampling factor. Support 2^n and 3. 22 | Default: 4. 23 | res_scale (float): Used to scale the residual in residual block. 24 | Default: 1. 25 | img_range (float): Image range. Default: 255. 26 | rgb_mean (tuple[float]): Image mean in RGB orders. 27 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 28 | """ 29 | 30 | def __init__(self, 31 | num_in_ch, 32 | num_out_ch, 33 | num_feat=64, 34 | num_block=16, 35 | upscale=4, 36 | res_scale=1, 37 | img_range=255., 38 | rgb_mean=(0.4488, 0.4371, 0.4040)): 39 | super(EDSR, self).__init__() 40 | 41 | self.img_range = img_range 42 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 43 | 44 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 45 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True) 46 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 47 | self.upsample = Upsample(upscale, num_feat) 48 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 49 | 50 | def forward(self, x): 51 | self.mean = self.mean.type_as(x) 52 | 53 | x = (x - self.mean) * self.img_range 54 | x = self.conv_first(x) 55 | res = self.conv_after_body(self.body(x)) 56 | res += x 57 | 58 | x = self.conv_last(self.upsample(res)) 59 | x = x / self.img_range + self.mean 60 | 61 | return x 62 | -------------------------------------------------------------------------------- /basicsr/archs/rcan_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import Upsample, make_layer 6 | 7 | 8 | class ChannelAttention(nn.Module): 9 | """Channel attention used in RCAN. 10 | 11 | Args: 12 | num_feat (int): Channel number of intermediate features. 13 | squeeze_factor (int): Channel squeeze factor. Default: 16. 14 | """ 15 | 16 | def __init__(self, num_feat, squeeze_factor=16): 17 | super(ChannelAttention, self).__init__() 18 | self.attention = nn.Sequential( 19 | nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 20 | nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) 21 | 22 | def forward(self, x): 23 | y = self.attention(x) 24 | return x * y 25 | 26 | 27 | class RCAB(nn.Module): 28 | """Residual Channel Attention Block (RCAB) used in RCAN. 29 | 30 | Args: 31 | num_feat (int): Channel number of intermediate features. 32 | squeeze_factor (int): Channel squeeze factor. Default: 16. 33 | res_scale (float): Scale the residual. Default: 1. 34 | """ 35 | 36 | def __init__(self, num_feat, squeeze_factor=16, res_scale=1): 37 | super(RCAB, self).__init__() 38 | self.res_scale = res_scale 39 | 40 | self.rcab = nn.Sequential( 41 | nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), 42 | ChannelAttention(num_feat, squeeze_factor)) 43 | 44 | def forward(self, x): 45 | res = self.rcab(x) * self.res_scale 46 | return res + x 47 | 48 | 49 | class ResidualGroup(nn.Module): 50 | """Residual Group of RCAB. 51 | 52 | Args: 53 | num_feat (int): Channel number of intermediate features. 54 | num_block (int): Block number in the body network. 55 | squeeze_factor (int): Channel squeeze factor. Default: 16. 56 | res_scale (float): Scale the residual. Default: 1. 57 | """ 58 | 59 | def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): 60 | super(ResidualGroup, self).__init__() 61 | 62 | self.residual_group = make_layer( 63 | RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) 64 | self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 65 | 66 | def forward(self, x): 67 | res = self.conv(self.residual_group(x)) 68 | return res + x 69 | 70 | 71 | @ARCH_REGISTRY.register() 72 | class RCAN(nn.Module): 73 | """Residual Channel Attention Networks. 74 | 75 | Paper: Image Super-Resolution Using Very Deep Residual Channel Attention 76 | Networks 77 | Ref git repo: https://github.com/yulunzhang/RCAN. 78 | 79 | Args: 80 | num_in_ch (int): Channel number of inputs. 81 | num_out_ch (int): Channel number of outputs. 82 | num_feat (int): Channel number of intermediate features. 83 | Default: 64. 84 | num_group (int): Number of ResidualGroup. Default: 10. 85 | num_block (int): Number of RCAB in ResidualGroup. Default: 16. 86 | squeeze_factor (int): Channel squeeze factor. Default: 16. 87 | upscale (int): Upsampling factor. Support 2^n and 3. 88 | Default: 4. 89 | res_scale (float): Used to scale the residual in residual block. 90 | Default: 1. 91 | img_range (float): Image range. Default: 255. 92 | rgb_mean (tuple[float]): Image mean in RGB orders. 93 | Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. 94 | """ 95 | 96 | def __init__(self, 97 | num_in_ch, 98 | num_out_ch, 99 | num_feat=64, 100 | num_group=10, 101 | num_block=16, 102 | squeeze_factor=16, 103 | upscale=4, 104 | res_scale=1, 105 | img_range=255., 106 | rgb_mean=(0.4488, 0.4371, 0.4040)): 107 | super(RCAN, self).__init__() 108 | 109 | self.img_range = img_range 110 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 111 | 112 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 113 | self.body = make_layer( 114 | ResidualGroup, 115 | num_group, 116 | num_feat=num_feat, 117 | num_block=num_block, 118 | squeeze_factor=squeeze_factor, 119 | res_scale=res_scale) 120 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 121 | self.upsample = Upsample(upscale, num_feat) 122 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 123 | 124 | def forward(self, x): 125 | self.mean = self.mean.type_as(x) 126 | 127 | x = (x - self.mean) * self.img_range 128 | x = self.conv_first(x) 129 | res = self.conv_after_body(self.body(x)) 130 | res += x 131 | 132 | x = self.conv_last(self.upsample(res)) 133 | x = x / self.img_range + self.mean 134 | 135 | return x 136 | -------------------------------------------------------------------------------- /basicsr/archs/rrdbnet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | from .arch_util import default_init_weights, make_layer, pixel_unshuffle 7 | 8 | 9 | class ResidualDenseBlock(nn.Module): 10 | """Residual Dense Block. 11 | 12 | Used in RRDB block in ESRGAN. 13 | 14 | Args: 15 | num_feat (int): Channel number of intermediate features. 16 | num_grow_ch (int): Channels for each growth. 17 | """ 18 | 19 | def __init__(self, num_feat=64, num_grow_ch=32): 20 | super(ResidualDenseBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) 22 | self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) 23 | self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) 24 | self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) 25 | self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) 26 | 27 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | 29 | # initialization 30 | default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 31 | 32 | def forward(self, x): 33 | x1 = self.lrelu(self.conv1(x)) 34 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 35 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 36 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 37 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 38 | # Empirically, we use 0.2 to scale the residual for better performance 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | """Residual in Residual Dense Block. 44 | 45 | Used in RRDB-Net in ESRGAN. 46 | 47 | Args: 48 | num_feat (int): Channel number of intermediate features. 49 | num_grow_ch (int): Channels for each growth. 50 | """ 51 | 52 | def __init__(self, num_feat, num_grow_ch=32): 53 | super(RRDB, self).__init__() 54 | self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) 55 | self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) 56 | self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) 57 | 58 | def forward(self, x): 59 | out = self.rdb1(x) 60 | out = self.rdb2(out) 61 | out = self.rdb3(out) 62 | # Empirically, we use 0.2 to scale the residual for better performance 63 | return out * 0.2 + x 64 | 65 | 66 | @ARCH_REGISTRY.register() 67 | class RRDBNet(nn.Module): 68 | """Networks consisting of Residual in Residual Dense Block, which is used 69 | in ESRGAN. 70 | 71 | ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. 72 | 73 | We extend ESRGAN for scale x2 and scale x1. 74 | Note: This is one option for scale 1, scale 2 in RRDBNet. 75 | We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size 76 | and enlarge the channel size before feeding inputs into the main ESRGAN architecture. 77 | 78 | Args: 79 | num_in_ch (int): Channel number of inputs. 80 | num_out_ch (int): Channel number of outputs. 81 | num_feat (int): Channel number of intermediate features. 82 | Default: 64 83 | num_block (int): Block number in the trunk network. Defaults: 23 84 | num_grow_ch (int): Channels for each growth. Default: 32. 85 | """ 86 | 87 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): 88 | super(RRDBNet, self).__init__() 89 | self.scale = scale 90 | if scale == 2: 91 | num_in_ch = num_in_ch * 4 92 | elif scale == 1: 93 | num_in_ch = num_in_ch * 16 94 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 95 | self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) 96 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 97 | # upsample 98 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 99 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 100 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 101 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 102 | 103 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 104 | 105 | def forward(self, x): 106 | if self.scale == 2: 107 | feat = pixel_unshuffle(x, scale=2) 108 | elif self.scale == 1: 109 | feat = pixel_unshuffle(x, scale=4) 110 | else: 111 | feat = x 112 | feat = self.conv_first(feat) 113 | body_feat = self.conv_body(self.body(feat)) 114 | feat = feat + body_feat 115 | # upsample 116 | feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) 117 | feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) 118 | out = self.conv_last(self.lrelu(self.conv_hr(feat))) 119 | return out 120 | -------------------------------------------------------------------------------- /basicsr/archs/spynet_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | from .arch_util import flow_warp 8 | 9 | 10 | class BasicModule(nn.Module): 11 | """Basic Module for SpyNet. 12 | """ 13 | 14 | def __init__(self): 15 | super(BasicModule, self).__init__() 16 | 17 | self.basic_module = nn.Sequential( 18 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 19 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 20 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 21 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False), 22 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 23 | 24 | def forward(self, tensor_input): 25 | return self.basic_module(tensor_input) 26 | 27 | 28 | @ARCH_REGISTRY.register() 29 | class SpyNet(nn.Module): 30 | """SpyNet architecture. 31 | 32 | Args: 33 | load_path (str): path for pretrained SpyNet. Default: None. 34 | """ 35 | 36 | def __init__(self, load_path=None): 37 | super(SpyNet, self).__init__() 38 | self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)]) 39 | if load_path: 40 | self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params']) 41 | 42 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 43 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 44 | 45 | def preprocess(self, tensor_input): 46 | tensor_output = (tensor_input - self.mean) / self.std 47 | return tensor_output 48 | 49 | def process(self, ref, supp): 50 | flow = [] 51 | 52 | ref = [self.preprocess(ref)] 53 | supp = [self.preprocess(supp)] 54 | 55 | for level in range(5): 56 | ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False)) 57 | supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False)) 58 | 59 | flow = ref[0].new_zeros( 60 | [ref[0].size(0), 2, 61 | int(math.floor(ref[0].size(2) / 2.0)), 62 | int(math.floor(ref[0].size(3) / 2.0))]) 63 | 64 | for level in range(len(ref)): 65 | upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0 66 | 67 | if upsampled_flow.size(2) != ref[level].size(2): 68 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate') 69 | if upsampled_flow.size(3) != ref[level].size(3): 70 | upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate') 71 | 72 | flow = self.basic_module[level](torch.cat([ 73 | ref[level], 74 | flow_warp( 75 | supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'), 76 | upsampled_flow 77 | ], 1)) + upsampled_flow 78 | 79 | return flow 80 | 81 | def forward(self, ref, supp): 82 | assert ref.size() == supp.size() 83 | 84 | h, w = ref.size(2), ref.size(3) 85 | w_floor = math.floor(math.ceil(w / 32.0) * 32.0) 86 | h_floor = math.floor(math.ceil(h / 32.0) * 32.0) 87 | 88 | ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 89 | supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False) 90 | 91 | flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False) 92 | 93 | flow[:, 0, :, :] *= float(w) / float(w_floor) 94 | flow[:, 1, :, :] *= float(h) / float(h_floor) 95 | 96 | return flow 97 | -------------------------------------------------------------------------------- /basicsr/archs/srresnet_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class MSRResNet(nn.Module): 10 | """Modified SRResNet. 11 | 12 | A compacted version modified from SRResNet in 13 | "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" 14 | It uses residual blocks without BN, similar to EDSR. 15 | Currently, it supports x2, x3 and x4 upsampling scale factor. 16 | 17 | Args: 18 | num_in_ch (int): Channel number of inputs. Default: 3. 19 | num_out_ch (int): Channel number of outputs. Default: 3. 20 | num_feat (int): Channel number of intermediate features. Default: 64. 21 | num_block (int): Block number in the body network. Default: 16. 22 | upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4. 23 | """ 24 | 25 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4): 26 | super(MSRResNet, self).__init__() 27 | self.upscale = upscale 28 | 29 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 30 | self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat) 31 | 32 | # upsampling 33 | if self.upscale in [2, 3]: 34 | self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1) 35 | self.pixel_shuffle = nn.PixelShuffle(self.upscale) 36 | elif self.upscale == 4: 37 | self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 38 | self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 39 | self.pixel_shuffle = nn.PixelShuffle(2) 40 | 41 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 42 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 43 | 44 | # activation function 45 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 46 | 47 | # initialization 48 | default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1) 49 | if self.upscale == 4: 50 | default_init_weights(self.upconv2, 0.1) 51 | 52 | def forward(self, x): 53 | feat = self.lrelu(self.conv_first(x)) 54 | out = self.body(feat) 55 | 56 | if self.upscale == 4: 57 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 58 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 59 | elif self.upscale in [2, 3]: 60 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 61 | 62 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 63 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 64 | out += base 65 | return out 66 | -------------------------------------------------------------------------------- /basicsr/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | 6 | 7 | @ARCH_REGISTRY.register(suffix='basicsr') 8 | class SRVGGNetCompact(nn.Module): 9 | """A compact VGG-style network structure for super-resolution. 10 | 11 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 12 | conducted on the HR feature space. 13 | 14 | Args: 15 | num_in_ch (int): Channel number of inputs. Default: 3. 16 | num_out_ch (int): Channel number of outputs. Default: 3. 17 | num_feat (int): Channel number of intermediate features. Default: 64. 18 | num_conv (int): Number of convolution layers in the body network. Default: 16. 19 | upscale (int): Upsampling factor. Default: 4. 20 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 21 | """ 22 | 23 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 24 | super(SRVGGNetCompact, self).__init__() 25 | self.num_in_ch = num_in_ch 26 | self.num_out_ch = num_out_ch 27 | self.num_feat = num_feat 28 | self.num_conv = num_conv 29 | self.upscale = upscale 30 | self.act_type = act_type 31 | 32 | self.body = nn.ModuleList() 33 | # the first conv 34 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 35 | # the first activation 36 | if act_type == 'relu': 37 | activation = nn.ReLU(inplace=True) 38 | elif act_type == 'prelu': 39 | activation = nn.PReLU(num_parameters=num_feat) 40 | elif act_type == 'leakyrelu': 41 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 42 | self.body.append(activation) 43 | 44 | # the body structure 45 | for _ in range(num_conv): 46 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 47 | # activation 48 | if act_type == 'relu': 49 | activation = nn.ReLU(inplace=True) 50 | elif act_type == 'prelu': 51 | activation = nn.PReLU(num_parameters=num_feat) 52 | elif act_type == 'leakyrelu': 53 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 54 | self.body.append(activation) 55 | 56 | # the last conv 57 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 58 | # upsample 59 | self.upsampler = nn.PixelShuffle(upscale) 60 | 61 | def forward(self, x): 62 | out = x 63 | for i in range(0, len(self.body)): 64 | out = self.body[i](out) 65 | 66 | out = self.upsampler(out) 67 | # add the nearest upsampled image, so that the network learns the residual 68 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 69 | out += base 70 | return out 71 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/ffhq_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | from os import path as osp 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.transforms import augment 8 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 9 | from basicsr.utils.registry import DATASET_REGISTRY 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class FFHQDataset(data.Dataset): 14 | """FFHQ dataset for StyleGAN. 15 | 16 | Args: 17 | opt (dict): Config for train datasets. It contains the following keys: 18 | dataroot_gt (str): Data root path for gt. 19 | io_backend (dict): IO backend type and other kwarg. 20 | mean (list | tuple): Image mean. 21 | std (list | tuple): Image std. 22 | use_hflip (bool): Whether to horizontally flip. 23 | 24 | """ 25 | 26 | def __init__(self, opt): 27 | super(FFHQDataset, self).__init__() 28 | self.opt = opt 29 | # file client (io backend) 30 | self.file_client = None 31 | self.io_backend_opt = opt['io_backend'] 32 | 33 | self.gt_folder = opt['dataroot_gt'] 34 | self.mean = opt['mean'] 35 | self.std = opt['std'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = self.gt_folder 39 | if not self.gt_folder.endswith('.lmdb'): 40 | raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 41 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 42 | self.paths = [line.split('.')[0] for line in fin] 43 | else: 44 | # FFHQ has 70000 images in total 45 | self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load gt image 52 | gt_path = self.paths[index] 53 | # avoid errors caused by high latency in reading files 54 | retry = 3 55 | while retry > 0: 56 | try: 57 | img_bytes = self.file_client.get(gt_path) 58 | except Exception as e: 59 | logger = get_root_logger() 60 | logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}') 61 | # change another file to read 62 | index = random.randint(0, self.__len__()) 63 | gt_path = self.paths[index] 64 | time.sleep(1) # sleep 1s for occasional server congestion 65 | else: 66 | break 67 | finally: 68 | retry -= 1 69 | img_gt = imfrombytes(img_bytes, float32=True) 70 | 71 | # random horizontal flip 72 | img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False) 73 | # BGR to RGB, HWC to CHW, numpy to tensor 74 | img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True) 75 | # normalize 76 | normalize(img_gt, self.mean, self.std, inplace=True) 77 | return {'gt': img_gt, 'gt_path': gt_path} 78 | 79 | def __len__(self): 80 | return len(self.paths) 81 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 1. 'lmdb': Use lmdb files. 18 | If opt['io_backend'] == lmdb. 19 | 2. 'meta_info_file': Use meta information file to generate paths. 20 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 21 | 3. 'folder': Scan folders to generate paths. 22 | The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info_file (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 35 | 36 | scale (bool): Scale, which will be added automatically. 37 | phase (str): 'train' or 'val'. 38 | """ 39 | 40 | def __init__(self, opt): 41 | super(PairedImageDataset, self).__init__() 42 | self.opt = opt 43 | # file client (io backend) 44 | self.file_client = None 45 | self.io_backend_opt = opt['io_backend'] 46 | self.mean = opt['mean'] if 'mean' in opt else None 47 | self.std = opt['std'] if 'std' in opt else None 48 | 49 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 50 | if 'filename_tmpl' in opt: 51 | self.filename_tmpl = opt['filename_tmpl'] 52 | else: 53 | self.filename_tmpl = '{}' 54 | 55 | if self.io_backend_opt['type'] == 'lmdb': 56 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 57 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 58 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 59 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 60 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 61 | self.opt['meta_info_file'], self.filename_tmpl) 62 | else: 63 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 64 | 65 | def __getitem__(self, index): 66 | if self.file_client is None: 67 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 68 | 69 | scale = self.opt['scale'] 70 | 71 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 72 | # image range: [0, 1], float32. 73 | gt_path = self.paths[index]['gt_path'] 74 | img_bytes = self.file_client.get(gt_path, 'gt') 75 | img_gt = imfrombytes(img_bytes, float32=True) 76 | lq_path = self.paths[index]['lq_path'] 77 | img_bytes = self.file_client.get(lq_path, 'lq') 78 | img_lq = imfrombytes(img_bytes, float32=True) 79 | 80 | # augmentation for training 81 | if self.opt['phase'] == 'train': 82 | gt_size = self.opt['gt_size'] 83 | # random crop 84 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 85 | # flip, rotation 86 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 87 | 88 | # color space transform 89 | if 'color' in self.opt and self.opt['color'] == 'y': 90 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] 91 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] 92 | 93 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 94 | # TODO: It is better to update the datasets, rather than force to crop 95 | if self.opt['phase'] != 'train': 96 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 97 | 98 | # BGR to RGB, HWC to CHW, numpy to tensor 99 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 100 | # normalize 101 | if self.mean is not None or self.std is not None: 102 | normalize(img_lq, self.mean, self.std, inplace=True) 103 | normalize(img_gt, self.mean, self.std, inplace=True) 104 | 105 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 106 | 107 | def __len__(self): 108 | return len(self.paths) 109 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /basicsr/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 6 | from basicsr.data.transforms import augment, paired_random_crop 7 | from basicsr.utils import FileClient, imfrombytes, img2tensor 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register(suffix='basicsr') 12 | class RealESRGANPairedDataset(data.Dataset): 13 | """Paired image dataset for image restoration. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 16 | 17 | There are three modes: 18 | 1. 'lmdb': Use lmdb files. 19 | If opt['io_backend'] == lmdb. 20 | 2. 'meta_info': Use meta information file to generate paths. 21 | If opt['io_backend'] != lmdb and opt['meta_info'] is not None. 22 | 3. 'folder': Scan folders to generate paths. 23 | The rest. 24 | 25 | Args: 26 | opt (dict): Config for train datasets. It contains the following keys: 27 | dataroot_gt (str): Data root path for gt. 28 | dataroot_lq (str): Data root path for lq. 29 | meta_info (str): Path for meta information file. 30 | io_backend (dict): IO backend type and other kwarg. 31 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 32 | Default: '{}'. 33 | gt_size (int): Cropped patched size for gt patches. 34 | use_hflip (bool): Use horizontal flips. 35 | use_rot (bool): Use rotation (use vertical flip and transposing h 36 | and w for implementation). 37 | 38 | scale (bool): Scale, which will be added automatically. 39 | phase (str): 'train' or 'val'. 40 | """ 41 | 42 | def __init__(self, opt): 43 | super(RealESRGANPairedDataset, self).__init__() 44 | self.opt = opt 45 | self.file_client = None 46 | self.io_backend_opt = opt['io_backend'] 47 | # mean and std for normalizing the input images 48 | self.mean = opt['mean'] if 'mean' in opt else None 49 | self.std = opt['std'] if 'std' in opt else None 50 | 51 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 52 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 53 | 54 | # file client (lmdb io backend) 55 | if self.io_backend_opt['type'] == 'lmdb': 56 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 57 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 58 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 59 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 60 | # disk backend with meta_info 61 | # Each line in the meta_info describes the relative path to an image 62 | with open(self.opt['meta_info']) as fin: 63 | paths = [line.strip() for line in fin] 64 | self.paths = [] 65 | for path in paths: 66 | gt_path, lq_path = path.split(', ') 67 | gt_path = os.path.join(self.gt_folder, gt_path) 68 | lq_path = os.path.join(self.lq_folder, lq_path) 69 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 70 | else: 71 | # disk backend 72 | # it will scan the whole folder to get meta info 73 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 74 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 75 | 76 | def __getitem__(self, index): 77 | if self.file_client is None: 78 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 79 | 80 | scale = self.opt['scale'] 81 | 82 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 83 | # image range: [0, 1], float32. 84 | gt_path = self.paths[index]['gt_path'] 85 | img_bytes = self.file_client.get(gt_path, 'gt') 86 | img_gt = imfrombytes(img_bytes, float32=True) 87 | lq_path = self.paths[index]['lq_path'] 88 | img_bytes = self.file_client.get(lq_path, 'lq') 89 | img_lq = imfrombytes(img_bytes, float32=True) 90 | 91 | # augmentation for training 92 | if self.opt['phase'] == 'train': 93 | gt_size = self.opt['gt_size'] 94 | # random crop 95 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 96 | # flip, rotation 97 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 98 | 99 | # BGR to RGB, HWC to CHW, numpy to tensor 100 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 101 | # normalize 102 | if self.mean is not None or self.std is not None: 103 | normalize(img_lq, self.mean, self.std, inplace=True) 104 | normalize(img_gt, self.mean, self.std, inplace=True) 105 | 106 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 107 | 108 | def __len__(self): 109 | return len(self.paths) 110 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SingleImageDataset(data.Dataset): 12 | """Read only lq images in the test phase. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 15 | 16 | There are two modes: 17 | 1. 'meta_info_file': Use meta information file to generate paths. 18 | 2. 'folder': Scan folders to generate paths. 19 | 20 | Args: 21 | opt (dict): Config for train datasets. It contains the following keys: 22 | dataroot_lq (str): Data root path for lq. 23 | meta_info_file (str): Path for meta information file. 24 | io_backend (dict): IO backend type and other kwarg. 25 | """ 26 | 27 | def __init__(self, opt): 28 | super(SingleImageDataset, self).__init__() 29 | self.opt = opt 30 | # file client (io backend) 31 | self.file_client = None 32 | self.io_backend_opt = opt['io_backend'] 33 | self.mean = opt['mean'] if 'mean' in opt else None 34 | self.std = opt['std'] if 'std' in opt else None 35 | self.lq_folder = opt['dataroot_lq'] 36 | 37 | if self.io_backend_opt['type'] == 'lmdb': 38 | self.io_backend_opt['db_paths'] = [self.lq_folder] 39 | self.io_backend_opt['client_keys'] = ['lq'] 40 | self.paths = paths_from_lmdb(self.lq_folder) 41 | elif 'meta_info_file' in self.opt: 42 | with open(self.opt['meta_info_file'], 'r') as fin: 43 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 44 | else: 45 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | # load lq image 52 | lq_path = self.paths[index] 53 | img_bytes = self.file_client.get(lq_path, 'lq') 54 | img_lq = imfrombytes(img_bytes, float32=True) 55 | 56 | # color space transform 57 | if 'color' in self.opt and self.opt['color'] == 'y': 58 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 59 | 60 | # BGR to RGB, HWC to CHW, numpy to tensor 61 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 62 | # normalize 63 | if self.mean is not None or self.std is not None: 64 | normalize(img_lq, self.mean, self.std, inplace=True) 65 | return {'lq': img_lq, 'lq_path': lq_path} 66 | 67 | def __len__(self): 68 | return len(self.paths) 69 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are 'none', 'mean' and 'sum'. 12 | 13 | Returns: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | else: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | 'none', 'mean' and 'sum'. Default: 'mean'. 34 | 35 | Returns: 36 | Tensor: Loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 42 | loss = loss * weight 43 | 44 | # if weight is not specified or reduction is sum, just reduce the loss 45 | if weight is None or reduction == 'sum': 46 | loss = reduce_loss(loss, reduction) 47 | # if reduction is mean, then compute mean over weight region 48 | elif reduction == 'mean': 49 | if weight.size(1) > 1: 50 | weight = weight.sum() 51 | else: 52 | weight = weight.sum() * loss.size(1) 53 | loss = loss.sum() / weight 54 | 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.5000) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, reduction='sum') 86 | tensor(3.) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 91 | # get element-wise loss 92 | loss = loss_func(pred, target, **kwargs) 93 | loss = weight_reduce_loss(loss, weight, reduction) 94 | return loss 95 | 96 | return wrapper 97 | 98 | 99 | def get_local_weights(residual, ksize): 100 | """Get local weights for generating the artifact map of LDL. 101 | 102 | It is only called by the `get_refined_artifact_map` function. 103 | 104 | Args: 105 | residual (Tensor): Residual between predicted and ground truth images. 106 | ksize (Int): size of the local window. 107 | 108 | Returns: 109 | Tensor: weight for each pixel to be discriminated as an artifact pixel 110 | """ 111 | 112 | pad = (ksize - 1) // 2 113 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') 114 | 115 | unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) 116 | pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) 117 | 118 | return pixel_level_weight 119 | 120 | 121 | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): 122 | """Calculate the artifact map of LDL 123 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) 124 | 125 | Args: 126 | img_gt (Tensor): ground truth images. 127 | img_output (Tensor): output images given by the optimizing model. 128 | img_ema (Tensor): output images given by the ema model. 129 | ksize (Int): size of the local window. 130 | 131 | Returns: 132 | overall_weight: weight for each pixel to be discriminated as an artifact pixel 133 | (calculated based on both local and global observations). 134 | """ 135 | 136 | residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) 137 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) 138 | 139 | patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) 140 | pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) 141 | overall_weight = patch_level_weight * pixel_level_weight 142 | 143 | overall_weight[residual_sr < residual_ema] = 0 144 | 145 | return overall_weight 146 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .niqe import calculate_niqe 5 | from .psnr_ssim import calculate_psnr, calculate_ssim 6 | 7 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 8 | 9 | 10 | def calculate_metric(data, opt): 11 | """Calculate metric from data and options. 12 | 13 | Args: 14 | opt (dict): Configuration. It must contain: 15 | type (str): Model type. 16 | """ 17 | opt = deepcopy(opt) 18 | metric_type = opt.pop('type') 19 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 20 | return metric 21 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 54 | and X_2 ~ N(mu_2, C_2) is 55 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 56 | Stable version by Dougal J. Sutherland. 57 | 58 | Args: 59 | mu1 (np.array): The sample mean over activations. 60 | sigma1 (np.array): The covariance matrix over activations for 61 | generated samples. 62 | mu2 (np.array): The sample mean over activations, precalculated on an 63 | representative data set. 64 | sigma2 (np.array): The covariance matrix over activations, 65 | precalculated on an representative data set. 66 | 67 | Returns: 68 | float: The Frechet Distance. 69 | """ 70 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 71 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 72 | 73 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 74 | 75 | # Product might be almost singular 76 | if not np.isfinite(cov_sqrt).all(): 77 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 78 | offset = np.eye(sigma1.shape[0]) * eps 79 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 80 | 81 | # Numerical error might give slight imaginary component 82 | if np.iscomplexobj(cov_sqrt): 83 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 84 | m = np.max(np.abs(cov_sqrt.imag)) 85 | raise ValueError(f'Imaginary component {m}') 86 | cov_sqrt = cov_sqrt.real 87 | 88 | mean_diff = mu1 - mu2 89 | mean_norm = mean_diff @ mean_diff 90 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 91 | fid = mean_norm + trace 92 | 93 | return fid 94 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/niqe_pris_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/basicsr/metrics/niqe_pris_params.npz -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /basicsr/models/edvr_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils import get_root_logger 2 | from basicsr.utils.registry import MODEL_REGISTRY 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class EDVRModel(VideoBaseModel): 8 | """EDVR Model. 9 | 10 | Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501 11 | """ 12 | 13 | def __init__(self, opt): 14 | super(EDVRModel, self).__init__(opt) 15 | if self.is_train: 16 | self.train_tsa_iter = opt['train'].get('tsa_iter') 17 | 18 | def setup_optimizers(self): 19 | train_opt = self.opt['train'] 20 | dcn_lr_mul = train_opt.get('dcn_lr_mul', 1) 21 | logger = get_root_logger() 22 | logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.') 23 | if dcn_lr_mul == 1: 24 | optim_params = self.net_g.parameters() 25 | else: # separate dcn params and normal params for different lr 26 | normal_params = [] 27 | dcn_params = [] 28 | for name, param in self.net_g.named_parameters(): 29 | if 'dcn' in name: 30 | dcn_params.append(param) 31 | else: 32 | normal_params.append(param) 33 | optim_params = [ 34 | { # add normal params first 35 | 'params': normal_params, 36 | 'lr': train_opt['optim_g']['lr'] 37 | }, 38 | { 39 | 'params': dcn_params, 40 | 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul 41 | }, 42 | ] 43 | 44 | optim_type = train_opt['optim_g'].pop('type') 45 | self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) 46 | self.optimizers.append(self.optimizer_g) 47 | 48 | def optimize_parameters(self, current_iter): 49 | if self.train_tsa_iter: 50 | if current_iter == 1: 51 | logger = get_root_logger() 52 | logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.') 53 | for name, param in self.net_g.named_parameters(): 54 | if 'fusion' not in name: 55 | param.requires_grad = False 56 | elif current_iter == self.train_tsa_iter: 57 | logger = get_root_logger() 58 | logger.warning('Train all the parameters.') 59 | for param in self.net_g.parameters(): 60 | param.requires_grad = True 61 | 62 | super(EDVRModel, self).optimize_parameters(current_iter) 63 | -------------------------------------------------------------------------------- /basicsr/models/esrgan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .srgan_model import SRGANModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class ESRGANModel(SRGANModel): 10 | """ESRGAN model for single image super-resolution.""" 11 | 12 | def optimize_parameters(self, current_iter): 13 | # optimize net_g 14 | for p in self.net_d.parameters(): 15 | p.requires_grad = False 16 | 17 | self.optimizer_g.zero_grad() 18 | self.output = self.net_g(self.lq) 19 | 20 | l_g_total = 0 21 | loss_dict = OrderedDict() 22 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 23 | # pixel loss 24 | if self.cri_pix: 25 | l_g_pix = self.cri_pix(self.output, self.gt) 26 | l_g_total += l_g_pix 27 | loss_dict['l_g_pix'] = l_g_pix 28 | # perceptual loss 29 | if self.cri_perceptual: 30 | l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt) 31 | if l_g_percep is not None: 32 | l_g_total += l_g_percep 33 | loss_dict['l_g_percep'] = l_g_percep 34 | if l_g_style is not None: 35 | l_g_total += l_g_style 36 | loss_dict['l_g_style'] = l_g_style 37 | # gan loss (relativistic gan) 38 | real_d_pred = self.net_d(self.gt).detach() 39 | fake_g_pred = self.net_d(self.output) 40 | l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False) 41 | l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False) 42 | l_g_gan = (l_g_real + l_g_fake) / 2 43 | 44 | l_g_total += l_g_gan 45 | loss_dict['l_g_gan'] = l_g_gan 46 | 47 | l_g_total.backward() 48 | self.optimizer_g.step() 49 | 50 | # optimize net_d 51 | for p in self.net_d.parameters(): 52 | p.requires_grad = True 53 | 54 | self.optimizer_d.zero_grad() 55 | # gan loss (relativistic gan) 56 | 57 | # In order to avoid the error in distributed training: 58 | # "Error detected in CudnnBatchNormBackward: RuntimeError: one of 59 | # the variables needed for gradient computation has been modified by 60 | # an inplace operation", 61 | # we separate the backwards for real and fake, and also detach the 62 | # tensor for calculating mean. 63 | 64 | # real 65 | fake_d_pred = self.net_d(self.output).detach() 66 | real_d_pred = self.net_d(self.gt) 67 | l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5 68 | l_d_real.backward() 69 | # fake 70 | fake_d_pred = self.net_d(self.output.detach()) 71 | l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5 72 | l_d_fake.backward() 73 | self.optimizer_d.step() 74 | 75 | loss_dict['l_d_real'] = l_d_real 76 | loss_dict['l_d_fake'] = l_d_fake 77 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 78 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 79 | 80 | self.log_dict = self.reduce_loss_dict(loss_dict) 81 | 82 | if self.ema_decay > 0: 83 | self.model_ema(decay=self.ema_decay) 84 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/models/swinir_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import MODEL_REGISTRY 5 | from .sr_model import SRModel 6 | 7 | 8 | @MODEL_REGISTRY.register() 9 | class SwinIRModel(SRModel): 10 | 11 | def test(self): 12 | # pad to multiplication of window_size 13 | window_size = self.opt['network_g']['window_size'] 14 | scale = self.opt.get('scale', 1) 15 | mod_pad_h, mod_pad_w = 0, 0 16 | _, _, h, w = self.lq.size() 17 | if h % window_size != 0: 18 | mod_pad_h = window_size - h % window_size 19 | if w % window_size != 0: 20 | mod_pad_w = window_size - w % window_size 21 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 22 | if hasattr(self, 'net_g_ema'): 23 | self.net_g_ema.eval() 24 | with torch.no_grad(): 25 | self.output = self.net_g_ema(img) 26 | else: 27 | self.net_g.eval() 28 | with torch.no_grad(): 29 | self.output = self.net_g(img) 30 | self.net_g.train() 31 | 32 | _, _, h, w = self.output.size() 33 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] 34 | -------------------------------------------------------------------------------- /basicsr/models/video_gan_model.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import MODEL_REGISTRY 2 | from .srgan_model import SRGANModel 3 | from .video_base_model import VideoBaseModel 4 | 5 | 6 | @MODEL_REGISTRY.register() 7 | class VideoGANModel(SRGANModel, VideoBaseModel): 8 | """Video GAN model. 9 | 10 | Use multiple inheritance. 11 | It will first use the functions of SRGANModel: 12 | init_training_settings 13 | setup_optimizers 14 | optimize_parameters 15 | save 16 | Then find functions in VideoBaseModel. 17 | """ 18 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(get_env_info()) 23 | logger.info(dict2str(opt)) 24 | 25 | # create test dataset and dataloader 26 | test_loaders = [] 27 | for _, dataset_opt in sorted(opt['datasets'].items()): 28 | test_set = build_dataset(dataset_opt) 29 | test_loader = build_dataloader( 30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 32 | test_loaders.append(test_loader) 33 | 34 | # create model 35 | model = build_model(opt) 36 | 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info(f'Testing {test_set_name}...') 40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 41 | 42 | 43 | if __name__ == '__main__': 44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 45 | test_pipeline(root_path) 46 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | 9 | __all__ = [ 10 | # color_util.py 11 | 'bgr2ycbcr', 12 | 'rgb2ycbcr', 13 | 'rgb2ycbcr_pt', 14 | 'ycbcr2bgr', 15 | 'ycbcr2rgb', 16 | # file_client.py 17 | 'FileClient', 18 | # img_util.py 19 | 'img2tensor', 20 | 'tensor2img', 21 | 'imfrombytes', 22 | 'imwrite', 23 | 'crop_border', 24 | # logger.py 25 | 'MessageLogger', 26 | 'AvgTimer', 27 | 'init_tb_logger', 28 | 'init_wandb_logger', 29 | 'get_root_logger', 30 | 'get_env_info', 31 | # misc.py 32 | 'set_random_seed', 33 | 'get_time_str', 34 | 'mkdir_and_rename', 35 | 'make_exp_dirs', 36 | 'scandir', 37 | 'check_resume', 38 | 'sizeof_fmt', 39 | # diffjpeg 40 | 'DiffJPEG', 41 | # img_process_util 42 | 'USMSharp', 43 | 'usm_sharp' 44 | ] 45 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Ref: 15 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 16 | 17 | Args: 18 | file_id (str): File id. 19 | save_path (str): Save path. 20 | """ 21 | 22 | session = requests.Session() 23 | URL = 'https://docs.google.com/uc?export=download' 24 | params = {'id': file_id} 25 | 26 | response = session.get(URL, params=params, stream=True) 27 | token = get_confirm_token(response) 28 | if token: 29 | params['confirm'] = token 30 | response = session.get(URL, params=params, stream=True) 31 | 32 | # get file size 33 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 34 | if 'Content-Range' in response_file_size.headers: 35 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 36 | else: 37 | file_size = None 38 | 39 | save_response_content(response, save_path, file_size) 40 | 41 | 42 | def get_confirm_token(response): 43 | for key, value in response.cookies.items(): 44 | if key.startswith('download_warning'): 45 | return value 46 | return None 47 | 48 | 49 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 50 | if file_size is not None: 51 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 52 | 53 | readable_file_size = sizeof_fmt(file_size) 54 | else: 55 | pbar = None 56 | 57 | with open(destination, 'wb') as f: 58 | downloaded_size = 0 59 | for chunk in response.iter_content(chunk_size): 60 | downloaded_size += chunk_size 61 | if pbar is not None: 62 | pbar.update(1) 63 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 64 | if chunk: # filter out keep-alive new chunks 65 | f.write(chunk) 66 | if pbar is not None: 67 | pbar.close() 68 | 69 | 70 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 71 | """Load file form http url, will download models if necessary. 72 | 73 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 74 | 75 | Args: 76 | url (str): URL to be downloaded. 77 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 78 | Default: None. 79 | progress (bool): Whether to show the download progress. Default: True. 80 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 81 | 82 | Returns: 83 | str: The path to the downloaded file. 84 | """ 85 | if model_dir is None: # use the pytorch hub_dir 86 | hub_dir = get_dir() 87 | model_dir = os.path.join(hub_dir, 'checkpoints') 88 | 89 | os.makedirs(model_dir, exist_ok=True) 90 | 91 | parts = urlparse(url) 92 | filename = os.path.basename(parts.path) 93 | if file_name is not None: 94 | filename = file_name 95 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 96 | if not os.path.exists(cached_file): 97 | print(f'Downloading: "{url}" to {cached_file}\n') 98 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 99 | return cached_file 100 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Ref: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/\ 72 | tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 73 | 74 | Args: 75 | values (list): A list of values to be smoothed. 76 | smooth_weight (float): Smooth weight. 77 | """ 78 | values_sm = [] 79 | last_sm_value = values[0] 80 | for value in values: 81 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 82 | values_sm.append(value_sm) 83 | last_sm_value = value_sm 84 | return values_sm 85 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Tue Sep 19 09:47:20 2023 3 | __version__ = '1.4.2' 4 | __gitsha__ = 'unknown' 5 | version_info = (1, 4, 2) 6 | -------------------------------------------------------------------------------- /complex_data_demo/dc_mask/get_dc_mask.m: -------------------------------------------------------------------------------- 1 | clc; 2 | clear; 3 | 4 | mask_2 = zeros([256,256]); 5 | mask_2(64:191,64:191)=1; 6 | mask_4 = zeros([256,256]); 7 | mask_4(96:159,96:159)=1; 8 | 9 | mas_2 = logical(mask_2); 10 | a_2 = sum(mas_2(:))/256/256; 11 | 12 | mas_4 = logical(mask_4); 13 | a_4 = sum(mas_4(:))/256/256; 14 | 15 | lr_mask = mask_2; 16 | % lr_mask = mask_4; 17 | 18 | save(['lr_2x.mat'],'lr_mask'); 19 | 20 | -------------------------------------------------------------------------------- /complex_data_demo/dc_mask/lr_2x.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/dc_mask/lr_2x.mat -------------------------------------------------------------------------------- /complex_data_demo/dc_mask/lr_4x.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/dc_mask/lr_4x.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_01.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_01.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_02.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_02.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_03.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_03.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_04.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_04.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_05.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_05.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_06.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_06.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_07.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_07.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/train/data_02_08.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/train/data_02_08.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/valid/data_02_09.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/valid/data_02_09.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI/valid/data_02_10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/complex_data_demo/fastMRI/valid/data_02_10.mat -------------------------------------------------------------------------------- /complex_data_demo/fastMRI_process.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all; 4 | 5 | %%%% Please download the data in dicom format from the FastMRI public dataset. 6 | path = '/Users/yymacpro13/Desktop/MRI/DuSR_data/select_knee_2/'; 7 | d1 = dir(path); 8 | shape_1=[]; 9 | shape_2=[]; 10 | 11 | for i= [2,6,13,15,24,27,28,31,33,34,40,45,47,61,64,73,79,81,84,95] 12 | data_name = d1(i+3).name; 13 | d2 = dir(strcat(path,data_name,'/')); 14 | mc_name_1 = d2(3).name; 15 | mc_name_2 = d2(4).name; 16 | dir_mc_name_1 = dir(strcat(path,data_name,'/',mc_name_1,'/')); 17 | dir_mc_name_2 = dir(strcat(path,data_name,'/',mc_name_2,'/')); 18 | 19 | for j=length(dir_mc_name_1)-2 20 | 21 | pattern_g = '.dcm'; 22 | str_1 = regexp(dir_mc_name_1(j+2).name, pattern_g, 'split'); 23 | str_2 = regexp(dir_mc_name_2(j+2).name, pattern_g, 'split'); 24 | 25 | dir_mc_1 = dicominfo(strcat(path,data_name,'/',mc_name_1,'/',str_1{1,1},'')); 26 | dir_mc_2 = dicominfo(strcat(path,data_name,'/',mc_name_2,'/',str_2{1,1},'')); 27 | 28 | shape_1(j)=dir_mc_1.Width; 29 | shape_2(j)=dir_mc_2.Width; 30 | disp(dir_mc_1.ScanOptions); 31 | disp(dir_mc_2.ScanOptions); 32 | 33 | if (strcmp(dir_mc_1.ScanOptions,'FS')) 34 | T2 = dicomread(dir_mc_1); 35 | T1 = dicomread(dir_mc_2); 36 | else 37 | T2 = dicomread(dir_mc_2); 38 | T1 = dicomread(dir_mc_1); 39 | end 40 | 41 | if(dir_mc_1.Width==320) 42 | T1 = T1/max(T1(:)); 43 | T1_ks=fft2c(T1); 44 | %---------T1 256 45 | T1_256_ks = T1_ks(32:287,32:287); 46 | T1_256_im = ifft2c(T1_256_ks); 47 | %---------T1 128 48 | k_T1_128_lr = T1_256_ks(64:191,64:191,:); 49 | im_T1_128_lr = ifft2c(k_T1_128_lr); 50 | %---------T1 64 51 | k_T1_64_lr = k_T1_128_lr(32:95,32:95,:); 52 | im_T1_64_lr = ifft2c(k_T1_64_lr); 53 | %%%%%%%%%%%%%%%%%%%%%%% 54 | T2 = T2/max(T2(:)); 55 | T2_ks = fft2c(T2); 56 | %---------T2 256 57 | T2_256_ks = T2_ks(32:287,32:287); 58 | T2_256_im = ifft2c(T2_256_ks); 59 | %---------T2 128 60 | k_T2_128_lr = T2_256_ks(64:191,64:191,:); 61 | im_T2_128_lr = ifft2c(k_T2_128_lr); 62 | %---------T2 64 63 | k_T2_64_lr = k_T2_128_lr(32:95,32:95,:); 64 | im_T2_64_lr = ifft2c(k_T2_64_lr); 65 | 66 | T1 = T1_256_im; 67 | T1_128 = im_T1_128_lr; 68 | T1_64 = im_T1_64_lr; 69 | 70 | T2 = T2_256_im; 71 | T2_128 = im_T2_128_lr; 72 | T2_64 = im_T2_64_lr; 73 | 74 | mkdir(strcat('mc_knee/valid/')); 75 | k = j-10; 76 | if (j-10<10) 77 | save(strcat('mc_knee/valid/',data_name,'_0',num2str(k),'.mat'),"T1","T1_128","T1_64","T2","T2_128","T2_64"); 78 | else 79 | save(strcat('mc_knee/valid/',data_name,'_',num2str(k),'.mat'),"T1","T1_128","T1_64","T2","T2_128","T2_64"); 80 | end 81 | 82 | end 83 | 84 | 85 | 86 | end 87 | % as(T1); 88 | % as(T2); 89 | 90 | 91 | 92 | 93 | end 94 | -------------------------------------------------------------------------------- /complex_data_demo/fft2c.m: -------------------------------------------------------------------------------- 1 | function res = fft2c(x) 2 | 3 | S = size(x); 4 | fctr = S(1)*S(2); 5 | 6 | x = reshape(x,S(1),S(2),prod(S(3:end))); 7 | 8 | res = zeros(size(x)); 9 | for n=1:size(x,3) 10 | res(:,:,n) = 1/sqrt(fctr)*fftshift(fft2(ifftshift(x(:,:,n)))); 11 | end 12 | 13 | res = reshape(res,S); 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /complex_data_demo/ifft2c.m: -------------------------------------------------------------------------------- 1 | function res = ifft2c(x) 2 | 3 | S = size(x); 4 | fctr = S(1)*S(2); 5 | 6 | x = reshape(x,S(1),S(2),prod(S(3:end))); 7 | 8 | res = zeros(size(x)); 9 | for n=1:size(x,3) 10 | res(:,:,n) = sqrt(fctr)*fftshift(ifft2(ifftshift(x(:,:,n)))); 11 | end 12 | 13 | 14 | res = reshape(res,S); 15 | 16 | -------------------------------------------------------------------------------- /fig/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/fig/model.png -------------------------------------------------------------------------------- /ldm/__pycache__/ddim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/ldm/__pycache__/ddim.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/ddpm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/ldm/__pycache__/ddpm.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/ldm/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/ldm/__pycache__/util2.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /metrics/DISTS/DISTS_pytorch/weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/metrics/DISTS/DISTS_pytorch/weights.pt -------------------------------------------------------------------------------- /metrics/DISTS/DISTS_tensorflow/DISTS_tf.py: -------------------------------------------------------------------------------- 1 | # This is a tensorflow implementation of DISTS metric. 2 | # Requirements: python >= 3.6, tensorflow-gpu >= 1.15 3 | 4 | import tensorflow.compat.v1 as tf 5 | import numpy as np 6 | import time 7 | import scipy.io as scio 8 | from PIL import Image 9 | import argparse 10 | # tf.enable_eager_execution() 11 | tf.disable_eager_execution() 12 | 13 | class DISTS(): 14 | def __init__(self): 15 | self.parameters = scio.loadmat('../weights/net_param.mat') 16 | self.chns = [3,64,128,256,512,512] 17 | self.mean = tf.constant(self.parameters['vgg_mean'], dtype=tf.float32, shape=(1,1,1,3),name="img_mean") 18 | self.std = tf.constant(self.parameters['vgg_std'], dtype=tf.float32, shape=(1,1,1,3),name="img_std") 19 | # self.alpha = tf.Variable(tf.random_normal(shape=(1,1,1,sum(self.chns)), mean=0.1, stddev=0.01),name="alpha") 20 | # self.beta = tf.Variable(tf.random_normal(shape=(1,1,1,sum(self.chns)), mean=0.1, stddev=0.01),name="beta") 21 | self.weights = scio.loadmat('../weights/alpha_beta.mat') 22 | self.alpha = tf.constant(np.reshape(self.weights['alpha'],(1,1,1,sum(self.chns))),name="alpha") 23 | self.beta = tf.constant(np.reshape(self.weights['beta'],(1,1,1,sum(self.chns))),name="beta") 24 | 25 | def get_features(self, img): 26 | 27 | x = (img - self.mean)/self.std 28 | 29 | self.conv1_1 = self.conv_layer(x, "conv1_1") 30 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 31 | self.pool1 = self.pool_layer(self.conv1_2, name="pool_1") 32 | 33 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 34 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 35 | self.pool2 = self.pool_layer(self.conv2_2, name="pool_2") 36 | 37 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 38 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 39 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 40 | self.pool3 = self.pool_layer(self.conv3_3, name="pool_3") 41 | 42 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 43 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 44 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 45 | self.pool4 = self.pool_layer(self.conv4_3, name="pool_4") 46 | 47 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 48 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 49 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 50 | 51 | return [img, self.conv1_2,self.conv2_2,self.conv3_3,self.conv4_3,self.conv5_3] 52 | 53 | def conv_layer(self, input, name): 54 | with tf.variable_scope(name) as _: 55 | filter = self.get_conv_filter(name) 56 | conv = tf.nn.conv2d(input, filter, strides=1, padding="SAME") 57 | bias = self.get_bias(name) 58 | conv = tf.nn.relu(tf.nn.bias_add(conv, bias)) 59 | return conv 60 | 61 | def pool_layer(self, input, name): 62 | # return tf.nn.max_pool(input, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 63 | with tf.variable_scope(name) as _: 64 | filter = tf.squeeze(tf.constant(self.parameters['L2'+name], name = "filter"),3) 65 | conv = tf.nn.conv2d(input**2, filter, strides=2, padding=[[0, 0], [1, 0], [1, 0], [0, 0]]) 66 | return tf.sqrt(tf.maximum(conv, 1e-12)) 67 | 68 | def get_conv_filter(self, name): 69 | return tf.constant(self.parameters[name+'_weight'], name = "filter") 70 | 71 | def get_bias(self, name): 72 | return tf.constant(np.squeeze(self.parameters[name+'_bias']), name = "bias") 73 | 74 | def get_score(self, img1, img2): 75 | feats0 = self.get_features(img1) 76 | feats1 = self.get_features(img2) 77 | dist1 = 0 78 | dist2 = 0 79 | c1 = 1e-6 80 | c2 = 1e-6 81 | w_sum = tf.reduce_sum(self.alpha) + tf.reduce_sum(self.beta) 82 | alpha = tf.split(self.alpha/w_sum, self.chns, axis=3) 83 | beta = tf.split(self.beta/w_sum, self.chns, axis=3) 84 | for k in range(len(self.chns)): 85 | x_mean = tf.reduce_mean(feats0[k],[1,2], keepdims=True) 86 | y_mean = tf.reduce_mean(feats1[k],[1,2], keepdims=True) 87 | S1 = (2*x_mean*y_mean+c1)/(x_mean**2+y_mean**2+c1) 88 | dist1 = dist1+tf.reduce_sum(alpha[k]*S1, 3, keepdims=True) 89 | x_var = tf.reduce_mean((feats0[k]-x_mean)**2,[1,2], keepdims=True) 90 | y_var = tf.reduce_mean((feats1[k]-y_mean)**2,[1,2], keepdims=True) 91 | xy_cov = tf.reduce_mean(feats0[k]*feats1[k],[1,2], keepdims=True) - x_mean*y_mean 92 | S2 = (2*xy_cov+c2)/(x_var+y_var+c2) 93 | dist2 = dist2+tf.reduce_sum(beta[k]*S2, 3, keepdims=True) 94 | 95 | dist = 1-tf.squeeze(dist1+dist2) 96 | return dist 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--ref', type=str, default='../images/r0.png') 102 | parser.add_argument('--dist', type=str, default='../images/r1.png') 103 | args = parser.parse_args() 104 | model = DISTS() 105 | 106 | ref = np.array(Image.open(args.ref).convert("RGB")) 107 | ref = np.expand_dims(ref,axis=0)/255. 108 | dist = np.array(Image.open(args.dist).convert("RGB")) 109 | dist = np.expand_dims(dist,axis=0)/255. 110 | 111 | x = tf.placeholder(dtype=tf.float32, shape=ref.shape, name= "ref") 112 | y = tf.placeholder(dtype=tf.float32, shape=dist.shape, name= "dist") 113 | score = model.get_score(x,y) 114 | with tf.Session() as sess: 115 | sess.run(tf.global_variables_initializer()) 116 | score = sess.run(score, feed_dict={x: ref, y: dist}) 117 | print(score) 118 | 119 | 120 | -------------------------------------------------------------------------------- /metrics/DISTS/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Keyan Ding 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /metrics/DISTS/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0 -------------------------------------------------------------------------------- /metrics/DISTS/weights/alpha_beta.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/metrics/DISTS/weights/alpha_beta.mat -------------------------------------------------------------------------------- /metrics/LPIPS.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import glob 4 | import numpy as np 5 | import os.path as osp 6 | from torchvision.transforms.functional import normalize 7 | import argparse 8 | from basicsr.utils import img2tensor 9 | 10 | 11 | import lpips 12 | 13 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--folder_gt', type=str, default='/root/datasets/DIV2K100/HR') 18 | parser.add_argument('--folder_restored', type=str, default='/root/results/DIV2K100') 19 | args = parser.parse_args() 20 | 21 | img_list = sorted(glob.glob(osp.join(args.folder_gt, '*.png'))) 22 | lr_list = sorted(glob.glob(osp.join(args.folder_restored, '*.png'))) 23 | 24 | 25 | loss_fn_vgg = lpips.LPIPS(net='alex').cuda() # RGB, normalized to [-1,1] 26 | lpips_all = [] 27 | 28 | mean = [0.5, 0.5, 0.5] 29 | std = [0.5, 0.5, 0.5] 30 | for i, (img_path, lr_path) in enumerate(zip(img_list,lr_list)): 31 | 32 | img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 33 | img_restored = cv2.imread(lr_path, cv2.IMREAD_UNCHANGED).astype( 34 | np.float32) / 255. 35 | 36 | img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True) 37 | _,h,w=img_gt.shape 38 | img_gt=img_gt[:,:h//4*4,:w//4*4] 39 | # norm to [-1, 1] 40 | normalize(img_gt, mean, std, inplace=True) 41 | normalize(img_restored, mean, std, inplace=True) 42 | 43 | # calculate lpips 44 | lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda()) 45 | # print(lpips_val) 46 | lpips_all.append(lpips_val.item()) 47 | 48 | print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}') 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /metrics/PSNR.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import numpy as np 4 | import os.path as osp 5 | from torchvision.transforms.functional import normalize 6 | from basicsr.utils import img2tensor 7 | import argparse 8 | from basicsr.metrics import calculate_psnr, calculate_ssim 9 | 10 | 11 | def main(): 12 | # Configurations 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--folder_gt', type=str, default='/root/datasets/DIV2K100/HR') 15 | parser.add_argument('--folder_restored', type=str, default='/root/results/DIV2K100') 16 | args = parser.parse_args() 17 | psnr_all = [] 18 | ssim_all = [] 19 | img_list = sorted(glob.glob(osp.join(args.folder_gt, '*.png'))) 20 | lr_list = sorted(glob.glob(osp.join(args.folder_restored, '*.png'))) 21 | for i, (img_path, lr_path) in enumerate(zip(img_list,lr_list)): 22 | basename, ext = osp.splitext(osp.basename(img_path)) 23 | img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 24 | img_restored = cv2.imread(osp.join(lr_path), cv2.IMREAD_UNCHANGED) 25 | psnr=calculate_psnr(img_restored, img_gt, crop_border=4, test_y_channel=True) 26 | ssim=calculate_ssim(img_restored, img_gt, crop_border=4, test_y_channel=True) 27 | psnr_all.append(psnr) 28 | ssim_all.append(ssim) 29 | 30 | print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f}') 31 | print(f'Average: SSIM: {sum(ssim_all) / len(ssim_all):.6f}') 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /metrics/PieAPP/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Modified BSD-2 License - for Non-Commercial Use Only 2 | 3 | Copyright (c) 2018, The Regents of the University of California 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, 7 | are permitted for non-commercial use only provided that the following conditions are met: 8 | 1. Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 17 | FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 18 | COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 19 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 20 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 23 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 24 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 25 | POSSIBILITY OF SUCH DAMAGE. 26 | 27 | For permission to use for commercial purposes, please contact UCSB’s 28 | Office of Technology & Industry Alliances at 805-893-5180 or info@tia.ucsb.edu. 29 | -------------------------------------------------------------------------------- /metrics/PieAPP/PieAPP_PT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import sys 4 | import torch 5 | from torch.autograd import Variable 6 | sys.path.append('model/') 7 | from model.PieAPPv0pt1_PT import PieAPP 8 | sys.path.append('utils/') 9 | from utils.image_utils import * 10 | import argparse 11 | import os 12 | 13 | ######## check for model and download if not present 14 | if not os.path.isfile('weights/PieAPPv0.1.pth'): 15 | print("downloading dataset") 16 | os.system("bash scripts/download_PieAPPv0.1_PT_weights.sh") 17 | if not os.path.isfile('weights/PieAPPv0.1.pth'): 18 | print("PieAPPv0.1.pth not downloaded") 19 | sys.exit() 20 | 21 | ######## variables 22 | patch_size = 64 23 | batch_size = 1 24 | 25 | ######## input args 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--ref_path", dest='ref_path', type=str, default='/data1/liangjie/BasicSR_ALL/scripts/metrics/PieAPP/imgs/Ref.png', help="specify input reference") 28 | parser.add_argument("--A_path", dest='A_path', type=str, default='/data1/liangjie/BasicSR_ALL/scripts/metrics/PieAPP/imgs/A.png', help="specify input image") 29 | parser.add_argument("--sampling_mode", dest='sampling_mode', type=str, default='dense', help="specify sparse or dense sampling of patches to compte PieAPP") 30 | parser.add_argument("--gpu_id", dest='gpu_id', type=str, default='3', help="specify which GPU to use") 31 | 32 | args = parser.parse_args() 33 | 34 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 35 | 36 | imagesA = np.expand_dims(cv2.imread(args.A_path),axis =0).astype('float32') 37 | imagesRef = np.expand_dims(cv2.imread(args.ref_path),axis =0).astype('float32') 38 | _,rows,cols,ch = imagesRef.shape 39 | 40 | if args.sampling_mode == 'sparse': 41 | stride_val = 27 42 | else: 43 | stride_val = 6 44 | 45 | try: 46 | gpu_num = float(args.gpu_id) 47 | use_gpu = 1 48 | except ValueError: 49 | use_gpu = 0 50 | except TypeError: 51 | use_gpu = 0 52 | 53 | y_loc = np.concatenate((np.arange(0, rows - patch_size, stride_val),np.array([rows - patch_size])), axis=0) 54 | num_y = len(y_loc) 55 | x_loc = np.concatenate((np.arange(0, cols - patch_size, stride_val),np.array([cols - patch_size])), axis=0) 56 | num_x = len(x_loc) 57 | num_patches_per_dim = 10 58 | num_patches = 10 59 | 60 | 61 | # state_dict = torch.load('weights/PieAPPv0.1.pth') 62 | # for name, weights in state_dict.items(): 63 | # print(name, weights.size()) # 可以查看模型中的模型名字和权重维度 64 | # if name == 'ref_score_subtract.weight': #判断需要修改维度的条件 65 | # state_dict[name] = weights.unsqueeze(0) #去掉维度0,把(1,128)转为(128) 66 | # # print(name,weights.squeeze(0).size()) 查看转化后的模型名字和权重维度 67 | # torch.save(state_dict, 'weights/PieAPPv0.1_.pth') 68 | 69 | ######## initialize the model 70 | PieAPP_net = PieAPP(batch_size, num_patches_per_dim) 71 | PieAPP_net.load_state_dict(torch.load('weights/PieAPPv0.1_.pth')) 72 | 73 | if use_gpu == 1: 74 | PieAPP_net.cuda() 75 | 76 | score_accum = 0.0 77 | weight_accum = 0.0 78 | 79 | # iterate through smaller size sub-images (to prevent memory overload) 80 | for x_iter in range(0, -(-num_x//num_patches)): 81 | for y_iter in range(0, -(-num_y//num_patches)): 82 | # compute the size of the subimage 83 | if (num_patches_per_dim*(x_iter + 1) >= num_x): 84 | size_slice_cols = cols - x_loc[num_patches_per_dim*x_iter] 85 | else: 86 | size_slice_cols = x_loc[num_patches_per_dim*(x_iter + 1)] - x_loc[num_patches_per_dim*x_iter] + patch_size - stride_val 87 | if (num_patches_per_dim*(y_iter + 1) >= num_y): 88 | size_slice_rows = rows - y_loc[num_patches_per_dim*y_iter] 89 | else: 90 | size_slice_rows = y_loc[num_patches_per_dim*(y_iter + 1)] - y_loc[num_patches_per_dim*y_iter] + patch_size - stride_val 91 | # obtain the subimage and samples patches 92 | A_sub_im = imagesA[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:] 93 | ref_sub_im = imagesRef[:, y_loc[num_patches_per_dim*y_iter]:y_loc[num_patches_per_dim*y_iter]+size_slice_rows, x_loc[num_patches_per_dim*x_iter]:x_loc[num_patches_per_dim*x_iter]+size_slice_cols,:] 94 | A_patches, ref_patches = sample_patches(A_sub_im, ref_sub_im, patch_size=64, strideval=stride_val, random_selection=False, uniform_grid_mode = 'strided') 95 | num_patches_curr = A_patches.shape[0]/batch_size 96 | 97 | PieAPP_net.num_patches = num_patches_curr 98 | 99 | # initialize variable to be fed to PieAPP_net 100 | A_patches_var = Variable(torch.from_numpy(np.transpose(A_patches,(0,3,1,2))), requires_grad=False) 101 | ref_patches_var = Variable(torch.from_numpy(np.transpose(ref_patches,(0,3,1,2))), requires_grad=False) 102 | if use_gpu == 1: 103 | A_patches_var = A_patches_var.cuda() 104 | ref_patches_var = ref_patches_var.cuda() 105 | 106 | # forward pass 107 | _, PieAPP_patchwise_errors, PieAPP_patchwise_weights = PieAPP_net.compute_score(A_patches_var.float(), ref_patches_var.float()) 108 | curr_err = PieAPP_patchwise_errors.cpu().data.numpy() 109 | curr_weights = PieAPP_patchwise_weights.cpu().data.numpy() 110 | score_accum += np.sum(np.multiply(curr_err, curr_weights)) 111 | weight_accum += np.sum(curr_weights) 112 | 113 | print('PieAPP value of '+args.A_path+ ' with respect to: '+str(score_accum/weight_accum)) 114 | -------------------------------------------------------------------------------- /metrics/PieAPP/README.md: -------------------------------------------------------------------------------- 1 | # Perceptual Image Error Metric (PieAPP v0.1) 2 | This is the repository for the [**"PieAPP"** metric](http://civc.ucsb.edu/graphics/Papers/CVPR2018_PieAPP/) which measures the perceptual error of a distorted image with respect to a reference and the associated [dataset](https://github.com/prashnani/PerceptualImageError/blob/master/dataset/dataset_README.md). 3 | 4 | Technical details about the metric can be found in our paper "**[PieAPP: Perceptual Image-Error Assessment through Pairwise Preference](https://prashnani.github.io/index_files/Prashnani_CVPR_2018_PieAPP_paper.pdf)**", published at CVPR 2018, and also on the [project webpage](http://civc.ucsb.edu/graphics/Papers/CVPR2018_PieAPP/). The directions to use the metric can be found in this repository. 5 | 6 | 7 | 8 | ## Using PieAPP 9 | In this repo, we provide the Tensorflow and PyTorch implementations of our evaluation code for PieAPP v0.1 along with the trained models. We also provide a Win64 command-line executable. 10 | 11 | UPDATE: The default patch sampling is changed to "dense" in the demo scripts [`test_PieAPP_TF.py`](test_PieAPP_TF.py) and [`test_PieAPP_PT.py`](test_PieAPP_PT.py), (see "Expected input and output" for details). 12 | This is the recommended setting for evaluating PieAPP for its accuracy as compared to other image error evaluation methods since the release of PieAPP. 13 | 14 | ### Dependencies 15 | The code uses Python 2.7, numpy, opencv and PyTorch 0.3.1 (tested with cuda 9.0; wheel can be found [here](https://pytorch.org/get-started/previous-versions/)) (files ending with _PT_) or [Tensorflow](https://www.tensorflow.org/versions/r1.4/) 1.4 (files ending with _TF_). 16 | 17 | ### Expected input and output 18 | The input to PieAPPv0.1 are two images: a reference image, R, and a distorted image, A and the output is the PieAPP value of A with respect to R. PieAPPv0.1 outputs a number that quantifies the perceptual error of A with respect to R. 19 | 20 | Since PieAPPv0.1 is computed based on a weighted combination of the patchwise errors, the number of patches extracted affects the speed and accuracy of the computed error. We have two modes of operation: 21 | - "Dense" sampling (default) : Selects 64x64 patches with a stride of 6 pixels for PieAPP computation; this mode is recommended for performance evaluation of PieAPP for its accuracy as compared to other image error evaluation methods. 22 | - "Sparse" sampling: Selects 64x64 patches with a stride of 27 pixels for PieAPP computation (recommended for high-speed processing, for example when used in a pipeline that requires fast execution time) 23 | 24 | For large images, to avoid holding all sampled patches in memory, we recommend fetching patchwise errors and weights for sub-images followed by a weighted averaging of the patchwise errors to get the overall image error (see demo scripts [`test_PieAPP_TF.py`](test_PieAPP_TF.py) and [`test_PieAPP_PT.py`](test_PieAPP_PT.py)). 25 | 26 | 27 | ### PieAPPv0.1 with Tensorflow: 28 | Script [`test_PieAPP_TF.py`](test_PieAPP_TF.py) demonstrates the inference using Tensorflow. 29 | 30 | Download trained model: 31 | 32 | bash scripts/download_PieAPPv0.1_TF_weights.sh 33 | 34 | Run the demo script: 35 | 36 | python test_PieAPP_TF.py --ref_path --A_path --sampling_mode --gpu_id 37 | 38 | For example: 39 | 40 | python test_PieAPP_TF.py --ref_path imgs/ref.png --A_path imgs/A.png --sampling_mode sparse --gpu_id 0 41 | 42 | 43 | 44 | ### PieAPPv0.1 with PyTorch: 45 | Script [`test_PieAPP_PT.py`](test_PieAPP_PT.py) demonstrates the inference using PyTorch. 46 | 47 | Download trained model: 48 | 49 | bash scripts/download_PieAPPv0.1_PT_weights.sh 50 | 51 | Run the demo script: 52 | 53 | python test_PieAPP_PT.py --ref_path --A_path --sampling_mode --gpu_id 54 | 55 | For example: 56 | 57 | python test_PieAPP_PT.py --ref_path imgs/ref.png --A_path imgs/A.png --sampling_mode sparse --gpu_id 0 58 | 59 | 60 | ### PieAPPv0.1 Win64 command-line executable: 61 | We also provide a Win64 command-line executable for PieAPPv0.1. To use it, [download the executable](https://www.ece.ucsb.edu/~ekta/projects/PieAPPv0.1/PieAPPv0.1_win64_exe.zip), open a Windows command prompt and run the following command: 62 | 63 | PieAPPv0.1 --ref_path --A_path --sampling_mode 64 | 65 | For example: 66 | 67 | PieAPPv0.1 --ref_path imgs/ref.png --A_path imgs/A.png --sampling_mode sparse 68 | 69 | ## The PieAPP dataset 70 | The dataset subdirectory contains information about the PieAPP dataset, terms of usage, and links to downloading the dataset. 71 | 72 | ## Citing PieAPPv0.1 73 | @InProceedings{Prashnani_2018_CVPR, 74 | author = {Prashnani, Ekta and Cai, Hong and Mostofi, Yasamin and Sen, Pradeep}, 75 | title = {PieAPP: Perceptual Image-Error Assessment Through Pairwise Preference}, 76 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 77 | month = {June}, 78 | year = {2018} 79 | } 80 | 81 | ## Acknowledgements 82 | This project was supported in part by NSF grants IIS-1321168 and IIS-1619376, as well as a Fall 2017 AI Grant (awarded to Ekta Prashnani). 83 | -------------------------------------------------------------------------------- /metrics/PieAPP/dataset/TERMS_OF_USE.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuangYuanKK/DiffMSR/df72b5c0ae54340eaccf590ced07e652c1b94347/metrics/PieAPP/dataset/TERMS_OF_USE.pdf -------------------------------------------------------------------------------- /metrics/PieAPP/imgs/images.md: -------------------------------------------------------------------------------- 1 | A.png 2 | ref.png 3 | teaser_PieAPPv0.1.png 4 | -------------------------------------------------------------------------------- /metrics/PieAPP/model/PieAPPv0pt1_PT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | class PieAPP(nn.Module): # How to ensure that everything goes on a GPU? do I need to fetch? 9 | def __init__(self,batch_size,num_patches): 10 | super(PieAPP, self).__init__() 11 | self.conv1 = nn.Conv2d(3,64,3,padding=1) 12 | self.conv2 = nn.Conv2d(64,64,3,padding=1) 13 | self.pool2 = nn.MaxPool2d(2,2) 14 | self.conv3 = nn.Conv2d(64,64,3,padding=1) 15 | self.conv4 = nn.Conv2d(64,128,3,padding=1) 16 | self.pool4 = nn.MaxPool2d(2,2) 17 | self.conv5 = nn.Conv2d(128,128,3,padding=1) 18 | self.conv6 = nn.Conv2d(128,128,3,padding=1) 19 | self.pool6 = nn.MaxPool2d(2,2) 20 | self.conv7 = nn.Conv2d(128,256,3,padding=1) 21 | self.conv8 = nn.Conv2d(256,256,3,padding=1) 22 | self.pool8 = nn.MaxPool2d(2,2) 23 | self.conv9 = nn.Conv2d(256,256,3,padding=1) 24 | self.conv10 = nn.Conv2d(256,512,3,padding=1) 25 | self.pool10 = nn.MaxPool2d(2,2) 26 | self.conv11 = nn.Conv2d(512,512,3,padding=1) 27 | self.fc1_score = nn.Linear(120832, 512) 28 | self.fc2_score = nn.Linear(512,1) 29 | self.fc1_weight = nn.Linear(2048,512) 30 | self.fc2_weight = nn.Linear(512,1) 31 | self.ref_score_subtract = nn.Linear(1, 1) 32 | self.batch_size = batch_size 33 | self.num_patches = num_patches 34 | 35 | def flatten(self,matrix): # takes NxCxHxW input and outputs NxHWC 36 | return matrix.view((self.batch_size*self.num_patches,-1)) 37 | 38 | def compute_features(self,input): 39 | #conv1 -> relu -> conv2 -> relu -> pool2 -> conv3 -> relu 40 | x3 = F.relu(self.conv3(self.pool2(F.relu(self.conv2(F.relu(self.conv1(input))))))) 41 | # conv4 -> relu -> pool4 -> conv5 -> relu 42 | x5 = F.relu(self.conv5(self.pool4(F.relu(self.conv4(x3))))) 43 | # conv6 -> relu -> pool6 -> conv7 -> relu 44 | x7 = F.relu(self.conv7(self.pool6(F.relu(self.conv6(x5))))) 45 | # conv8 -> relu -> pool8 -> conv9 -> relu 46 | x9 = F.relu(self.conv9(self.pool8(F.relu(self.conv8(x7))))) 47 | # conv10 -> relu -> pool10 -> conv11 -> relU 48 | x11 = self.flatten(F.relu(self.conv11(self.pool10(F.relu(self.conv10(x9)))))) 49 | # flatten and concatenate 50 | feature_ms = torch.cat((self.flatten(x3),self.flatten(x5),self.flatten(x7),self.flatten(x9),x11),1) 51 | return feature_ms, x11 52 | 53 | def compute_score(self,image_A_patches, image_ref_patches): 54 | A_multi_scale, A_coarse = self.compute_features(image_A_patches) 55 | ref_multi_scale, ref_coarse = self.compute_features(image_ref_patches) 56 | diff_ms = ref_multi_scale - A_multi_scale 57 | diff_coarse = ref_coarse - A_coarse 58 | # per patch score: fc1_score -> relu -> fc2_score 59 | per_patch_score = self.ref_score_subtract(0.01*self.fc2_score(F.relu(self.fc1_score(diff_ms)))) 60 | per_patch_score.view((-1,self.num_patches)) 61 | # per patch weight: fc1_weight -> relu -> fc2_weight 62 | const = Variable(torch.from_numpy(0.000001*np.ones((1,))).float(), requires_grad=False) 63 | const_cuda = const.cuda() 64 | per_patch_weight = self.fc2_weight(F.relu(self.fc1_weight(diff_coarse)))+const_cuda 65 | per_patch_weight.view((-1,self.num_patches)) 66 | product_val = torch.mul(per_patch_weight,per_patch_score) 67 | dot_product_val = torch.sum(product_val) 68 | norm_factor = torch.sum(per_patch_weight) 69 | return torch.div(dot_product_val, norm_factor), per_patch_score, per_patch_weight 70 | -------------------------------------------------------------------------------- /metrics/PieAPP/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /metrics/PieAPP/run.sh: -------------------------------------------------------------------------------- 1 | python test_PieAPP_TF_folder.py 2 | python test_PieAPP_TF_folder.py 3 | python test_PieAPP_TF_folder.py 4 | python test_PieAPP_TF_folder.py 5 | python test_PieAPP_TF_folder.py 6 | python test_PieAPP_TF_folder.py 7 | python test_PieAPP_TF_folder.py 8 | python test_PieAPP_TF_folder.py 9 | python test_PieAPP_TF_folder.py 10 | python test_PieAPP_TF_folder.py 11 | python test_PieAPP_TF_folder.py 12 | python test_PieAPP_TF_folder.py 13 | python test_PieAPP_TF_folder.py 14 | python test_PieAPP_TF_folder.py 15 | python test_PieAPP_TF_folder.py 16 | python test_PieAPP_TF_folder.py 17 | python test_PieAPP_TF_folder.py 18 | python test_PieAPP_TF_folder.py 19 | python test_PieAPP_TF_folder.py 20 | python test_PieAPP_TF_folder.py 21 | python test_PieAPP_TF_folder.py 22 | python test_PieAPP_TF_folder.py 23 | python test_PieAPP_TF_folder.py 24 | python test_PieAPP_TF_folder.py 25 | python test_PieAPP_TF_folder.py 26 | python test_PieAPP_TF_folder.py 27 | python test_PieAPP_TF_folder.py 28 | python test_PieAPP_TF_folder.py 29 | python test_PieAPP_TF_folder.py 30 | python test_PieAPP_TF_folder.py 31 | python test_PieAPP_TF_folder.py 32 | python test_PieAPP_TF_folder.py 33 | python test_PieAPP_TF_folder.py 34 | python test_PieAPP_TF_folder.py 35 | python test_PieAPP_TF_folder.py 36 | python test_PieAPP_TF_folder.py 37 | python test_PieAPP_TF_folder.py 38 | python test_PieAPP_TF_folder.py 39 | python test_PieAPP_TF_folder.py 40 | python test_PieAPP_TF_folder.py 41 | python test_PieAPP_TF_folder.py 42 | python test_PieAPP_TF_folder.py 43 | python test_PieAPP_TF_folder.py 44 | python test_PieAPP_TF_folder.py 45 | python test_PieAPP_TF_folder.py 46 | python test_PieAPP_TF_folder.py 47 | python test_PieAPP_TF_folder.py 48 | python test_PieAPP_TF_folder.py -------------------------------------------------------------------------------- /metrics/PieAPP/scripts/download_PieAPPv0.1_PT_weights.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir weights 3 | 4 | wget https://web.ece.ucsb.edu/~ekta/projects/PieAPPv0.1/weights/PieAPPv0.1.pth --no-check-certificate 5 | 6 | mv PieAPPv0.1.pth weights/PieAPPv0.1.pth -------------------------------------------------------------------------------- /metrics/PieAPP/scripts/download_PieAPPv0.1_TF_weights.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir weights 3 | 4 | wget https://web.ece.ucsb.edu/~ekta/projects/PieAPPv0.1/weights/PieAPPv0.1_TF.tar.gz --no-check-certificate 5 | 6 | tar -xzf PieAPPv0.1_TF.tar.gz -C weights 7 | rm PieAPPv0.1_TF.tar.gz 8 | -------------------------------------------------------------------------------- /metrics/PieAPP/scripts/download_scripts.md: -------------------------------------------------------------------------------- 1 | download_PieAPPv0.1_PT_weights.sh 2 | download_PieAPPv0.1_TF_weights.sh 3 | -------------------------------------------------------------------------------- /metrics/PieAPP/test_PieAPP_TF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import cv2 4 | import sys 5 | sys.path.append('model/') 6 | from model.PieAPPv0pt1_TF import PieAPP 7 | import argparse 8 | import os 9 | import glob 10 | 11 | ######## check for model and download if not present 12 | if not len(glob.glob('weights/PieAPP_model_v0.1.ckpt.*')) == 3: 13 | # print "downloading dataset" 14 | os.system("bash scripts/download_PieAPPv0.1_TF_weights.sh") 15 | if not len(glob.glob('weights/PieAPP_model_v0.1.ckpt.*')) == 3: 16 | # print "PieAPP_model_v0.1.ckpt files not downloaded" 17 | sys.exit() 18 | 19 | ######## variables 20 | patch_size = 64 21 | batch_size = 1 22 | 23 | ######## input args 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--ref_path", dest='ref_path', type=str, default='/data1/liangjie/BasicSR_ALL/scripts/metrics/PieAPP/imgs/Ref.png', help="specify input reference") 26 | parser.add_argument("--A_path", dest='A_path', type=str, default='/data1/liangjie/BasicSR_ALL/scripts/metrics/PieAPP/imgs/A.png', help="specify input image") 27 | parser.add_argument("--sampling_mode", dest='sampling_mode', type=str, default='dense', help="specify sparse or dense sampling of patches to compte PieAPP") 28 | parser.add_argument("--gpu_id", dest='gpu_id', type=str, default='7', help="specify which GPU to use (don't specify this argument if using CPU only)") 29 | 30 | args = parser.parse_args() 31 | 32 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 33 | 34 | imagesRef = np.expand_dims(cv2.imread(args.ref_path).astype('float32'),axis=0) 35 | imagesA = np.expand_dims(cv2.imread(args.A_path).astype('float32'),axis=0) 36 | _,rows,cols,ch = imagesRef.shape 37 | if args.sampling_mode == 'sparse': 38 | stride_val = 27 39 | if args.sampling_mode == 'dense': 40 | stride_val = 6 41 | y_loc = np.concatenate((np.arange(0, rows - patch_size, stride_val),np.array([rows - patch_size])), axis=0) 42 | num_y = len(y_loc) 43 | x_loc = np.concatenate((np.arange(0, cols - patch_size, stride_val),np.array([cols - patch_size])), axis=0) 44 | num_x = len(x_loc) 45 | num_patches = 10 46 | 47 | ######## TF placeholder for graph input 48 | image_A_batch = tf.placeholder(tf.float32) 49 | image_ref_batch = tf.placeholder(tf.float32) #, [None, rows, cols, ch] 50 | 51 | ######## initialize the model 52 | PieAPP_net = PieAPP(batch_size, args.sampling_mode) 53 | PieAPP_value, patchwise_errors, patchwise_weights = PieAPP_net.forward(image_A_batch, image_ref_batch) 54 | saverPieAPP = tf.train.Saver() 55 | 56 | ######## compute PieAPP 57 | with tf.Session() as sess: 58 | sess.run(tf.local_variables_initializer()) 59 | sess.run(tf.global_variables_initializer()) 60 | saverPieAPP.restore(sess, 'weights/PieAPP_model_v0.1.ckpt') # restore weights 61 | # iterate through smaller size sub-images (to prevent memory overload) 62 | score_accum = 0.0 63 | weight_accum = 0.0 64 | for x_iter in range(0, -(-num_x//num_patches)): 65 | for y_iter in range(0, -(-num_y//num_patches)): 66 | # compute scores on subimage to avoid memory issues 67 | # NOTE if image is 512x512 or smaller, PieAPP_value_fetched below gives the overall PieAPP value 68 | if (num_patches*(x_iter + 1) >= num_x): 69 | size_slice_cols = cols - x_loc[num_patches*x_iter] 70 | else: 71 | size_slice_cols = x_loc[num_patches*(x_iter + 1)] - x_loc[num_patches*x_iter] + patch_size - stride_val 72 | if (num_patches*(y_iter + 1) >= num_y): 73 | size_slice_rows = rows - y_loc[num_patches*y_iter] 74 | else: 75 | size_slice_rows = y_loc[num_patches*(y_iter + 1)] - y_loc[num_patches*y_iter] + patch_size - stride_val 76 | im_A = imagesA[:, y_loc[num_patches*y_iter]:y_loc[num_patches*y_iter]+size_slice_rows, x_loc[num_patches*x_iter]:x_loc[num_patches*x_iter]+size_slice_cols,:] 77 | im_Ref = imagesRef[:, y_loc[num_patches*y_iter]:y_loc[num_patches*y_iter]+size_slice_rows, x_loc[num_patches*x_iter]:x_loc[num_patches*x_iter]+size_slice_cols,:] 78 | # forward pass 79 | PieAPP_value_fetched, PieAPP_patchwise_errors, PieAPP_patchwise_weights = sess.run([PieAPP_value, patchwise_errors, patchwise_weights], 80 | feed_dict={ 81 | image_A_batch: im_A, 82 | image_ref_batch: im_Ref 83 | }) 84 | score_accum += np.sum(np.multiply(PieAPP_patchwise_errors,PieAPP_patchwise_weights),axis=1) 85 | weight_accum += np.sum(PieAPP_patchwise_weights, axis=1) 86 | 87 | print('PieAPP value of '+args.A_path+ ' with respect to: '+str(score_accum/weight_accum)) 88 | 89 | -------------------------------------------------------------------------------- /metrics/PieAPP/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /metrics/PieAPP/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import cv2 4 | 5 | def generate_images_from_list(A_names, B_names, ref_names, gt_labels, num_imgs_to_get, cursor, max_images): 6 | 7 | image_A = cv2.imread(A_names[cursor]) # sloppy 8 | 9 | batch_A = np.zeros((num_imgs_to_get,image_A.shape[0],image_A.shape[1],image_A.shape[2])) 10 | batch_B = np.zeros((num_imgs_to_get,image_A.shape[0],image_A.shape[1],image_A.shape[2])) 11 | batch_ref = np.zeros((num_imgs_to_get,image_A.shape[0],image_A.shape[1],image_A.shape[2])) 12 | batch_label = np.zeros((num_imgs_to_get,1)) 13 | 14 | for n in range(0,num_imgs_to_get): 15 | batch_A[n,:,:,:] = cv2.imread(A_names[cursor]).astype('float32') 16 | # print B_names[cursor] 17 | batch_B[n,:,:,:] = cv2.imread(B_names[cursor]).astype('float32') 18 | batch_ref[n,:,:,:] = cv2.imread(ref_names[cursor]).astype('float32') 19 | batch_label[n] = gt_labels[cursor] 20 | cursor += 1 21 | if cursor == max_images: 22 | cursor = 0 23 | 24 | return batch_A, batch_B, batch_ref, batch_label, cursor 25 | 26 | def sample_patches(batch_A, batch_ref, patch_size=None, patches_per_image=None, seed='', random_selection=True, uniform_grid_mode = 'strided',strideval = None): 27 | 28 | # sampling modes: 29 | # 1. random selection (optionally with a seed value): samples patches_per_image number of patches randomly; required vars: patches_per_image, random_selection=True; optional: seed 30 | # 2. fixed sampling with stride on a uniform grid: select all patches from image with a stride of strideval; required vars: strideVal, random_selection=False 31 | # 3. fixed sampling of patches_per_image patches on a uniform grid: patches from image on a sqrt(patches_per_image)xsqrt(patches_per_image) grid; required vars: patches_per_image, random_selection=False 32 | 33 | num_rows = batch_A.shape[1] 34 | num_cols = batch_A.shape[2] 35 | num_channels = batch_A.shape[3] 36 | batch_size = batch_A.shape[0] 37 | 38 | if not random_selection: 39 | # sequentially select patches_per_image number of patches from ref and image A 40 | if uniform_grid_mode == 'strided': 41 | temp_r = np.int_(np.floor(np.arange(0,num_rows-patch_size+1,strideval))) # patches_per_image is square-rootable 42 | temp_c = np.int_(np.floor(np.arange(0,num_cols-patch_size+1,strideval))) 43 | else: 44 | temp_r = np.int_(np.floor(np.linspace(0,num_rows-patch_size+1,np.sqrt(patches_per_image)))) # patches_per_image is square-rootable 45 | temp_c = np.int_(np.floor(np.linspace(0,num_cols-patch_size+1,np.sqrt(patches_per_image)))) 46 | select_cols,select_rows = np.meshgrid(temp_c,temp_r) 47 | select_cols = np.reshape(select_cols,(select_cols.shape[0]*select_cols.shape[1],)) 48 | select_rows = np.reshape(select_rows,(select_rows.shape[0]*select_rows.shape[1],)) 49 | patches_per_image = select_rows.shape[0] 50 | 51 | # patch output pre-allocation 52 | patch_batch_A = np.zeros((batch_size*patches_per_image,patch_size,patch_size,num_channels)) 53 | patch_batch_ref = np.zeros((batch_size*patches_per_image,patch_size,patch_size,num_channels)) 54 | 55 | location = 0 56 | 57 | for iter_batch in range(0,batch_size): 58 | if random_selection: 59 | # randomly select patches_per_image number of patches 60 | if len(seed) > 0: 61 | np.random.seed(seed) 62 | select_rows = np.random.choice(num_rows-patch_size+1,patches_per_image) 63 | select_cols = np.random.choice(num_cols-patch_size+1,patches_per_image) 64 | 65 | for iter_patch in range(0,patches_per_image): 66 | # load batch A 67 | patch_batch_A[location:location+1,:,:,:] = batch_A[iter_batch,select_rows[iter_patch]:select_rows[iter_patch]+patch_size, 68 | select_cols[iter_patch]:select_cols[iter_patch]+patch_size,:] 69 | # load batch ref 70 | patch_batch_ref[location:location+1,:,:,:] = batch_ref[iter_batch,select_rows[iter_patch]:select_rows[iter_patch]+patch_size, 71 | select_cols[iter_patch]:select_cols[iter_patch]+patch_size,:] 72 | location += 1 73 | 74 | return patch_batch_A, patch_batch_ref 75 | 76 | 77 | -------------------------------------------------------------------------------- /metrics/PieAPP/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def extract_image_patches(image_batch, patch_size,patch_stride): 5 | patches = tf.extract_image_patches(images =image_batch,ksizes=[1,patch_size,patch_size,1],rates=[1,1,1,1],strides=[1,patch_stride,patch_stride,1],padding='VALID') 6 | patches_shape = patches.get_shape().as_list() 7 | return tf.reshape(patches,[-1,patch_size,patch_size,3])#, patches_shape[1]*patches_shape[2] # NOTE: assuming 3 channels 8 | 9 | 10 | def conv_init(name,input_channels, filter_height, filter_width, num_filters, groups=1): 11 | weights = get_scope_variable(name, 'weights', shape=[filter_height, filter_width, input_channels/groups, num_filters], trainable=False) 12 | biases = get_scope_variable(name, 'biases', shape = [num_filters],trainable=False) 13 | 14 | 15 | def fc_init(name, num_in, num_out): 16 | weights = get_scope_variable(name, 'weights', shape=[num_in, num_out], trainable=False) 17 | biases = get_scope_variable(name, 'biases', shape=[num_out], trainable=False) 18 | 19 | 20 | def conv(x, filter_height, filter_width, num_filters, stride_y, stride_x, name, padding='SAME', relu=True): 21 | input_channels = int(x.get_shape().as_list()[3]) 22 | convolve = lambda i, k: tf.nn.conv2d(i, k, strides = [1, stride_y, stride_x, 1], padding = padding) 23 | weights = get_scope_variable(name, 'weights', shape=[filter_height, filter_width, input_channels, num_filters]) 24 | biases = get_scope_variable(name, 'biases', shape = [num_filters]) 25 | conv = convolve(x, weights) 26 | bias_val = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv)) 27 | if relu == True: 28 | relu = tf.nn.relu(bias_val, name = name) 29 | return relu 30 | else: 31 | return bias_val 32 | 33 | 34 | def fc(x, num_in, num_out, name, relu = True): 35 | weights = get_scope_variable(name, 'weights', shape=[num_in, num_out]) 36 | biases = get_scope_variable(name, 'biases', shape=[num_out]) 37 | act = tf.nn.xw_plus_b(x, weights, biases, name=name) 38 | if relu == True: 39 | relu = tf.nn.relu(act) 40 | return relu 41 | else: 42 | return act 43 | 44 | 45 | def max_pool(x, filter_height, filter_width, stride_y, stride_x, name, padding='SAME'): 46 | return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], strides = [1, stride_y, stride_x, 1], padding = padding, name = name) 47 | 48 | 49 | def dropout(x, keep_prob): 50 | return tf.nn.dropout(x, keep_prob) 51 | 52 | 53 | def get_scope_variable(scope_name, var, shape=None, initialvals=None,trainable=False): 54 | with tf.variable_scope(scope_name,reuse=tf.AUTO_REUSE) as scope: 55 | v = tf.get_variable(var,shape,dtype=tf.float32, initializer=initialvals,trainable=trainable) 56 | return v 57 | 58 | 59 | -------------------------------------------------------------------------------- /options/test.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: your_experiment_name 3 | model_type: DiffMSRS2Model 4 | scale: 4 5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs 6 | manual_seed: 0 7 | 8 | 9 | 10 | # dataset and data loader settings 11 | datasets: 12 | 13 | test_1: 14 | name: Brain 15 | type: MCSRPairedDataset #PairedImageDataset 16 | dataroot_gt: mri_data_complex/mc_brain/valid 17 | dataroot_lq: mri_data_complex/mc_brain/valid 18 | dataroot_mask: "mri_data_complex/dc_mask.mat" 19 | filename_tmpl: '{}' 20 | io_backend: 21 | type: disk 22 | 23 | 24 | # network structures 25 | network_g: 26 | type: DiffMSR_S2 27 | n_encoder_res: 9 28 | inp_channels: 2 29 | out_channels: 2 30 | dim: 32 31 | num_blocks: [6,6,6,6] 32 | num_refinement_blocks: 6 33 | heads: [4,4,4,4] 34 | ffn_expansion_factor: 2.2 35 | bias: False 36 | LayerNorm_type: BiasFree 37 | n_denoise_res: 1 38 | linear_start: 0.1 39 | linear_end: 0.99 40 | timesteps: 4 41 | 42 | # network structures 43 | network_S1: 44 | type: DiffMSR_S1 45 | n_encoder_res: 9 46 | inp_channels: 2 47 | out_channels: 2 48 | dim: 32 49 | num_blocks: [6,6,6,6] 50 | num_refinement_blocks: 6 51 | heads: [4,4,4,4] 52 | ffn_expansion_factor: 2.2 53 | bias: False 54 | LayerNorm_type: BiasFree 55 | 56 | # path 57 | path: 58 | pretrain_network_S1: experiments/Stage_one/models/net_g_latest.pth 59 | pretrain_network_g: experiments/Stage_two/models/net_g_latest.pth 60 | param_key_g: params_ema 61 | strict_load_g: False 62 | ignore_resume_networks: network_S1 63 | 64 | train: 65 | ema_decay: 0.999 66 | 67 | optim_g: 68 | type: Adam 69 | lr: !!float 2e-4 70 | weight_decay: 0 71 | betas: [0.9, 0.99] 72 | 73 | scheduler: 74 | type: MultiStepLR 75 | milestones: [250000,400000, 450000, 475000] 76 | gamma: 0.5 77 | 78 | total_iter: 500000 79 | warmup_iter: -1 # no warm up 80 | 81 | encoder_iter: 0 82 | lr_encoder: !!float 2e-4 83 | lr_sr: !!float 2e-4 84 | gamma_encoder: 0.1 85 | gamma_sr: 0.5 86 | lr_decay_encoder: 30000 87 | lr_decay_sr: 300000 88 | 89 | val: 90 | window_size: 8 91 | save_img: True 92 | suffix: ~ # add suffix to saved images, if None, use exp name 93 | 94 | metrics: 95 | psnr: # metric name 96 | type: calculate_psnr 97 | crop_border: 0 98 | test_y_channel: true 99 | ssim: # metric name 100 | type: calculate_ssim 101 | crop_border: 0 102 | test_y_channel: true 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /options/train_DiffMSR_S1_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: your_experiment_name 3 | model_type: DiffMSRS1Model 4 | scale: 4 5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs 6 | manual_seed: 0 7 | 8 | gt_size: 256 9 | # dataset and data loader settings 10 | datasets: 11 | train: 12 | name: MCSR 13 | type: MCSRPairedDataset #PairedImageDataset 14 | dataroot_gt: mri_data_complex/mc_brain/train 15 | dataroot_lq: mri_data_complex/mc_brain/train 16 | dataroot_mask: "mri_data_complex/dc_mask.mat" 17 | io_backend: 18 | type: disk 19 | 20 | gt_size: 256 21 | use_hflip: true 22 | use_rot: true 23 | 24 | # data loader 25 | num_worker_per_gpu: 1 26 | batch_size_per_gpu: 1 27 | dataset_enlarge_ratio: 100 28 | prefetch_mode: ~ 29 | 30 | # Uncomment these for validation 31 | val_1: 32 | name: Brain 33 | type: MCSRPairedDataset #PairedImageDataset 34 | dataroot_gt: mri_data_complex/mc_brain/valid 35 | dataroot_lq: mri_data_complex/mc_brain/valid 36 | dataroot_mask: "mri_data_complex/dc_mask.mat" 37 | io_backend: 38 | type: disk 39 | 40 | # network structures 41 | network_g: 42 | type: DiffMSR_S1 43 | n_encoder_res: 9 44 | inp_channels: 2 45 | out_channels: 2 46 | dim: 32 47 | num_blocks: [6,6,6,6] 48 | num_refinement_blocks: 6 49 | heads: [4,4,4,4] 50 | ffn_expansion_factor: 2.2 51 | bias: False 52 | LayerNorm_type: BiasFree 53 | 54 | # path 55 | path: 56 | pretrain_network_g: ~ 57 | param_key_g: params_ema 58 | strict_load_g: true 59 | resume_state: ~ 60 | 61 | # training settings 62 | train: 63 | ema_decay: 0.999 64 | 65 | optim_g: 66 | type: Adam 67 | lr: !!float 1e-5 68 | weight_decay: 0 69 | betas: [0.9, 0.99] 70 | 71 | scheduler: 72 | type: MultiStepLR 73 | milestones: [250000,400000, 450000, 475000] 74 | gamma: 0.5 75 | 76 | total_iter: 500000 77 | warmup_iter: -1 # no warm up 78 | 79 | # losses 80 | pixel_opt: 81 | type: L1Loss 82 | loss_weight: 1.0 83 | reduction: mean 84 | 85 | # Uncomment these for validation 86 | # validation settings 87 | val: 88 | window_size: 8 89 | val_freq: !!float 5e3 90 | save_img: Ture 91 | 92 | metrics: 93 | psnr: # metric name 94 | type: calculate_psnr 95 | crop_border: 0 96 | test_y_channel: true 97 | ssim: # metric name 98 | type: calculate_ssim 99 | crop_border: 0 100 | test_y_channel: true 101 | 102 | # logging settings 103 | logger: 104 | print_freq: 1000 105 | save_checkpoint_freq: !!float 5e3 106 | use_tb_logger: true 107 | wandb: 108 | project: ~ 109 | resume_id: ~ 110 | 111 | # dist training settings 112 | dist_params: 113 | backend: nccl 114 | port: 29500 -------------------------------------------------------------------------------- /options/train_DiffMSR_S2_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: your_experiment_name 3 | model_type: DiffMSRS2Model 4 | scale: 4 5 | num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs 6 | manual_seed: 0 7 | 8 | gt_size: 256 9 | # dataset and data loader settings 10 | datasets: 11 | train: 12 | name: MCSR 13 | type: MCSRPairedDataset #PairedImageDataset 14 | dataroot_gt: mri_data_complex/mc_brain/train 15 | dataroot_lq: mri_data_complex/mc_brain/train 16 | dataroot_mask: "/mri_data_complex/dc_mask.mat" 17 | filename_tmpl: '{}' 18 | io_backend: 19 | type: disk 20 | # (for lmdb) 21 | # type: lmdb 22 | 23 | 24 | gt_size: 256 25 | use_hflip: true 26 | use_rot: true 27 | 28 | # data loader 29 | num_worker_per_gpu: 1 30 | batch_size_per_gpu: 1 31 | dataset_enlarge_ratio: 100 32 | prefetch_mode: ~ 33 | 34 | # Uncomment these for validation 35 | val_1: 36 | name: Brain 37 | type: MCSRPairedDataset #PairedImageDataset 38 | dataroot_gt: mri_data_complex/mc_brain/valid 39 | dataroot_lq: mri_data_complex/mc_brain/valid 40 | dataroot_mask: "mri_data_complex/dc_mask.mat" 41 | # filename_tmpl: '{}x4' 42 | io_backend: 43 | type: disk 44 | 45 | 46 | # network structures 47 | network_g: 48 | type: DiffMSR_S2 49 | n_encoder_res: 9 50 | inp_channels: 2 51 | out_channels: 2 52 | dim: 32 53 | num_blocks: [6,6,6,6] 54 | num_refinement_blocks: 6 55 | heads: [4,4,4,4] 56 | ffn_expansion_factor: 2.2 57 | bias: False 58 | LayerNorm_type: BiasFree 59 | n_denoise_res: 1 60 | linear_start: 0.1 61 | linear_end: 0.99 62 | timesteps: 8 63 | 64 | # network structures 65 | network_S1: 66 | type: DiffMSR_S1 67 | n_encoder_res: 9 68 | inp_channels: 2 69 | out_channels: 2 70 | dim: 32 71 | num_blocks: [6,6,6,6] 72 | num_refinement_blocks: 6 73 | heads: [4,4,4,4] 74 | ffn_expansion_factor: 2.2 75 | bias: False 76 | LayerNorm_type: BiasFree 77 | 78 | 79 | # path 80 | path: 81 | pretrain_network_g: experiments/Stage_one/models/net_g_*.pth 82 | pretrain_network_S1: experiments/Stage_one/models/net_g_*.pth 83 | param_key_g: params_ema 84 | strict_load_g: False 85 | resume_state: ~ 86 | ignore_resume_networks: network_S1 87 | 88 | # training settings 89 | train: 90 | ema_decay: 0.999 91 | 92 | optim_g: 93 | type: Adam 94 | lr: !!float 2e-4 95 | weight_decay: 0 96 | betas: [0.9, 0.99] 97 | 98 | scheduler: 99 | type: MultiStepLR 100 | milestones: [250000,400000, 450000, 475000] 101 | gamma: 0.5 102 | 103 | total_iter: 500000 104 | warmup_iter: -1 # no warm up 105 | 106 | encoder_iter: 0 107 | lr_encoder: !!float 2e-4 108 | lr_sr: !!float 2e-4 109 | gamma_encoder: 0.1 110 | gamma_sr: 0.5 111 | lr_decay_encoder: 30000 112 | lr_decay_sr: 300000 113 | 114 | # losses 115 | pixel_opt: 116 | type: L1Loss 117 | loss_weight: 1.0 118 | reduction: mean 119 | 120 | kd_opt: 121 | type: KDLoss 122 | loss_weight: 1 123 | temperature: 0.15 124 | 125 | # Uncomment these for validation 126 | # validation settings 127 | val: 128 | window_size: 8 129 | val_freq: !!float 5e3 130 | save_img: Ture 131 | 132 | metrics: 133 | psnr: # metric name 134 | type: calculate_psnr 135 | crop_border: 0 136 | test_y_channel: true 137 | ssim: # metric name 138 | type: calculate_ssim 139 | crop_border: 0 140 | test_y_channel: true 141 | 142 | # logging settings 143 | logger: 144 | print_freq: 1000 145 | save_checkpoint_freq: !!float 5e3 146 | use_tb_logger: true 147 | wandb: 148 | project: ~ 149 | resume_id: ~ 150 | 151 | # dist training settings 152 | dist_params: 153 | backend: nccl 154 | port: 29500 155 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | facexlib>=0.2.0.3 2 | gfpgan>=0.2.1 3 | numpy 4 | opencv-python 5 | Pillow 6 | torch>=1.8 7 | torchvision 8 | tqdm 9 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python3 DiffMSR_Main/test.py -opt options/test.yml -------------------------------------------------------------------------------- /train_S1.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python DiffMSR_Main/train.py -opt options/train_DiffMSRS1_x4.yml 2 | -------------------------------------------------------------------------------- /train_S2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python DiffMSR_Main/train.py -opt options/train_DiffMSRS2_x4.yml 2 | --------------------------------------------------------------------------------