├── utils ├── __init__.py ├── util2.py └── util.py ├── figs ├── process.jpg ├── aniso-quan.jpg └── iso-quan.jpg ├── data ├── benchmark.py ├── df2k.py ├── common.py ├── __init__.py └── multiscalesrdata.py ├── main_stage1.py ├── main_stage1.sh ├── main_stage2.py ├── main_stage3.py ├── main_stage4.py ├── test_iso_stage4.sh ├── test_anisoAnoise_stage4.sh ├── loss ├── vgg.py ├── discriminator.py ├── adversarial.py └── __init__.py ├── template.py ├── model_ST ├── common.py ├── __init__.py └── blindsr.py ├── model_TA ├── common.py ├── __init__.py └── blindsr.py ├── model ├── blindsr.py ├── common.py └── __init__.py ├── model_meta_stage3 ├── blindsr.py ├── common.py └── __init__.py ├── main_stage2.sh ├── main_stage3.sh ├── main_stage4.sh ├── model_meta ├── common.py ├── blindsr.py └── __init__.py ├── README.md ├── dataloader.py ├── trainer_stage1.py ├── utility2.py ├── utility.py ├── option.py ├── trainer_stage3.py ├── trainer_stage4.py └── trainer_stage2.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/process.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zj-BinXia/MRDA/HEAD/figs/process.jpg -------------------------------------------------------------------------------- /figs/aniso-quan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zj-BinXia/MRDA/HEAD/figs/aniso-quan.jpg -------------------------------------------------------------------------------- /figs/iso-quan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zj-BinXia/MRDA/HEAD/figs/iso-quan.jpg -------------------------------------------------------------------------------- /data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import common 3 | from data import multiscalesrdata as srdata 4 | 5 | 6 | class Benchmark(srdata.SRData): 7 | def __init__(self, args, name='', train=True): 8 | super(Benchmark, self).__init__( 9 | args, name=name, train=train, benchmark=True 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data,'benchmark', self.name) 14 | self.dir_hr = os.path.join(self.apath, 'HR') 15 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 16 | self.ext = ('.png','.png') 17 | print(self.dir_hr) 18 | print(self.dir_lr) 19 | -------------------------------------------------------------------------------- /main_stage1.py: -------------------------------------------------------------------------------- 1 | from option import args 2 | import torch 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from trainer_stage1 import Trainer 8 | 9 | 10 | if __name__ == '__main__': 11 | torch.manual_seed(args.seed) 12 | checkpoint = utility.checkpoint(args) 13 | if checkpoint.ok: 14 | loader = data.Data(args) 15 | model = model.Model(args, checkpoint) 16 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 17 | t = Trainer(args, loader, model, loss, checkpoint) 18 | while not t.terminate(): 19 | t.train() 20 | t.test() 21 | 22 | 23 | checkpoint.done() 24 | -------------------------------------------------------------------------------- /main_stage1.sh: -------------------------------------------------------------------------------- 1 | ## noise-free degradations with isotropic Gaussian blurs or anisotropic + noise 2 | # using oracle degradation training 3 | CUDA_VISIBLE_DEVICES=0 python3 main_stage1.py --dir_data='/root/datasets' \ 4 | --model='blindsr' \ 5 | --scale='4' \ 6 | --n_GPUs=1 \ 7 | --epochs_encoder 0 \ 8 | --epochs_sr 600 \ 9 | --data_test Set14 \ 10 | --st_save_epoch 590 \ 11 | --n_feats 128 \ 12 | --batch_size 64 \ 13 | --patch_size 64 \ 14 | --data_train DF2K \ 15 | --save stage1 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /main_stage2.py: -------------------------------------------------------------------------------- 1 | from option import args 2 | import torch 3 | import utility 4 | import data 5 | import model_meta 6 | import model 7 | import loss 8 | from trainer_stage2 import Trainer 9 | 10 | 11 | if __name__ == '__main__': 12 | torch.manual_seed(args.seed) 13 | checkpoint = utility.checkpoint(args) 14 | if checkpoint.ok: 15 | loader = data.Data(args) 16 | model_meta = model_meta.Model(args, checkpoint) 17 | model_meta_copy = model.Model(args, checkpoint) 18 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 19 | t = Trainer(args, loader, model_meta, model_meta_copy, loss, checkpoint) 20 | while not t.terminate(): 21 | t.train() 22 | t.test() 23 | 24 | 25 | checkpoint.done() -------------------------------------------------------------------------------- /main_stage3.py: -------------------------------------------------------------------------------- 1 | from option import args 2 | import torch 3 | import utility 4 | import data 5 | import model_TA 6 | import model_meta_stage3 7 | import loss 8 | from trainer_stage3 import Trainer 9 | 10 | 11 | def count_param(model): 12 | param_count = 0 13 | for param in model.parameters(): 14 | param_count += param.view(-1).size()[0] 15 | return param_count 16 | 17 | if __name__ == '__main__': 18 | torch.manual_seed(args.seed) 19 | checkpoint = utility.checkpoint(args) 20 | if checkpoint.ok: 21 | loader = data.Data(args) 22 | model_TA = model_TA.Model(args, checkpoint) 23 | print(count_param(model_TA)) 24 | model_meta_copy = model_meta_stage3.Model(args, checkpoint) 25 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 26 | t = Trainer(args, loader, model_meta_copy, model_TA, loss, checkpoint) 27 | while not t.terminate(): 28 | t.train() 29 | t.test() 30 | 31 | 32 | checkpoint.done() -------------------------------------------------------------------------------- /main_stage4.py: -------------------------------------------------------------------------------- 1 | from option import args 2 | import torch 3 | import utility 4 | import data 5 | import model_TA 6 | import model_ST 7 | import model_meta_stage3 8 | import loss 9 | from trainer_stage4 import Trainer 10 | 11 | 12 | # def count_param(model): 13 | # param_count = 0 14 | # for param in model.parameters(): 15 | # param_count += param.view(-1).size()[0] 16 | # return param_count 17 | 18 | if __name__ == '__main__': 19 | torch.manual_seed(args.seed) 20 | checkpoint = utility.checkpoint(args) 21 | if checkpoint.ok: 22 | loader = data.Data(args) 23 | model_TA = model_TA.Model(args, checkpoint) 24 | model_meta_copy = model_meta_stage3.Model(args, checkpoint) 25 | model_ST = model_ST.Model(args, checkpoint) 26 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 27 | t = Trainer(args, loader, model_ST, model_TA, model_meta_copy, loss, checkpoint) 28 | while not t.terminate(): 29 | t.train() 30 | t.test() 31 | 32 | 33 | 34 | checkpoint.done() -------------------------------------------------------------------------------- /data/df2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import multiscalesrdata 3 | 4 | 5 | class DF2K(multiscalesrdata.SRData): 6 | def __init__(self, args, name='DF2K', train=True, benchmark=False): 7 | data_range = [r.split('-') for r in args.data_range.split('/')] 8 | if train: 9 | data_range = data_range[0] 10 | else: 11 | if args.test_only and len(data_range) == 1: 12 | data_range = data_range[0] 13 | else: 14 | data_range = data_range[1] 15 | 16 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 17 | super(DF2K, self).__init__(args, name=name, train=train, benchmark=benchmark) 18 | 19 | def _scan(self): 20 | names_hr = super(DF2K, self)._scan() 21 | names_hr = names_hr[self.begin - 1:self.end] 22 | 23 | return names_hr 24 | 25 | def _set_filesystem(self, dir_data): 26 | super(DF2K, self)._set_filesystem(dir_data) 27 | self.dir_hr = os.path.join(self.apath, 'HR') 28 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 29 | 30 | -------------------------------------------------------------------------------- /test_iso_stage4.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 main_stage4.py --dir_data='/mnt/bn/xiabinsr/datasets' \ 2 | --model='blindsr' \ 3 | --scale='4' \ 4 | --blur_type='iso_gaussian' \ 5 | --noise=0.0 \ 6 | --sig_min=0.2 \ 7 | --sig_max=4.0 \ 8 | --sig 3.6 \ 9 | --save ours_iso_36\ 10 | --n_GPUs=1 \ 11 | --epochs_encoder 100 \ 12 | --epochs_sr 500 \ 13 | --data_test Set14 \ 14 | --st_save_epoch 480 \ 15 | --n_feats 128 \ 16 | --patch_size 48 \ 17 | --task_iter 5 \ 18 | --test_iter 5 \ 19 | --meta_batch_size 5 \ 20 | --batch_size 16 \ 21 | --lr_sr 1e-4 \ 22 | --lr_task 1e-2 \ 23 | --lr_encoder 1e-4 \ 24 | --pre_train_ST="./experiment/iso_ST.pt" \ 25 | --pre_train="./experiment/iso_ST.pt" \ 26 | --save_results False \ 27 | --resume 0 \ 28 | --test_only -------------------------------------------------------------------------------- /test_anisoAnoise_stage4.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python3 main_stage4.py --dir_data='/mnt/bn/xiabinsr/datasets' \ 2 | --model='blindsr' \ 3 | --scale='4' \ 4 | --blur_type='aniso_gaussian' \ 5 | --noise=10.0 \ 6 | --lambda_1 3.5 \ 7 | --lambda_2 1.5 \ 8 | --theta 20 \ 9 | --save ours_iso_36\ 10 | --n_GPUs=1 \ 11 | --epochs_encoder 100 \ 12 | --epochs_sr 500 \ 13 | --data_test Set14 \ 14 | --st_save_epoch 480 \ 15 | --n_feats 128 \ 16 | --patch_size 48 \ 17 | --task_iter 5 \ 18 | --test_iter 5 \ 19 | --meta_batch_size 5 \ 20 | --batch_size 16 \ 21 | --lr_sr 1e-4 \ 22 | --lr_task 1e-2 \ 23 | --lr_encoder 1e-4 \ 24 | --pre_train_ST="./experiment/aniso_noise_ST.pt" \ 25 | --pre_train="./experiment/aniso_noise_ST.pt" \ 26 | --save_results False\ 27 | --resume 0 \ 28 | --test_only -------------------------------------------------------------------------------- /loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | from torch.autograd import Variable 8 | 9 | class VGG(nn.Module): 10 | def __init__(self, conv_index, rgb_range=1): 11 | super(VGG, self).__init__() 12 | vgg_features = models.vgg19(pretrained=True).features 13 | modules = [m for m in vgg_features] 14 | if conv_index == '22': 15 | self.vgg = nn.Sequential(*modules[:8]) 16 | elif conv_index == '54': 17 | self.vgg = nn.Sequential(*modules[:35]) 18 | 19 | vgg_mean = (0.485, 0.456, 0.406) 20 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 21 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 22 | self.vgg.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.lr_decay = 100 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.lr_decay = 500 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.lr_decay = 150 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | -------------------------------------------------------------------------------- /loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | def __init__(self, args, gan_type='GAN'): 7 | super(Discriminator, self).__init__() 8 | 9 | in_channels = 3 10 | out_channels = 64 11 | depth = 7 12 | #bn = not gan_type == 'WGAN_GP' 13 | bn = True 14 | act = nn.LeakyReLU(negative_slope=0.2, inplace=True) 15 | 16 | m_features = [ 17 | common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act) 18 | ] 19 | for i in range(depth): 20 | in_channels = out_channels 21 | if i % 2 == 1: 22 | stride = 1 23 | out_channels *= 2 24 | else: 25 | stride = 2 26 | m_features.append(common.BasicBlock( 27 | in_channels, out_channels, 3, stride=stride, bn=bn, act=act 28 | )) 29 | 30 | self.features = nn.Sequential(*m_features) 31 | 32 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 33 | m_classifier = [ 34 | nn.Linear(out_channels * patch_size**2, 1024), 35 | act, 36 | nn.Linear(1024, 1) 37 | ] 38 | self.classifier = nn.Sequential(*m_classifier) 39 | 40 | def forward(self, x): 41 | features = self.features(x) 42 | output = self.classifier(features.view(features.size(0), -1)) 43 | 44 | return output 45 | 46 | -------------------------------------------------------------------------------- /model_ST/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 9 | 10 | 11 | class MeanShift(nn.Conv2d): 12 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 13 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 14 | std = torch.Tensor(rgb_std) 15 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 16 | self.weight.data.div_(std.view(3, 1, 1, 1)) 17 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 18 | self.bias.data.div_(std) 19 | self.weight.requires_grad = False 20 | self.bias.requires_grad = False 21 | 22 | 23 | class Upsampler(nn.Sequential): 24 | def __init__(self, conv, scale, n_feat, act=False, bias=True): 25 | m = [] 26 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 27 | for _ in range(int(math.log(scale, 2))): 28 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 29 | m.append(nn.PixelShuffle(2)) 30 | if act: m.append(act()) 31 | elif scale == 3: 32 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 33 | m.append(nn.PixelShuffle(3)) 34 | if act: m.append(act()) 35 | else: 36 | raise NotImplementedError 37 | 38 | super(Upsampler, self).__init__(*m) 39 | 40 | -------------------------------------------------------------------------------- /model_TA/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 9 | 10 | 11 | class MeanShift(nn.Conv2d): 12 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 13 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 14 | std = torch.Tensor(rgb_std) 15 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 16 | self.weight.data.div_(std.view(3, 1, 1, 1)) 17 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 18 | self.bias.data.div_(std) 19 | self.weight.requires_grad = False 20 | self.bias.requires_grad = False 21 | 22 | 23 | class Upsampler(nn.Sequential): 24 | def __init__(self, conv, scale, n_feat, act=False, bias=True): 25 | m = [] 26 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 27 | for _ in range(int(math.log(scale, 2))): 28 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 29 | m.append(nn.PixelShuffle(2)) 30 | if act: m.append(act()) 31 | elif scale == 3: 32 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 33 | m.append(nn.PixelShuffle(3)) 34 | if act: m.append(act()) 35 | else: 36 | raise NotImplementedError 37 | 38 | super(Upsampler, self).__init__(*m) 39 | 40 | -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | 9 | def get_patch(hr, patch_size=48, scale=2, multi=False, input_large=False): 10 | ih, iw = hr.shape[:2] 11 | 12 | ip = scale * patch_size 13 | 14 | ix = random.randrange(0, iw - ip + 1) 15 | iy = random.randrange(0, ih - ip + 1) 16 | 17 | hr = hr[int(iy):int(iy + ip), int(ix):int(ix + ip), :] 18 | 19 | return hr 20 | 21 | 22 | def set_channel(hr, n_channels=3): 23 | def _set_channel(img): 24 | if img.ndim == 2: 25 | img = np.expand_dims(img, axis=2) 26 | 27 | c = img.shape[2] 28 | if n_channels == 1 and c == 3: 29 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 30 | elif n_channels == 3 and c == 1: 31 | img = np.concatenate([img] * n_channels, 2) 32 | 33 | return img 34 | 35 | return _set_channel(hr) 36 | 37 | 38 | def np2Tensor(hr, rgb_range=255): 39 | def _np2Tensor(img): 40 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 41 | tensor = torch.from_numpy(np_transpose).float() 42 | tensor.mul_(rgb_range / 255) 43 | 44 | return tensor 45 | 46 | return _np2Tensor(hr) 47 | 48 | 49 | def augment(hr, hflip=True, rot=True): 50 | hflip = hflip and random.random() < 0.5 51 | vflip = rot and random.random() < 0.5 52 | rot90 = rot and random.random() < 0.5 53 | 54 | def _augment(img): 55 | if hflip: img = img[:, ::-1, :] 56 | if vflip: img = img[::-1, :, :] 57 | if rot90: img = img.transpose(1, 0, 2) 58 | 59 | return img 60 | 61 | return _augment(hr) 62 | 63 | -------------------------------------------------------------------------------- /model/blindsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import model.common as common 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_model(args): 8 | return MLN(args) 9 | 10 | 11 | class MLN(nn.Module): 12 | def __init__(self, args, conv=common.default_conv): 13 | super(MLN, self).__init__() 14 | 15 | n_feats = args.n_feats 16 | kernel_size = 3 17 | scale = args.scale[0] 18 | act = nn.LeakyReLU(0.1, True) 19 | self.head = nn.Sequential( 20 | nn.Conv2d(3, n_feats, kernel_size=3, padding=1), 21 | act 22 | ) 23 | m_body = [ 24 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 25 | act, 26 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 27 | act, 28 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 29 | act, 30 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 31 | act, 32 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 33 | act, 34 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 35 | act 36 | ] 37 | m_tail = [ 38 | common.Upsampler(conv, scale, n_feats, act=False), 39 | nn.Conv2d( 40 | n_feats, args.n_colors, kernel_size, 41 | padding=(kernel_size // 2) 42 | ) 43 | ] 44 | self.tail = nn.Sequential(*m_tail) 45 | 46 | self.body = nn.Sequential(*m_body) 47 | 48 | 49 | def forward(self, lr,lr_bic): 50 | res = self.head(lr) 51 | res = self.body(res) 52 | res = self.tail(res) 53 | res +=lr_bic 54 | 55 | return res 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /model_meta_stage3/blindsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import model.common as common 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_model(args): 8 | return MLN(args) 9 | 10 | 11 | class MLN(nn.Module): 12 | def __init__(self, args, conv=common.default_conv): 13 | super(MLN, self).__init__() 14 | 15 | n_feats = args.n_feats 16 | kernel_size = 3 17 | scale = args.scale[0] 18 | act = nn.LeakyReLU(0.1, True) 19 | self.head = nn.Sequential( 20 | nn.Conv2d(3, n_feats, kernel_size=3, padding=1), 21 | act 22 | ) 23 | m_body = [ 24 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 25 | act, 26 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 27 | act, 28 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 29 | act, 30 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 31 | act, 32 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 33 | act, 34 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 35 | act 36 | ] 37 | m_tail = [ 38 | common.Upsampler(conv, scale, n_feats, act=False), 39 | nn.Conv2d( 40 | n_feats, args.n_colors, kernel_size, 41 | padding=(kernel_size // 2) 42 | ) 43 | ] 44 | self.tail = nn.Sequential(*m_tail) 45 | 46 | self.body = nn.Sequential(*m_body) 47 | 48 | self.ad = nn.AdaptiveAvgPool2d(1) 49 | 50 | 51 | def forward(self, lr,lr_bic): 52 | res = self.head(lr) 53 | res = self.body(res) 54 | deg_repre = res 55 | res = self.tail[0](res) 56 | res = self.tail[1](res) 57 | res +=lr_bic 58 | 59 | 60 | return res, deg_repre 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | print(d) 23 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 24 | m = import_module('data.' + module_name.lower()) 25 | datasets.append(getattr(m, module_name)(args, name=d)) 26 | 27 | self.loader_train = dataloader.DataLoader( 28 | MyConcatDataset(datasets), 29 | batch_size=args.batch_size, 30 | shuffle=True, 31 | pin_memory=not args.cpu, 32 | num_workers=args.n_threads, 33 | ) 34 | 35 | self.loader_test = [] 36 | for d in args.data_test: 37 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 38 | m = import_module('data.benchmark') 39 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 40 | else: 41 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 42 | m = import_module('data.' + module_name.lower()) 43 | testset = getattr(m, module_name)(args, train=False, name=d) 44 | 45 | self.loader_test.append( 46 | dataloader.DataLoader( 47 | testset, 48 | batch_size=1, 49 | shuffle=False, 50 | pin_memory=not args.cpu, 51 | num_workers=args.n_threads, 52 | ) 53 | ) 54 | 55 | -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 9 | 10 | class ResBlock(nn.Module): 11 | def __init__( 12 | self, conv, n_feats, kernel_size, 13 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 14 | 15 | super(ResBlock, self).__init__() 16 | m = [] 17 | for i in range(2): 18 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 19 | if bn: 20 | m.append(nn.BatchNorm2d(n_feats)) 21 | if i == 0: 22 | m.append(act) 23 | 24 | self.body = nn.Sequential(*m) 25 | self.res_scale = res_scale 26 | 27 | def forward(self, x): 28 | res = self.body(x).mul(self.res_scale) 29 | res += x 30 | 31 | return res 32 | 33 | class MeanShift(nn.Conv2d): 34 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 35 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 36 | std = torch.Tensor(rgb_std) 37 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 38 | self.weight.data.div_(std.view(3, 1, 1, 1)) 39 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 40 | self.bias.data.div_(std) 41 | self.weight.requires_grad = False 42 | self.bias.requires_grad = False 43 | 44 | 45 | class Upsampler(nn.Sequential): 46 | def __init__(self, conv, scale, n_feat, act=False, bias=True): 47 | m = [] 48 | if (int(scale) & (int(scale) - 1)) == 0: # Is scale = 2^n? 49 | for _ in range(int(math.log(scale, 2))): 50 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 51 | m.append(nn.PixelShuffle(2)) 52 | if act: m.append(act()) 53 | elif scale == 3: 54 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 55 | m.append(nn.PixelShuffle(3)) 56 | if act: m.append(act()) 57 | else: 58 | raise NotImplementedError 59 | 60 | super(Upsampler, self).__init__(*m) 61 | 62 | -------------------------------------------------------------------------------- /model_meta_stage3/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 9 | 10 | class ResBlock(nn.Module): 11 | def __init__( 12 | self, conv, n_feats, kernel_size, 13 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 14 | 15 | super(ResBlock, self).__init__() 16 | m = [] 17 | for i in range(2): 18 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 19 | if bn: 20 | m.append(nn.BatchNorm2d(n_feats)) 21 | if i == 0: 22 | m.append(act) 23 | 24 | self.body = nn.Sequential(*m) 25 | self.res_scale = res_scale 26 | 27 | def forward(self, x): 28 | res = self.body(x).mul(self.res_scale) 29 | res += x 30 | 31 | return res 32 | 33 | class MeanShift(nn.Conv2d): 34 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 35 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 36 | std = torch.Tensor(rgb_std) 37 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 38 | self.weight.data.div_(std.view(3, 1, 1, 1)) 39 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 40 | self.bias.data.div_(std) 41 | self.weight.requires_grad = False 42 | self.bias.requires_grad = False 43 | 44 | 45 | class Upsampler(nn.Sequential): 46 | def __init__(self, conv, scale, n_feat, act=False, bias=True): 47 | m = [] 48 | if (int(scale) & (int(scale) - 1)) == 0: # Is scale = 2^n? 49 | for _ in range(int(math.log(scale, 2))): 50 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 51 | m.append(nn.PixelShuffle(2)) 52 | if act: m.append(act()) 53 | elif scale == 3: 54 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 55 | m.append(nn.PixelShuffle(3)) 56 | if act: m.append(act()) 57 | else: 58 | raise NotImplementedError 59 | 60 | super(Upsampler, self).__init__(*m) 61 | 62 | -------------------------------------------------------------------------------- /main_stage2.sh: -------------------------------------------------------------------------------- 1 | ## noise-free degradations with isotropic Gaussian blurs 2 | # training knowledge distillation 3 | CUDA_VISIBLE_DEVICES=0 python3 main_stage2.py --dir_data='/root/datasets' \ 4 | --model='blindsr' \ 5 | --scale='4' \ 6 | --blur_type='iso_gaussian' \ 7 | --noise=0.0 \ 8 | --sig_min=0.2 \ 9 | --sig_max=4.0 \ 10 | --sig 2.6 \ 11 | --n_GPUs=1 \ 12 | --epochs_encoder 0 \ 13 | --epochs_sr 600 \ 14 | --data_test Set5 \ 15 | --st_save_epoch 95 \ 16 | --n_feats 128 \ 17 | --patch_size 64 \ 18 | --task_iter 5 \ 19 | --test_iter 5 \ 20 | --meta_batch_size 5 \ 21 | --task_batch_size 16 \ 22 | --lr_sr 1e-4 \ 23 | --lr_task 1e-2 \ 24 | --pre_train="./experiment/stage1.pt" \ 25 | --resume 0 \ 26 | --test_every 240 \ 27 | --print_every 40 \ 28 | --lr_decay_sr 150 \ 29 | --data_train DF2K \ 30 | --save stage2 31 | 32 | #--batch_size 32 \ 33 | 34 | # anisotropic + noise 35 | # CUDA_VISIBLE_DEVICES=0 python3 main_stage2.py --dir_data='/root/datasets' \ 36 | # --model='blindsr' \ 37 | # --scale='4' \ 38 | # --blur_type='aniso_gaussian' \ 39 | # --noise=25.0 \ 40 | # --lambda_min=0.2 \ 41 | # --lambda_max=4.0 \ 42 | # --n_GPUs=1 \ 43 | # --epochs_encoder 0 \ 44 | # --epochs_sr 600 \ 45 | # --data_test Set5 \ 46 | # --st_save_epoch 95 \ 47 | # --n_feats 128 \ 48 | # --patch_size 64 \ 49 | # --task_iter 5 \ 50 | # --test_iter 5 \ 51 | # --meta_batch_size 5 \ 52 | # --task_batch_size 16 \ 53 | # --lr_sr 1e-4 \ 54 | # --lr_task 1e-2 \ 55 | # --pre_train="./experiment/stage1.pt" \ 56 | # --resume 0 \ 57 | # --test_every 240 \ 58 | # --print_every 40 \ 59 | # --lr_decay_sr 150 \ 60 | # --data_train DF2K \ 61 | # --save stage2 -------------------------------------------------------------------------------- /main_stage3.sh: -------------------------------------------------------------------------------- 1 | ## noise-free degradations with isotropic Gaussian blurs 2 | # training knowledge distillation 3 | CUDA_VISIBLE_DEVICES=0 python3 main_stage3.py --dir_data='/root/datasets' \ 4 | --model='blindsr' \ 5 | --scale='4' \ 6 | --blur_type='iso_gaussian' \ 7 | --noise=0.0 \ 8 | --sig_min=0.2 \ 9 | --sig_max=4.0 \ 10 | --sig 2.6 \ 11 | --n_GPUs=1 \ 12 | --epochs_encoder 0 \ 13 | --epochs_sr 500 \ 14 | --data_test Set5 \ 15 | --st_save_epoch 480 \ 16 | --n_feats 128 \ 17 | --patch_size 64 \ 18 | --task_iter 5 \ 19 | --test_iter 5 \ 20 | --meta_batch_size 5 \ 21 | --batch_size 16 \ 22 | --lr_sr 1e-4 \ 23 | --lr_task 1e-2 \ 24 | --pre_train_meta="./experiment/stage2.pt" \ 25 | --resume 0 \ 26 | --test_every 1500 \ 27 | --print_every 300 \ 28 | --lr_decay_sr 125 \ 29 | --data_train DF2K \ 30 | --save stage3 31 | 32 | #--batch_size 32 \ 33 | 34 | # anisotropic + noise 35 | # CUDA_VISIBLE_DEVICES=0 python3 main_stage3.py --dir_data='/root/datasets' \ 36 | # --model='blindsr' \ 37 | # --scale='4' \ 38 | # --blur_type='aniso_gaussian' \ 39 | # --noise=25.0 \ 40 | # --lambda_min=0.2 \ 41 | # --lambda_max=4.0 \ 42 | # --n_GPUs=1 \ 43 | # --epochs_encoder 0 \ 44 | # --epochs_sr 500 \ 45 | # --data_test Set14 \ 46 | # --st_save_epoch 480 \ 47 | # --n_feats 128 \ 48 | # --patch_size 64 \ 49 | # --task_iter 5 \ 50 | # --test_iter 5 \ 51 | # --meta_batch_size 5 \ 52 | # --batch_size 16 \ 53 | # --lr_sr 1e-4 \ 54 | # --lr_task 1e-2 \ 55 | # --pre_train_meta="./experiment/stage2.pt" \ 56 | # --resume 0 \ 57 | # --test_every 1500 \ 58 | # --print_every 300 \ 59 | # --lr_decay_sr 125 \ 60 | # --data_train DF2K \ 61 | # --save stage3 62 | -------------------------------------------------------------------------------- /main_stage4.sh: -------------------------------------------------------------------------------- 1 | ## noise-free degradations with isotropic Gaussian blurs 2 | # training knowledge distillation 3 | CUDA_VISIBLE_DEVICES=2 python3 main_stage4.py --dir_data='/root/datasets' \ 4 | --model='blindsr' \ 5 | --scale='4' \ 6 | --blur_type='iso_gaussian' \ 7 | --noise=0.0 \ 8 | --sig_min=0.2 \ 9 | --sig_max=4.0 \ 10 | --sig 2.6 \ 11 | --n_GPUs=1 \ 12 | --epochs_encoder 100 \ 13 | --epochs_sr 500 \ 14 | --data_test Set5 \ 15 | --st_save_epoch 480 \ 16 | --n_feats 128 \ 17 | --patch_size 64 \ 18 | --task_iter 5 \ 19 | --test_iter 5 \ 20 | --meta_batch_size 5 \ 21 | --batch_size 16 \ 22 | --lr_sr 1e-4 \ 23 | --lr_task 1e-2 \ 24 | --lr_encoder 1e-4 \ 25 | --pre_train_TA="./experiment/stage3.pt" \ 26 | --pre_train_ST="./experiment/stage3.pt" \ 27 | --pre_train_meta="./experiment/stage2.pt" \ 28 | --resume 0 \ 29 | --test_every 800 \ 30 | --print_every 250 \ 31 | --lr_decay_sr 125 \ 32 | --data_train DF2K \ 33 | --save stage4 34 | 35 | 36 | #anisotropic Gaussian kenrnel + noise 37 | # CUDA_VISIBLE_DEVICES=0 python3 main_stage4.py --dir_data='/root/datasets' \ 38 | # --model='blindsr' \ 39 | # --scale='4' \ 40 | # --blur_type='aniso_gaussian' \ 41 | # --noise=25.0 \ 42 | # --lambda_min=0.2 \ 43 | # --lambda_max=4.0 \ 44 | # --n_GPUs=1 \ 45 | # --epochs_encoder 100 \ 46 | # --epochs_sr 500 \ 47 | # --data_test Set5 \ 48 | # --st_save_epoch 480 \ 49 | # --n_feats 128 \ 50 | # --patch_size 64 \ 51 | # --task_iter 5 \ 52 | # --test_iter 5 \ 53 | # --meta_batch_size 5 \ 54 | # --batch_size 16 \ 55 | # --lr_sr 1e-4 \ 56 | # --lr_task 1e-2 \ 57 | # --lr_encoder 1e-4 \ 58 | # --pre_train_TA="./experiment/stage3.pt" \ 59 | # --pre_train_ST="./experiment/stage3.pt" \ 60 | # --pre_train_meta="./experiment/stage2.pt" \ 61 | # --resume 0 \ 62 | # --test_every 800 \ 63 | # --print_every 250 \ 64 | # --lr_decay_sr 125 \ 65 | # --data_train DF2K \ 66 | # --save stage4 -------------------------------------------------------------------------------- /model_meta/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | def linear(input, weight, bias=None): 8 | if bias is None: 9 | return F.linear(input, weight) 10 | else: 11 | return F.linear(input, weight, bias) 12 | 13 | def conv2d(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1): 14 | return F.conv2d(input, weight, bias, stride, padding, dilation, groups) 15 | 16 | def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1): 17 | ''' momentum = 1 restricts stats to the current mini-batch ''' 18 | # This hack only works when momentum is 1 and avoids needing to track running stats 19 | # by substuting dummy variables 20 | # running_mean = torch.zeros(np.prod(np.array(input.data.size()[1]))).cuda() 21 | # running_var = torch.ones(np.prod(np.array(input.data.size()[1]))).cuda() 22 | return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps) 23 | 24 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 25 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) 26 | 27 | class ResBlock(nn.Module): 28 | def __init__( 29 | self, conv, n_feats, kernel_size, 30 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 31 | 32 | super(ResBlock, self).__init__() 33 | m = [] 34 | for i in range(2): 35 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 36 | if bn: 37 | m.append(nn.BatchNorm2d(n_feats)) 38 | if i == 0: 39 | m.append(act) 40 | 41 | self.body = nn.Sequential(*m) 42 | self.res_scale = res_scale 43 | 44 | def forward(self, x): 45 | res = self.body(x).mul(self.res_scale) 46 | res += x 47 | 48 | return res 49 | 50 | class MeanShift(nn.Conv2d): 51 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 52 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 53 | std = torch.Tensor(rgb_std) 54 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 55 | self.weight.data.div_(std.view(3, 1, 1, 1)) 56 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 57 | self.bias.data.div_(std) 58 | self.weight.requires_grad = False 59 | self.bias.requires_grad = False 60 | 61 | 62 | class Upsampler(nn.Sequential): 63 | def __init__(self, conv, scale, n_feat, act=False, bias=True): 64 | m = [] 65 | if (int(scale) & (int(scale) - 1)) == 0: # Is scale = 2^n? 66 | for _ in range(int(math.log(scale, 2))): 67 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 68 | m.append(nn.PixelShuffle(2)) 69 | if act: m.append(act()) 70 | elif scale == 3: 71 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 72 | m.append(nn.PixelShuffle(3)) 73 | if act: m.append(act()) 74 | else: 75 | raise NotImplementedError 76 | 77 | super(Upsampler, self).__init__(*m) 78 | 79 | -------------------------------------------------------------------------------- /model_meta/blindsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import model_meta.common as common 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_model(args): 8 | return MLN(args) 9 | 10 | 11 | class MLN(nn.Module): 12 | def __init__(self, args, conv=common.default_conv): 13 | super(MLN, self).__init__() 14 | 15 | n_feats = args.n_feats 16 | kernel_size = 3 17 | scale = args.scale[0] 18 | act = nn.LeakyReLU(0.1, True) 19 | self.head = nn.Sequential( 20 | nn.Conv2d(3, n_feats, kernel_size=3, padding=1), 21 | act 22 | ) 23 | m_body = [ 24 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 25 | act, 26 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 27 | act, 28 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 29 | act, 30 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 31 | act, 32 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 33 | act, 34 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1), 35 | act 36 | ] 37 | m_tail = [ 38 | common.Upsampler(conv, scale, n_feats, act=False), 39 | nn.Conv2d( 40 | n_feats, args.n_colors, kernel_size, 41 | padding=(kernel_size // 2) 42 | ) 43 | ] 44 | self.tail = nn.Sequential(*m_tail) 45 | 46 | self.body = nn.Sequential(*m_body) 47 | 48 | 49 | def forward(self, lr,lr_bic,weights,base=''): 50 | #************************head******************** 51 | res = common.conv2d(lr, weights[base + 'head.0.weight'], weights[base + 'head.0.bias'], stride=1, padding=1) 52 | res = F.leaky_relu(res, 0.1, True) 53 | #************************body******************** 54 | res = common.conv2d(res, weights[base + 'body.0.weight'], weights[base + 'body.0.bias'], stride=1, padding=1) 55 | res = F.leaky_relu(res, 0.1, True) 56 | res = common.conv2d(res, weights[base + 'body.2.weight'], weights[base + 'body.2.bias'], stride=1, padding=1) 57 | res = F.leaky_relu(res, 0.1, True) 58 | res = common.conv2d(res, weights[base + 'body.4.weight'], weights[base + 'body.4.bias'], stride=1, padding=1) 59 | res = F.leaky_relu(res, 0.1, True) 60 | res = common.conv2d(res, weights[base + 'body.6.weight'], weights[base + 'body.6.bias'], stride=1, padding=1) 61 | res = F.leaky_relu(res, 0.1, True) 62 | res = common.conv2d(res, weights[base + 'body.8.weight'], weights[base + 'body.8.bias'], stride=1, padding=1) 63 | res = F.leaky_relu(res, 0.1, True) 64 | res = common.conv2d(res, weights[base + 'body.10.weight'], weights[base + 'body.10.bias'], stride=1, padding=1) 65 | res = F.leaky_relu(res, 0.1, True) 66 | #**********************tailx4*********************** 67 | res = common.conv2d(res, weights[base +"tail.0.0.weight"], weights[base +"tail.0.0.bias"], stride=1, padding=1) 68 | res = F.pixel_shuffle(res, 2) 69 | res = common.conv2d(res, weights[base +"tail.0.2.weight"], weights[base +"tail.0.2.bias"], stride=1, padding=1) 70 | res = F.pixel_shuffle(res, 2) 71 | res = common.conv2d(res, weights[base + "tail.1.weight"], weights[base + "tail.1.bias"], stride=1, padding=1) 72 | # res = self.tail(res) 73 | res +=lr_bic 74 | 75 | return res 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from model import common 3 | from loss import discriminator 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | 11 | class Adversarial(nn.Module): 12 | def __init__(self, args, gan_type): 13 | super(Adversarial, self).__init__() 14 | self.gan_type = gan_type 15 | self.gan_k = args.gan_k 16 | self.discriminator = discriminator.Discriminator(args, gan_type) 17 | if gan_type != 'WGAN_GP': 18 | self.optimizer = utility.make_optimizer(args, self.discriminator) 19 | else: 20 | self.optimizer = optim.Adam( 21 | self.discriminator.parameters(), 22 | betas=(0, 0.9), eps=1e-8, lr=1e-5 23 | ) 24 | self.scheduler = utility.make_scheduler(args, self.optimizer) 25 | 26 | def forward(self, fake, real): 27 | fake_detach = fake.detach() 28 | 29 | self.loss = 0 30 | for _ in range(self.gan_k): 31 | self.optimizer.zero_grad() 32 | d_fake = self.discriminator(fake_detach) 33 | d_real = self.discriminator(real) 34 | if self.gan_type == 'GAN': 35 | label_fake = torch.zeros_like(d_fake) 36 | label_real = torch.ones_like(d_real) 37 | loss_d \ 38 | = F.binary_cross_entropy_with_logits(d_fake, label_fake) \ 39 | + F.binary_cross_entropy_with_logits(d_real, label_real) 40 | elif self.gan_type.find('WGAN') >= 0: 41 | loss_d = (d_fake - d_real).mean() 42 | if self.gan_type.find('GP') >= 0: 43 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 44 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 45 | hat.requires_grad = True 46 | d_hat = self.discriminator(hat) 47 | gradients = torch.autograd.grad( 48 | outputs=d_hat.sum(), inputs=hat, 49 | retain_graph=True, create_graph=True, only_inputs=True 50 | )[0] 51 | gradients = gradients.view(gradients.size(0), -1) 52 | gradient_norm = gradients.norm(2, dim=1) 53 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 54 | loss_d += gradient_penalty 55 | 56 | # Discriminator update 57 | self.loss += loss_d.item() 58 | loss_d.backward() 59 | self.optimizer.step() 60 | 61 | if self.gan_type == 'WGAN': 62 | for p in self.discriminator.parameters(): 63 | p.data.clamp_(-1, 1) 64 | 65 | self.loss /= self.gan_k 66 | 67 | d_fake_for_g = self.discriminator(fake) 68 | if self.gan_type == 'GAN': 69 | loss_g = F.binary_cross_entropy_with_logits( 70 | d_fake_for_g, label_real 71 | ) 72 | elif self.gan_type.find('WGAN') >= 0: 73 | loss_g = -d_fake_for_g.mean() 74 | 75 | # Generator loss 76 | return loss_g 77 | 78 | def state_dict(self, *args, **kwargs): 79 | state_discriminator = self.discriminator.state_dict(*args, **kwargs) 80 | state_optimizer = self.optimizer.state_dict() 81 | 82 | return dict(**state_discriminator, **state_optimizer) 83 | 84 | # Some references 85 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 86 | # OR 87 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MRDA 2 | 3 | 4 | This project is the official implementation of 'Meta-Learning based Degradation Representation for Blind Super-Resolution', TIP2023 5 | > **Meta-Learning based Degradation Representation for Blind Super-Resolution [[Paper](https://arxiv.org/pdf/2207.13963.pdf)] [[Project](https://github.com/Zj-BinXia/MRDA)]** 6 | 7 | This is code for MRDA (for classic degradation model, ie y=kx+n) 8 | 9 |

10 | 11 |

12 | 13 | --- 14 | 15 | ## Dependencies and Installation 16 | 17 | - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 18 | - [PyTorch >= 1.10](https://pytorch.org/) 19 | 20 | ## Dataset Preparation 21 | 22 | We use DF2K, which combines [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (800 images) and [Flickr2K](https://github.com/LimBee/NTIRE2017) (2650 images). 23 | 24 | --- 25 | 26 | ## Training 27 | 28 | 1. train Meta-Learning Network (MLN) bicubic pretraining 29 | 30 | ```bash 31 | sh main_stage1.sh 32 | ``` 33 | 34 | ### Isotropic Gaussian Kernels 35 | 36 | 37 | 2. we train MLN using meta-learning scheme. **It is notable that modify the ''pre_train'' of main_stage2.sh to the path of trained main_stage1 checkpoint.** Then, we run 38 | 39 | ```bash 40 | sh main_stage2.sh 41 | ``` 42 | 43 | 3. we train MLN with teacher MRDA_T together. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint.** Then, we run 44 | 45 | ```bash 46 | sh main_stage3.sh 47 | ``` 48 | 49 | 50 | 4. we train student MRDA_S. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint. ''pre_train_TA'' and ''pre_train_ST'' are both set to the path of trained main_stage3 checkpoint..** Then, we run 51 | 52 | ```bash 53 | sh main_stage4.sh 54 | ``` 55 | 56 | ### Anisotropic Gaussian Kernels plus noise 57 | 58 | It's training process is the same as isotropic Gaussian Kernels, except we use the anisotropic Gaussian Kernels settings in main_stage2.sh, main_stage3.sh, and main_stage4.sh . 59 | 60 | 2. we train MLN using meta-learning scheme. **It is notable that modify the ''pre_train'' of main_stage2.sh to the path of trained main_stage1 checkpoint.** Then, we run 61 | 62 | ```bash 63 | sh main_stage2.sh 64 | ``` 65 | 66 | 3. we train MLN with teacher MRDA_T together. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint.** Then, we run 67 | 68 | ```bash 69 | sh main_stage3.sh 70 | ``` 71 | 72 | 73 | 4. we train student MRDA_S. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint. ''pre_train_TA'' and ''pre_train_ST'' are both set to the path of trained main_stage3 checkpoint..** Then, we run 74 | 75 | ```bash 76 | sh main_stage4.sh 77 | ``` 78 | 79 | --- 80 | 81 | ## :european_castle: Model Zoo 82 | 83 | Please download checkpoints from [Google Drive](https://drive.google.com/drive/folders/1gB-q3k_e_XeZbvOFXMNlS4aHDkpDqMtU?usp=sharing). 84 | 85 | --- 86 | 87 | ## Testing 88 | 89 | ### Isotropic Gaussian Kernels 90 | 91 | ```bash 92 | sh test_iso_stage4.sh 93 | ``` 94 | 95 | ### Anisotropic Gaussian Kernels plus noise 96 | 97 | ```bash 98 | sh test_anisoAnoise_stage4.sh 99 | ``` 100 | 101 | --- 102 | 103 | ## Results 104 |

105 | 106 |

107 | 108 |

109 | 110 |

111 | 112 | 113 | --- 114 | 115 | ## BibTeX 116 | 117 | @article{xia2022meta, 118 | title={Meta-learning based degradation representation for blind super-resolution}, 119 | author={Xia, Bin and Tian, Yapeng and Zhang, Yulun and Hang, Yucheng and Yang, Wenming and Liao, Qingmin}, 120 | journal={IEEE Transactions on Image Processing}, 121 | year={2023} 122 | } 123 | 124 | ## 📧 Contact 125 | 126 | If you have any question, please email `zjbinxia@gmail.com`. 127 | 128 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Loss(nn.modules.loss._Loss): 14 | def __init__(self, args, ckp): 15 | super(Loss, self).__init__() 16 | print('Preparing loss function:') 17 | 18 | self.n_GPUs = args.n_GPUs 19 | self.loss = [] 20 | self.loss_module = nn.ModuleList() 21 | for loss in args.loss.split('+'): 22 | weight, loss_type = loss.split('*') 23 | if loss_type == 'MSE': 24 | loss_function = nn.MSELoss() 25 | elif loss_type == 'L1': 26 | loss_function = nn.L1Loss() 27 | elif loss_type == 'CE': 28 | loss_function = nn.CrossEntropyLoss() 29 | elif loss_type.find('VGG') >= 0: 30 | module = import_module('loss.vgg') 31 | loss_function = getattr(module, 'VGG')( 32 | loss_type[3:], 33 | rgb_range=args.rgb_range 34 | ) 35 | elif loss_type.find('GAN') >= 0: 36 | module = import_module('loss.adversarial') 37 | loss_function = getattr(module, 'Adversarial')( 38 | args, 39 | loss_type 40 | ) 41 | 42 | self.loss.append({ 43 | 'type': loss_type, 44 | 'weight': float(weight), 45 | 'function': loss_function} 46 | ) 47 | if loss_type.find('GAN') >= 0: 48 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 49 | 50 | if len(self.loss) > 1: 51 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 52 | 53 | for l in self.loss: 54 | if l['function'] is not None: 55 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 56 | self.loss_module.append(l['function']) 57 | 58 | self.log = torch.Tensor() 59 | 60 | device = torch.device('cpu' if args.cpu else 'cuda') 61 | self.loss_module.to(device) 62 | if args.precision == 'half': self.loss_module.half() 63 | if not args.cpu and args.n_GPUs > 1: 64 | self.loss_module = nn.DataParallel( 65 | self.loss_module, range(args.n_GPUs) 66 | ) 67 | 68 | if args.load != '.': self.load(ckp.dir, cpu=args.cpu) 69 | 70 | def forward(self, sr, hr): 71 | losses = [] 72 | for i, l in enumerate(self.loss): 73 | if l['function'] is not None: 74 | loss = l['function'](sr, hr) 75 | effective_loss = l['weight'] * loss 76 | losses.append(effective_loss) 77 | self.log[-1, i] += effective_loss.item() 78 | elif l['type'] == 'DIS': 79 | self.log[-1, i] += self.loss[i - 1]['function'].loss 80 | 81 | loss_sum = sum(losses) 82 | if len(self.loss) > 1: 83 | self.log[-1, -1] += loss_sum.item() 84 | 85 | return loss_sum 86 | 87 | def step(self): 88 | for l in self.get_loss_module(): 89 | if hasattr(l, 'scheduler'): 90 | l.scheduler.step() 91 | 92 | def start_log(self): 93 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 94 | 95 | def end_log(self, n_batches): 96 | self.log[-1].div_(n_batches) 97 | 98 | def display_loss(self, batch): 99 | n_samples = batch + 1 100 | log = [] 101 | for l, c in zip(self.loss, self.log[-1]): 102 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 103 | 104 | return ''.join(log) 105 | 106 | def plot_loss(self, apath, epoch): 107 | axis = np.linspace(1, epoch, epoch) 108 | for i, l in enumerate(self.loss): 109 | label = '{} Loss'.format(l['type']) 110 | fig = plt.figure() 111 | plt.title(label) 112 | plt.plot(axis, self.log[:, i].numpy(), label=label) 113 | plt.legend() 114 | plt.xlabel('Epochs') 115 | plt.ylabel('Loss') 116 | plt.grid(True) 117 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) 118 | plt.close(fig) 119 | 120 | def get_loss_module(self): 121 | if self.n_GPUs == 1: 122 | return self.loss_module 123 | else: 124 | return self.loss_module.module 125 | 126 | def save(self, apath): 127 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 128 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 129 | 130 | def load(self, apath, cpu=False): 131 | if cpu: 132 | kwargs = {'map_location': lambda storage, loc: storage} 133 | else: 134 | kwargs = {} 135 | 136 | self.load_state_dict(torch.load( 137 | os.path.join(apath, 'loss.pt'), 138 | **kwargs 139 | )) 140 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 141 | for l in self.loss_module: 142 | if hasattr(l, 'scheduler'): 143 | for _ in range(len(self.log)): l.scheduler.step() -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import threading 3 | import queue 4 | import random 5 | import collections 6 | 7 | import torch 8 | import torch.multiprocessing as multiprocessing 9 | 10 | from torch._C import _set_worker_signal_handlers 11 | from torch.utils.data.dataloader import DataLoader 12 | from torch.utils.data.dataloader import _DataLoaderIter 13 | from torch.utils.data import _utils 14 | 15 | if sys.version_info[0] == 2: 16 | import Queue as queue 17 | else: 18 | import queue 19 | 20 | def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id): 21 | global _use_shared_memory 22 | _use_shared_memory = True 23 | _set_worker_signal_handlers() 24 | 25 | torch.set_num_threads(1) 26 | torch.manual_seed(seed) 27 | while True: 28 | r = index_queue.get() 29 | if r is None: 30 | break 31 | idx, batch_indices = r 32 | try: 33 | idx_scale = 0 34 | if len(scale) > 1 and dataset.train: 35 | idx_scale = random.randrange(0, len(scale)) 36 | dataset.set_scale(idx_scale) 37 | 38 | samples = collate_fn([dataset[i] for i in batch_indices]) 39 | samples.append(idx_scale) 40 | 41 | except Exception: 42 | data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info()))) 43 | else: 44 | data_queue.put((idx, samples)) 45 | 46 | class _MSDataLoaderIter(_DataLoaderIter): 47 | def __init__(self, loader): 48 | self.dataset = loader.dataset 49 | self.scale = loader.scale 50 | self.collate_fn = loader.collate_fn 51 | self.batch_sampler = loader.batch_sampler 52 | self.num_workers = loader.num_workers 53 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 54 | self.timeout = loader.timeout 55 | self.done_event = threading.Event() 56 | 57 | self.sample_iter = iter(self.batch_sampler) 58 | 59 | if self.num_workers > 0: 60 | self.worker_init_fn = loader.worker_init_fn 61 | self.index_queues = [ 62 | multiprocessing.Queue() for _ in range(self.num_workers) 63 | ] 64 | self.worker_queue_idx = 0 65 | self.worker_result_queue = multiprocessing.Queue() 66 | self.batches_outstanding = 0 67 | self.worker_pids_set = False 68 | self.shutdown = False 69 | self.send_idx = 0 70 | self.rcvd_idx = 0 71 | self.reorder_dict = {} 72 | 73 | base_seed = torch.LongTensor(1).random_()[0] 74 | self.workers = [ 75 | multiprocessing.Process( 76 | target=_ms_loop, 77 | args=( 78 | self.dataset, 79 | self.index_queues[i], 80 | self.worker_result_queue, 81 | self.collate_fn, 82 | self.scale, 83 | base_seed + i, 84 | self.worker_init_fn, 85 | i 86 | ) 87 | ) 88 | for i in range(self.num_workers)] 89 | 90 | if self.pin_memory or self.timeout > 0: 91 | self.data_queue = queue.Queue() 92 | if self.pin_memory: 93 | maybe_device_id = torch.cuda.current_device() 94 | else: 95 | # do not initialize cuda context if not necessary 96 | maybe_device_id = None 97 | self.pin_memory_thread = threading.Thread( 98 | target=_utils.pin_memory._pin_memory_loop, 99 | args=(self.worker_result_queue, self.data_queue, maybe_device_id, self.done_event)) 100 | self.pin_memory_thread.daemon = True 101 | self.pin_memory_thread.start() 102 | else: 103 | self.data_queue = self.worker_result_queue 104 | 105 | for w in self.workers: 106 | w.daemon = True # ensure that the worker exits on process exit 107 | w.start() 108 | 109 | _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers)) 110 | _utils.signal_handling._set_SIGCHLD_handler() 111 | self.worker_pids_set = True 112 | 113 | # prime the prefetch loop 114 | for _ in range(2 * self.num_workers): 115 | self._put_indices() 116 | 117 | class MSDataLoader(DataLoader): 118 | def __init__( 119 | self, args, dataset, batch_size=1, shuffle=False, 120 | sampler=None, batch_sampler=None, 121 | collate_fn=_utils.collate.default_collate, pin_memory=False, drop_last=True, 122 | timeout=0, worker_init_fn=None): 123 | 124 | super(MSDataLoader, self).__init__( 125 | dataset, batch_size=batch_size, shuffle=shuffle, 126 | sampler=sampler, batch_sampler=batch_sampler, 127 | num_workers=args.n_threads, collate_fn=collate_fn, 128 | pin_memory=pin_memory, drop_last=drop_last, 129 | timeout=timeout, worker_init_fn=worker_init_fn) 130 | 131 | self.scale = args.scale 132 | 133 | def __iter__(self): 134 | return _MSDataLoaderIter(self) 135 | -------------------------------------------------------------------------------- /data/multiscalesrdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | 14 | class SRData(data.Dataset): 15 | def __init__(self, args, name='', train=True, benchmark=False): 16 | self.args = args 17 | self.name = name 18 | self.train = train 19 | self.split = 'train' if train else 'test' 20 | self.do_eval = True 21 | self.benchmark = benchmark 22 | self.input_large = (args.model == 'VDSR') 23 | self.scale = args.scale 24 | self.idx_scale = 0 25 | 26 | self._set_filesystem(args.dir_data) 27 | if args.ext.find('img') < 0: 28 | path_bin = os.path.join(self.apath, 'bin') 29 | os.makedirs(path_bin, exist_ok=True) 30 | 31 | list_hr, list_lr = self._scan() 32 | if args.ext.find('img') >= 0 or benchmark: 33 | self.images_hr, self.images_lr = list_hr, list_lr 34 | elif args.ext.find('sep') >= 0: 35 | os.makedirs( 36 | self.dir_hr.replace(self.apath, path_bin), 37 | exist_ok=True 38 | ) 39 | for s in self.scale: 40 | os.makedirs( 41 | os.path.join( 42 | self.dir_lr.replace(self.apath, path_bin), 43 | 'X{}'.format(int(s)) 44 | ), 45 | exist_ok=True 46 | ) 47 | 48 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 49 | for h in list_hr: 50 | b = h.replace(self.apath, path_bin) 51 | b = b.replace(self.ext[0], '.pt') 52 | self.images_hr.append(b) 53 | self._check_and_load(args.ext, h, b, verbose=True) 54 | for i, ll in enumerate(list_lr): 55 | for l in ll: 56 | b = l.replace(self.apath, path_bin) 57 | b = b.replace(self.ext[1], '.pt') 58 | self.images_lr[i].append(b) 59 | self._check_and_load(args.ext, l, b, verbose=True) 60 | if train: 61 | n_patches = args.batch_size * args.test_every 62 | n_images = len(args.data_train) * len(self.images_hr) 63 | if n_images == 0: 64 | self.repeat = 0 65 | else: 66 | self.repeat = max(n_patches // n_images, 1) 67 | 68 | # Below functions as used to prepare images 69 | def _scan(self): 70 | names_hr = sorted( 71 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 72 | ) 73 | names_lr = [[] for _ in self.scale] 74 | for f in names_hr: 75 | filename, _ = os.path.splitext(os.path.basename(f)) 76 | for si, s in enumerate(self.scale): 77 | names_lr[si].append(os.path.join( 78 | self.dir_lr, 'X{}/{}x{}{}'.format( 79 | int(s), filename, int(s), self.ext[1] 80 | ) 81 | )) 82 | 83 | return names_hr, names_lr 84 | 85 | def _set_filesystem(self, dir_data): 86 | self.apath = os.path.join(dir_data, self.name) 87 | self.dir_hr = os.path.join(self.apath, 'HR') 88 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 89 | if self.input_large: self.dir_lr += 'L' 90 | self.ext = ('.png', '.png') 91 | 92 | def _check_and_load(self, ext, img, f, verbose=True): 93 | if not os.path.isfile(f) or ext.find('reset') >= 0: 94 | if verbose: 95 | print('Making a binary: {}'.format(f)) 96 | with open(f, 'wb') as _f: 97 | pickle.dump(imageio.imread(img), _f) 98 | 99 | def __getitem__(self, idx): 100 | hr, filename = self._load_file(idx) 101 | hr = self.get_patch( hr) 102 | hr = common.set_channel(hr, n_channels=self.args.n_colors) 103 | hr = common.np2Tensor(hr, rgb_range=self.args.rgb_range) 104 | 105 | return hr, filename 106 | 107 | def __len__(self): 108 | if self.train: 109 | return len(self.images_hr) * self.repeat 110 | else: 111 | return len(self.images_hr) 112 | 113 | def _get_index(self, idx): 114 | if self.train: 115 | return idx % len(self.images_hr) 116 | else: 117 | return idx 118 | 119 | def _load_file(self, idx): 120 | idx = self._get_index(idx) 121 | f_hr = self.images_hr[idx] 122 | 123 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 124 | if self.args.ext == 'img' or self.benchmark: 125 | hr = imageio.imread(f_hr) 126 | elif self.args.ext.find('sep') >= 0: 127 | with open(f_hr, 'rb') as _f: 128 | hr = pickle.load(_f) 129 | 130 | return hr, filename 131 | 132 | def get_patch(self, hr): 133 | scale = self.scale[self.idx_scale] 134 | if self.train: 135 | hr = common.get_patch( 136 | hr, 137 | patch_size=self.args.patch_size, 138 | scale=scale, 139 | multi=(len(self.scale) > 1), 140 | input_large=self.input_large 141 | ) 142 | if not self.args.no_augment: hr = common.augment( hr) 143 | else: 144 | ih, iw = hr.shape[:2] 145 | ih = ih//scale 146 | iw = iw // scale 147 | hr = hr[0:int(ih * scale), 0:int(iw * scale)] 148 | 149 | return hr 150 | 151 | def set_scale(self, idx_scale): 152 | if not self.input_large: 153 | self.idx_scale = idx_scale 154 | else: 155 | self.idx_scale = random.randint(0, len(self.scale) - 1) 156 | 157 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | self.args = args 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | self.save = args.save 23 | 24 | module = import_module('model.'+args.model) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | 38 | def forward(self, lr,lr_bic): 39 | if self.self_ensemble and not self.training: 40 | if self.chop: 41 | forward_function = self.forward_chop 42 | else: 43 | forward_function = self.model.forward 44 | 45 | return self.forward_x8(lr, forward_function) 46 | elif self.chop and not self.training: 47 | return self.forward_chop(lr) 48 | else: 49 | return self.model(lr,lr_bic) 50 | 51 | def get_model(self): 52 | if self.n_GPUs <= 1 or self.cpu: 53 | return self.model 54 | else: 55 | return self.model.module 56 | 57 | def state_dict(self, **kwargs): 58 | target = self.get_model() 59 | return target.state_dict(**kwargs) 60 | 61 | def save(self, apath, epoch, is_best=False): 62 | target = self.get_model() 63 | torch.save( 64 | target.state_dict(), 65 | os.path.join(apath, 'model', 'model_latest.pt') 66 | ) 67 | if is_best: 68 | torch.save( 69 | target.state_dict(), 70 | os.path.join(apath, 'model', 'model_best.pt') 71 | ) 72 | 73 | if self.save_models: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 77 | ) 78 | 79 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 80 | if cpu: 81 | kwargs = {'map_location': lambda storage, loc: storage} 82 | else: 83 | kwargs = {} 84 | 85 | if resume == -1: 86 | self.get_model().load_state_dict( 87 | torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs), 88 | strict=True 89 | ) 90 | 91 | elif resume == 0: 92 | if pre_train != '.': 93 | self.get_model().load_state_dict( 94 | torch.load(pre_train, **kwargs), 95 | strict=True 96 | ) 97 | 98 | elif resume > 0: 99 | self.get_model().load_state_dict( 100 | torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs), 101 | strict=False 102 | ) 103 | 104 | def forward_chop(self, x, shave=10, min_size=160000): 105 | scale = self.scale[self.idx_scale] 106 | n_GPUs = min(self.n_GPUs, 4) 107 | b, c, h, w = x.size() 108 | h_half, w_half = h // 2, w // 2 109 | h_size, w_size = h_half + shave, w_half + shave 110 | lr_list = [ 111 | x[:, :, 0:h_size, 0:w_size], 112 | x[:, :, 0:h_size, (w - w_size):w], 113 | x[:, :, (h - h_size):h, 0:w_size], 114 | x[:, :, (h - h_size):h, (w - w_size):w]] 115 | 116 | if w_size * h_size < min_size: 117 | sr_list = [] 118 | for i in range(0, 4, n_GPUs): 119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 120 | sr_batch = self.model(lr_batch) 121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 122 | else: 123 | sr_list = [ 124 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 125 | for patch in lr_list 126 | ] 127 | 128 | h, w = scale * h, scale * w 129 | h_half, w_half = scale * h_half, scale * w_half 130 | h_size, w_size = scale * h_size, scale * w_size 131 | shave *= scale 132 | 133 | output = x.new(b, c, h, w) 134 | output[:, :, 0:h_half, 0:w_half] \ 135 | = sr_list[0][:, :, 0:h_half, 0:w_half] 136 | output[:, :, 0:h_half, w_half:w] \ 137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 138 | output[:, :, h_half:h, 0:w_half] \ 139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 140 | output[:, :, h_half:h, w_half:w] \ 141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 142 | 143 | return output 144 | 145 | def forward_x8(self, x, forward_function): 146 | def _transform(v, op): 147 | if self.precision != 'single': v = v.float() 148 | 149 | v2np = v.data.cpu().numpy() 150 | if op == 'v': 151 | tfnp = v2np[:, :, :, ::-1].copy() 152 | elif op == 'h': 153 | tfnp = v2np[:, :, ::-1, :].copy() 154 | elif op == 't': 155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 156 | 157 | ret = torch.Tensor(tfnp).to(self.device) 158 | if self.precision == 'half': ret = ret.half() 159 | 160 | return ret 161 | 162 | lr_list = [x] 163 | for tf in 'v', 'h', 't': 164 | lr_list.extend([_transform(t, tf) for t in lr_list]) 165 | 166 | sr_list = [forward_function(aug) for aug in lr_list] 167 | for i in range(len(sr_list)): 168 | if i > 3: 169 | sr_list[i] = _transform(sr_list[i], 't') 170 | if i % 4 > 1: 171 | sr_list[i] = _transform(sr_list[i], 'h') 172 | if (i % 4) % 2 == 1: 173 | sr_list[i] = _transform(sr_list[i], 'v') 174 | 175 | output_cat = torch.cat(sr_list, dim=0) 176 | output = output_cat.mean(dim=0, keepdim=True) 177 | 178 | return output 179 | 180 | -------------------------------------------------------------------------------- /model_ST/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | self.args = args 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | self.save = args.save 23 | 24 | module = import_module('model_ST.'+args.model) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train_ST, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | 38 | def forward(self, x): 39 | if self.self_ensemble and not self.training: 40 | if self.chop: 41 | forward_function = self.forward_chop 42 | else: 43 | forward_function = self.model.forward 44 | 45 | return self.forward_x8(x, forward_function) 46 | elif self.chop and not self.training: 47 | return self.forward_chop(x) 48 | else: 49 | return self.model(x) 50 | 51 | def get_model(self): 52 | if self.n_GPUs <= 1 or self.cpu: 53 | return self.model 54 | else: 55 | return self.model.module 56 | 57 | def state_dict(self, **kwargs): 58 | target = self.get_model() 59 | return target.state_dict(**kwargs) 60 | 61 | def save(self, apath, epoch, is_best=False): 62 | target = self.get_model() 63 | torch.save( 64 | target.state_dict(), 65 | os.path.join(apath, 'model', 'model_ST_latest.pt') 66 | ) 67 | if is_best: 68 | torch.save( 69 | target.state_dict(), 70 | os.path.join(apath, 'model', 'model_ST_best.pt') 71 | ) 72 | 73 | if self.save_models: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_ST_{}.pt'.format(epoch)) 77 | ) 78 | 79 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 80 | if cpu: 81 | kwargs = {'map_location': lambda storage, loc: storage} 82 | else: 83 | kwargs = {} 84 | 85 | if resume == -1: 86 | self.get_model().load_state_dict( 87 | torch.load(os.path.join(apath, 'model', 'model_ST_latest.pt'), **kwargs), 88 | strict=True 89 | ) 90 | 91 | elif resume == 0: 92 | if pre_train != '.': 93 | self.get_model().load_state_dict( 94 | torch.load(pre_train, **kwargs), 95 | strict=False 96 | ) 97 | 98 | elif resume > 0: 99 | self.get_model().load_state_dict( 100 | torch.load(os.path.join(apath, 'model', 'model_ST_{}.pt'.format(resume)), **kwargs), 101 | strict=False 102 | ) 103 | 104 | def forward_chop(self, x, shave=10, min_size=160000): 105 | scale = self.scale[self.idx_scale] 106 | n_GPUs = min(self.n_GPUs, 4) 107 | b, c, h, w = x.size() 108 | h_half, w_half = h // 2, w // 2 109 | h_size, w_size = h_half + shave, w_half + shave 110 | lr_list = [ 111 | x[:, :, 0:h_size, 0:w_size], 112 | x[:, :, 0:h_size, (w - w_size):w], 113 | x[:, :, (h - h_size):h, 0:w_size], 114 | x[:, :, (h - h_size):h, (w - w_size):w]] 115 | 116 | if w_size * h_size < min_size: 117 | sr_list = [] 118 | for i in range(0, 4, n_GPUs): 119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 120 | sr_batch = self.model(lr_batch) 121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 122 | else: 123 | sr_list = [ 124 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 125 | for patch in lr_list 126 | ] 127 | 128 | h, w = scale * h, scale * w 129 | h_half, w_half = scale * h_half, scale * w_half 130 | h_size, w_size = scale * h_size, scale * w_size 131 | shave *= scale 132 | 133 | output = x.new(b, c, h, w) 134 | output[:, :, 0:h_half, 0:w_half] \ 135 | = sr_list[0][:, :, 0:h_half, 0:w_half] 136 | output[:, :, 0:h_half, w_half:w] \ 137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 138 | output[:, :, h_half:h, 0:w_half] \ 139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 140 | output[:, :, h_half:h, w_half:w] \ 141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 142 | 143 | return output 144 | 145 | def forward_x8(self, x, forward_function): 146 | def _transform(v, op): 147 | if self.precision != 'single': v = v.float() 148 | 149 | v2np = v.data.cpu().numpy() 150 | if op == 'v': 151 | tfnp = v2np[:, :, :, ::-1].copy() 152 | elif op == 'h': 153 | tfnp = v2np[:, :, ::-1, :].copy() 154 | elif op == 't': 155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 156 | 157 | ret = torch.Tensor(tfnp).to(self.device) 158 | if self.precision == 'half': ret = ret.half() 159 | 160 | return ret 161 | 162 | lr_list = [x] 163 | for tf in 'v', 'h', 't': 164 | lr_list.extend([_transform(t, tf) for t in lr_list]) 165 | 166 | sr_list = [forward_function(aug) for aug in lr_list] 167 | for i in range(len(sr_list)): 168 | if i > 3: 169 | sr_list[i] = _transform(sr_list[i], 't') 170 | if i % 4 > 1: 171 | sr_list[i] = _transform(sr_list[i], 'h') 172 | if (i % 4) % 2 == 1: 173 | sr_list[i] = _transform(sr_list[i], 'v') 174 | 175 | output_cat = torch.cat(sr_list, dim=0) 176 | output = output_cat.mean(dim=0, keepdim=True) 177 | 178 | return output 179 | 180 | -------------------------------------------------------------------------------- /model_TA/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | self.args = args 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | self.save = args.save 23 | 24 | module = import_module('model_TA.'+args.model) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train_TA, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | 38 | def forward(self, x, deg_repre): 39 | if self.self_ensemble and not self.training: 40 | if self.chop: 41 | forward_function = self.forward_chop 42 | else: 43 | forward_function = self.model.forward 44 | 45 | return self.forward_x8(x, forward_function) 46 | elif self.chop and not self.training: 47 | return self.forward_chop(x) 48 | else: 49 | return self.model(x, deg_repre) 50 | 51 | def get_model(self): 52 | if self.n_GPUs <= 1 or self.cpu: 53 | return self.model 54 | else: 55 | return self.model.module 56 | 57 | def state_dict(self, **kwargs): 58 | target = self.get_model() 59 | return target.state_dict(**kwargs) 60 | 61 | def save(self, apath, epoch, is_best=False): 62 | target = self.get_model() 63 | torch.save( 64 | target.state_dict(), 65 | os.path.join(apath, 'model', 'model_latest.pt') 66 | ) 67 | if is_best: 68 | torch.save( 69 | target.state_dict(), 70 | os.path.join(apath, 'model', 'model_best.pt') 71 | ) 72 | 73 | if self.save_models: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 77 | ) 78 | 79 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 80 | if cpu: 81 | kwargs = {'map_location': lambda storage, loc: storage} 82 | else: 83 | kwargs = {} 84 | 85 | if resume == -1: 86 | self.get_model().load_state_dict( 87 | torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs), 88 | strict=True 89 | ) 90 | 91 | elif resume == 0: 92 | if pre_train != '.': 93 | self.get_model().load_state_dict( 94 | torch.load(pre_train, **kwargs), 95 | strict=True 96 | ) 97 | 98 | elif resume > 0: 99 | self.get_model().load_state_dict( 100 | torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs), 101 | strict=False 102 | ) 103 | 104 | def forward_chop(self, x, shave=10, min_size=160000): 105 | scale = self.scale[self.idx_scale] 106 | n_GPUs = min(self.n_GPUs, 4) 107 | b, c, h, w = x.size() 108 | h_half, w_half = h // 2, w // 2 109 | h_size, w_size = h_half + shave, w_half + shave 110 | lr_list = [ 111 | x[:, :, 0:h_size, 0:w_size], 112 | x[:, :, 0:h_size, (w - w_size):w], 113 | x[:, :, (h - h_size):h, 0:w_size], 114 | x[:, :, (h - h_size):h, (w - w_size):w]] 115 | 116 | if w_size * h_size < min_size: 117 | sr_list = [] 118 | for i in range(0, 4, n_GPUs): 119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 120 | sr_batch = self.model(lr_batch) 121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 122 | else: 123 | sr_list = [ 124 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 125 | for patch in lr_list 126 | ] 127 | 128 | h, w = scale * h, scale * w 129 | h_half, w_half = scale * h_half, scale * w_half 130 | h_size, w_size = scale * h_size, scale * w_size 131 | shave *= scale 132 | 133 | output = x.new(b, c, h, w) 134 | output[:, :, 0:h_half, 0:w_half] \ 135 | = sr_list[0][:, :, 0:h_half, 0:w_half] 136 | output[:, :, 0:h_half, w_half:w] \ 137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 138 | output[:, :, h_half:h, 0:w_half] \ 139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 140 | output[:, :, h_half:h, w_half:w] \ 141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 142 | 143 | return output 144 | 145 | def forward_x8(self, x, forward_function): 146 | def _transform(v, op): 147 | if self.precision != 'single': v = v.float() 148 | 149 | v2np = v.data.cpu().numpy() 150 | if op == 'v': 151 | tfnp = v2np[:, :, :, ::-1].copy() 152 | elif op == 'h': 153 | tfnp = v2np[:, :, ::-1, :].copy() 154 | elif op == 't': 155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 156 | 157 | ret = torch.Tensor(tfnp).to(self.device) 158 | if self.precision == 'half': ret = ret.half() 159 | 160 | return ret 161 | 162 | lr_list = [x] 163 | for tf in 'v', 'h', 't': 164 | lr_list.extend([_transform(t, tf) for t in lr_list]) 165 | 166 | sr_list = [forward_function(aug) for aug in lr_list] 167 | for i in range(len(sr_list)): 168 | if i > 3: 169 | sr_list[i] = _transform(sr_list[i], 't') 170 | if i % 4 > 1: 171 | sr_list[i] = _transform(sr_list[i], 'h') 172 | if (i % 4) % 2 == 1: 173 | sr_list[i] = _transform(sr_list[i], 'v') 174 | 175 | output_cat = torch.cat(sr_list, dim=0) 176 | output = output_cat.mean(dim=0, keepdim=True) 177 | 178 | return output 179 | 180 | -------------------------------------------------------------------------------- /model_meta_stage3/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | self.args = args 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | self.save = args.save 23 | 24 | module = import_module('model_meta_stage3.'+args.model) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train_meta, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | 38 | def forward(self, lr,lr_bic): 39 | if self.self_ensemble and not self.training: 40 | if self.chop: 41 | forward_function = self.forward_chop 42 | else: 43 | forward_function = self.model.forward 44 | 45 | return self.forward_x8(lr, forward_function) 46 | elif self.chop and not self.training: 47 | return self.forward_chop(lr) 48 | else: 49 | return self.model(lr,lr_bic) 50 | 51 | def get_model(self): 52 | if self.n_GPUs <= 1 or self.cpu: 53 | return self.model 54 | else: 55 | return self.model.module 56 | 57 | def state_dict(self, **kwargs): 58 | target = self.get_model() 59 | return target.state_dict(**kwargs) 60 | 61 | def save(self, apath, epoch, is_best=False): 62 | target = self.get_model() 63 | torch.save( 64 | target.state_dict(), 65 | os.path.join(apath, 'model', 'model_latest.pt') 66 | ) 67 | if is_best: 68 | torch.save( 69 | target.state_dict(), 70 | os.path.join(apath, 'model', 'model_best.pt') 71 | ) 72 | 73 | if self.save_models: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 77 | ) 78 | 79 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 80 | if cpu: 81 | kwargs = {'map_location': lambda storage, loc: storage} 82 | else: 83 | kwargs = {} 84 | 85 | if resume == -1: 86 | self.get_model().load_state_dict( 87 | torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs), 88 | strict=True 89 | ) 90 | 91 | elif resume == 0: 92 | if pre_train != '.': 93 | self.get_model().load_state_dict( 94 | torch.load(pre_train, **kwargs), 95 | strict=True 96 | ) 97 | 98 | elif resume > 0: 99 | self.get_model().load_state_dict( 100 | torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs), 101 | strict=False 102 | ) 103 | 104 | def forward_chop(self, x, shave=10, min_size=160000): 105 | scale = self.scale[self.idx_scale] 106 | n_GPUs = min(self.n_GPUs, 4) 107 | b, c, h, w = x.size() 108 | h_half, w_half = h // 2, w // 2 109 | h_size, w_size = h_half + shave, w_half + shave 110 | lr_list = [ 111 | x[:, :, 0:h_size, 0:w_size], 112 | x[:, :, 0:h_size, (w - w_size):w], 113 | x[:, :, (h - h_size):h, 0:w_size], 114 | x[:, :, (h - h_size):h, (w - w_size):w]] 115 | 116 | if w_size * h_size < min_size: 117 | sr_list = [] 118 | for i in range(0, 4, n_GPUs): 119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 120 | sr_batch = self.model(lr_batch) 121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 122 | else: 123 | sr_list = [ 124 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 125 | for patch in lr_list 126 | ] 127 | 128 | h, w = scale * h, scale * w 129 | h_half, w_half = scale * h_half, scale * w_half 130 | h_size, w_size = scale * h_size, scale * w_size 131 | shave *= scale 132 | 133 | output = x.new(b, c, h, w) 134 | output[:, :, 0:h_half, 0:w_half] \ 135 | = sr_list[0][:, :, 0:h_half, 0:w_half] 136 | output[:, :, 0:h_half, w_half:w] \ 137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 138 | output[:, :, h_half:h, 0:w_half] \ 139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 140 | output[:, :, h_half:h, w_half:w] \ 141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 142 | 143 | return output 144 | 145 | def forward_x8(self, x, forward_function): 146 | def _transform(v, op): 147 | if self.precision != 'single': v = v.float() 148 | 149 | v2np = v.data.cpu().numpy() 150 | if op == 'v': 151 | tfnp = v2np[:, :, :, ::-1].copy() 152 | elif op == 'h': 153 | tfnp = v2np[:, :, ::-1, :].copy() 154 | elif op == 't': 155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 156 | 157 | ret = torch.Tensor(tfnp).to(self.device) 158 | if self.precision == 'half': ret = ret.half() 159 | 160 | return ret 161 | 162 | lr_list = [x] 163 | for tf in 'v', 'h', 't': 164 | lr_list.extend([_transform(t, tf) for t in lr_list]) 165 | 166 | sr_list = [forward_function(aug) for aug in lr_list] 167 | for i in range(len(sr_list)): 168 | if i > 3: 169 | sr_list[i] = _transform(sr_list[i], 't') 170 | if i % 4 > 1: 171 | sr_list[i] = _transform(sr_list[i], 'h') 172 | if (i % 4) % 2 == 1: 173 | sr_list[i] = _transform(sr_list[i], 'v') 174 | 175 | output_cat = torch.cat(sr_list, dim=0) 176 | output = output_cat.mean(dim=0, keepdim=True) 177 | 178 | return output 179 | 180 | -------------------------------------------------------------------------------- /model_meta/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | self.args = args 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | self.save = args.save 23 | 24 | module = import_module('model_meta.'+args.model) 25 | self.model = module.make_model(args).to(self.device) 26 | if args.precision == 'half': self.model.half() 27 | 28 | if not args.cpu and args.n_GPUs > 1: 29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | 38 | def forward(self, lr,lr_bic, weights): 39 | if self.self_ensemble and not self.training: 40 | if self.chop: 41 | forward_function = self.forward_chop 42 | else: 43 | forward_function = self.model.forward 44 | 45 | return self.forward_x8(lr, forward_function) 46 | elif self.chop and not self.training: 47 | return self.forward_chop(lr) 48 | else: 49 | return self.model(lr,lr_bic, weights) 50 | 51 | def get_model(self): 52 | if self.n_GPUs <= 1 or self.cpu: 53 | return self.model 54 | else: 55 | return self.model.module 56 | 57 | def state_dict(self, **kwargs): 58 | target = self.get_model() 59 | return target.state_dict(**kwargs) 60 | 61 | def save(self, apath, epoch, is_best=False): 62 | target = self.get_model() 63 | torch.save( 64 | target.state_dict(), 65 | os.path.join(apath, 'model', 'model_meta_latest.pt') 66 | ) 67 | if is_best: 68 | torch.save( 69 | target.state_dict(), 70 | os.path.join(apath, 'model', 'model_meta_best.pt') 71 | ) 72 | 73 | if self.save_models: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_meta_{}.pt'.format(epoch)) 77 | ) 78 | 79 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 80 | if cpu: 81 | kwargs = {'map_location': lambda storage, loc: storage} 82 | else: 83 | kwargs = {} 84 | 85 | if resume == -1: 86 | self.get_model().load_state_dict( 87 | torch.load(os.path.join(apath, 'model', 'model_meta_latest.pt'), **kwargs), 88 | strict=True 89 | ) 90 | 91 | elif resume == 0: 92 | if pre_train != '.': 93 | self.get_model().load_state_dict( 94 | torch.load(pre_train, **kwargs), 95 | strict=True 96 | ) 97 | 98 | elif resume > 0: 99 | self.get_model().load_state_dict( 100 | torch.load(os.path.join(apath, 'model', 'model_meta_{}.pt'.format(resume)), **kwargs), 101 | strict=False 102 | ) 103 | 104 | def forward_chop(self, x, shave=10, min_size=160000): 105 | scale = self.scale[self.idx_scale] 106 | n_GPUs = min(self.n_GPUs, 4) 107 | b, c, h, w = x.size() 108 | h_half, w_half = h // 2, w // 2 109 | h_size, w_size = h_half + shave, w_half + shave 110 | lr_list = [ 111 | x[:, :, 0:h_size, 0:w_size], 112 | x[:, :, 0:h_size, (w - w_size):w], 113 | x[:, :, (h - h_size):h, 0:w_size], 114 | x[:, :, (h - h_size):h, (w - w_size):w]] 115 | 116 | if w_size * h_size < min_size: 117 | sr_list = [] 118 | for i in range(0, 4, n_GPUs): 119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 120 | sr_batch = self.model(lr_batch) 121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 122 | else: 123 | sr_list = [ 124 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 125 | for patch in lr_list 126 | ] 127 | 128 | h, w = scale * h, scale * w 129 | h_half, w_half = scale * h_half, scale * w_half 130 | h_size, w_size = scale * h_size, scale * w_size 131 | shave *= scale 132 | 133 | output = x.new(b, c, h, w) 134 | output[:, :, 0:h_half, 0:w_half] \ 135 | = sr_list[0][:, :, 0:h_half, 0:w_half] 136 | output[:, :, 0:h_half, w_half:w] \ 137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 138 | output[:, :, h_half:h, 0:w_half] \ 139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 140 | output[:, :, h_half:h, w_half:w] \ 141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 142 | 143 | return output 144 | 145 | def forward_x8(self, x, forward_function): 146 | def _transform(v, op): 147 | if self.precision != 'single': v = v.float() 148 | 149 | v2np = v.data.cpu().numpy() 150 | if op == 'v': 151 | tfnp = v2np[:, :, :, ::-1].copy() 152 | elif op == 'h': 153 | tfnp = v2np[:, :, ::-1, :].copy() 154 | elif op == 't': 155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 156 | 157 | ret = torch.Tensor(tfnp).to(self.device) 158 | if self.precision == 'half': ret = ret.half() 159 | 160 | return ret 161 | 162 | lr_list = [x] 163 | for tf in 'v', 'h', 't': 164 | lr_list.extend([_transform(t, tf) for t in lr_list]) 165 | 166 | sr_list = [forward_function(aug) for aug in lr_list] 167 | for i in range(len(sr_list)): 168 | if i > 3: 169 | sr_list[i] = _transform(sr_list[i], 't') 170 | if i % 4 > 1: 171 | sr_list[i] = _transform(sr_list[i], 'h') 172 | if (i % 4) % 2 == 1: 173 | sr_list[i] = _transform(sr_list[i], 'v') 174 | 175 | output_cat = torch.cat(sr_list, dim=0) 176 | output = output_cat.mean(dim=0, keepdim=True) 177 | 178 | return output 179 | 180 | -------------------------------------------------------------------------------- /trainer_stage1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utility 3 | import torch 4 | from decimal import Decimal 5 | import torch.nn.functional as F 6 | from utils import util 7 | import numpy as np 8 | 9 | 10 | class Trainer(): 11 | def __init__(self, args, loader, my_model, my_loss, ckp): 12 | self.args = args 13 | self.scale = args.scale 14 | self.test_res_psnr = [] 15 | self.test_res_ssim = [] 16 | self.ckp = ckp 17 | self.loader_train = loader.loader_train 18 | self.loader_test = loader.loader_test 19 | self.model = my_model 20 | self.loss = my_loss 21 | self.optimizer = utility.make_optimizer(args, self.model) 22 | self.scheduler = utility.make_scheduler(args, self.optimizer) 23 | 24 | if self.args.load != '.': 25 | self.optimizer.load_state_dict( 26 | torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 27 | ) 28 | for _ in range(len(ckp.log)): self.scheduler.step() 29 | 30 | def train(self): 31 | self.scheduler.step() 32 | self.loss.step() 33 | epoch = self.scheduler.last_epoch + 1 34 | 35 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr)) 36 | for param_group in self.optimizer.param_groups: 37 | param_group['lr'] = lr 38 | 39 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))) 40 | self.loss.start_log() 41 | self.model.train() 42 | 43 | degrade = util.BicubicPreprocessing( 44 | self.scale[0], 45 | rgb_range = self.args.rgb_range 46 | ) 47 | 48 | timer = utility.timer() 49 | losses_sr = utility.AverageMeter() 50 | 51 | for batch, (hr, _,) in enumerate(self.loader_train): 52 | hr = hr.cuda() # b, c, h, w 53 | lr,lr_bic = degrade(hr) # b, c, h, w 54 | #b_kernels, noise_level = degradation 55 | self.optimizer.zero_grad() 56 | 57 | timer.tic() 58 | # forward 59 | ## train the whole network 60 | sr = self.model(lr,lr_bic) 61 | loss_SR = self.loss(sr, hr) 62 | loss = loss_SR 63 | losses_sr.update(loss_SR.item()) 64 | 65 | # backward 66 | loss.backward() 67 | self.optimizer.step() 68 | timer.hold() 69 | 70 | if (batch + 1) % self.args.print_every == 0: 71 | self.ckp.write_log( 72 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t' 73 | 'Loss [SR loss:{:.3f}]\t' 74 | 'Time [{:.1f}s]'.format( 75 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 76 | losses_sr.avg, 77 | timer.release(), 78 | )) 79 | 80 | self.loss.end_log(len(self.loader_train)) 81 | 82 | # save model 83 | if epoch > self.args.st_save_epoch or epoch%30==0: 84 | target = self.model.get_model() 85 | model_dict = target.state_dict() 86 | torch.save( 87 | model_dict, 88 | os.path.join(self.ckp.dir, 'model', 'model_{}.pt'.format(epoch)) 89 | ) 90 | 91 | target = self.model.get_model() 92 | model_dict = target.state_dict() 93 | torch.save( 94 | model_dict, 95 | os.path.join(self.ckp.dir, 'model', 'model_last.pt') 96 | ) 97 | 98 | def test(self): 99 | self.ckp.write_log('\nEvaluation:') 100 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 101 | self.model.eval() 102 | 103 | timer_test = utility.timer() 104 | degrade = util.BicubicPreprocessing( 105 | self.scale[0], 106 | rgb_range=self.args.rgb_range 107 | ) 108 | 109 | with torch.no_grad(): 110 | for idx_scale, d in enumerate(self.loader_test): 111 | for idx_scale, scale in enumerate(self.scale): 112 | d.dataset.set_scale(idx_scale) 113 | eval_psnr = 0 114 | eval_ssim = 0 115 | for idx_img, (hr, filename) in enumerate(d): 116 | hr = hr.cuda() # b, c, h, w 117 | hr = self.crop_border(hr, scale) 118 | lr,lr_bic = degrade(hr) # b, c, h, w 119 | 120 | # inference 121 | timer_test.tic() 122 | sr = self.model(lr,lr_bic) 123 | timer_test.hold() 124 | 125 | sr = utility.quantize(sr, self.args.rgb_range) 126 | hr = utility.quantize(hr, self.args.rgb_range) 127 | 128 | # metrics 129 | eval_psnr += utility.calc_psnr( 130 | sr, hr, scale, self.args.rgb_range, 131 | benchmark=True 132 | ) 133 | eval_ssim += utility.calc_ssim( 134 | (sr * 255).round().clamp(0, 255), (hr * 255).round().clamp(0, 255), scale, 135 | benchmark=True 136 | ) 137 | 138 | # save results 139 | if self.args.save_results: 140 | save_list = [sr] 141 | filename = filename[0] 142 | self.ckp.save_results(filename, save_list, scale) 143 | 144 | if len(self.test_res_psnr) > 10: 145 | self.test_res_psnr.pop(0) 146 | self.test_res_ssim.pop(0) 147 | self.test_res_psnr.append(eval_psnr / len(self.loader_test)) 148 | self.test_res_ssim.append(eval_ssim / len(self.loader_test)) 149 | 150 | self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test) 151 | self.ckp.write_log( 152 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format( 153 | self.args.resume, 154 | self.args.data_test, 155 | scale, 156 | eval_psnr / len(self.loader_test), 157 | eval_ssim / len(self.loader_test), 158 | np.mean(self.test_res_psnr), 159 | np.mean(self.test_res_ssim) 160 | )) 161 | 162 | def crop_border(self, img_hr, scale): 163 | b, c, h, w = img_hr.size() 164 | 165 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)] 166 | 167 | return img_hr 168 | 169 | def terminate(self): 170 | if self.args.test_only: 171 | self.test() 172 | return True 173 | else: 174 | epoch = self.scheduler.last_epoch + 1 175 | return epoch >= self.args.epochs_sr 176 | 177 | -------------------------------------------------------------------------------- /model_TA/blindsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import model.common as common 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_model(args): 8 | return BlindSR(args) 9 | 10 | class RDA_conv(nn.Module): 11 | def __init__(self, channels_in, channels_out, kernel_size, reduction): 12 | super(RDA_conv, self).__init__() 13 | self.channels_out = channels_out 14 | self.channels_in = channels_in 15 | self.kernel_size = kernel_size 16 | 17 | self.kernel = nn.Sequential( 18 | nn.Linear(64, 64, bias=False), 19 | nn.LeakyReLU(0.1, True), 20 | nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False) 21 | ) 22 | self.conv = common.default_conv(channels_in, 1, 1) 23 | self.relu = nn.LeakyReLU(0.1, True) 24 | self.sig = nn.Sigmoid() 25 | def forward(self, x): 26 | ''' 27 | :param x[0]: feature map: B * C * H * W 28 | :param x[1]: degradation representation: B * C 29 | ''' 30 | b, c, h, w = x[0].size() 31 | 32 | # branch 1 33 | kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size) 34 | out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2)) 35 | out = out.view(b, -1, h, w) 36 | M = self.sig(self.conv(x[0])) 37 | # branch 2 38 | out = out*M + x[0] 39 | 40 | return out 41 | 42 | 43 | class RDAB(nn.Module): 44 | def __init__(self, conv, n_feat, kernel_size, reduction): 45 | super(RDAB, self).__init__() 46 | 47 | self.da_conv1 = RDA_conv(n_feat, n_feat, kernel_size, reduction) 48 | self.da_conv2 = RDA_conv(n_feat, n_feat, kernel_size, reduction) 49 | self.conv1 = conv(n_feat, n_feat, kernel_size) 50 | self.conv2 = conv(n_feat, n_feat, kernel_size) 51 | 52 | self.relu = nn.LeakyReLU(0.1, True) 53 | 54 | def forward(self, x): 55 | ''' 56 | :param x[0]: feature map: B * C * H * W 57 | :param x[1]: degradation representation: B * C 58 | ''' 59 | 60 | out = self.relu(self.da_conv1(x)) 61 | out = self.relu(self.conv1(out)) 62 | out = self.relu(self.da_conv2([out, x[1]])) 63 | out = self.conv2(out) + x[0] 64 | 65 | return out 66 | 67 | 68 | class RDAG(nn.Module): 69 | def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks): 70 | super(RDAG, self).__init__() 71 | self.n_blocks = n_blocks 72 | modules_body = [ 73 | RDAB(conv, n_feat, kernel_size, reduction) \ 74 | for _ in range(n_blocks) 75 | ] 76 | #modules_body.append(conv(n_feat, n_feat, kernel_size)) 77 | 78 | self.body = nn.Sequential(*modules_body) 79 | 80 | def forward(self, x): 81 | ''' 82 | :param x[0]: feature map: B * C * H * W 83 | :param x[1]: degradation representation: B * C 84 | ''' 85 | res = x[0] 86 | for i in range(self.n_blocks): 87 | res = self.body[i]([res, x[1]]) 88 | 89 | 90 | return res 91 | 92 | 93 | class RDAN(nn.Module): 94 | def __init__(self, args, conv=common.default_conv): 95 | super(RDAN, self).__init__() 96 | n_blocks = 27 97 | n_feats = 64 98 | kernel_size = 3 99 | reduction = 8 100 | scale = int(args.scale[0]) 101 | 102 | # RGB mean for DIV2K 103 | rgb_mean = (0.4488, 0.4371, 0.4040) 104 | rgb_std = (1.0, 1.0, 1.0) 105 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 106 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 107 | 108 | # head module 109 | modules_head = [conv(3, n_feats, kernel_size)] 110 | self.head = nn.Sequential(*modules_head) 111 | 112 | # compress 113 | self.compress = nn.Sequential( 114 | nn.Linear(256, 64, bias=False), 115 | nn.LeakyReLU(0.1, True) 116 | ) 117 | 118 | # body 119 | modules_body = [ 120 | RDAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks) 121 | ] 122 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 123 | self.body = nn.Sequential(*modules_body) 124 | 125 | # tail 126 | modules_tail = [common.Upsampler(conv, scale, n_feats, act=False), 127 | conv(n_feats, 3, kernel_size)] 128 | self.tail = nn.Sequential(*modules_tail) 129 | 130 | def forward(self, x, k_v): 131 | k_v = self.compress(k_v) 132 | 133 | # sub mean 134 | x = self.sub_mean(x) 135 | 136 | # head 137 | x = self.head(x) 138 | 139 | # body 140 | res = x 141 | res = self.body[0]([res, k_v]) 142 | res = self.body[-1](res) 143 | res = res + x 144 | 145 | # tail 146 | x = self.tail(res) 147 | 148 | # add mean 149 | x = self.add_mean(x) 150 | 151 | return x 152 | 153 | 154 | class DEN(nn.Module): 155 | def __init__(self,args): 156 | super(DEN, self).__init__() 157 | n_feats = args.n_feats 158 | self.E = nn.Sequential( 159 | nn.Conv2d(n_feats, 64, kernel_size=3, padding=1), 160 | nn.LeakyReLU(0.1, True), 161 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 162 | nn.LeakyReLU(0.1, True), 163 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 164 | nn.LeakyReLU(0.1, True), 165 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 166 | nn.LeakyReLU(0.1, True), 167 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 168 | nn.LeakyReLU(0.1, True), 169 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 170 | nn.LeakyReLU(0.1, True), 171 | nn.AdaptiveAvgPool2d(1), 172 | ) 173 | self.mlp = nn.Sequential( 174 | nn.Linear(256, 256), 175 | nn.LeakyReLU(0.1, True), 176 | nn.Linear(256, 256), 177 | ) 178 | 179 | def forward(self, x): 180 | fea = self.E(x).squeeze(-1).squeeze(-1) 181 | T_fea = [] 182 | for i in range(len(self.mlp)): 183 | fea = self.mlp[i](fea) 184 | if i==2: 185 | T_fea.append(fea) 186 | 187 | return fea,T_fea 188 | 189 | 190 | class BlindSR(nn.Module): 191 | def __init__(self, args): 192 | super(BlindSR, self).__init__() 193 | 194 | # Generator 195 | self.G = RDAN(args) 196 | 197 | self.E = DEN(args) 198 | 199 | 200 | def forward(self, x, deg_repre): 201 | if self.training: 202 | 203 | # degradation-aware represenetion learning 204 | deg_repre, T_fea = self.E(deg_repre) 205 | 206 | # degradation-aware SR 207 | sr = self.G(x, deg_repre) 208 | 209 | return sr, T_fea 210 | else: 211 | # degradation-aware represenetion learning 212 | deg_repre, _ = self.E(deg_repre) 213 | 214 | # degradation-aware SR 215 | sr = self.G(x, deg_repre) 216 | 217 | return sr 218 | -------------------------------------------------------------------------------- /model_ST/blindsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import model.common as common 4 | import torch.nn.functional as F 5 | 6 | 7 | import torch 8 | from torch import nn 9 | import model_ST.common as common 10 | import torch.nn.functional as F 11 | 12 | 13 | def make_model(args): 14 | return BlindSR(args) 15 | 16 | class RDA_conv(nn.Module): 17 | def __init__(self, channels_in, channels_out, kernel_size, reduction): 18 | super(RDA_conv, self).__init__() 19 | self.channels_out = channels_out 20 | self.channels_in = channels_in 21 | self.kernel_size = kernel_size 22 | 23 | self.kernel = nn.Sequential( 24 | nn.Linear(64, 64, bias=False), 25 | nn.LeakyReLU(0.1, True), 26 | nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False) 27 | ) 28 | self.conv = common.default_conv(channels_in, 1, 1) 29 | self.relu = nn.LeakyReLU(0.1, True) 30 | self.sig = nn.Sigmoid() 31 | def forward(self, x): 32 | ''' 33 | :param x[0]: feature map: B * C * H * W 34 | :param x[1]: degradation representation: B * C 35 | ''' 36 | b, c, h, w = x[0].size() 37 | 38 | # branch 1 39 | kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size) 40 | out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2)) 41 | out = out.view(b, -1, h, w) 42 | M = self.sig(self.conv(x[0])) 43 | # branch 2 44 | out = out*M + x[0] 45 | 46 | return out 47 | 48 | 49 | class RDAB(nn.Module): 50 | def __init__(self, conv, n_feat, kernel_size, reduction): 51 | super(RDAB, self).__init__() 52 | 53 | self.da_conv1 = RDA_conv(n_feat, n_feat, kernel_size, reduction) 54 | self.da_conv2 = RDA_conv(n_feat, n_feat, kernel_size, reduction) 55 | self.conv1 = conv(n_feat, n_feat, kernel_size) 56 | self.conv2 = conv(n_feat, n_feat, kernel_size) 57 | 58 | self.relu = nn.LeakyReLU(0.1, True) 59 | 60 | def forward(self, x): 61 | ''' 62 | :param x[0]: feature map: B * C * H * W 63 | :param x[1]: degradation representation: B * C 64 | ''' 65 | 66 | out = self.relu(self.da_conv1(x)) 67 | out = self.relu(self.conv1(out)) 68 | out = self.relu(self.da_conv2([out, x[1]])) 69 | out = self.conv2(out) + x[0] 70 | 71 | return out 72 | 73 | 74 | class RDAG(nn.Module): 75 | def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks): 76 | super(RDAG, self).__init__() 77 | self.n_blocks = n_blocks 78 | modules_body = [ 79 | RDAB(conv, n_feat, kernel_size, reduction) \ 80 | for _ in range(n_blocks) 81 | ] 82 | #modules_body.append(conv(n_feat, n_feat, kernel_size)) 83 | 84 | self.body = nn.Sequential(*modules_body) 85 | 86 | def forward(self, x): 87 | ''' 88 | :param x[0]: feature map: B * C * H * W 89 | :param x[1]: degradation representation: B * C 90 | ''' 91 | res = x[0] 92 | for i in range(self.n_blocks): 93 | res = self.body[i]([res, x[1]]) 94 | 95 | 96 | return res 97 | 98 | 99 | class RDAN(nn.Module): 100 | def __init__(self, args, conv=common.default_conv): 101 | super(RDAN, self).__init__() 102 | n_blocks = 27 103 | n_feats = 64 104 | kernel_size = 3 105 | reduction = 8 106 | scale = int(args.scale[0]) 107 | 108 | # RGB mean for DIV2K 109 | rgb_mean = (0.4488, 0.4371, 0.4040) 110 | rgb_std = (1.0, 1.0, 1.0) 111 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 112 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 113 | 114 | # head module 115 | modules_head = [conv(3, n_feats, kernel_size)] 116 | self.head = nn.Sequential(*modules_head) 117 | 118 | # compress 119 | self.compress = nn.Sequential( 120 | nn.Linear(256, 64, bias=False), 121 | nn.LeakyReLU(0.1, True) 122 | ) 123 | 124 | # body 125 | modules_body = [ 126 | RDAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks) 127 | ] 128 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 129 | self.body = nn.Sequential(*modules_body) 130 | 131 | # tail 132 | modules_tail = [common.Upsampler(conv, scale, n_feats, act=False), 133 | conv(n_feats, 3, kernel_size)] 134 | self.tail = nn.Sequential(*modules_tail) 135 | 136 | def forward(self, x, k_v): 137 | k_v = self.compress(k_v) 138 | 139 | # sub mean 140 | x = self.sub_mean(x) 141 | 142 | # head 143 | x = self.head(x) 144 | 145 | # body 146 | res = x 147 | res = self.body[0]([res, k_v]) 148 | res = self.body[-1](res) 149 | res = res + x 150 | 151 | # tail 152 | x = self.tail(res) 153 | 154 | # add mean 155 | x = self.add_mean(x) 156 | 157 | return x 158 | 159 | 160 | 161 | class DEN(nn.Module): 162 | def __init__(self,args): 163 | super(DEN, self).__init__() 164 | n_feats = args.n_feats 165 | self.E = nn.Sequential( 166 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 167 | nn.LeakyReLU(0.1, True), 168 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 169 | nn.LeakyReLU(0.1, True), 170 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 171 | nn.LeakyReLU(0.1, True), 172 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 173 | nn.LeakyReLU(0.1, True), 174 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 175 | nn.LeakyReLU(0.1, True), 176 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 177 | nn.LeakyReLU(0.1, True), 178 | nn.AdaptiveAvgPool2d(1), 179 | ) 180 | self.mlp = nn.Sequential( 181 | nn.Linear(256, 256), 182 | nn.LeakyReLU(0.1, True), 183 | nn.Linear(256, 256), 184 | ) 185 | 186 | def forward(self, x): 187 | fea = self.E(x).squeeze(-1).squeeze(-1) 188 | S_fea = [] 189 | for i in range(len(self.mlp)): 190 | fea = self.mlp[i](fea) 191 | if i==2: 192 | S_fea.append(fea) 193 | 194 | return fea,S_fea 195 | 196 | 197 | class BlindSR(nn.Module): 198 | def __init__(self, args): 199 | super(BlindSR, self).__init__() 200 | 201 | # Generator 202 | self.G = RDAN(args) 203 | 204 | self.E_st = DEN(args) 205 | 206 | 207 | def forward(self, x): 208 | if self.training: 209 | 210 | # degradation-aware represenetion learning 211 | deg_repre, S_fea = self.E_st(x) 212 | 213 | # degradation-aware SR 214 | sr = self.G(x, deg_repre) 215 | 216 | return sr, S_fea 217 | else: 218 | # degradation-aware represenetion learning 219 | deg_repre, _ = self.E_st(x) 220 | 221 | # degradation-aware SR 222 | sr = self.G(x, deg_repre) 223 | 224 | return sr 225 | -------------------------------------------------------------------------------- /utility2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import scipy.misc as misc 8 | import cv2 9 | import torch 10 | import torch.optim as optim 11 | import torch.optim.lr_scheduler as lrs 12 | 13 | 14 | class AverageMeter(object): 15 | """Computes and stores the average and current value""" 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | class timer(): 33 | def __init__(self): 34 | self.acc = 0 35 | self.tic() 36 | 37 | def tic(self): 38 | self.t0 = time.time() 39 | 40 | def toc(self): 41 | return time.time() - self.t0 42 | 43 | def hold(self): 44 | self.acc += self.toc() 45 | 46 | def release(self): 47 | ret = self.acc 48 | self.acc = 0 49 | 50 | return ret 51 | 52 | def reset(self): 53 | self.acc = 0 54 | 55 | 56 | class checkpoint(): 57 | def __init__(self, args): 58 | self.args = args 59 | self.ok = True 60 | self.log = torch.Tensor() 61 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 62 | 63 | if args.blur_type == 'iso_gaussian': 64 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_iso' 65 | elif args.blur_type == 'aniso_gaussian': 66 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_aniso' 67 | 68 | def _make_dir(path): 69 | if not os.path.exists(path): os.makedirs(path) 70 | 71 | _make_dir(self.dir) 72 | _make_dir(self.dir + '/model') 73 | _make_dir(self.dir + '/results') 74 | 75 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' 76 | self.log_file = open(self.dir + '/log.txt', open_type) 77 | with open(self.dir + '/config.txt', open_type) as f: 78 | f.write(now + '\n\n') 79 | for arg in vars(args): 80 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 81 | f.write('\n') 82 | 83 | def save(self, trainer, epoch, is_best=False): 84 | trainer.model.save(self.dir, epoch, is_best=is_best) 85 | trainer.loss.save(self.dir) 86 | trainer.loss.plot_loss(self.dir, epoch) 87 | 88 | self.plot_psnr(epoch) 89 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) 90 | torch.save( 91 | trainer.optimizer.state_dict(), 92 | os.path.join(self.dir, 'optimizer.pt') 93 | ) 94 | 95 | def add_log(self, log): 96 | self.log = torch.cat([self.log, log]) 97 | 98 | def write_log(self, log, refresh=False): 99 | print(log) 100 | self.log_file.write(log + '\n') 101 | if refresh: 102 | self.log_file.close() 103 | self.log_file = open(self.dir + '/log.txt', 'a') 104 | 105 | def done(self): 106 | self.log_file.close() 107 | 108 | def plot_psnr(self, epoch): 109 | axis = np.linspace(1, epoch, epoch) 110 | label = 'SR on {}'.format(self.args.data_test) 111 | fig = plt.figure() 112 | plt.title(label) 113 | for idx_scale, scale in enumerate(self.args.scale): 114 | plt.plot( 115 | axis, 116 | self.log[:, idx_scale].numpy(), 117 | label='Scale {}'.format(scale) 118 | ) 119 | plt.legend() 120 | plt.xlabel('Epochs') 121 | plt.ylabel('PSNR') 122 | plt.grid(True) 123 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) 124 | plt.close(fig) 125 | 126 | def save_results(self, filename, save_list, scale): 127 | filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale) 128 | 129 | normalized = save_list[0][0].data.mul(255 / self.args.rgb_range) 130 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 131 | misc.imsave('{}{}.png'.format(filename, 'SR'), ndarr) 132 | 133 | 134 | def quantize(img, rgb_range): 135 | pixel_range = 255 / rgb_range 136 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 137 | 138 | 139 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): 140 | diff = (sr - hr).data.div(rgb_range) 141 | if benchmark: 142 | shave = scale 143 | if diff.size(1) > 1: 144 | convert = diff.new(1, 3, 1, 1) 145 | convert[0, 0, 0, 0] = 65.738 146 | convert[0, 1, 0, 0] = 129.057 147 | convert[0, 2, 0, 0] = 25.064 148 | diff.mul_(convert).div_(256) 149 | diff = diff.sum(dim=1, keepdim=True) 150 | else: 151 | shave = scale + 6 152 | import math 153 | shave = math.ceil(shave) 154 | valid = diff[:, :, shave:-shave, shave:-shave] 155 | mse = valid.pow(2).mean() 156 | 157 | return -10 * math.log10(mse) 158 | 159 | 160 | def calc_ssim(img1, img2, scale=2, benchmark=False): 161 | '''calculate SSIM 162 | the same outputs as MATLAB's 163 | img1, img2: [0, 255] 164 | ''' 165 | if benchmark: 166 | border = math.ceil(scale) 167 | else: 168 | border = math.ceil(scale) + 6 169 | 170 | img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy() 171 | img1 = np.transpose(img1, (1, 2, 0)) 172 | img2 = img2.data.squeeze().cpu().numpy() 173 | img2 = np.transpose(img2, (1, 2, 0)) 174 | 175 | img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0 176 | img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0 177 | if not img1.shape == img2.shape: 178 | raise ValueError('Input images must have the same dimensions.') 179 | h, w = img1.shape[:2] 180 | img1_y = img1_y[border:h - border, border:w - border] 181 | img2_y = img2_y[border:h - border, border:w - border] 182 | 183 | if img1_y.ndim == 2: 184 | return ssim(img1_y, img2_y) 185 | elif img1.ndim == 3: 186 | if img1.shape[2] == 3: 187 | ssims = [] 188 | for i in range(3): 189 | ssims.append(ssim(img1, img2)) 190 | return np.array(ssims).mean() 191 | elif img1.shape[2] == 1: 192 | return ssim(np.squeeze(img1), np.squeeze(img2)) 193 | else: 194 | raise ValueError('Wrong input image dimensions.') 195 | 196 | 197 | def ssim(img1, img2): 198 | C1 = (0.01 * 255) ** 2 199 | C2 = (0.03 * 255) ** 2 200 | 201 | img1 = img1.astype(np.float64) 202 | img2 = img2.astype(np.float64) 203 | kernel = cv2.getGaussianKernel(11, 1.5) 204 | window = np.outer(kernel, kernel.transpose()) 205 | 206 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 207 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 208 | mu1_sq = mu1 ** 2 209 | mu2_sq = mu2 ** 2 210 | mu1_mu2 = mu1 * mu2 211 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 212 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 213 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 214 | 215 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 216 | (sigma1_sq + sigma2_sq + C2)) 217 | return ssim_map.mean() 218 | 219 | 220 | def make_optimizer(args, my_model): 221 | trainable = filter(lambda x: x.requires_grad, my_model.parameters()) 222 | 223 | if args.optimizer == 'SGD': 224 | optimizer_function = optim.SGD 225 | kwargs = {'momentum': args.momentum} 226 | elif args.optimizer == 'ADAM': 227 | optimizer_function = optim.Adam 228 | kwargs = { 229 | 'betas': (args.beta1, args.beta2), 230 | 'eps': args.epsilon 231 | } 232 | elif args.optimizer == 'RMSprop': 233 | optimizer_function = optim.RMSprop 234 | kwargs = {'eps': args.epsilon} 235 | 236 | kwargs['weight_decay'] = args.weight_decay 237 | 238 | return optimizer_function(trainable, **kwargs) 239 | 240 | 241 | def make_scheduler(args, my_optimizer): 242 | if args.decay_type == 'step': 243 | scheduler = lrs.StepLR( 244 | my_optimizer, 245 | step_size=args.lr_decay_sr, 246 | gamma=args.gamma_sr, 247 | ) 248 | elif args.decay_type.find('step') >= 0: 249 | milestones = args.decay_type.split('_') 250 | milestones.pop(0) 251 | milestones = list(map(lambda x: int(x), milestones)) 252 | scheduler = lrs.MultiStepLR( 253 | my_optimizer, 254 | milestones=milestones, 255 | gamma=args.gamma 256 | ) 257 | 258 | scheduler.step(args.start_epoch - 1) 259 | 260 | return scheduler 261 | 262 | -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import scipy.misc as misc 8 | import cv2 9 | import torch 10 | import torch.optim as optim 11 | import torch.optim.lr_scheduler as lrs 12 | import imageio 13 | 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | self.avg = self.sum / self.count 31 | 32 | 33 | class timer(): 34 | def __init__(self): 35 | self.acc = 0 36 | self.tic() 37 | 38 | def tic(self): 39 | self.t0 = time.time() 40 | 41 | def toc(self): 42 | return time.time() - self.t0 43 | 44 | def hold(self): 45 | self.acc += self.toc() 46 | 47 | def release(self): 48 | ret = self.acc 49 | self.acc = 0 50 | 51 | return ret 52 | 53 | def reset(self): 54 | self.acc = 0 55 | 56 | 57 | class checkpoint(): 58 | def __init__(self, args): 59 | self.args = args 60 | self.ok = True 61 | self.log = torch.Tensor() 62 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 63 | 64 | if args.blur_type == 'iso_gaussian': 65 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_iso' 66 | elif args.blur_type == 'aniso_gaussian': 67 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_aniso' 68 | 69 | def _make_dir(path): 70 | if not os.path.exists(path): os.makedirs(path) 71 | 72 | _make_dir(self.dir) 73 | _make_dir(self.dir + '/model') 74 | _make_dir(self.dir + '/optimzer') 75 | _make_dir(self.dir + '/results') 76 | 77 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' 78 | self.log_file = open(self.dir + '/log.txt', open_type) 79 | with open(self.dir + '/config.txt', open_type) as f: 80 | f.write(now + '\n\n') 81 | for arg in vars(args): 82 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 83 | f.write('\n') 84 | 85 | def save(self, trainer, epoch, is_best=False): 86 | trainer.model.save(self.dir, epoch, is_best=is_best) 87 | trainer.loss.save(self.dir) 88 | trainer.loss.plot_loss(self.dir, epoch) 89 | 90 | self.plot_psnr(epoch) 91 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) 92 | torch.save( 93 | trainer.optimizer.state_dict(), 94 | os.path.join(self.dir, 'optimizer.pt') 95 | ) 96 | 97 | def add_log(self, log): 98 | self.log = torch.cat([self.log, log]) 99 | 100 | def write_log(self, log, refresh=False): 101 | print(log) 102 | self.log_file.write(log + '\n') 103 | if refresh: 104 | self.log_file.close() 105 | self.log_file = open(self.dir + '/log.txt', 'a') 106 | 107 | def done(self): 108 | self.log_file.close() 109 | 110 | def plot_psnr(self, epoch): 111 | axis = np.linspace(1, epoch, epoch) 112 | label = 'SR on {}'.format(self.args.data_test) 113 | fig = plt.figure() 114 | plt.title(label) 115 | for idx_scale, scale in enumerate(self.args.scale): 116 | plt.plot( 117 | axis, 118 | self.log[:, idx_scale].numpy(), 119 | label='Scale {}'.format(scale) 120 | ) 121 | plt.legend() 122 | plt.xlabel('Epochs') 123 | plt.ylabel('PSNR') 124 | plt.grid(True) 125 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) 126 | plt.close(fig) 127 | 128 | def save_results(self, filename, save_list, scale): 129 | filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale) 130 | 131 | normalized = save_list[0][0].data.mul(255 / self.args.rgb_range) 132 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 133 | imageio.imsave('{}{}.png'.format(filename, 'SR'), ndarr) 134 | 135 | 136 | def quantize(img, rgb_range): 137 | pixel_range = 255 / rgb_range 138 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 139 | 140 | 141 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): 142 | diff = (sr - hr).data.div(rgb_range) 143 | if benchmark: 144 | shave = scale 145 | if diff.size(1) > 1: 146 | convert = diff.new(1, 3, 1, 1) 147 | convert[0, 0, 0, 0] = 65.738 148 | convert[0, 1, 0, 0] = 129.057 149 | convert[0, 2, 0, 0] = 25.064 150 | diff.mul_(convert).div_(256) 151 | diff = diff.sum(dim=1, keepdim=True) 152 | else: 153 | shave = scale + 6 154 | import math 155 | shave = math.ceil(shave) 156 | valid = diff[:, :, shave:-shave, shave:-shave] 157 | mse = valid.pow(2).mean() 158 | 159 | return -10 * math.log10(mse) 160 | 161 | 162 | def calc_ssim(img1, img2, scale=2, benchmark=False): 163 | '''calculate SSIM 164 | the same outputs as MATLAB's 165 | img1, img2: [0, 255] 166 | ''' 167 | if benchmark: 168 | border = math.ceil(scale) 169 | else: 170 | border = math.ceil(scale) + 6 171 | 172 | img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy() 173 | img1 = np.transpose(img1, (1, 2, 0)) 174 | img2 = img2.data.squeeze().cpu().numpy() 175 | img2 = np.transpose(img2, (1, 2, 0)) 176 | 177 | img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0 178 | img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0 179 | if not img1.shape == img2.shape: 180 | raise ValueError('Input images must have the same dimensions.') 181 | h, w = img1.shape[:2] 182 | img1_y = img1_y[border:h - border, border:w - border] 183 | img2_y = img2_y[border:h - border, border:w - border] 184 | 185 | if img1_y.ndim == 2: 186 | return ssim(img1_y, img2_y) 187 | elif img1.ndim == 3: 188 | if img1.shape[2] == 3: 189 | ssims = [] 190 | for i in range(3): 191 | ssims.append(ssim(img1, img2)) 192 | return np.array(ssims).mean() 193 | elif img1.shape[2] == 1: 194 | return ssim(np.squeeze(img1), np.squeeze(img2)) 195 | else: 196 | raise ValueError('Wrong input image dimensions.') 197 | 198 | 199 | def ssim(img1, img2): 200 | C1 = (0.01 * 255) ** 2 201 | C2 = (0.03 * 255) ** 2 202 | 203 | img1 = img1.astype(np.float64) 204 | img2 = img2.astype(np.float64) 205 | kernel = cv2.getGaussianKernel(11, 1.5) 206 | window = np.outer(kernel, kernel.transpose()) 207 | 208 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 209 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 210 | mu1_sq = mu1 ** 2 211 | mu2_sq = mu2 ** 2 212 | mu1_mu2 = mu1 * mu2 213 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 214 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 215 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 216 | 217 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 218 | (sigma1_sq + sigma2_sq + C2)) 219 | return ssim_map.mean() 220 | 221 | 222 | def make_optimizer(args, my_model): 223 | trainable = filter(lambda x: x.requires_grad, my_model.parameters()) 224 | 225 | if args.optimizer == 'SGD': 226 | optimizer_function = optim.SGD 227 | kwargs = {'momentum': args.momentum} 228 | elif args.optimizer == 'ADAM': 229 | optimizer_function = optim.Adam 230 | kwargs = { 231 | 'betas': (args.beta1, args.beta2), 232 | 'eps': args.epsilon 233 | } 234 | elif args.optimizer == 'RMSprop': 235 | optimizer_function = optim.RMSprop 236 | kwargs = {'eps': args.epsilon} 237 | 238 | kwargs['weight_decay'] = args.weight_decay 239 | 240 | return optimizer_function(trainable, **kwargs) 241 | 242 | 243 | def make_scheduler(args, my_optimizer): 244 | if args.decay_type == 'step': 245 | scheduler = lrs.StepLR( 246 | my_optimizer, 247 | step_size=args.lr_decay_sr, 248 | gamma=args.gamma_sr, 249 | ) 250 | elif args.decay_type.find('step') >= 0: 251 | milestones = args.decay_type.split('_') 252 | milestones.pop(0) 253 | milestones = list(map(lambda x: int(x), milestones)) 254 | scheduler = lrs.MultiStepLR( 255 | my_optimizer, 256 | milestones=milestones, 257 | gamma=args.gamma 258 | ) 259 | 260 | scheduler.step(args.start_epoch - 1) 261 | 262 | return scheduler 263 | 264 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=4, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', type=bool, default=False, 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=2, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | parser.add_argument('--pre_train_meta', type=str, default= '.', 21 | help='pre-trained model directory') 22 | parser.add_argument('--pre_train_TA', type=str, default= '.', 23 | help='pre-trained model directory') 24 | parser.add_argument('--pre_train_ST', type=str, default= '.', 25 | help='pre-trained model directory') 26 | parser.add_argument('--is_stage3', action='store_true', 27 | help='set this option to test the model') 28 | parser.add_argument('--temperature', type=float, default=0.15, 29 | help='for konwledge distillation') 30 | 31 | # Meta-Learning 32 | parser.add_argument('--task_iter', type=int, default=20, 33 | help='each task iteration times') 34 | parser.add_argument('--test_iter', type=int, default=20, 35 | help='each task iteration times') 36 | parser.add_argument('--meta_batch_size', type=int, default=8, 37 | help='each task iteration times') 38 | parser.add_argument('--task_batch_size', type=int, default=16, 39 | help='each task iteration times') 40 | parser.add_argument('--lr_task', type=float, default=1e-3, 41 | help='learning rate to train the whole network') 42 | 43 | 44 | # Data specifications 45 | parser.add_argument('--dir_data', type=str, default='D:/LongguangWang/Data', 46 | help='dataset directory') 47 | parser.add_argument('--dir_demo', type=str, default='../test', 48 | help='demo image directory') 49 | parser.add_argument('--data_train', type=str, default='DF2K', 50 | help='train dataset name') 51 | parser.add_argument('--data_test', type=str, default='Set5', 52 | help='test dataset name') 53 | parser.add_argument('--data_range', type=str, default='1-3450/801-810', 54 | help='train/test data range') 55 | parser.add_argument('--ext', type=str, default='sep', 56 | help='dataset file extension') 57 | parser.add_argument('--scale', type=str, default='4', 58 | help='super resolution scale') 59 | parser.add_argument('--patch_size', type=int, default=36, 60 | help='output patch size') 61 | parser.add_argument('--rgb_range', type=int, default=1, 62 | help='maximum value of RGB') 63 | parser.add_argument('--n_colors', type=int, default=3, 64 | help='number of color channels to use') 65 | parser.add_argument('--chop', action='store_true', 66 | help='enable memory-efficient forward') 67 | parser.add_argument('--no_augment', action='store_true', 68 | help='do not use data augmentation') 69 | 70 | # Degradation specifications 71 | parser.add_argument('--blur_kernel', type=int, default=21, 72 | help='size of blur kernels') 73 | parser.add_argument('--blur_type', type=str, default='iso_gaussian', 74 | help='blur types (iso_gaussian | aniso_gaussian)') 75 | parser.add_argument('--mode', type=str, default='bicubic', 76 | help='downsampler (bicubic | s-fold)') 77 | parser.add_argument('--noise', type=float, default=0.0, 78 | help='noise level') 79 | ## isotropic Gaussian blur 80 | parser.add_argument('--sig_min', type=float, default=0.2, 81 | help='minimum sigma of isotropic Gaussian blurs') 82 | parser.add_argument('--sig_max', type=float, default=4.0, 83 | help='maximum sigma of isotropic Gaussian blurs') 84 | parser.add_argument('--sig', type=float, default=4.0, 85 | help='specific sigma of isotropic Gaussian blurs') 86 | ## anisotropic Gaussian blur 87 | parser.add_argument('--lambda_min', type=float, default=0.2, 88 | help='minimum value for the eigenvalue of anisotropic Gaussian blurs') 89 | parser.add_argument('--lambda_max', type=float, default=4.0, 90 | help='maximum value for the eigenvalue of anisotropic Gaussian blurs') 91 | parser.add_argument('--lambda_1', type=float, default=0.2, 92 | help='one eigenvalue of anisotropic Gaussian blurs') 93 | parser.add_argument('--lambda_2', type=float, default=4.0, 94 | help='another eigenvalue of anisotropic Gaussian blurs') 95 | parser.add_argument('--theta', type=float, default=0.0, 96 | help='rotation angle of anisotropic Gaussian blurs [0, 180]') 97 | 98 | 99 | # Model specifications 100 | parser.add_argument('--model', default='blindsr', 101 | help='model name') 102 | parser.add_argument('--pre_train', type=str, default= '.', 103 | help='pre-trained model directory') 104 | parser.add_argument('--extend', type=str, default='.', 105 | help='pre-trained model directory') 106 | parser.add_argument('--shift_mean', default=True, 107 | help='subtract pixel mean from the input') 108 | parser.add_argument('--dilation', action='store_true', 109 | help='use dilated convolution') 110 | parser.add_argument('--precision', type=str, default='single', 111 | choices=('single', 'half'), 112 | help='FP precision for test (single | half)') 113 | parser.add_argument('--n_resblocks', type=int, default=20, 114 | help='number of residual blocks') 115 | parser.add_argument('--n_feats', type=int, default=64, 116 | help='number of feature maps') 117 | parser.add_argument('--res_scale', type=float, default=1, 118 | help='residual scaling') 119 | 120 | # Training specifications 121 | parser.add_argument('--reset', action='store_true', 122 | help='reset the training') 123 | parser.add_argument('--test_every', type=int, default=1000, 124 | help='do test per every N batches') 125 | parser.add_argument('--epochs_encoder', type=int, default=100, 126 | help='number of epochs to train the degradation encoder') 127 | parser.add_argument('--epochs_sr', type=int, default=500, 128 | help='number of epochs to train the whole network') 129 | parser.add_argument('--st_save_epoch', type=int, default=550, 130 | help='number of epochs to save network') 131 | parser.add_argument('--batch_size', type=int, default=32, 132 | help='input batch size for training') 133 | parser.add_argument('--split_batch', type=int, default=1, 134 | help='split the batch into smaller chunks') 135 | parser.add_argument('--self_ensemble', action='store_true', 136 | help='use self-ensemble method for test') 137 | parser.add_argument('--test_only', action='store_true', 138 | help='set this option to test the model') 139 | 140 | # Optimization specifications 141 | parser.add_argument('--lr_encoder', type=float, default=1e-3, 142 | help='learning rate to train the degradation encoder') 143 | parser.add_argument('--lr_sr', type=float, default=1e-4, 144 | help='learning rate to train the whole network') 145 | parser.add_argument('--lr_decay_encoder', type=int, default=60, 146 | help='learning rate decay per N epochs') 147 | parser.add_argument('--lr_decay_sr', type=int, default=125, 148 | help='learning rate decay per N epochs') 149 | parser.add_argument('--decay_type', type=str, default='step', 150 | help='learning rate decay type') 151 | parser.add_argument('--gamma_encoder', type=float, default=0.1, 152 | help='learning rate decay factor for step decay') 153 | parser.add_argument('--gamma_sr', type=float, default=0.5, 154 | help='learning rate decay factor for step decay') 155 | parser.add_argument('--optimizer', default='ADAM', 156 | choices=('SGD', 'ADAM', 'RMSprop'), 157 | help='optimizer to use (SGD | ADAM | RMSprop)') 158 | parser.add_argument('--momentum', type=float, default=0.9, 159 | help='SGD momentum') 160 | parser.add_argument('--beta1', type=float, default=0.9, 161 | help='ADAM beta1') 162 | parser.add_argument('--beta2', type=float, default=0.999, 163 | help='ADAM beta2') 164 | parser.add_argument('--epsilon', type=float, default=1e-8, 165 | help='ADAM epsilon for numerical stability') 166 | parser.add_argument('--weight_decay', type=float, default=0, 167 | help='weight decay') 168 | parser.add_argument('--start_epoch', type=int, default=0, 169 | help='resume from the snapshot, and the start_epoch') 170 | 171 | # Loss specifications 172 | parser.add_argument('--loss', type=str, default='1*L1', 173 | help='loss function configuration') 174 | parser.add_argument('--skip_threshold', type=float, default='1e6', 175 | help='skipping batch that has large error') 176 | 177 | # Log specifications 178 | parser.add_argument('--save', type=str, default='blindsr', 179 | help='file name to save') 180 | parser.add_argument('--load', type=str, default='.', 181 | help='file name to load') 182 | parser.add_argument('--resume', type=int, default=0, 183 | help='resume from specific checkpoint') 184 | parser.add_argument('--save_models', action='store_true', 185 | help='save all intermediate models') 186 | parser.add_argument('--print_every', type=int, default=200, 187 | help='how many batches to wait before logging training status') 188 | parser.add_argument('--save_results', default=False, 189 | help='save output results') 190 | 191 | args = parser.parse_args() 192 | template.set_template(args) 193 | 194 | args.scale = list(map(lambda x: float(x), args.scale.split('+'))) 195 | 196 | args.data_train = args.data_train.split('+') 197 | args.data_test = args.data_test.split('+') 198 | 199 | 200 | for arg in vars(args): 201 | if vars(args)[arg] == 'True': 202 | vars(args)[arg] = True 203 | elif vars(args)[arg] == 'False': 204 | vars(args)[arg] = False -------------------------------------------------------------------------------- /trainer_stage3.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utility2 3 | import torch 4 | from decimal import Decimal 5 | import torch.nn.functional as F 6 | from utils import util 7 | from utils import util2 8 | from collections import OrderedDict 9 | import random 10 | import numpy as np 11 | import torch.nn as nn 12 | 13 | 14 | class Trainer(): 15 | def __init__(self, args, loader, model_meta_copy, model_TA, my_loss, ckp): 16 | self.test_res_psnr = [] 17 | self.test_res_ssim = [] 18 | self.args = args 19 | self.scale = args.scale 20 | self.loss1= nn.L1Loss() 21 | self.ckp = ckp 22 | self.loader_train = loader.loader_train 23 | self.loader_test = loader.loader_test 24 | self.model_meta_copy = model_meta_copy 25 | 26 | self.model_meta_copy_state = self.model_meta_copy.state_dict() 27 | self.model_TA = model_TA 28 | self.loss = my_loss 29 | self.meta_batch_size = args.meta_batch_size 30 | self.task_batch_size = args.task_batch_size 31 | self.task_iter = args.task_iter 32 | self.test_iter = args.test_iter 33 | self.optimizer = utility2.make_optimizer(args, self.model_TA) 34 | self.scheduler = utility2.make_scheduler(args, self.optimizer) 35 | self.lr_task = args.lr_task 36 | self.taskiter_weight=[1/self.task_iter]*self.task_iter 37 | 38 | 39 | if self.args.load != '.': 40 | self.optimizer.load_state_dict( 41 | torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 42 | ) 43 | for _ in range(len(ckp.log)): self.scheduler.step() 44 | 45 | def train(self): 46 | self.scheduler.step() 47 | self.loss.step() 48 | epoch = self.scheduler.last_epoch + 1 49 | 50 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr)) 51 | for param_group in self.optimizer.param_groups: 52 | param_group['lr'] = lr 53 | 54 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))) 55 | self.loss.start_log() 56 | self.model_TA.train() 57 | 58 | degrade = util2.SRMDPreprocessing( 59 | self.scale[0], 60 | kernel_size=self.args.blur_kernel, 61 | blur_type=self.args.blur_type, 62 | sig_min=self.args.sig_min, 63 | sig_max=self.args.sig_max, 64 | lambda_min=self.args.lambda_min, 65 | lambda_max=self.args.lambda_max, 66 | noise=self.args.noise 67 | ) 68 | 69 | timer = utility2.timer() 70 | 71 | for batch, (hr, _) in enumerate(self.loader_train): 72 | hr = hr.cuda() # b, c, h, w 73 | timer.tic() 74 | loss_all = 0 75 | 76 | lr_blur, hr_blur = degrade(hr) # b, c, h, w 77 | self.model_meta_copy.get_model().load_state_dict(self.model_meta_copy_state) 78 | 79 | for iter in range(self.test_iter): 80 | if iter == 0: 81 | learning_rate = 1e-2 82 | elif iter < 5: 83 | learning_rate = 5e-3 84 | else: 85 | learning_rate = 1e-3 86 | sr,_ = self.model_meta_copy(lr_blur, hr_blur) 87 | loss = self.loss1(sr, hr) 88 | self.model_meta_copy.zero_grad() 89 | loss.backward() 90 | 91 | for param in self.model_meta_copy.parameters(): 92 | if param.requires_grad: 93 | param.data.sub_(param.grad.data * learning_rate) 94 | 95 | _, deg_repre = self.model_meta_copy(lr_blur, hr_blur) 96 | sr,_ = self.model_TA(lr_blur, deg_repre.detach()) 97 | loss_all += self.loss1(sr,hr) 98 | self.optimizer.zero_grad() 99 | loss_all.backward() 100 | self.optimizer.step() 101 | 102 | # Remove the hooks before next training phase 103 | timer.hold() 104 | 105 | if (batch + 1) % self.args.print_every == 0: 106 | self.ckp.write_log( 107 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t' 108 | 'Loss [SR loss:{:.3f}]\t' 109 | 'Time [{:.1f}s]'.format( 110 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 111 | loss_all.item(), 112 | timer.release(), 113 | )) 114 | 115 | self.loss.end_log(len(self.loader_train)) 116 | 117 | # save model 118 | if epoch > self.args.st_save_epoch or (epoch %20 ==0): 119 | target = self.model_TA.get_model() 120 | model_dict = target.state_dict() 121 | torch.save( 122 | model_dict, 123 | os.path.join(self.ckp.dir, 'model', 'model_TA_{}.pt'.format(epoch)) 124 | ) 125 | 126 | optimzer_dict = self.optimizer.state_dict() 127 | torch.save( 128 | optimzer_dict, 129 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_TA_{}.pt'.format(epoch)) 130 | ) 131 | 132 | target = self.model_TA.get_model() 133 | model_dict = target.state_dict() 134 | torch.save( 135 | model_dict, 136 | os.path.join(self.ckp.dir, 'model', 'model_TA_last.pt') 137 | ) 138 | optimzer_dict = self.optimizer.state_dict() 139 | torch.save( 140 | optimzer_dict, 141 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_TA_last.pt') 142 | ) 143 | 144 | 145 | def test(self): 146 | self.ckp.write_log('\nEvaluation:') 147 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 148 | 149 | timer_test = utility2.timer() 150 | t = 0 151 | self.model_TA.eval() 152 | 153 | degrade = util2.SRMDPreprocessing( 154 | self.scale[0], 155 | kernel_size=self.args.blur_kernel, 156 | blur_type=self.args.blur_type, 157 | sig=self.args.sig, 158 | lambda_1=self.args.lambda_1, 159 | lambda_2=self.args.lambda_2, 160 | theta=self.args.theta, 161 | noise=self.args.noise 162 | ) 163 | 164 | for idx_scale, d in enumerate(self.loader_test): 165 | for idx_scale, scale in enumerate(self.scale): 166 | d.dataset.set_scale(idx_scale) 167 | eval_psnr = 0 168 | eval_ssim = 0 169 | iter_psnr = [0] * (self.test_iter + 1) 170 | iter_ssim = [0] * (self.test_iter + 1) 171 | for idx_img, (hr, filename) in enumerate(d): 172 | self.model_meta_copy.get_model().load_state_dict(self.model_meta_copy_state) 173 | hr = hr.cuda() # b, c, h, w 174 | hr = self.crop_border(hr, scale) 175 | # inference 176 | # timer_test.tic() 177 | lr_blur, hr_blur = degrade(hr, random=False) 178 | 179 | 180 | 181 | for iter in range(self.test_iter): 182 | if iter == 0: 183 | learning_rate = 1e-2 184 | elif iter < 5: 185 | learning_rate = 5e-3 186 | else: 187 | learning_rate = 1e-3 188 | 189 | sr,_ = self.model_meta_copy(lr_blur, hr_blur) 190 | loss = self.loss1(sr, hr) 191 | self.model_meta_copy.zero_grad() 192 | loss.backward() 193 | 194 | for param in self.model_meta_copy.parameters(): 195 | if param.requires_grad: 196 | param.data.sub_(param.grad.data * learning_rate) 197 | 198 | _, deg_repre = self.model_meta_copy(lr_blur, hr_blur) 199 | #print(lr_blur.shape,deg_repre.shape) 200 | torch.cuda.synchronize() 201 | timer_test.tic() 202 | sr = self.model_TA(lr_blur, deg_repre.detach()) 203 | torch.cuda.synchronize() 204 | timer_test.hold() 205 | t0 = timer_test.release() 206 | 207 | print("idx:", idx_img, ",time consuming:", t0) 208 | t += t0 209 | 210 | hr = utility2.quantize(hr, self.args.rgb_range) 211 | sr = utility2.quantize(sr, self.args.rgb_range) 212 | 213 | iter_psnr[-1] += utility2.calc_psnr( 214 | sr, hr, scale, self.args.rgb_range, 215 | benchmark=True 216 | ) 217 | iter_ssim[-1] += utility2.calc_ssim( 218 | (sr * 255).round().clamp(0, 255), (hr * 255).round().clamp(0, 255), scale, 219 | benchmark=True 220 | ) 221 | 222 | timer_test.hold() 223 | 224 | # save results 225 | if self.args.save_results: 226 | save_list = [sr] 227 | filename = filename[0] 228 | self.ckp.save_results(filename, save_list, scale) 229 | 230 | # for t in range(self.test_iter+1): 231 | # print("iter:",t,",PSNR:",iter_psnr[t]/ len(self.loader_test),",SSIM:",iter_ssim[t]/ len(self.loader_test)) 232 | 233 | eval_psnr = iter_psnr[-1] 234 | eval_ssim = iter_ssim[-1] 235 | 236 | if len(self.test_res_psnr)>10: 237 | self.test_res_psnr.pop(0) 238 | self.test_res_ssim.pop(0) 239 | self.test_res_psnr.append(eval_psnr / len(self.loader_test)) 240 | self.test_res_ssim.append(eval_ssim / len(self.loader_test)) 241 | print("All time consuming:", t / len(self.loader_test)) 242 | 243 | self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test) 244 | self.ckp.write_log( 245 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format( 246 | self.args.resume, 247 | self.args.data_test, 248 | scale, 249 | eval_psnr / len(self.loader_test), 250 | eval_ssim / len(self.loader_test), 251 | np.mean(self.test_res_psnr), 252 | np.mean(self.test_res_ssim) 253 | )) 254 | 255 | def crop_border(self, img_hr, scale): 256 | b, c, h, w = img_hr.size() 257 | 258 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)] 259 | 260 | return img_hr 261 | 262 | def get_patch(self, img, patch_size=48, scale=4): 263 | tb, tc, th, tw = img.shape ## HR image 264 | tp = round(scale * patch_size) 265 | tx = random.randrange(0, (tw - tp)) 266 | ty = random.randrange(0, (th - tp)) 267 | 268 | return img[:,:,ty:ty + tp, tx:tx + tp] 269 | 270 | def crop(self, img_hr): 271 | # b, c, h, w = img_hr.size() 272 | tp_hr = [] 273 | for i in range(self.task_batch_size): 274 | tp_hr.append(self.get_patch(img_hr,self.args.patch_size,self.scale[0])) 275 | tp_hr = torch.cat(tp_hr,dim=0) 276 | return tp_hr 277 | 278 | def terminate(self): 279 | if self.args.test_only: 280 | self.test() 281 | return True 282 | else: 283 | epoch = self.scheduler.last_epoch + 1 284 | return epoch >= self.args.epochs_sr -------------------------------------------------------------------------------- /trainer_stage4.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utility 3 | import torch 4 | from decimal import Decimal 5 | import torch.nn.functional as F 6 | from utils import util2 7 | import numpy as np 8 | import torch.nn as nn 9 | 10 | 11 | class Trainer(): 12 | def __init__(self, args, loader, model_ST, model_TA,model_meta_copy, my_loss, ckp): 13 | self.is_first =True 14 | self.args = args 15 | self.scale = args.scale 16 | self.test_res_psnr = [] 17 | self.test_res_ssim = [] 18 | self.ckp = ckp 19 | self.loader_train = loader.loader_train 20 | self.loader_test = loader.loader_test 21 | self.model_ST = model_ST 22 | self.model_TA = model_TA 23 | self.model_meta_copy = model_meta_copy 24 | for k,v in self.model_meta_copy.named_parameters(): 25 | if "tail" in k: 26 | v.requires_grad=False 27 | self.model_meta_copy_state = self.model_meta_copy.state_dict() 28 | self.model_Est = torch.nn.DataParallel(self.model_ST.get_model().E_st, range(self.args.n_GPUs)) 29 | self.model_Eta = torch.nn.DataParallel(self.model_TA.get_model().E, range(self.args.n_GPUs)) 30 | self.loss1 = nn.L1Loss() 31 | self.loss = my_loss 32 | self.optimizer = utility.make_optimizer(args, self.model_ST) 33 | self.scheduler = utility.make_scheduler(args, self.optimizer) 34 | self.task_iter = args.task_iter 35 | self.test_iter = args.test_iter 36 | self.lr_task = args.lr_task 37 | self.temperature = args.temperature 38 | 39 | if self.args.load != '.': 40 | self.optimizer.load_state_dict( 41 | torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 42 | ) 43 | for _ in range(len(ckp.log)): self.scheduler.step() 44 | 45 | def train(self): 46 | self.scheduler.step() 47 | self.loss.step() 48 | epoch = self.scheduler.last_epoch + 1 49 | 50 | if epoch <= self.args.epochs_encoder: 51 | lr = self.args.lr_encoder * (self.args.gamma_encoder ** (epoch // self.args.lr_decay_encoder)) 52 | for param_group in self.optimizer.param_groups: 53 | param_group['lr'] = lr 54 | else: 55 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr)) 56 | for param_group in self.optimizer.param_groups: 57 | param_group['lr'] = lr 58 | 59 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))) 60 | self.loss.start_log() 61 | self.model_ST.train() 62 | 63 | degrade = util2.SRMDPreprocessing( 64 | self.scale[0], 65 | kernel_size=self.args.blur_kernel, 66 | blur_type=self.args.blur_type, 67 | sig_min=self.args.sig_min, 68 | sig_max=self.args.sig_max, 69 | lambda_min=self.args.lambda_min, 70 | lambda_max=self.args.lambda_max, 71 | noise=self.args.noise 72 | ) 73 | 74 | timer = utility.timer() 75 | losses_sr, losses_distill_distribution, losses_distill_abs = utility.AverageMeter(),utility.AverageMeter(),utility.AverageMeter() 76 | 77 | 78 | for batch, (hr, _) in enumerate(self.loader_train): 79 | hr = hr.cuda() # b, c, h, w 80 | # b,c,h,w = hr.shape 81 | timer.tic() 82 | # loss_all = 0 83 | lr_blur, hr_blur = degrade(hr) # b, c, h, w 84 | self.model_meta_copy.get_model().load_state_dict(self.model_meta_copy_state) 85 | 86 | for iter in range(self.test_iter): 87 | if iter == 0: 88 | learning_rate = 1e-2 89 | elif iter < 5: 90 | learning_rate = 5e-3 91 | else: 92 | learning_rate = 1e-3 93 | sr, _ = self.model_meta_copy(lr_blur, hr_blur) 94 | loss = self.loss1(sr, hr) 95 | self.model_meta_copy.zero_grad() 96 | loss.backward() 97 | 98 | for param in self.model_meta_copy.parameters(): 99 | if param.requires_grad: 100 | param.data.sub_(param.grad.data * learning_rate) 101 | 102 | _, deg_repre = self.model_meta_copy(lr_blur, hr_blur) 103 | _, T_fea = self.model_Eta(deg_repre) 104 | 105 | loss_distill_dis = 0 106 | loss_distill_abs = 0 107 | 108 | if epoch <= self.args.epochs_encoder: 109 | if self.is_first: 110 | _, S_fea = self.model_ST(lr_blur) 111 | self.is_first = False 112 | else: 113 | _, S_fea = self.model_Est(lr_blur) 114 | for i in range(len(T_fea)): 115 | student_distance = F.log_softmax(S_fea[i] / self.temperature, dim=1) 116 | teacher_distance = F.softmax(T_fea[i]/ self.temperature, dim=1) 117 | loss_distill_dis += F.kl_div( 118 | student_distance, teacher_distance, reduction='batchmean') 119 | loss_distill_abs += nn.L1Loss()(S_fea[i], T_fea[i]) 120 | losses_distill_distribution.update(loss_distill_dis.item()) 121 | losses_distill_abs.update(loss_distill_abs.item()) 122 | loss = loss_distill_dis + 0.1*loss_distill_abs 123 | else: 124 | sr, S_fea = self.model_ST(lr_blur) 125 | loss_SR = self.loss(sr, hr) 126 | for i in range(len(T_fea)): 127 | student_distance = F.log_softmax(S_fea[i] / self.temperature, dim=1) 128 | teacher_distance = F.softmax(T_fea[i] / self.temperature, dim=1) 129 | loss_distill_dis += F.kl_div( 130 | student_distance, teacher_distance, reduction='batchmean') 131 | loss_distill_abs += nn.L1Loss()(S_fea[i], T_fea[i]) 132 | losses_distill_distribution.update(loss_distill_dis.item()) 133 | losses_distill_abs.update(loss_distill_abs.item()) 134 | loss = loss_SR + loss_distill_dis + 0.1 * loss_distill_abs 135 | losses_sr.update(loss_SR.item()) 136 | 137 | # backward 138 | self.optimizer.zero_grad() 139 | loss.backward() 140 | self.optimizer.step() 141 | timer.hold() 142 | 143 | if epoch <= self.args.epochs_encoder: 144 | if (batch + 1) % self.args.print_every == 0: 145 | self.ckp.write_log( 146 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t' 147 | 'Loss [ distill_dis loss:{:.3f}, distill_abs loss:{:.3f}]\t' 148 | 'Time [{:.1f}s]'.format( 149 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 150 | losses_distill_distribution.avg, losses_distill_abs.avg, 151 | timer.release(), 152 | )) 153 | else: 154 | if (batch + 1) % self.args.print_every == 0: 155 | self.ckp.write_log( 156 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t' 157 | 'Loss [SR loss:{:.3f}, distill_dis loss:{:.3f}, distill_abs loss:{:.3f}]\t' 158 | 'Time [{:.1f}s]'.format( 159 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 160 | losses_sr.avg, losses_distill_distribution.avg, losses_distill_abs.avg, 161 | timer.release(), 162 | )) 163 | 164 | self.loss.end_log(len(self.loader_train)) 165 | 166 | # save model 167 | if epoch > self.args.st_save_epoch or epoch%30==0: 168 | target = self.model_ST.get_model() 169 | model_dict = target.state_dict() 170 | torch.save( 171 | model_dict, 172 | os.path.join(self.ckp.dir, 'model', 'model_ST_{}.pt'.format(epoch)) 173 | ) 174 | 175 | target = self.model_ST.get_model() 176 | model_dict = target.state_dict() 177 | torch.save( 178 | model_dict, 179 | os.path.join(self.ckp.dir, 'model', 'model_ST_last.pt') 180 | ) 181 | 182 | def test(self): 183 | self.ckp.write_log('\nEvaluation:') 184 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 185 | self.model_ST.eval() 186 | t = 0 187 | timer_test = utility.timer() 188 | 189 | degrade = util2.SRMDPreprocessing( 190 | self.scale[0], 191 | kernel_size=self.args.blur_kernel, 192 | blur_type=self.args.blur_type, 193 | sig=self.args.sig, 194 | lambda_1=self.args.lambda_1, 195 | lambda_2=self.args.lambda_2, 196 | theta=self.args.theta, 197 | noise=self.args.noise 198 | ) 199 | 200 | with torch.no_grad(): 201 | for idx_scale, d in enumerate(self.loader_test): 202 | for idx_scale, scale in enumerate(self.scale): 203 | d.dataset.set_scale(idx_scale) 204 | eval_psnr = 0 205 | eval_ssim = 0 206 | for idx_img, (hr, filename) in enumerate(d): 207 | hr = hr.cuda() # b, c, h, w 208 | hr = self.crop_border(hr, scale) 209 | lr_blur, hr_blur = degrade(hr, random=False) # b, c, h, w 210 | 211 | # inference 212 | torch.cuda.synchronize() 213 | timer_test.tic() 214 | sr = self.model_ST(lr_blur) 215 | torch.cuda.synchronize() 216 | timer_test.hold() 217 | t0 = timer_test.release() 218 | print("idx:", idx_img, ",time consuming:", t0) 219 | t += t0 220 | 221 | sr = utility.quantize(sr, self.args.rgb_range) 222 | hr = utility.quantize(hr, self.args.rgb_range) 223 | 224 | # metrics 225 | eval_psnr += utility.calc_psnr( 226 | sr, hr, scale, self.args.rgb_range, 227 | benchmark=True 228 | ) 229 | eval_ssim += utility.calc_ssim( 230 | (sr * 255).round().clamp(0, 255), (hr * 255).round().clamp(0, 255), scale, 231 | benchmark=True 232 | ) 233 | 234 | # save results 235 | if self.args.save_results: 236 | save_list = [sr] 237 | filename = filename[0] 238 | self.ckp.save_results(filename, save_list, scale) 239 | 240 | if len(self.test_res_psnr) > 10: 241 | self.test_res_psnr.pop(0) 242 | self.test_res_ssim.pop(0) 243 | self.test_res_psnr.append(eval_psnr / len(d)) 244 | self.test_res_ssim.append(eval_ssim / len(d)) 245 | print("All time consuming:", t / len(d)) 246 | 247 | self.ckp.log[-1, idx_scale] = eval_psnr / len(d) 248 | self.ckp.write_log( 249 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format( 250 | self.args.resume, 251 | self.args.data_test, 252 | scale, 253 | eval_psnr / len(d), 254 | eval_ssim / len(d), 255 | np.mean(self.test_res_psnr), 256 | np.mean(self.test_res_ssim) 257 | )) 258 | 259 | def crop_border(self, img_hr, scale): 260 | b, c, h, w = img_hr.size() 261 | 262 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)] 263 | 264 | return img_hr 265 | 266 | def terminate(self): 267 | if self.args.test_only: 268 | self.test() 269 | return True 270 | else: 271 | epoch = self.scheduler.last_epoch + 1 272 | return epoch >= self.args.epochs_sr 273 | 274 | -------------------------------------------------------------------------------- /utils/util2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import utility 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def cal_sigma(sig_x, sig_y, radians): 10 | sig_x = sig_x.view(-1, 1, 1) 11 | sig_y = sig_y.view(-1, 1, 1) 12 | radians = radians.view(-1, 1, 1) 13 | 14 | D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1) 15 | U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2), 16 | torch.cat([radians.sin(), radians.cos()], 2)], 1) 17 | sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2))) 18 | 19 | return sigma 20 | 21 | 22 | def anisotropic_gaussian_kernel(batch, kernel_size, covar): 23 | ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2 24 | 25 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 26 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 27 | xy = torch.stack([xx, yy], -1).view(batch, -1, 2) 28 | 29 | inverse_sigma = torch.inverse(covar) 30 | kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size) 31 | 32 | return kernel / kernel.sum([1, 2], keepdim=True) 33 | 34 | 35 | def isotropic_gaussian_kernel(batch, kernel_size, sigma): 36 | ax = torch.arange(kernel_size).float().cuda() - kernel_size//2 37 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 38 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 39 | kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2)) 40 | 41 | return kernel / kernel.sum([1,2], keepdim=True) 42 | 43 | 44 | def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0): 45 | theta = torch.rand(batch).cuda() / 180 * math.pi 46 | lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min 47 | lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min 48 | 49 | covar = cal_sigma(lambda_1, lambda_2, theta) 50 | kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar) 51 | return kernel 52 | 53 | 54 | def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0): 55 | theta = torch.ones(1).cuda() * theta / 180 * math.pi 56 | lambda_1 = torch.ones(1).cuda() * lambda_1 57 | lambda_2 = torch.ones(1).cuda() * lambda_2 58 | 59 | covar = cal_sigma(lambda_1, lambda_2, theta) 60 | kernel = anisotropic_gaussian_kernel(1, kernel_size, covar) 61 | return kernel 62 | 63 | 64 | def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0): 65 | x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min 66 | k = isotropic_gaussian_kernel(batch, kernel_size, x) 67 | return k 68 | 69 | 70 | def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0): 71 | x = torch.ones(1).cuda() * sig 72 | k = isotropic_gaussian_kernel(1, kernel_size, x) 73 | return k 74 | 75 | 76 | def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0): 77 | if blur_type == 'iso_gaussian': 78 | return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max) 79 | elif blur_type == 'aniso_gaussian': 80 | return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max) 81 | 82 | 83 | def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0): 84 | if blur_type == 'iso_gaussian': 85 | return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig) 86 | elif blur_type == 'aniso_gaussian': 87 | return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta) 88 | 89 | 90 | # implementation of matlab bicubic interpolation in pytorch 91 | class bicubic(nn.Module): 92 | def __init__(self): 93 | super(bicubic, self).__init__() 94 | 95 | def cubic(self, x): 96 | absx = torch.abs(x) 97 | absx2 = torch.abs(x) * torch.abs(x) 98 | absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x) 99 | 100 | condition1 = (absx <= 1).to(torch.float32) 101 | condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32) 102 | 103 | f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2 104 | return f 105 | 106 | def contribute(self, in_size, out_size, scale): 107 | kernel_width = 4 108 | if scale < 1: 109 | kernel_width = 4 / scale 110 | x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda() 111 | x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda() 112 | 113 | u0 = x0 / scale + 0.5 * (1 - 1 / scale) 114 | u1 = x1 / scale + 0.5 * (1 - 1 / scale) 115 | 116 | left0 = torch.floor(u0 - kernel_width / 2) 117 | left1 = torch.floor(u1 - kernel_width / 2) 118 | 119 | P = np.ceil(kernel_width) + 2 120 | 121 | indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda() 122 | indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda() 123 | 124 | mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0) 125 | mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0) 126 | 127 | if scale < 1: 128 | weight0 = scale * self.cubic(mid0 * scale) 129 | weight1 = scale * self.cubic(mid1 * scale) 130 | else: 131 | weight0 = self.cubic(mid0) 132 | weight1 = self.cubic(mid1) 133 | 134 | weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2)) 135 | weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2)) 136 | 137 | indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0) 138 | indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0) 139 | 140 | kill0 = torch.eq(weight0, 0)[0][0] 141 | kill1 = torch.eq(weight1, 0)[0][0] 142 | 143 | weight0 = weight0[:, :, kill0 == 0] 144 | weight1 = weight1[:, :, kill1 == 0] 145 | 146 | indice0 = indice0[:, :, kill0 == 0] 147 | indice1 = indice1[:, :, kill1 == 0] 148 | 149 | return weight0, weight1, indice0, indice1 150 | 151 | def forward(self, input, scale=1/4): 152 | b, c, h, w = input.shape 153 | 154 | weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale) 155 | weight0 = weight0[0] 156 | weight1 = weight1[0] 157 | 158 | indice0 = indice0[0].long() 159 | indice1 = indice1[0].long() 160 | 161 | out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4)) 162 | out = (torch.sum(out, dim=3)) 163 | A = out.permute(0, 1, 3, 2) 164 | 165 | out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4)) 166 | out = out.sum(3).permute(0, 1, 3, 2) 167 | 168 | return out 169 | 170 | 171 | class Gaussin_Kernel(object): 172 | def __init__(self, kernel_size=21, blur_type='iso_gaussian', 173 | sig=2.6, sig_min=0.2, sig_max=4.0, 174 | lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0): 175 | self.kernel_size = kernel_size 176 | self.blur_type = blur_type 177 | 178 | self.sig = sig 179 | self.sig_min = sig_min 180 | self.sig_max = sig_max 181 | 182 | self.lambda_1 = lambda_1 183 | self.lambda_2 = lambda_2 184 | self.theta = theta 185 | self.lambda_min = lambda_min 186 | self.lambda_max = lambda_max 187 | 188 | def __call__(self, batch, random): 189 | # random kernel 190 | if random == True: 191 | return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type, 192 | sig_min=self.sig_min, sig_max=self.sig_max, 193 | lambda_min=self.lambda_min, lambda_max=self.lambda_max) 194 | 195 | # stable kernel 196 | else: 197 | return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type, 198 | sig=self.sig, 199 | lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta) 200 | 201 | class BatchBlur(nn.Module): 202 | def __init__(self, kernel_size=21): 203 | super(BatchBlur, self).__init__() 204 | self.kernel_size = kernel_size 205 | if kernel_size % 2 == 1: 206 | self.pad = nn.ReflectionPad2d(kernel_size//2) 207 | else: 208 | self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1)) 209 | 210 | def forward(self, input, kernel): 211 | B, C, H, W = input.size() 212 | input_pad = self.pad(input) 213 | H_p, W_p = input_pad.size()[-2:] 214 | 215 | if len(kernel.size()) == 2: 216 | input_CBHW = input_pad.view((C * B, 1, H_p, W_p)) 217 | kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size)) 218 | 219 | return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W)) 220 | else: 221 | input_CBHW = input_pad.view((1, C * B, H_p, W_p)) 222 | kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size)) 223 | kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size)) 224 | 225 | return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W)) 226 | 227 | 228 | class SRMDPreprocessing(object): 229 | def __init__(self, 230 | scale, 231 | mode='bicubic', 232 | kernel_size=21, 233 | blur_type='iso_gaussian', 234 | sig=2.6, 235 | sig_min=0.2, 236 | sig_max=4.0, 237 | lambda_1=0.2, 238 | lambda_2=4.0, 239 | theta=0, 240 | lambda_min=0.2, 241 | lambda_max=4.0, 242 | noise=0.0, 243 | rgb_range=1 244 | ): 245 | ''' 246 | # sig, sig_min and sig_max are used for isotropic Gaussian blurs 247 | During training phase (random=True): 248 | the width of the blur kernel is randomly selected from [sig_min, sig_max] 249 | During test phase (random=False): 250 | the width of the blur kernel is set to sig 251 | 252 | # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs 253 | During training phase (random=True): 254 | the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max] 255 | the angle value is randomly selected from [0, pi] 256 | During test phase (random=False): 257 | the eigenvalues of the covariance are set to lambda_1 and lambda_2 258 | the angle value is set to theta 259 | ''' 260 | self.kernel_size = kernel_size 261 | self.scale = scale 262 | self.mode = mode 263 | self.noise = noise / (255 / rgb_range) 264 | self.rgb_range = rgb_range 265 | 266 | self.gen_kernel = Gaussin_Kernel( 267 | kernel_size=kernel_size, blur_type=blur_type, 268 | sig=sig, sig_min=sig_min, sig_max=sig_max, 269 | lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max 270 | ) 271 | self.blur = BatchBlur(kernel_size=kernel_size) 272 | self.bicubic = bicubic() 273 | 274 | def __call__(self, hr_tensor, random=True): 275 | with torch.no_grad(): 276 | # only downsampling 277 | if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0: 278 | B, C, H, W = hr_tensor.size() 279 | hr_blured = hr_tensor.view(-1, C, H, W) 280 | b_kernels = None 281 | 282 | # gaussian blur + downsampling 283 | else: 284 | B, C, H, W = hr_tensor.size() 285 | b_kernels = self.gen_kernel(1, random) # B degradations 286 | b_kernels = b_kernels.expand(B,-1,-1) 287 | 288 | # blur 289 | hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels) 290 | hr_blured = hr_blured.view(-1, C, H, W) # B, C, H, W 291 | 292 | # downsampling 293 | if self.mode == 'bicubic': 294 | lr_blured = self.bicubic(hr_blured, scale=1/self.scale) 295 | elif self.mode == 's-fold': 296 | lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0] 297 | 298 | 299 | # add noise 300 | if self.noise > 0: 301 | _, C, H_lr, W_lr = lr_blured.size() 302 | noise_level = torch.rand(1, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise 303 | noise = torch.randn_like(lr_blured).view(-1, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr) 304 | lr_blured.add_(noise) 305 | 306 | hr_blured = self.bicubic(lr_blured, scale= self.scale) 307 | 308 | 309 | lr_blured = utility.quantize(lr_blured, self.rgb_range) 310 | hr_blured = utility.quantize(hr_blured, self.rgb_range) 311 | 312 | 313 | 314 | return lr_blured.view(B, C, H//int(self.scale), W//int(self.scale)), hr_blured.view(B, C, H, W) 315 | 316 | -------------------------------------------------------------------------------- /trainer_stage2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import utility2 3 | import torch 4 | from decimal import Decimal 5 | import torch.nn.functional as F 6 | from utils import util 7 | from utils import util2 8 | from collections import OrderedDict 9 | import random 10 | import numpy as np 11 | import torch.nn as nn 12 | 13 | 14 | class Trainer(): 15 | def __init__(self, args, loader, model_meta, model_meta_copy, my_loss, ckp): 16 | self.test_res_psnr = [] 17 | self.test_res_ssim = [] 18 | self.args = args 19 | self.scale = args.scale 20 | self.loss1= nn.L1Loss() 21 | self.ckp = ckp 22 | self.loader_train = loader.loader_train 23 | self.loader_test = loader.loader_test 24 | self.model_meta = model_meta 25 | self.model_meta_copy = model_meta_copy 26 | 27 | self.loss = my_loss 28 | self.meta_batch_size = args.meta_batch_size 29 | self.task_batch_size = args.task_batch_size 30 | self.task_iter = args.task_iter 31 | self.test_iter = args.test_iter 32 | self.optimizer = utility2.make_optimizer(args, self.model_meta) 33 | self.scheduler = utility2.make_scheduler(args, self.optimizer) 34 | self.lr_task = args.lr_task 35 | self.taskiter_weight=[1/self.task_iter]*self.task_iter 36 | 37 | if self.args.load != '.': 38 | self.optimizer.load_state_dict( 39 | torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 40 | ) 41 | for _ in range(len(ckp.log)): self.scheduler.step() 42 | 43 | def train(self): 44 | self.scheduler.step() 45 | self.loss.step() 46 | epoch = self.scheduler.last_epoch + 1 47 | 48 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr)) 49 | for param_group in self.optimizer.param_groups: 50 | param_group['lr'] = lr 51 | 52 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))) 53 | self.loss.start_log() 54 | self.model_meta.train() 55 | 56 | degrade = util2.SRMDPreprocessing( 57 | self.scale[0], 58 | kernel_size=self.args.blur_kernel, 59 | blur_type=self.args.blur_type, 60 | sig_min=self.args.sig_min, 61 | sig_max=self.args.sig_max, 62 | lambda_min=self.args.lambda_min, 63 | lambda_max=self.args.lambda_max, 64 | noise=self.args.noise 65 | ) 66 | 67 | timer = utility2.timer() 68 | losses_sr = utility2.AverageMeter() 69 | 70 | for batch, (hr, _,) in enumerate(self.loader_train): 71 | hr = hr.cuda() # b, c, h, w 72 | b,c,h,w =hr.shape 73 | query_label = hr[:b//2,:,:,:] 74 | support_label = hr[b//2:,:,:,:] 75 | weights = OrderedDict( 76 | (name, param) for (name, param) in self.model_meta.model.named_parameters() if param.requires_grad ) 77 | 78 | meta_grads = {name: 0 for (name, param) in weights.items()} 79 | timer.tic() 80 | loss_train = [0]*self.task_iter 81 | for i in range(self.meta_batch_size): 82 | lr_blur, hr_blur = degrade(hr) # b, c, h, w 83 | query_lr = lr_blur[:b // 2, :, :, :] 84 | support_lr = lr_blur[b // 2:, :, :, :] 85 | query_hr = hr_blur[:b//2,:,:,:] 86 | support_hr = hr_blur[b // 2:, :, :, :] 87 | sr = self.model_meta(query_lr,query_hr, weights) 88 | 89 | loss = self.loss1(sr,query_label) 90 | loss_train[0] += loss.item()/self.meta_batch_size 91 | grads = torch.autograd.grad(loss, weights.values()) 92 | fast_weights = OrderedDict( 93 | (name, param - self.lr_task * grad) for ((name, param), grad) in zip(weights.items(), grads)) 94 | 95 | 96 | for j in range(1,self.task_iter): 97 | sr = self.model_meta(query_lr,query_hr, fast_weights) 98 | loss = self.loss1(sr,query_label) 99 | loss_train[j] += loss.item() / self.meta_batch_size 100 | grads = torch.autograd.grad(loss, fast_weights.values()) 101 | fast_weights = OrderedDict( 102 | (name, param - self.lr_task * grad) for ((name, param), grad) in zip(fast_weights.items(), grads)) 103 | #***************support*********************** 104 | sr = self.model_meta(support_lr,support_hr, fast_weights) 105 | loss = self.loss1(sr, support_label) 106 | grads = torch.autograd.grad(loss, weights.values()) 107 | for ((name, _), g) in zip(meta_grads.items(), grads): 108 | meta_grads[name] = meta_grads[name] + g/self.meta_batch_size 109 | 110 | 111 | hooks = [] 112 | for (k, v) in self.model_meta.model.named_parameters(): 113 | def get_closure(): 114 | key = k 115 | def replace_grad(grad): 116 | return meta_grads[key] 117 | return replace_grad 118 | if v.requires_grad: 119 | hooks.append(v.register_hook(get_closure())) 120 | 121 | sr = self.model_meta(query_lr,query_hr, weights) 122 | loss = self.loss(sr, query_label) 123 | self.optimizer.zero_grad() 124 | loss.backward() 125 | self.optimizer.step() 126 | 127 | # Remove the hooks before next training phase 128 | for h in hooks: 129 | h.remove() 130 | 131 | timer.hold() 132 | 133 | if (batch + 1) % self.args.print_every == 0: 134 | self.ckp.write_log( 135 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t' 136 | 'Loss [SR0 loss:{:.3f},SR1 loss:{:.3f},SR2 loss:{:.3f},SR3 loss:{:.3f},SR4 loss:{:.3f}]\t' 137 | 'Time [{:.1f}s]'.format( 138 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), 139 | loss_train[0],loss_train[1],loss_train[2],loss_train[3],loss_train[4], 140 | timer.release(), 141 | )) 142 | 143 | self.loss.end_log(len(self.loader_train)) 144 | 145 | # save model 146 | if epoch > self.args.st_save_epoch or (epoch %10 ==0): 147 | target = self.model_meta.get_model() 148 | model_dict = target.state_dict() 149 | torch.save( 150 | model_dict, 151 | os.path.join(self.ckp.dir, 'model', 'model_meta_{}.pt'.format(epoch)) 152 | ) 153 | 154 | optimzer_dict = self.optimizer.state_dict() 155 | torch.save( 156 | optimzer_dict, 157 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_meta_{}.pt'.format(epoch)) 158 | ) 159 | 160 | target = self.model_meta.get_model() 161 | model_dict = target.state_dict() 162 | torch.save( 163 | model_dict, 164 | os.path.join(self.ckp.dir, 'model', 'model_meta_last.pt') 165 | ) 166 | optimzer_dict = self.optimizer.state_dict() 167 | torch.save( 168 | optimzer_dict, 169 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_meta_last.pt') 170 | ) 171 | 172 | 173 | def test(self): 174 | self.ckp.write_log('\nEvaluation:') 175 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 176 | 177 | timer_test = utility2.timer() 178 | self.model_meta.eval() 179 | degrade = util2.SRMDPreprocessing( 180 | self.scale[0], 181 | kernel_size=self.args.blur_kernel, 182 | blur_type=self.args.blur_type, 183 | sig=self.args.sig, 184 | lambda_1=self.args.lambda_1, 185 | lambda_2=self.args.lambda_2, 186 | theta=self.args.theta, 187 | noise=self.args.noise 188 | ) 189 | 190 | for idx_scale, d in enumerate(self.loader_test): 191 | for idx_scale, scale in enumerate(self.scale): 192 | d.dataset.set_scale(idx_scale) 193 | eval_psnr = 0 194 | eval_ssim = 0 195 | iter_psnr = [0]*(self.test_iter+1) 196 | iter_ssim = [0]*(self.test_iter+1) 197 | for idx_img, (hr, filename) in enumerate(d): 198 | self.model_meta_copy.model.load_state_dict(self.model_meta.state_dict()) 199 | 200 | hr = hr.cuda() # b, c, h, w 201 | hr = self.crop_border(hr, scale) 202 | # inference 203 | timer_test.tic() 204 | hr_batch = hr 205 | hr = utility2.quantize(hr, self.args.rgb_range) 206 | lr_blur, hr_blur = degrade(hr_batch, random=False) 207 | 208 | for iter in range(self.test_iter): 209 | if iter == 0: 210 | learning_rate = 1e-2 211 | elif iter < 5: 212 | learning_rate = 5e-3 213 | else: 214 | learning_rate = 1e-3 215 | 216 | sr = self.model_meta_copy(lr_blur,hr_blur) 217 | 218 | loss = self.loss1(sr, hr_batch) 219 | 220 | sr = utility2.quantize(sr, self.args.rgb_range) 221 | 222 | iter_psnr[iter] += utility2.calc_psnr( 223 | sr, hr, scale, self.args.rgb_range, 224 | benchmark=True 225 | ) 226 | iter_ssim[iter] += utility2.calc_ssim( 227 | (sr*255).round().clamp(0,255), (hr*255).round().clamp(0,255),scale, 228 | benchmark=True 229 | ) 230 | 231 | self.model_meta_copy.zero_grad() 232 | loss.backward() 233 | 234 | for param in self.model_meta_copy.parameters(): 235 | if param.requires_grad: 236 | param.data.sub_(param.grad.data * learning_rate) 237 | 238 | sr = self.model_meta_copy(lr_blur,hr_blur) 239 | 240 | 241 | timer_test.hold() 242 | 243 | sr = utility2.quantize(sr, self.args.rgb_range) 244 | 245 | 246 | # metrics 247 | iter_psnr[-1] += utility2.calc_psnr( 248 | sr, hr, scale, self.args.rgb_range, 249 | benchmark=True 250 | ) 251 | iter_ssim[-1] += utility2.calc_ssim( 252 | (sr*255).round().clamp(0,255), (hr*255).round().clamp(0,255),scale, 253 | benchmark=True 254 | ) 255 | 256 | # save results 257 | if self.args.save_results: 258 | save_list = [sr] 259 | filename = filename[0] 260 | self.ckp.save_results(filename, save_list, scale) 261 | 262 | 263 | for t in range(self.test_iter+1): 264 | print("iter:",t,",PSNR:",iter_psnr[t]/ len(self.loader_test),",SSIM:",iter_ssim[t]/ len(self.loader_test)) 265 | 266 | eval_psnr = iter_psnr[-1] 267 | eval_ssim = iter_ssim[-1] 268 | 269 | if len(self.test_res_psnr)>10: 270 | self.test_res_psnr.pop(0) 271 | self.test_res_ssim.pop(0) 272 | self.test_res_psnr.append(eval_psnr / len(self.loader_test)) 273 | self.test_res_ssim.append(eval_ssim / len(self.loader_test)) 274 | 275 | self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test) 276 | 277 | self.ckp.write_log( 278 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format( 279 | self.args.resume, 280 | self.args.data_test, 281 | scale, 282 | eval_psnr / len(self.loader_test), 283 | eval_ssim / len(self.loader_test), 284 | np.mean(self.test_res_psnr), 285 | np.mean(self.test_res_ssim) 286 | )) 287 | 288 | def crop_border(self, img_hr, scale): 289 | b, c, h, w = img_hr.size() 290 | 291 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)] 292 | 293 | return img_hr 294 | 295 | def get_patch(self, img, patch_size=48, scale=4): 296 | tb, tc, th, tw = img.shape ## HR image 297 | tp = round(scale * patch_size) 298 | tx = random.randrange(0, (tw - tp)) 299 | ty = random.randrange(0, (th - tp)) 300 | 301 | return img[:,:,ty:ty + tp, tx:tx + tp] 302 | 303 | def crop(self, img_hr): 304 | tp_hr = [] 305 | for i in range(self.task_batch_size): 306 | tp_hr.append(self.get_patch(img_hr,self.args.patch_size,self.scale[0])) 307 | tp_hr = torch.cat(tp_hr,dim=0) 308 | return tp_hr 309 | 310 | def terminate(self): 311 | if self.args.test_only: 312 | self.test() 313 | return True 314 | else: 315 | epoch = self.scheduler.last_epoch + 1 316 | return epoch >= self.args.epochs_sr -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import utility 7 | 8 | 9 | def cal_sigma(sig_x, sig_y, radians): 10 | sig_x = sig_x.view(-1, 1, 1) 11 | sig_y = sig_y.view(-1, 1, 1) 12 | radians = radians.view(-1, 1, 1) 13 | 14 | D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1) 15 | U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2), 16 | torch.cat([radians.sin(), radians.cos()], 2)], 1) 17 | sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2))) 18 | 19 | return sigma 20 | 21 | 22 | def anisotropic_gaussian_kernel(batch, kernel_size, covar): 23 | ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2 24 | 25 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 26 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 27 | xy = torch.stack([xx, yy], -1).view(batch, -1, 2) 28 | 29 | inverse_sigma = torch.inverse(covar) 30 | kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size) 31 | 32 | return kernel / kernel.sum([1, 2], keepdim=True) 33 | 34 | 35 | def isotropic_gaussian_kernel(batch, kernel_size, sigma): 36 | ax = torch.arange(kernel_size).float().cuda() - kernel_size//2 37 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 38 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) 39 | kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2)) 40 | 41 | return kernel / kernel.sum([1,2], keepdim=True) 42 | 43 | 44 | def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0): 45 | theta = torch.rand(batch).cuda() / 180 * math.pi 46 | lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min 47 | lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min 48 | 49 | covar = cal_sigma(lambda_1, lambda_2, theta) 50 | kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar) 51 | return kernel 52 | 53 | 54 | def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0): 55 | theta = torch.ones(1).cuda() * theta / 180 * math.pi 56 | lambda_1 = torch.ones(1).cuda() * lambda_1 57 | lambda_2 = torch.ones(1).cuda() * lambda_2 58 | 59 | covar = cal_sigma(lambda_1, lambda_2, theta) 60 | kernel = anisotropic_gaussian_kernel(1, kernel_size, covar) 61 | return kernel 62 | 63 | 64 | def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0): 65 | x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min 66 | k = isotropic_gaussian_kernel(batch, kernel_size, x) 67 | return k 68 | 69 | 70 | def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0): 71 | x = torch.ones(1).cuda() * sig 72 | k = isotropic_gaussian_kernel(1, kernel_size, x) 73 | return k 74 | 75 | 76 | def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0): 77 | if blur_type == 'iso_gaussian': 78 | return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max) 79 | elif blur_type == 'aniso_gaussian': 80 | return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max) 81 | 82 | 83 | def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0): 84 | if blur_type == 'iso_gaussian': 85 | return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig) 86 | elif blur_type == 'aniso_gaussian': 87 | return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta) 88 | 89 | 90 | # implementation of matlab bicubic interpolation in pytorch 91 | class bicubic(nn.Module): 92 | def __init__(self): 93 | super(bicubic, self).__init__() 94 | 95 | def cubic(self, x): 96 | absx = torch.abs(x) 97 | absx2 = torch.abs(x) * torch.abs(x) 98 | absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x) 99 | 100 | condition1 = (absx <= 1).to(torch.float32) 101 | condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32) 102 | 103 | f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2 104 | return f 105 | 106 | def contribute(self, in_size, out_size, scale): 107 | kernel_width = 4 108 | if scale < 1: 109 | kernel_width = 4 / scale 110 | x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda() 111 | x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda() 112 | 113 | u0 = x0 / scale + 0.5 * (1 - 1 / scale) 114 | u1 = x1 / scale + 0.5 * (1 - 1 / scale) 115 | 116 | left0 = torch.floor(u0 - kernel_width / 2) 117 | left1 = torch.floor(u1 - kernel_width / 2) 118 | 119 | P = np.ceil(kernel_width) + 2 120 | 121 | indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda() 122 | indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda() 123 | 124 | mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0) 125 | mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0) 126 | 127 | if scale < 1: 128 | weight0 = scale * self.cubic(mid0 * scale) 129 | weight1 = scale * self.cubic(mid1 * scale) 130 | else: 131 | weight0 = self.cubic(mid0) 132 | weight1 = self.cubic(mid1) 133 | 134 | weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2)) 135 | weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2)) 136 | 137 | indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0) 138 | indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0) 139 | 140 | kill0 = torch.eq(weight0, 0)[0][0] 141 | kill1 = torch.eq(weight1, 0)[0][0] 142 | 143 | weight0 = weight0[:, :, kill0 == 0] 144 | weight1 = weight1[:, :, kill1 == 0] 145 | 146 | indice0 = indice0[:, :, kill0 == 0] 147 | indice1 = indice1[:, :, kill1 == 0] 148 | 149 | return weight0, weight1, indice0, indice1 150 | 151 | def forward(self, input, scale=1/4): 152 | b, c, h, w = input.shape 153 | 154 | weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale) 155 | weight0 = weight0[0] 156 | weight1 = weight1[0] 157 | 158 | indice0 = indice0[0].long() 159 | indice1 = indice1[0].long() 160 | 161 | out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4)) 162 | out = (torch.sum(out, dim=3)) 163 | A = out.permute(0, 1, 3, 2) 164 | 165 | out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4)) 166 | out = out.sum(3).permute(0, 1, 3, 2) 167 | 168 | return out 169 | 170 | 171 | class Gaussin_Kernel(object): 172 | def __init__(self, kernel_size=21, blur_type='iso_gaussian', 173 | sig=2.6, sig_min=0.2, sig_max=4.0, 174 | lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0): 175 | self.kernel_size = kernel_size 176 | self.blur_type = blur_type 177 | 178 | self.sig = sig 179 | self.sig_min = sig_min 180 | self.sig_max = sig_max 181 | 182 | self.lambda_1 = lambda_1 183 | self.lambda_2 = lambda_2 184 | self.theta = theta 185 | self.lambda_min = lambda_min 186 | self.lambda_max = lambda_max 187 | 188 | def __call__(self, batch, random): 189 | # random kernel 190 | if random == True: 191 | return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type, 192 | sig_min=self.sig_min, sig_max=self.sig_max, 193 | lambda_min=self.lambda_min, lambda_max=self.lambda_max) 194 | 195 | # stable kernel 196 | else: 197 | return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type, 198 | sig=self.sig, 199 | lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta) 200 | 201 | class BatchBlur(nn.Module): 202 | def __init__(self, kernel_size=21): 203 | super(BatchBlur, self).__init__() 204 | self.kernel_size = kernel_size 205 | if kernel_size % 2 == 1: 206 | self.pad = nn.ReflectionPad2d(kernel_size//2) 207 | else: 208 | self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1)) 209 | 210 | def forward(self, input, kernel): 211 | B, C, H, W = input.size() 212 | input_pad = self.pad(input) 213 | H_p, W_p = input_pad.size()[-2:] 214 | 215 | if len(kernel.size()) == 2: 216 | input_CBHW = input_pad.view((C * B, 1, H_p, W_p)) 217 | kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size)) 218 | 219 | return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W)) 220 | else: 221 | input_CBHW = input_pad.view((1, C * B, H_p, W_p)) 222 | kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size)) 223 | kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size)) 224 | 225 | return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W)) 226 | 227 | 228 | class SRMDPreprocessing(object): 229 | def __init__(self, 230 | scale, 231 | mode='bicubic', 232 | kernel_size=21, 233 | blur_type='iso_gaussian', 234 | sig=2.6, 235 | sig_min=0.2, 236 | sig_max=4.0, 237 | lambda_1=0.2, 238 | lambda_2=4.0, 239 | theta=0, 240 | lambda_min=0.2, 241 | lambda_max=4.0, 242 | noise=0.0, 243 | rgb_range=1 244 | ): 245 | ''' 246 | # sig, sig_min and sig_max are used for isotropic Gaussian blurs 247 | During training phase (random=True): 248 | the width of the blur kernel is randomly selected from [sig_min, sig_max] 249 | During test phase (random=False): 250 | the width of the blur kernel is set to sig 251 | 252 | # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs 253 | During training phase (random=True): 254 | the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max] 255 | the angle value is randomly selected from [0, pi] 256 | During test phase (random=False): 257 | the eigenvalues of the covariance are set to lambda_1 and lambda_2 258 | the angle value is set to theta 259 | ''' 260 | self.kernel_size = kernel_size 261 | self.scale = scale 262 | self.mode = mode 263 | self.noise = noise / (255/rgb_range) 264 | self.rgb_range=rgb_range 265 | 266 | self.gen_kernel = Gaussin_Kernel( 267 | kernel_size=kernel_size, blur_type=blur_type, 268 | sig=sig, sig_min=sig_min, sig_max=sig_max, 269 | lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max 270 | ) 271 | self.blur = BatchBlur(kernel_size=kernel_size) 272 | self.bicubic = bicubic() 273 | 274 | def __call__(self, hr_tensor, random=True): 275 | with torch.no_grad(): 276 | # only downsampling 277 | if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0: 278 | B, C, H, W = hr_tensor.size() 279 | hr_blured = hr_tensor.view(-1, C, H, W) 280 | b_kernels = None 281 | 282 | # gaussian blur + downsampling 283 | else: 284 | B, C, H, W = hr_tensor.size() 285 | b_kernels = self.gen_kernel(B, random) # B degradations 286 | 287 | # blur 288 | hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels) 289 | hr_blured = hr_blured.view(-1, C, H, W) # B, C, H, W 290 | 291 | # downsampling 292 | if self.mode == 'bicubic': 293 | lr_blured = self.bicubic(hr_blured, scale=1/self.scale) 294 | elif self.mode == 's-fold': 295 | lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0] 296 | 297 | 298 | # add noise 299 | noise_level = None 300 | if self.noise > 0: 301 | _, C, H_lr, W_lr = lr_blured.size() 302 | noise_level = torch.rand(B, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise 303 | noise = torch.randn_like(lr_blured).view(-1, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr) 304 | lr_blured.add_(noise) 305 | 306 | lr_blured = utility.quantize(lr_blured, self.rgb_range) 307 | 308 | if isinstance(noise_level, float): 309 | noise_level = torch.Tensor([noise_level]).view(1,1,1,1).to(lr_blured.device) 310 | return lr_blured.view(B, C, H//int(self.scale), W//int(self.scale)), [b_kernels,noise_level] 311 | 312 | 313 | class BicubicPreprocessing(object): 314 | def __init__(self, 315 | scale, 316 | rgb_range=1 317 | ): 318 | ''' 319 | # sig, sig_min and sig_max are used for isotropic Gaussian blurs 320 | During training phase (random=True): 321 | the width of the blur kernel is randomly selected from [sig_min, sig_max] 322 | During test phase (random=False): 323 | the width of the blur kernel is set to sig 324 | 325 | # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs 326 | During training phase (random=True): 327 | the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max] 328 | the angle value is randomly selected from [0, pi] 329 | During test phase (random=False): 330 | the eigenvalues of the covariance are set to lambda_1 and lambda_2 331 | the angle value is set to theta 332 | ''' 333 | self.scale = scale 334 | self.rgb_range=rgb_range 335 | self.bicubic = bicubic() 336 | 337 | def __call__(self, hr_tensor, random=True): 338 | with torch.no_grad(): 339 | B, C, H, W = hr_tensor.size() 340 | 341 | lr = self.bicubic(hr_tensor, scale=1/self.scale) 342 | lr_bic = self.bicubic(lr, scale=self.scale) 343 | 344 | 345 | lr = utility.quantize(lr, self.rgb_range) 346 | lr_bic = utility.quantize(lr_bic, self.rgb_range) 347 | 348 | return lr.view(B, C, H//int(self.scale), W//int(self.scale)),lr_bic.view(B, C, H, W) 349 | --------------------------------------------------------------------------------