├── utils
├── __init__.py
├── util2.py
└── util.py
├── figs
├── process.jpg
├── aniso-quan.jpg
└── iso-quan.jpg
├── data
├── benchmark.py
├── df2k.py
├── common.py
├── __init__.py
└── multiscalesrdata.py
├── main_stage1.py
├── main_stage1.sh
├── main_stage2.py
├── main_stage3.py
├── main_stage4.py
├── test_iso_stage4.sh
├── test_anisoAnoise_stage4.sh
├── loss
├── vgg.py
├── discriminator.py
├── adversarial.py
└── __init__.py
├── template.py
├── model_ST
├── common.py
├── __init__.py
└── blindsr.py
├── model_TA
├── common.py
├── __init__.py
└── blindsr.py
├── model
├── blindsr.py
├── common.py
└── __init__.py
├── model_meta_stage3
├── blindsr.py
├── common.py
└── __init__.py
├── main_stage2.sh
├── main_stage3.sh
├── main_stage4.sh
├── model_meta
├── common.py
├── blindsr.py
└── __init__.py
├── README.md
├── dataloader.py
├── trainer_stage1.py
├── utility2.py
├── utility.py
├── option.py
├── trainer_stage3.py
├── trainer_stage4.py
└── trainer_stage2.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figs/process.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zj-BinXia/MRDA/HEAD/figs/process.jpg
--------------------------------------------------------------------------------
/figs/aniso-quan.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zj-BinXia/MRDA/HEAD/figs/aniso-quan.jpg
--------------------------------------------------------------------------------
/figs/iso-quan.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zj-BinXia/MRDA/HEAD/figs/iso-quan.jpg
--------------------------------------------------------------------------------
/data/benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import common
3 | from data import multiscalesrdata as srdata
4 |
5 |
6 | class Benchmark(srdata.SRData):
7 | def __init__(self, args, name='', train=True):
8 | super(Benchmark, self).__init__(
9 | args, name=name, train=train, benchmark=True
10 | )
11 |
12 | def _set_filesystem(self, dir_data):
13 | self.apath = os.path.join(dir_data,'benchmark', self.name)
14 | self.dir_hr = os.path.join(self.apath, 'HR')
15 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
16 | self.ext = ('.png','.png')
17 | print(self.dir_hr)
18 | print(self.dir_lr)
19 |
--------------------------------------------------------------------------------
/main_stage1.py:
--------------------------------------------------------------------------------
1 | from option import args
2 | import torch
3 | import utility
4 | import data
5 | import model
6 | import loss
7 | from trainer_stage1 import Trainer
8 |
9 |
10 | if __name__ == '__main__':
11 | torch.manual_seed(args.seed)
12 | checkpoint = utility.checkpoint(args)
13 | if checkpoint.ok:
14 | loader = data.Data(args)
15 | model = model.Model(args, checkpoint)
16 | loss = loss.Loss(args, checkpoint) if not args.test_only else None
17 | t = Trainer(args, loader, model, loss, checkpoint)
18 | while not t.terminate():
19 | t.train()
20 | t.test()
21 |
22 |
23 | checkpoint.done()
24 |
--------------------------------------------------------------------------------
/main_stage1.sh:
--------------------------------------------------------------------------------
1 | ## noise-free degradations with isotropic Gaussian blurs or anisotropic + noise
2 | # using oracle degradation training
3 | CUDA_VISIBLE_DEVICES=0 python3 main_stage1.py --dir_data='/root/datasets' \
4 | --model='blindsr' \
5 | --scale='4' \
6 | --n_GPUs=1 \
7 | --epochs_encoder 0 \
8 | --epochs_sr 600 \
9 | --data_test Set14 \
10 | --st_save_epoch 590 \
11 | --n_feats 128 \
12 | --batch_size 64 \
13 | --patch_size 64 \
14 | --data_train DF2K \
15 | --save stage1
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/main_stage2.py:
--------------------------------------------------------------------------------
1 | from option import args
2 | import torch
3 | import utility
4 | import data
5 | import model_meta
6 | import model
7 | import loss
8 | from trainer_stage2 import Trainer
9 |
10 |
11 | if __name__ == '__main__':
12 | torch.manual_seed(args.seed)
13 | checkpoint = utility.checkpoint(args)
14 | if checkpoint.ok:
15 | loader = data.Data(args)
16 | model_meta = model_meta.Model(args, checkpoint)
17 | model_meta_copy = model.Model(args, checkpoint)
18 | loss = loss.Loss(args, checkpoint) if not args.test_only else None
19 | t = Trainer(args, loader, model_meta, model_meta_copy, loss, checkpoint)
20 | while not t.terminate():
21 | t.train()
22 | t.test()
23 |
24 |
25 | checkpoint.done()
--------------------------------------------------------------------------------
/main_stage3.py:
--------------------------------------------------------------------------------
1 | from option import args
2 | import torch
3 | import utility
4 | import data
5 | import model_TA
6 | import model_meta_stage3
7 | import loss
8 | from trainer_stage3 import Trainer
9 |
10 |
11 | def count_param(model):
12 | param_count = 0
13 | for param in model.parameters():
14 | param_count += param.view(-1).size()[0]
15 | return param_count
16 |
17 | if __name__ == '__main__':
18 | torch.manual_seed(args.seed)
19 | checkpoint = utility.checkpoint(args)
20 | if checkpoint.ok:
21 | loader = data.Data(args)
22 | model_TA = model_TA.Model(args, checkpoint)
23 | print(count_param(model_TA))
24 | model_meta_copy = model_meta_stage3.Model(args, checkpoint)
25 | loss = loss.Loss(args, checkpoint) if not args.test_only else None
26 | t = Trainer(args, loader, model_meta_copy, model_TA, loss, checkpoint)
27 | while not t.terminate():
28 | t.train()
29 | t.test()
30 |
31 |
32 | checkpoint.done()
--------------------------------------------------------------------------------
/main_stage4.py:
--------------------------------------------------------------------------------
1 | from option import args
2 | import torch
3 | import utility
4 | import data
5 | import model_TA
6 | import model_ST
7 | import model_meta_stage3
8 | import loss
9 | from trainer_stage4 import Trainer
10 |
11 |
12 | # def count_param(model):
13 | # param_count = 0
14 | # for param in model.parameters():
15 | # param_count += param.view(-1).size()[0]
16 | # return param_count
17 |
18 | if __name__ == '__main__':
19 | torch.manual_seed(args.seed)
20 | checkpoint = utility.checkpoint(args)
21 | if checkpoint.ok:
22 | loader = data.Data(args)
23 | model_TA = model_TA.Model(args, checkpoint)
24 | model_meta_copy = model_meta_stage3.Model(args, checkpoint)
25 | model_ST = model_ST.Model(args, checkpoint)
26 | loss = loss.Loss(args, checkpoint) if not args.test_only else None
27 | t = Trainer(args, loader, model_ST, model_TA, model_meta_copy, loss, checkpoint)
28 | while not t.terminate():
29 | t.train()
30 | t.test()
31 |
32 |
33 |
34 | checkpoint.done()
--------------------------------------------------------------------------------
/data/df2k.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import multiscalesrdata
3 |
4 |
5 | class DF2K(multiscalesrdata.SRData):
6 | def __init__(self, args, name='DF2K', train=True, benchmark=False):
7 | data_range = [r.split('-') for r in args.data_range.split('/')]
8 | if train:
9 | data_range = data_range[0]
10 | else:
11 | if args.test_only and len(data_range) == 1:
12 | data_range = data_range[0]
13 | else:
14 | data_range = data_range[1]
15 |
16 | self.begin, self.end = list(map(lambda x: int(x), data_range))
17 | super(DF2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
18 |
19 | def _scan(self):
20 | names_hr = super(DF2K, self)._scan()
21 | names_hr = names_hr[self.begin - 1:self.end]
22 |
23 | return names_hr
24 |
25 | def _set_filesystem(self, dir_data):
26 | super(DF2K, self)._set_filesystem(dir_data)
27 | self.dir_hr = os.path.join(self.apath, 'HR')
28 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
29 |
30 |
--------------------------------------------------------------------------------
/test_iso_stage4.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python3 main_stage4.py --dir_data='/mnt/bn/xiabinsr/datasets' \
2 | --model='blindsr' \
3 | --scale='4' \
4 | --blur_type='iso_gaussian' \
5 | --noise=0.0 \
6 | --sig_min=0.2 \
7 | --sig_max=4.0 \
8 | --sig 3.6 \
9 | --save ours_iso_36\
10 | --n_GPUs=1 \
11 | --epochs_encoder 100 \
12 | --epochs_sr 500 \
13 | --data_test Set14 \
14 | --st_save_epoch 480 \
15 | --n_feats 128 \
16 | --patch_size 48 \
17 | --task_iter 5 \
18 | --test_iter 5 \
19 | --meta_batch_size 5 \
20 | --batch_size 16 \
21 | --lr_sr 1e-4 \
22 | --lr_task 1e-2 \
23 | --lr_encoder 1e-4 \
24 | --pre_train_ST="./experiment/iso_ST.pt" \
25 | --pre_train="./experiment/iso_ST.pt" \
26 | --save_results False \
27 | --resume 0 \
28 | --test_only
--------------------------------------------------------------------------------
/test_anisoAnoise_stage4.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=1 python3 main_stage4.py --dir_data='/mnt/bn/xiabinsr/datasets' \
2 | --model='blindsr' \
3 | --scale='4' \
4 | --blur_type='aniso_gaussian' \
5 | --noise=10.0 \
6 | --lambda_1 3.5 \
7 | --lambda_2 1.5 \
8 | --theta 20 \
9 | --save ours_iso_36\
10 | --n_GPUs=1 \
11 | --epochs_encoder 100 \
12 | --epochs_sr 500 \
13 | --data_test Set14 \
14 | --st_save_epoch 480 \
15 | --n_feats 128 \
16 | --patch_size 48 \
17 | --task_iter 5 \
18 | --test_iter 5 \
19 | --meta_batch_size 5 \
20 | --batch_size 16 \
21 | --lr_sr 1e-4 \
22 | --lr_task 1e-2 \
23 | --lr_encoder 1e-4 \
24 | --pre_train_ST="./experiment/aniso_noise_ST.pt" \
25 | --pre_train="./experiment/aniso_noise_ST.pt" \
26 | --save_results False\
27 | --resume 0 \
28 | --test_only
--------------------------------------------------------------------------------
/loss/vgg.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 | from torch.autograd import Variable
8 |
9 | class VGG(nn.Module):
10 | def __init__(self, conv_index, rgb_range=1):
11 | super(VGG, self).__init__()
12 | vgg_features = models.vgg19(pretrained=True).features
13 | modules = [m for m in vgg_features]
14 | if conv_index == '22':
15 | self.vgg = nn.Sequential(*modules[:8])
16 | elif conv_index == '54':
17 | self.vgg = nn.Sequential(*modules[:35])
18 |
19 | vgg_mean = (0.485, 0.456, 0.406)
20 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
21 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
22 | self.vgg.requires_grad = False
23 |
24 | def forward(self, sr, hr):
25 | def _forward(x):
26 | x = self.sub_mean(x)
27 | x = self.vgg(x)
28 | return x
29 |
30 | vgg_sr = _forward(sr)
31 | with torch.no_grad():
32 | vgg_hr = _forward(hr.detach())
33 |
34 | loss = F.mse_loss(vgg_sr, vgg_hr)
35 |
36 | return loss
37 |
--------------------------------------------------------------------------------
/template.py:
--------------------------------------------------------------------------------
1 | def set_template(args):
2 | # Set the templates here
3 | if args.template.find('jpeg') >= 0:
4 | args.data_train = 'DIV2K_jpeg'
5 | args.data_test = 'DIV2K_jpeg'
6 | args.epochs = 200
7 | args.lr_decay = 100
8 |
9 | if args.template.find('EDSR_paper') >= 0:
10 | args.model = 'EDSR'
11 | args.n_resblocks = 32
12 | args.n_feats = 256
13 | args.res_scale = 0.1
14 |
15 | if args.template.find('MDSR') >= 0:
16 | args.model = 'MDSR'
17 | args.patch_size = 48
18 | args.epochs = 650
19 |
20 | if args.template.find('DDBPN') >= 0:
21 | args.model = 'DDBPN'
22 | args.patch_size = 128
23 | args.scale = '4'
24 |
25 | args.data_test = 'Set5'
26 |
27 | args.batch_size = 20
28 | args.epochs = 1000
29 | args.lr_decay = 500
30 | args.gamma = 0.1
31 | args.weight_decay = 1e-4
32 |
33 | args.loss = '1*MSE'
34 |
35 | if args.template.find('GAN') >= 0:
36 | args.epochs = 200
37 | args.lr = 5e-5
38 | args.lr_decay = 150
39 |
40 | if args.template.find('RCAN') >= 0:
41 | args.model = 'RCAN'
42 | args.n_resgroups = 10
43 | args.n_resblocks = 20
44 | args.n_feats = 64
45 | args.chop = True
46 |
47 |
--------------------------------------------------------------------------------
/loss/discriminator.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | class Discriminator(nn.Module):
6 | def __init__(self, args, gan_type='GAN'):
7 | super(Discriminator, self).__init__()
8 |
9 | in_channels = 3
10 | out_channels = 64
11 | depth = 7
12 | #bn = not gan_type == 'WGAN_GP'
13 | bn = True
14 | act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
15 |
16 | m_features = [
17 | common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
18 | ]
19 | for i in range(depth):
20 | in_channels = out_channels
21 | if i % 2 == 1:
22 | stride = 1
23 | out_channels *= 2
24 | else:
25 | stride = 2
26 | m_features.append(common.BasicBlock(
27 | in_channels, out_channels, 3, stride=stride, bn=bn, act=act
28 | ))
29 |
30 | self.features = nn.Sequential(*m_features)
31 |
32 | patch_size = args.patch_size // (2**((depth + 1) // 2))
33 | m_classifier = [
34 | nn.Linear(out_channels * patch_size**2, 1024),
35 | act,
36 | nn.Linear(1024, 1)
37 | ]
38 | self.classifier = nn.Sequential(*m_classifier)
39 |
40 | def forward(self, x):
41 | features = self.features(x)
42 | output = self.classifier(features.view(features.size(0), -1))
43 |
44 | return output
45 |
46 |
--------------------------------------------------------------------------------
/model_ST/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
9 |
10 |
11 | class MeanShift(nn.Conv2d):
12 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
13 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
14 | std = torch.Tensor(rgb_std)
15 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
16 | self.weight.data.div_(std.view(3, 1, 1, 1))
17 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
18 | self.bias.data.div_(std)
19 | self.weight.requires_grad = False
20 | self.bias.requires_grad = False
21 |
22 |
23 | class Upsampler(nn.Sequential):
24 | def __init__(self, conv, scale, n_feat, act=False, bias=True):
25 | m = []
26 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
27 | for _ in range(int(math.log(scale, 2))):
28 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
29 | m.append(nn.PixelShuffle(2))
30 | if act: m.append(act())
31 | elif scale == 3:
32 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
33 | m.append(nn.PixelShuffle(3))
34 | if act: m.append(act())
35 | else:
36 | raise NotImplementedError
37 |
38 | super(Upsampler, self).__init__(*m)
39 |
40 |
--------------------------------------------------------------------------------
/model_TA/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
9 |
10 |
11 | class MeanShift(nn.Conv2d):
12 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
13 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
14 | std = torch.Tensor(rgb_std)
15 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
16 | self.weight.data.div_(std.view(3, 1, 1, 1))
17 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
18 | self.bias.data.div_(std)
19 | self.weight.requires_grad = False
20 | self.bias.requires_grad = False
21 |
22 |
23 | class Upsampler(nn.Sequential):
24 | def __init__(self, conv, scale, n_feat, act=False, bias=True):
25 | m = []
26 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
27 | for _ in range(int(math.log(scale, 2))):
28 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
29 | m.append(nn.PixelShuffle(2))
30 | if act: m.append(act())
31 | elif scale == 3:
32 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
33 | m.append(nn.PixelShuffle(3))
34 | if act: m.append(act())
35 | else:
36 | raise NotImplementedError
37 |
38 | super(Upsampler, self).__init__(*m)
39 |
40 |
--------------------------------------------------------------------------------
/data/common.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import skimage.color as sc
5 |
6 | import torch
7 |
8 |
9 | def get_patch(hr, patch_size=48, scale=2, multi=False, input_large=False):
10 | ih, iw = hr.shape[:2]
11 |
12 | ip = scale * patch_size
13 |
14 | ix = random.randrange(0, iw - ip + 1)
15 | iy = random.randrange(0, ih - ip + 1)
16 |
17 | hr = hr[int(iy):int(iy + ip), int(ix):int(ix + ip), :]
18 |
19 | return hr
20 |
21 |
22 | def set_channel(hr, n_channels=3):
23 | def _set_channel(img):
24 | if img.ndim == 2:
25 | img = np.expand_dims(img, axis=2)
26 |
27 | c = img.shape[2]
28 | if n_channels == 1 and c == 3:
29 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
30 | elif n_channels == 3 and c == 1:
31 | img = np.concatenate([img] * n_channels, 2)
32 |
33 | return img
34 |
35 | return _set_channel(hr)
36 |
37 |
38 | def np2Tensor(hr, rgb_range=255):
39 | def _np2Tensor(img):
40 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
41 | tensor = torch.from_numpy(np_transpose).float()
42 | tensor.mul_(rgb_range / 255)
43 |
44 | return tensor
45 |
46 | return _np2Tensor(hr)
47 |
48 |
49 | def augment(hr, hflip=True, rot=True):
50 | hflip = hflip and random.random() < 0.5
51 | vflip = rot and random.random() < 0.5
52 | rot90 = rot and random.random() < 0.5
53 |
54 | def _augment(img):
55 | if hflip: img = img[:, ::-1, :]
56 | if vflip: img = img[::-1, :, :]
57 | if rot90: img = img.transpose(1, 0, 2)
58 |
59 | return img
60 |
61 | return _augment(hr)
62 |
63 |
--------------------------------------------------------------------------------
/model/blindsr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import model.common as common
4 | import torch.nn.functional as F
5 |
6 |
7 | def make_model(args):
8 | return MLN(args)
9 |
10 |
11 | class MLN(nn.Module):
12 | def __init__(self, args, conv=common.default_conv):
13 | super(MLN, self).__init__()
14 |
15 | n_feats = args.n_feats
16 | kernel_size = 3
17 | scale = args.scale[0]
18 | act = nn.LeakyReLU(0.1, True)
19 | self.head = nn.Sequential(
20 | nn.Conv2d(3, n_feats, kernel_size=3, padding=1),
21 | act
22 | )
23 | m_body = [
24 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
25 | act,
26 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
27 | act,
28 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
29 | act,
30 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
31 | act,
32 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
33 | act,
34 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
35 | act
36 | ]
37 | m_tail = [
38 | common.Upsampler(conv, scale, n_feats, act=False),
39 | nn.Conv2d(
40 | n_feats, args.n_colors, kernel_size,
41 | padding=(kernel_size // 2)
42 | )
43 | ]
44 | self.tail = nn.Sequential(*m_tail)
45 |
46 | self.body = nn.Sequential(*m_body)
47 |
48 |
49 | def forward(self, lr,lr_bic):
50 | res = self.head(lr)
51 | res = self.body(res)
52 | res = self.tail(res)
53 | res +=lr_bic
54 |
55 | return res
56 |
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/model_meta_stage3/blindsr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import model.common as common
4 | import torch.nn.functional as F
5 |
6 |
7 | def make_model(args):
8 | return MLN(args)
9 |
10 |
11 | class MLN(nn.Module):
12 | def __init__(self, args, conv=common.default_conv):
13 | super(MLN, self).__init__()
14 |
15 | n_feats = args.n_feats
16 | kernel_size = 3
17 | scale = args.scale[0]
18 | act = nn.LeakyReLU(0.1, True)
19 | self.head = nn.Sequential(
20 | nn.Conv2d(3, n_feats, kernel_size=3, padding=1),
21 | act
22 | )
23 | m_body = [
24 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
25 | act,
26 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
27 | act,
28 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
29 | act,
30 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
31 | act,
32 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
33 | act,
34 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
35 | act
36 | ]
37 | m_tail = [
38 | common.Upsampler(conv, scale, n_feats, act=False),
39 | nn.Conv2d(
40 | n_feats, args.n_colors, kernel_size,
41 | padding=(kernel_size // 2)
42 | )
43 | ]
44 | self.tail = nn.Sequential(*m_tail)
45 |
46 | self.body = nn.Sequential(*m_body)
47 |
48 | self.ad = nn.AdaptiveAvgPool2d(1)
49 |
50 |
51 | def forward(self, lr,lr_bic):
52 | res = self.head(lr)
53 | res = self.body(res)
54 | deg_repre = res
55 | res = self.tail[0](res)
56 | res = self.tail[1](res)
57 | res +=lr_bic
58 |
59 |
60 | return res, deg_repre
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | #from dataloader import MSDataLoader
3 | from torch.utils.data import dataloader
4 | from torch.utils.data import ConcatDataset
5 |
6 | # This is a simple wrapper function for ConcatDataset
7 | class MyConcatDataset(ConcatDataset):
8 | def __init__(self, datasets):
9 | super(MyConcatDataset, self).__init__(datasets)
10 | self.train = datasets[0].train
11 |
12 | def set_scale(self, idx_scale):
13 | for d in self.datasets:
14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15 |
16 | class Data:
17 | def __init__(self, args):
18 | self.loader_train = None
19 | if not args.test_only:
20 | datasets = []
21 | for d in args.data_train:
22 | print(d)
23 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
24 | m = import_module('data.' + module_name.lower())
25 | datasets.append(getattr(m, module_name)(args, name=d))
26 |
27 | self.loader_train = dataloader.DataLoader(
28 | MyConcatDataset(datasets),
29 | batch_size=args.batch_size,
30 | shuffle=True,
31 | pin_memory=not args.cpu,
32 | num_workers=args.n_threads,
33 | )
34 |
35 | self.loader_test = []
36 | for d in args.data_test:
37 | if d in ['Set5', 'Set14', 'B100', 'Urban100']:
38 | m = import_module('data.benchmark')
39 | testset = getattr(m, 'Benchmark')(args, train=False, name=d)
40 | else:
41 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
42 | m = import_module('data.' + module_name.lower())
43 | testset = getattr(m, module_name)(args, train=False, name=d)
44 |
45 | self.loader_test.append(
46 | dataloader.DataLoader(
47 | testset,
48 | batch_size=1,
49 | shuffle=False,
50 | pin_memory=not args.cpu,
51 | num_workers=args.n_threads,
52 | )
53 | )
54 |
55 |
--------------------------------------------------------------------------------
/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
9 |
10 | class ResBlock(nn.Module):
11 | def __init__(
12 | self, conv, n_feats, kernel_size,
13 | bias=True, bn=False, act=nn.PReLU(), res_scale=1):
14 |
15 | super(ResBlock, self).__init__()
16 | m = []
17 | for i in range(2):
18 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
19 | if bn:
20 | m.append(nn.BatchNorm2d(n_feats))
21 | if i == 0:
22 | m.append(act)
23 |
24 | self.body = nn.Sequential(*m)
25 | self.res_scale = res_scale
26 |
27 | def forward(self, x):
28 | res = self.body(x).mul(self.res_scale)
29 | res += x
30 |
31 | return res
32 |
33 | class MeanShift(nn.Conv2d):
34 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
35 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
36 | std = torch.Tensor(rgb_std)
37 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
38 | self.weight.data.div_(std.view(3, 1, 1, 1))
39 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
40 | self.bias.data.div_(std)
41 | self.weight.requires_grad = False
42 | self.bias.requires_grad = False
43 |
44 |
45 | class Upsampler(nn.Sequential):
46 | def __init__(self, conv, scale, n_feat, act=False, bias=True):
47 | m = []
48 | if (int(scale) & (int(scale) - 1)) == 0: # Is scale = 2^n?
49 | for _ in range(int(math.log(scale, 2))):
50 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
51 | m.append(nn.PixelShuffle(2))
52 | if act: m.append(act())
53 | elif scale == 3:
54 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
55 | m.append(nn.PixelShuffle(3))
56 | if act: m.append(act())
57 | else:
58 | raise NotImplementedError
59 |
60 | super(Upsampler, self).__init__(*m)
61 |
62 |
--------------------------------------------------------------------------------
/model_meta_stage3/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
8 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
9 |
10 | class ResBlock(nn.Module):
11 | def __init__(
12 | self, conv, n_feats, kernel_size,
13 | bias=True, bn=False, act=nn.PReLU(), res_scale=1):
14 |
15 | super(ResBlock, self).__init__()
16 | m = []
17 | for i in range(2):
18 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
19 | if bn:
20 | m.append(nn.BatchNorm2d(n_feats))
21 | if i == 0:
22 | m.append(act)
23 |
24 | self.body = nn.Sequential(*m)
25 | self.res_scale = res_scale
26 |
27 | def forward(self, x):
28 | res = self.body(x).mul(self.res_scale)
29 | res += x
30 |
31 | return res
32 |
33 | class MeanShift(nn.Conv2d):
34 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
35 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
36 | std = torch.Tensor(rgb_std)
37 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
38 | self.weight.data.div_(std.view(3, 1, 1, 1))
39 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
40 | self.bias.data.div_(std)
41 | self.weight.requires_grad = False
42 | self.bias.requires_grad = False
43 |
44 |
45 | class Upsampler(nn.Sequential):
46 | def __init__(self, conv, scale, n_feat, act=False, bias=True):
47 | m = []
48 | if (int(scale) & (int(scale) - 1)) == 0: # Is scale = 2^n?
49 | for _ in range(int(math.log(scale, 2))):
50 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
51 | m.append(nn.PixelShuffle(2))
52 | if act: m.append(act())
53 | elif scale == 3:
54 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
55 | m.append(nn.PixelShuffle(3))
56 | if act: m.append(act())
57 | else:
58 | raise NotImplementedError
59 |
60 | super(Upsampler, self).__init__(*m)
61 |
62 |
--------------------------------------------------------------------------------
/main_stage2.sh:
--------------------------------------------------------------------------------
1 | ## noise-free degradations with isotropic Gaussian blurs
2 | # training knowledge distillation
3 | CUDA_VISIBLE_DEVICES=0 python3 main_stage2.py --dir_data='/root/datasets' \
4 | --model='blindsr' \
5 | --scale='4' \
6 | --blur_type='iso_gaussian' \
7 | --noise=0.0 \
8 | --sig_min=0.2 \
9 | --sig_max=4.0 \
10 | --sig 2.6 \
11 | --n_GPUs=1 \
12 | --epochs_encoder 0 \
13 | --epochs_sr 600 \
14 | --data_test Set5 \
15 | --st_save_epoch 95 \
16 | --n_feats 128 \
17 | --patch_size 64 \
18 | --task_iter 5 \
19 | --test_iter 5 \
20 | --meta_batch_size 5 \
21 | --task_batch_size 16 \
22 | --lr_sr 1e-4 \
23 | --lr_task 1e-2 \
24 | --pre_train="./experiment/stage1.pt" \
25 | --resume 0 \
26 | --test_every 240 \
27 | --print_every 40 \
28 | --lr_decay_sr 150 \
29 | --data_train DF2K \
30 | --save stage2
31 |
32 | #--batch_size 32 \
33 |
34 | # anisotropic + noise
35 | # CUDA_VISIBLE_DEVICES=0 python3 main_stage2.py --dir_data='/root/datasets' \
36 | # --model='blindsr' \
37 | # --scale='4' \
38 | # --blur_type='aniso_gaussian' \
39 | # --noise=25.0 \
40 | # --lambda_min=0.2 \
41 | # --lambda_max=4.0 \
42 | # --n_GPUs=1 \
43 | # --epochs_encoder 0 \
44 | # --epochs_sr 600 \
45 | # --data_test Set5 \
46 | # --st_save_epoch 95 \
47 | # --n_feats 128 \
48 | # --patch_size 64 \
49 | # --task_iter 5 \
50 | # --test_iter 5 \
51 | # --meta_batch_size 5 \
52 | # --task_batch_size 16 \
53 | # --lr_sr 1e-4 \
54 | # --lr_task 1e-2 \
55 | # --pre_train="./experiment/stage1.pt" \
56 | # --resume 0 \
57 | # --test_every 240 \
58 | # --print_every 40 \
59 | # --lr_decay_sr 150 \
60 | # --data_train DF2K \
61 | # --save stage2
--------------------------------------------------------------------------------
/main_stage3.sh:
--------------------------------------------------------------------------------
1 | ## noise-free degradations with isotropic Gaussian blurs
2 | # training knowledge distillation
3 | CUDA_VISIBLE_DEVICES=0 python3 main_stage3.py --dir_data='/root/datasets' \
4 | --model='blindsr' \
5 | --scale='4' \
6 | --blur_type='iso_gaussian' \
7 | --noise=0.0 \
8 | --sig_min=0.2 \
9 | --sig_max=4.0 \
10 | --sig 2.6 \
11 | --n_GPUs=1 \
12 | --epochs_encoder 0 \
13 | --epochs_sr 500 \
14 | --data_test Set5 \
15 | --st_save_epoch 480 \
16 | --n_feats 128 \
17 | --patch_size 64 \
18 | --task_iter 5 \
19 | --test_iter 5 \
20 | --meta_batch_size 5 \
21 | --batch_size 16 \
22 | --lr_sr 1e-4 \
23 | --lr_task 1e-2 \
24 | --pre_train_meta="./experiment/stage2.pt" \
25 | --resume 0 \
26 | --test_every 1500 \
27 | --print_every 300 \
28 | --lr_decay_sr 125 \
29 | --data_train DF2K \
30 | --save stage3
31 |
32 | #--batch_size 32 \
33 |
34 | # anisotropic + noise
35 | # CUDA_VISIBLE_DEVICES=0 python3 main_stage3.py --dir_data='/root/datasets' \
36 | # --model='blindsr' \
37 | # --scale='4' \
38 | # --blur_type='aniso_gaussian' \
39 | # --noise=25.0 \
40 | # --lambda_min=0.2 \
41 | # --lambda_max=4.0 \
42 | # --n_GPUs=1 \
43 | # --epochs_encoder 0 \
44 | # --epochs_sr 500 \
45 | # --data_test Set14 \
46 | # --st_save_epoch 480 \
47 | # --n_feats 128 \
48 | # --patch_size 64 \
49 | # --task_iter 5 \
50 | # --test_iter 5 \
51 | # --meta_batch_size 5 \
52 | # --batch_size 16 \
53 | # --lr_sr 1e-4 \
54 | # --lr_task 1e-2 \
55 | # --pre_train_meta="./experiment/stage2.pt" \
56 | # --resume 0 \
57 | # --test_every 1500 \
58 | # --print_every 300 \
59 | # --lr_decay_sr 125 \
60 | # --data_train DF2K \
61 | # --save stage3
62 |
--------------------------------------------------------------------------------
/main_stage4.sh:
--------------------------------------------------------------------------------
1 | ## noise-free degradations with isotropic Gaussian blurs
2 | # training knowledge distillation
3 | CUDA_VISIBLE_DEVICES=2 python3 main_stage4.py --dir_data='/root/datasets' \
4 | --model='blindsr' \
5 | --scale='4' \
6 | --blur_type='iso_gaussian' \
7 | --noise=0.0 \
8 | --sig_min=0.2 \
9 | --sig_max=4.0 \
10 | --sig 2.6 \
11 | --n_GPUs=1 \
12 | --epochs_encoder 100 \
13 | --epochs_sr 500 \
14 | --data_test Set5 \
15 | --st_save_epoch 480 \
16 | --n_feats 128 \
17 | --patch_size 64 \
18 | --task_iter 5 \
19 | --test_iter 5 \
20 | --meta_batch_size 5 \
21 | --batch_size 16 \
22 | --lr_sr 1e-4 \
23 | --lr_task 1e-2 \
24 | --lr_encoder 1e-4 \
25 | --pre_train_TA="./experiment/stage3.pt" \
26 | --pre_train_ST="./experiment/stage3.pt" \
27 | --pre_train_meta="./experiment/stage2.pt" \
28 | --resume 0 \
29 | --test_every 800 \
30 | --print_every 250 \
31 | --lr_decay_sr 125 \
32 | --data_train DF2K \
33 | --save stage4
34 |
35 |
36 | #anisotropic Gaussian kenrnel + noise
37 | # CUDA_VISIBLE_DEVICES=0 python3 main_stage4.py --dir_data='/root/datasets' \
38 | # --model='blindsr' \
39 | # --scale='4' \
40 | # --blur_type='aniso_gaussian' \
41 | # --noise=25.0 \
42 | # --lambda_min=0.2 \
43 | # --lambda_max=4.0 \
44 | # --n_GPUs=1 \
45 | # --epochs_encoder 100 \
46 | # --epochs_sr 500 \
47 | # --data_test Set5 \
48 | # --st_save_epoch 480 \
49 | # --n_feats 128 \
50 | # --patch_size 64 \
51 | # --task_iter 5 \
52 | # --test_iter 5 \
53 | # --meta_batch_size 5 \
54 | # --batch_size 16 \
55 | # --lr_sr 1e-4 \
56 | # --lr_task 1e-2 \
57 | # --lr_encoder 1e-4 \
58 | # --pre_train_TA="./experiment/stage3.pt" \
59 | # --pre_train_ST="./experiment/stage3.pt" \
60 | # --pre_train_meta="./experiment/stage2.pt" \
61 | # --resume 0 \
62 | # --test_every 800 \
63 | # --print_every 250 \
64 | # --lr_decay_sr 125 \
65 | # --data_train DF2K \
66 | # --save stage4
--------------------------------------------------------------------------------
/model_meta/common.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | def linear(input, weight, bias=None):
8 | if bias is None:
9 | return F.linear(input, weight)
10 | else:
11 | return F.linear(input, weight, bias)
12 |
13 | def conv2d(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
14 | return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
15 |
16 | def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1):
17 | ''' momentum = 1 restricts stats to the current mini-batch '''
18 | # This hack only works when momentum is 1 and avoids needing to track running stats
19 | # by substuting dummy variables
20 | # running_mean = torch.zeros(np.prod(np.array(input.data.size()[1]))).cuda()
21 | # running_var = torch.ones(np.prod(np.array(input.data.size()[1]))).cuda()
22 | return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
23 |
24 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
25 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
26 |
27 | class ResBlock(nn.Module):
28 | def __init__(
29 | self, conv, n_feats, kernel_size,
30 | bias=True, bn=False, act=nn.PReLU(), res_scale=1):
31 |
32 | super(ResBlock, self).__init__()
33 | m = []
34 | for i in range(2):
35 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
36 | if bn:
37 | m.append(nn.BatchNorm2d(n_feats))
38 | if i == 0:
39 | m.append(act)
40 |
41 | self.body = nn.Sequential(*m)
42 | self.res_scale = res_scale
43 |
44 | def forward(self, x):
45 | res = self.body(x).mul(self.res_scale)
46 | res += x
47 |
48 | return res
49 |
50 | class MeanShift(nn.Conv2d):
51 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
52 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
53 | std = torch.Tensor(rgb_std)
54 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
55 | self.weight.data.div_(std.view(3, 1, 1, 1))
56 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
57 | self.bias.data.div_(std)
58 | self.weight.requires_grad = False
59 | self.bias.requires_grad = False
60 |
61 |
62 | class Upsampler(nn.Sequential):
63 | def __init__(self, conv, scale, n_feat, act=False, bias=True):
64 | m = []
65 | if (int(scale) & (int(scale) - 1)) == 0: # Is scale = 2^n?
66 | for _ in range(int(math.log(scale, 2))):
67 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
68 | m.append(nn.PixelShuffle(2))
69 | if act: m.append(act())
70 | elif scale == 3:
71 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
72 | m.append(nn.PixelShuffle(3))
73 | if act: m.append(act())
74 | else:
75 | raise NotImplementedError
76 |
77 | super(Upsampler, self).__init__(*m)
78 |
79 |
--------------------------------------------------------------------------------
/model_meta/blindsr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import model_meta.common as common
4 | import torch.nn.functional as F
5 |
6 |
7 | def make_model(args):
8 | return MLN(args)
9 |
10 |
11 | class MLN(nn.Module):
12 | def __init__(self, args, conv=common.default_conv):
13 | super(MLN, self).__init__()
14 |
15 | n_feats = args.n_feats
16 | kernel_size = 3
17 | scale = args.scale[0]
18 | act = nn.LeakyReLU(0.1, True)
19 | self.head = nn.Sequential(
20 | nn.Conv2d(3, n_feats, kernel_size=3, padding=1),
21 | act
22 | )
23 | m_body = [
24 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
25 | act,
26 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
27 | act,
28 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
29 | act,
30 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
31 | act,
32 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
33 | act,
34 | nn.Conv2d(n_feats, n_feats, kernel_size=3, padding=1),
35 | act
36 | ]
37 | m_tail = [
38 | common.Upsampler(conv, scale, n_feats, act=False),
39 | nn.Conv2d(
40 | n_feats, args.n_colors, kernel_size,
41 | padding=(kernel_size // 2)
42 | )
43 | ]
44 | self.tail = nn.Sequential(*m_tail)
45 |
46 | self.body = nn.Sequential(*m_body)
47 |
48 |
49 | def forward(self, lr,lr_bic,weights,base=''):
50 | #************************head********************
51 | res = common.conv2d(lr, weights[base + 'head.0.weight'], weights[base + 'head.0.bias'], stride=1, padding=1)
52 | res = F.leaky_relu(res, 0.1, True)
53 | #************************body********************
54 | res = common.conv2d(res, weights[base + 'body.0.weight'], weights[base + 'body.0.bias'], stride=1, padding=1)
55 | res = F.leaky_relu(res, 0.1, True)
56 | res = common.conv2d(res, weights[base + 'body.2.weight'], weights[base + 'body.2.bias'], stride=1, padding=1)
57 | res = F.leaky_relu(res, 0.1, True)
58 | res = common.conv2d(res, weights[base + 'body.4.weight'], weights[base + 'body.4.bias'], stride=1, padding=1)
59 | res = F.leaky_relu(res, 0.1, True)
60 | res = common.conv2d(res, weights[base + 'body.6.weight'], weights[base + 'body.6.bias'], stride=1, padding=1)
61 | res = F.leaky_relu(res, 0.1, True)
62 | res = common.conv2d(res, weights[base + 'body.8.weight'], weights[base + 'body.8.bias'], stride=1, padding=1)
63 | res = F.leaky_relu(res, 0.1, True)
64 | res = common.conv2d(res, weights[base + 'body.10.weight'], weights[base + 'body.10.bias'], stride=1, padding=1)
65 | res = F.leaky_relu(res, 0.1, True)
66 | #**********************tailx4***********************
67 | res = common.conv2d(res, weights[base +"tail.0.0.weight"], weights[base +"tail.0.0.bias"], stride=1, padding=1)
68 | res = F.pixel_shuffle(res, 2)
69 | res = common.conv2d(res, weights[base +"tail.0.2.weight"], weights[base +"tail.0.2.bias"], stride=1, padding=1)
70 | res = F.pixel_shuffle(res, 2)
71 | res = common.conv2d(res, weights[base + "tail.1.weight"], weights[base + "tail.1.bias"], stride=1, padding=1)
72 | # res = self.tail(res)
73 | res +=lr_bic
74 |
75 | return res
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/loss/adversarial.py:
--------------------------------------------------------------------------------
1 | import utility
2 | from model import common
3 | from loss import discriminator
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from torch.autograd import Variable
10 |
11 | class Adversarial(nn.Module):
12 | def __init__(self, args, gan_type):
13 | super(Adversarial, self).__init__()
14 | self.gan_type = gan_type
15 | self.gan_k = args.gan_k
16 | self.discriminator = discriminator.Discriminator(args, gan_type)
17 | if gan_type != 'WGAN_GP':
18 | self.optimizer = utility.make_optimizer(args, self.discriminator)
19 | else:
20 | self.optimizer = optim.Adam(
21 | self.discriminator.parameters(),
22 | betas=(0, 0.9), eps=1e-8, lr=1e-5
23 | )
24 | self.scheduler = utility.make_scheduler(args, self.optimizer)
25 |
26 | def forward(self, fake, real):
27 | fake_detach = fake.detach()
28 |
29 | self.loss = 0
30 | for _ in range(self.gan_k):
31 | self.optimizer.zero_grad()
32 | d_fake = self.discriminator(fake_detach)
33 | d_real = self.discriminator(real)
34 | if self.gan_type == 'GAN':
35 | label_fake = torch.zeros_like(d_fake)
36 | label_real = torch.ones_like(d_real)
37 | loss_d \
38 | = F.binary_cross_entropy_with_logits(d_fake, label_fake) \
39 | + F.binary_cross_entropy_with_logits(d_real, label_real)
40 | elif self.gan_type.find('WGAN') >= 0:
41 | loss_d = (d_fake - d_real).mean()
42 | if self.gan_type.find('GP') >= 0:
43 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
44 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
45 | hat.requires_grad = True
46 | d_hat = self.discriminator(hat)
47 | gradients = torch.autograd.grad(
48 | outputs=d_hat.sum(), inputs=hat,
49 | retain_graph=True, create_graph=True, only_inputs=True
50 | )[0]
51 | gradients = gradients.view(gradients.size(0), -1)
52 | gradient_norm = gradients.norm(2, dim=1)
53 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
54 | loss_d += gradient_penalty
55 |
56 | # Discriminator update
57 | self.loss += loss_d.item()
58 | loss_d.backward()
59 | self.optimizer.step()
60 |
61 | if self.gan_type == 'WGAN':
62 | for p in self.discriminator.parameters():
63 | p.data.clamp_(-1, 1)
64 |
65 | self.loss /= self.gan_k
66 |
67 | d_fake_for_g = self.discriminator(fake)
68 | if self.gan_type == 'GAN':
69 | loss_g = F.binary_cross_entropy_with_logits(
70 | d_fake_for_g, label_real
71 | )
72 | elif self.gan_type.find('WGAN') >= 0:
73 | loss_g = -d_fake_for_g.mean()
74 |
75 | # Generator loss
76 | return loss_g
77 |
78 | def state_dict(self, *args, **kwargs):
79 | state_discriminator = self.discriminator.state_dict(*args, **kwargs)
80 | state_optimizer = self.optimizer.state_dict()
81 |
82 | return dict(**state_discriminator, **state_optimizer)
83 |
84 | # Some references
85 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
86 | # OR
87 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
88 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MRDA
2 |
3 |
4 | This project is the official implementation of 'Meta-Learning based Degradation Representation for Blind Super-Resolution', TIP2023
5 | > **Meta-Learning based Degradation Representation for Blind Super-Resolution [[Paper](https://arxiv.org/pdf/2207.13963.pdf)] [[Project](https://github.com/Zj-BinXia/MRDA)]**
6 |
7 | This is code for MRDA (for classic degradation model, ie y=kx+n)
8 |
9 |
10 |
11 |
12 |
13 | ---
14 |
15 | ## Dependencies and Installation
16 |
17 | - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
18 | - [PyTorch >= 1.10](https://pytorch.org/)
19 |
20 | ## Dataset Preparation
21 |
22 | We use DF2K, which combines [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (800 images) and [Flickr2K](https://github.com/LimBee/NTIRE2017) (2650 images).
23 |
24 | ---
25 |
26 | ## Training
27 |
28 | 1. train Meta-Learning Network (MLN) bicubic pretraining
29 |
30 | ```bash
31 | sh main_stage1.sh
32 | ```
33 |
34 | ### Isotropic Gaussian Kernels
35 |
36 |
37 | 2. we train MLN using meta-learning scheme. **It is notable that modify the ''pre_train'' of main_stage2.sh to the path of trained main_stage1 checkpoint.** Then, we run
38 |
39 | ```bash
40 | sh main_stage2.sh
41 | ```
42 |
43 | 3. we train MLN with teacher MRDA_T together. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint.** Then, we run
44 |
45 | ```bash
46 | sh main_stage3.sh
47 | ```
48 |
49 |
50 | 4. we train student MRDA_S. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint. ''pre_train_TA'' and ''pre_train_ST'' are both set to the path of trained main_stage3 checkpoint..** Then, we run
51 |
52 | ```bash
53 | sh main_stage4.sh
54 | ```
55 |
56 | ### Anisotropic Gaussian Kernels plus noise
57 |
58 | It's training process is the same as isotropic Gaussian Kernels, except we use the anisotropic Gaussian Kernels settings in main_stage2.sh, main_stage3.sh, and main_stage4.sh .
59 |
60 | 2. we train MLN using meta-learning scheme. **It is notable that modify the ''pre_train'' of main_stage2.sh to the path of trained main_stage1 checkpoint.** Then, we run
61 |
62 | ```bash
63 | sh main_stage2.sh
64 | ```
65 |
66 | 3. we train MLN with teacher MRDA_T together. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint.** Then, we run
67 |
68 | ```bash
69 | sh main_stage3.sh
70 | ```
71 |
72 |
73 | 4. we train student MRDA_S. **It is notable that modify the ''pre_train_meta'' of main_stage3.sh to the path of trained main_stage2 checkpoint. ''pre_train_TA'' and ''pre_train_ST'' are both set to the path of trained main_stage3 checkpoint..** Then, we run
74 |
75 | ```bash
76 | sh main_stage4.sh
77 | ```
78 |
79 | ---
80 |
81 | ## :european_castle: Model Zoo
82 |
83 | Please download checkpoints from [Google Drive](https://drive.google.com/drive/folders/1gB-q3k_e_XeZbvOFXMNlS4aHDkpDqMtU?usp=sharing).
84 |
85 | ---
86 |
87 | ## Testing
88 |
89 | ### Isotropic Gaussian Kernels
90 |
91 | ```bash
92 | sh test_iso_stage4.sh
93 | ```
94 |
95 | ### Anisotropic Gaussian Kernels plus noise
96 |
97 | ```bash
98 | sh test_anisoAnoise_stage4.sh
99 | ```
100 |
101 | ---
102 |
103 | ## Results
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 | ---
114 |
115 | ## BibTeX
116 |
117 | @article{xia2022meta,
118 | title={Meta-learning based degradation representation for blind super-resolution},
119 | author={Xia, Bin and Tian, Yapeng and Zhang, Yulun and Hang, Yucheng and Yang, Wenming and Liao, Qingmin},
120 | journal={IEEE Transactions on Image Processing},
121 | year={2023}
122 | }
123 |
124 | ## 📧 Contact
125 |
126 | If you have any question, please email `zjbinxia@gmail.com`.
127 |
128 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import matplotlib.pyplot as plt
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class Loss(nn.modules.loss._Loss):
14 | def __init__(self, args, ckp):
15 | super(Loss, self).__init__()
16 | print('Preparing loss function:')
17 |
18 | self.n_GPUs = args.n_GPUs
19 | self.loss = []
20 | self.loss_module = nn.ModuleList()
21 | for loss in args.loss.split('+'):
22 | weight, loss_type = loss.split('*')
23 | if loss_type == 'MSE':
24 | loss_function = nn.MSELoss()
25 | elif loss_type == 'L1':
26 | loss_function = nn.L1Loss()
27 | elif loss_type == 'CE':
28 | loss_function = nn.CrossEntropyLoss()
29 | elif loss_type.find('VGG') >= 0:
30 | module = import_module('loss.vgg')
31 | loss_function = getattr(module, 'VGG')(
32 | loss_type[3:],
33 | rgb_range=args.rgb_range
34 | )
35 | elif loss_type.find('GAN') >= 0:
36 | module = import_module('loss.adversarial')
37 | loss_function = getattr(module, 'Adversarial')(
38 | args,
39 | loss_type
40 | )
41 |
42 | self.loss.append({
43 | 'type': loss_type,
44 | 'weight': float(weight),
45 | 'function': loss_function}
46 | )
47 | if loss_type.find('GAN') >= 0:
48 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
49 |
50 | if len(self.loss) > 1:
51 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
52 |
53 | for l in self.loss:
54 | if l['function'] is not None:
55 | print('{:.3f} * {}'.format(l['weight'], l['type']))
56 | self.loss_module.append(l['function'])
57 |
58 | self.log = torch.Tensor()
59 |
60 | device = torch.device('cpu' if args.cpu else 'cuda')
61 | self.loss_module.to(device)
62 | if args.precision == 'half': self.loss_module.half()
63 | if not args.cpu and args.n_GPUs > 1:
64 | self.loss_module = nn.DataParallel(
65 | self.loss_module, range(args.n_GPUs)
66 | )
67 |
68 | if args.load != '.': self.load(ckp.dir, cpu=args.cpu)
69 |
70 | def forward(self, sr, hr):
71 | losses = []
72 | for i, l in enumerate(self.loss):
73 | if l['function'] is not None:
74 | loss = l['function'](sr, hr)
75 | effective_loss = l['weight'] * loss
76 | losses.append(effective_loss)
77 | self.log[-1, i] += effective_loss.item()
78 | elif l['type'] == 'DIS':
79 | self.log[-1, i] += self.loss[i - 1]['function'].loss
80 |
81 | loss_sum = sum(losses)
82 | if len(self.loss) > 1:
83 | self.log[-1, -1] += loss_sum.item()
84 |
85 | return loss_sum
86 |
87 | def step(self):
88 | for l in self.get_loss_module():
89 | if hasattr(l, 'scheduler'):
90 | l.scheduler.step()
91 |
92 | def start_log(self):
93 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
94 |
95 | def end_log(self, n_batches):
96 | self.log[-1].div_(n_batches)
97 |
98 | def display_loss(self, batch):
99 | n_samples = batch + 1
100 | log = []
101 | for l, c in zip(self.loss, self.log[-1]):
102 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
103 |
104 | return ''.join(log)
105 |
106 | def plot_loss(self, apath, epoch):
107 | axis = np.linspace(1, epoch, epoch)
108 | for i, l in enumerate(self.loss):
109 | label = '{} Loss'.format(l['type'])
110 | fig = plt.figure()
111 | plt.title(label)
112 | plt.plot(axis, self.log[:, i].numpy(), label=label)
113 | plt.legend()
114 | plt.xlabel('Epochs')
115 | plt.ylabel('Loss')
116 | plt.grid(True)
117 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))
118 | plt.close(fig)
119 |
120 | def get_loss_module(self):
121 | if self.n_GPUs == 1:
122 | return self.loss_module
123 | else:
124 | return self.loss_module.module
125 |
126 | def save(self, apath):
127 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
128 | torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
129 |
130 | def load(self, apath, cpu=False):
131 | if cpu:
132 | kwargs = {'map_location': lambda storage, loc: storage}
133 | else:
134 | kwargs = {}
135 |
136 | self.load_state_dict(torch.load(
137 | os.path.join(apath, 'loss.pt'),
138 | **kwargs
139 | ))
140 | self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
141 | for l in self.loss_module:
142 | if hasattr(l, 'scheduler'):
143 | for _ in range(len(self.log)): l.scheduler.step()
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import threading
3 | import queue
4 | import random
5 | import collections
6 |
7 | import torch
8 | import torch.multiprocessing as multiprocessing
9 |
10 | from torch._C import _set_worker_signal_handlers
11 | from torch.utils.data.dataloader import DataLoader
12 | from torch.utils.data.dataloader import _DataLoaderIter
13 | from torch.utils.data import _utils
14 |
15 | if sys.version_info[0] == 2:
16 | import Queue as queue
17 | else:
18 | import queue
19 |
20 | def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
21 | global _use_shared_memory
22 | _use_shared_memory = True
23 | _set_worker_signal_handlers()
24 |
25 | torch.set_num_threads(1)
26 | torch.manual_seed(seed)
27 | while True:
28 | r = index_queue.get()
29 | if r is None:
30 | break
31 | idx, batch_indices = r
32 | try:
33 | idx_scale = 0
34 | if len(scale) > 1 and dataset.train:
35 | idx_scale = random.randrange(0, len(scale))
36 | dataset.set_scale(idx_scale)
37 |
38 | samples = collate_fn([dataset[i] for i in batch_indices])
39 | samples.append(idx_scale)
40 |
41 | except Exception:
42 | data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info())))
43 | else:
44 | data_queue.put((idx, samples))
45 |
46 | class _MSDataLoaderIter(_DataLoaderIter):
47 | def __init__(self, loader):
48 | self.dataset = loader.dataset
49 | self.scale = loader.scale
50 | self.collate_fn = loader.collate_fn
51 | self.batch_sampler = loader.batch_sampler
52 | self.num_workers = loader.num_workers
53 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
54 | self.timeout = loader.timeout
55 | self.done_event = threading.Event()
56 |
57 | self.sample_iter = iter(self.batch_sampler)
58 |
59 | if self.num_workers > 0:
60 | self.worker_init_fn = loader.worker_init_fn
61 | self.index_queues = [
62 | multiprocessing.Queue() for _ in range(self.num_workers)
63 | ]
64 | self.worker_queue_idx = 0
65 | self.worker_result_queue = multiprocessing.Queue()
66 | self.batches_outstanding = 0
67 | self.worker_pids_set = False
68 | self.shutdown = False
69 | self.send_idx = 0
70 | self.rcvd_idx = 0
71 | self.reorder_dict = {}
72 |
73 | base_seed = torch.LongTensor(1).random_()[0]
74 | self.workers = [
75 | multiprocessing.Process(
76 | target=_ms_loop,
77 | args=(
78 | self.dataset,
79 | self.index_queues[i],
80 | self.worker_result_queue,
81 | self.collate_fn,
82 | self.scale,
83 | base_seed + i,
84 | self.worker_init_fn,
85 | i
86 | )
87 | )
88 | for i in range(self.num_workers)]
89 |
90 | if self.pin_memory or self.timeout > 0:
91 | self.data_queue = queue.Queue()
92 | if self.pin_memory:
93 | maybe_device_id = torch.cuda.current_device()
94 | else:
95 | # do not initialize cuda context if not necessary
96 | maybe_device_id = None
97 | self.pin_memory_thread = threading.Thread(
98 | target=_utils.pin_memory._pin_memory_loop,
99 | args=(self.worker_result_queue, self.data_queue, maybe_device_id, self.done_event))
100 | self.pin_memory_thread.daemon = True
101 | self.pin_memory_thread.start()
102 | else:
103 | self.data_queue = self.worker_result_queue
104 |
105 | for w in self.workers:
106 | w.daemon = True # ensure that the worker exits on process exit
107 | w.start()
108 |
109 | _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
110 | _utils.signal_handling._set_SIGCHLD_handler()
111 | self.worker_pids_set = True
112 |
113 | # prime the prefetch loop
114 | for _ in range(2 * self.num_workers):
115 | self._put_indices()
116 |
117 | class MSDataLoader(DataLoader):
118 | def __init__(
119 | self, args, dataset, batch_size=1, shuffle=False,
120 | sampler=None, batch_sampler=None,
121 | collate_fn=_utils.collate.default_collate, pin_memory=False, drop_last=True,
122 | timeout=0, worker_init_fn=None):
123 |
124 | super(MSDataLoader, self).__init__(
125 | dataset, batch_size=batch_size, shuffle=shuffle,
126 | sampler=sampler, batch_sampler=batch_sampler,
127 | num_workers=args.n_threads, collate_fn=collate_fn,
128 | pin_memory=pin_memory, drop_last=drop_last,
129 | timeout=timeout, worker_init_fn=worker_init_fn)
130 |
131 | self.scale = args.scale
132 |
133 | def __iter__(self):
134 | return _MSDataLoaderIter(self)
135 |
--------------------------------------------------------------------------------
/data/multiscalesrdata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 | import pickle
5 |
6 | from data import common
7 |
8 | import numpy as np
9 | import imageio
10 | import torch
11 | import torch.utils.data as data
12 |
13 |
14 | class SRData(data.Dataset):
15 | def __init__(self, args, name='', train=True, benchmark=False):
16 | self.args = args
17 | self.name = name
18 | self.train = train
19 | self.split = 'train' if train else 'test'
20 | self.do_eval = True
21 | self.benchmark = benchmark
22 | self.input_large = (args.model == 'VDSR')
23 | self.scale = args.scale
24 | self.idx_scale = 0
25 |
26 | self._set_filesystem(args.dir_data)
27 | if args.ext.find('img') < 0:
28 | path_bin = os.path.join(self.apath, 'bin')
29 | os.makedirs(path_bin, exist_ok=True)
30 |
31 | list_hr, list_lr = self._scan()
32 | if args.ext.find('img') >= 0 or benchmark:
33 | self.images_hr, self.images_lr = list_hr, list_lr
34 | elif args.ext.find('sep') >= 0:
35 | os.makedirs(
36 | self.dir_hr.replace(self.apath, path_bin),
37 | exist_ok=True
38 | )
39 | for s in self.scale:
40 | os.makedirs(
41 | os.path.join(
42 | self.dir_lr.replace(self.apath, path_bin),
43 | 'X{}'.format(int(s))
44 | ),
45 | exist_ok=True
46 | )
47 |
48 | self.images_hr, self.images_lr = [], [[] for _ in self.scale]
49 | for h in list_hr:
50 | b = h.replace(self.apath, path_bin)
51 | b = b.replace(self.ext[0], '.pt')
52 | self.images_hr.append(b)
53 | self._check_and_load(args.ext, h, b, verbose=True)
54 | for i, ll in enumerate(list_lr):
55 | for l in ll:
56 | b = l.replace(self.apath, path_bin)
57 | b = b.replace(self.ext[1], '.pt')
58 | self.images_lr[i].append(b)
59 | self._check_and_load(args.ext, l, b, verbose=True)
60 | if train:
61 | n_patches = args.batch_size * args.test_every
62 | n_images = len(args.data_train) * len(self.images_hr)
63 | if n_images == 0:
64 | self.repeat = 0
65 | else:
66 | self.repeat = max(n_patches // n_images, 1)
67 |
68 | # Below functions as used to prepare images
69 | def _scan(self):
70 | names_hr = sorted(
71 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
72 | )
73 | names_lr = [[] for _ in self.scale]
74 | for f in names_hr:
75 | filename, _ = os.path.splitext(os.path.basename(f))
76 | for si, s in enumerate(self.scale):
77 | names_lr[si].append(os.path.join(
78 | self.dir_lr, 'X{}/{}x{}{}'.format(
79 | int(s), filename, int(s), self.ext[1]
80 | )
81 | ))
82 |
83 | return names_hr, names_lr
84 |
85 | def _set_filesystem(self, dir_data):
86 | self.apath = os.path.join(dir_data, self.name)
87 | self.dir_hr = os.path.join(self.apath, 'HR')
88 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
89 | if self.input_large: self.dir_lr += 'L'
90 | self.ext = ('.png', '.png')
91 |
92 | def _check_and_load(self, ext, img, f, verbose=True):
93 | if not os.path.isfile(f) or ext.find('reset') >= 0:
94 | if verbose:
95 | print('Making a binary: {}'.format(f))
96 | with open(f, 'wb') as _f:
97 | pickle.dump(imageio.imread(img), _f)
98 |
99 | def __getitem__(self, idx):
100 | hr, filename = self._load_file(idx)
101 | hr = self.get_patch( hr)
102 | hr = common.set_channel(hr, n_channels=self.args.n_colors)
103 | hr = common.np2Tensor(hr, rgb_range=self.args.rgb_range)
104 |
105 | return hr, filename
106 |
107 | def __len__(self):
108 | if self.train:
109 | return len(self.images_hr) * self.repeat
110 | else:
111 | return len(self.images_hr)
112 |
113 | def _get_index(self, idx):
114 | if self.train:
115 | return idx % len(self.images_hr)
116 | else:
117 | return idx
118 |
119 | def _load_file(self, idx):
120 | idx = self._get_index(idx)
121 | f_hr = self.images_hr[idx]
122 |
123 | filename, _ = os.path.splitext(os.path.basename(f_hr))
124 | if self.args.ext == 'img' or self.benchmark:
125 | hr = imageio.imread(f_hr)
126 | elif self.args.ext.find('sep') >= 0:
127 | with open(f_hr, 'rb') as _f:
128 | hr = pickle.load(_f)
129 |
130 | return hr, filename
131 |
132 | def get_patch(self, hr):
133 | scale = self.scale[self.idx_scale]
134 | if self.train:
135 | hr = common.get_patch(
136 | hr,
137 | patch_size=self.args.patch_size,
138 | scale=scale,
139 | multi=(len(self.scale) > 1),
140 | input_large=self.input_large
141 | )
142 | if not self.args.no_augment: hr = common.augment( hr)
143 | else:
144 | ih, iw = hr.shape[:2]
145 | ih = ih//scale
146 | iw = iw // scale
147 | hr = hr[0:int(ih * scale), 0:int(iw * scale)]
148 |
149 | return hr
150 |
151 | def set_scale(self, idx_scale):
152 | if not self.input_large:
153 | self.idx_scale = idx_scale
154 | else:
155 | self.idx_scale = random.randint(0, len(self.scale) - 1)
156 |
157 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args, ckp):
10 | super(Model, self).__init__()
11 | print('Making model...')
12 | self.args = args
13 | self.scale = args.scale
14 | self.idx_scale = 0
15 | self.self_ensemble = args.self_ensemble
16 | self.chop = args.chop
17 | self.precision = args.precision
18 | self.cpu = args.cpu
19 | self.device = torch.device('cpu' if args.cpu else 'cuda')
20 | self.n_GPUs = args.n_GPUs
21 | self.save_models = args.save_models
22 | self.save = args.save
23 |
24 | module = import_module('model.'+args.model)
25 | self.model = module.make_model(args).to(self.device)
26 | if args.precision == 'half': self.model.half()
27 |
28 | if not args.cpu and args.n_GPUs > 1:
29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs))
30 |
31 | self.load(
32 | ckp.dir,
33 | pre_train=args.pre_train,
34 | resume=args.resume,
35 | cpu=args.cpu
36 | )
37 |
38 | def forward(self, lr,lr_bic):
39 | if self.self_ensemble and not self.training:
40 | if self.chop:
41 | forward_function = self.forward_chop
42 | else:
43 | forward_function = self.model.forward
44 |
45 | return self.forward_x8(lr, forward_function)
46 | elif self.chop and not self.training:
47 | return self.forward_chop(lr)
48 | else:
49 | return self.model(lr,lr_bic)
50 |
51 | def get_model(self):
52 | if self.n_GPUs <= 1 or self.cpu:
53 | return self.model
54 | else:
55 | return self.model.module
56 |
57 | def state_dict(self, **kwargs):
58 | target = self.get_model()
59 | return target.state_dict(**kwargs)
60 |
61 | def save(self, apath, epoch, is_best=False):
62 | target = self.get_model()
63 | torch.save(
64 | target.state_dict(),
65 | os.path.join(apath, 'model', 'model_latest.pt')
66 | )
67 | if is_best:
68 | torch.save(
69 | target.state_dict(),
70 | os.path.join(apath, 'model', 'model_best.pt')
71 | )
72 |
73 | if self.save_models:
74 | torch.save(
75 | target.state_dict(),
76 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
77 | )
78 |
79 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
80 | if cpu:
81 | kwargs = {'map_location': lambda storage, loc: storage}
82 | else:
83 | kwargs = {}
84 |
85 | if resume == -1:
86 | self.get_model().load_state_dict(
87 | torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs),
88 | strict=True
89 | )
90 |
91 | elif resume == 0:
92 | if pre_train != '.':
93 | self.get_model().load_state_dict(
94 | torch.load(pre_train, **kwargs),
95 | strict=True
96 | )
97 |
98 | elif resume > 0:
99 | self.get_model().load_state_dict(
100 | torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs),
101 | strict=False
102 | )
103 |
104 | def forward_chop(self, x, shave=10, min_size=160000):
105 | scale = self.scale[self.idx_scale]
106 | n_GPUs = min(self.n_GPUs, 4)
107 | b, c, h, w = x.size()
108 | h_half, w_half = h // 2, w // 2
109 | h_size, w_size = h_half + shave, w_half + shave
110 | lr_list = [
111 | x[:, :, 0:h_size, 0:w_size],
112 | x[:, :, 0:h_size, (w - w_size):w],
113 | x[:, :, (h - h_size):h, 0:w_size],
114 | x[:, :, (h - h_size):h, (w - w_size):w]]
115 |
116 | if w_size * h_size < min_size:
117 | sr_list = []
118 | for i in range(0, 4, n_GPUs):
119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
120 | sr_batch = self.model(lr_batch)
121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
122 | else:
123 | sr_list = [
124 | self.forward_chop(patch, shave=shave, min_size=min_size) \
125 | for patch in lr_list
126 | ]
127 |
128 | h, w = scale * h, scale * w
129 | h_half, w_half = scale * h_half, scale * w_half
130 | h_size, w_size = scale * h_size, scale * w_size
131 | shave *= scale
132 |
133 | output = x.new(b, c, h, w)
134 | output[:, :, 0:h_half, 0:w_half] \
135 | = sr_list[0][:, :, 0:h_half, 0:w_half]
136 | output[:, :, 0:h_half, w_half:w] \
137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
138 | output[:, :, h_half:h, 0:w_half] \
139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
140 | output[:, :, h_half:h, w_half:w] \
141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
142 |
143 | return output
144 |
145 | def forward_x8(self, x, forward_function):
146 | def _transform(v, op):
147 | if self.precision != 'single': v = v.float()
148 |
149 | v2np = v.data.cpu().numpy()
150 | if op == 'v':
151 | tfnp = v2np[:, :, :, ::-1].copy()
152 | elif op == 'h':
153 | tfnp = v2np[:, :, ::-1, :].copy()
154 | elif op == 't':
155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
156 |
157 | ret = torch.Tensor(tfnp).to(self.device)
158 | if self.precision == 'half': ret = ret.half()
159 |
160 | return ret
161 |
162 | lr_list = [x]
163 | for tf in 'v', 'h', 't':
164 | lr_list.extend([_transform(t, tf) for t in lr_list])
165 |
166 | sr_list = [forward_function(aug) for aug in lr_list]
167 | for i in range(len(sr_list)):
168 | if i > 3:
169 | sr_list[i] = _transform(sr_list[i], 't')
170 | if i % 4 > 1:
171 | sr_list[i] = _transform(sr_list[i], 'h')
172 | if (i % 4) % 2 == 1:
173 | sr_list[i] = _transform(sr_list[i], 'v')
174 |
175 | output_cat = torch.cat(sr_list, dim=0)
176 | output = output_cat.mean(dim=0, keepdim=True)
177 |
178 | return output
179 |
180 |
--------------------------------------------------------------------------------
/model_ST/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args, ckp):
10 | super(Model, self).__init__()
11 | print('Making model...')
12 | self.args = args
13 | self.scale = args.scale
14 | self.idx_scale = 0
15 | self.self_ensemble = args.self_ensemble
16 | self.chop = args.chop
17 | self.precision = args.precision
18 | self.cpu = args.cpu
19 | self.device = torch.device('cpu' if args.cpu else 'cuda')
20 | self.n_GPUs = args.n_GPUs
21 | self.save_models = args.save_models
22 | self.save = args.save
23 |
24 | module = import_module('model_ST.'+args.model)
25 | self.model = module.make_model(args).to(self.device)
26 | if args.precision == 'half': self.model.half()
27 |
28 | if not args.cpu and args.n_GPUs > 1:
29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs))
30 |
31 | self.load(
32 | ckp.dir,
33 | pre_train=args.pre_train_ST,
34 | resume=args.resume,
35 | cpu=args.cpu
36 | )
37 |
38 | def forward(self, x):
39 | if self.self_ensemble and not self.training:
40 | if self.chop:
41 | forward_function = self.forward_chop
42 | else:
43 | forward_function = self.model.forward
44 |
45 | return self.forward_x8(x, forward_function)
46 | elif self.chop and not self.training:
47 | return self.forward_chop(x)
48 | else:
49 | return self.model(x)
50 |
51 | def get_model(self):
52 | if self.n_GPUs <= 1 or self.cpu:
53 | return self.model
54 | else:
55 | return self.model.module
56 |
57 | def state_dict(self, **kwargs):
58 | target = self.get_model()
59 | return target.state_dict(**kwargs)
60 |
61 | def save(self, apath, epoch, is_best=False):
62 | target = self.get_model()
63 | torch.save(
64 | target.state_dict(),
65 | os.path.join(apath, 'model', 'model_ST_latest.pt')
66 | )
67 | if is_best:
68 | torch.save(
69 | target.state_dict(),
70 | os.path.join(apath, 'model', 'model_ST_best.pt')
71 | )
72 |
73 | if self.save_models:
74 | torch.save(
75 | target.state_dict(),
76 | os.path.join(apath, 'model', 'model_ST_{}.pt'.format(epoch))
77 | )
78 |
79 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
80 | if cpu:
81 | kwargs = {'map_location': lambda storage, loc: storage}
82 | else:
83 | kwargs = {}
84 |
85 | if resume == -1:
86 | self.get_model().load_state_dict(
87 | torch.load(os.path.join(apath, 'model', 'model_ST_latest.pt'), **kwargs),
88 | strict=True
89 | )
90 |
91 | elif resume == 0:
92 | if pre_train != '.':
93 | self.get_model().load_state_dict(
94 | torch.load(pre_train, **kwargs),
95 | strict=False
96 | )
97 |
98 | elif resume > 0:
99 | self.get_model().load_state_dict(
100 | torch.load(os.path.join(apath, 'model', 'model_ST_{}.pt'.format(resume)), **kwargs),
101 | strict=False
102 | )
103 |
104 | def forward_chop(self, x, shave=10, min_size=160000):
105 | scale = self.scale[self.idx_scale]
106 | n_GPUs = min(self.n_GPUs, 4)
107 | b, c, h, w = x.size()
108 | h_half, w_half = h // 2, w // 2
109 | h_size, w_size = h_half + shave, w_half + shave
110 | lr_list = [
111 | x[:, :, 0:h_size, 0:w_size],
112 | x[:, :, 0:h_size, (w - w_size):w],
113 | x[:, :, (h - h_size):h, 0:w_size],
114 | x[:, :, (h - h_size):h, (w - w_size):w]]
115 |
116 | if w_size * h_size < min_size:
117 | sr_list = []
118 | for i in range(0, 4, n_GPUs):
119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
120 | sr_batch = self.model(lr_batch)
121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
122 | else:
123 | sr_list = [
124 | self.forward_chop(patch, shave=shave, min_size=min_size) \
125 | for patch in lr_list
126 | ]
127 |
128 | h, w = scale * h, scale * w
129 | h_half, w_half = scale * h_half, scale * w_half
130 | h_size, w_size = scale * h_size, scale * w_size
131 | shave *= scale
132 |
133 | output = x.new(b, c, h, w)
134 | output[:, :, 0:h_half, 0:w_half] \
135 | = sr_list[0][:, :, 0:h_half, 0:w_half]
136 | output[:, :, 0:h_half, w_half:w] \
137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
138 | output[:, :, h_half:h, 0:w_half] \
139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
140 | output[:, :, h_half:h, w_half:w] \
141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
142 |
143 | return output
144 |
145 | def forward_x8(self, x, forward_function):
146 | def _transform(v, op):
147 | if self.precision != 'single': v = v.float()
148 |
149 | v2np = v.data.cpu().numpy()
150 | if op == 'v':
151 | tfnp = v2np[:, :, :, ::-1].copy()
152 | elif op == 'h':
153 | tfnp = v2np[:, :, ::-1, :].copy()
154 | elif op == 't':
155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
156 |
157 | ret = torch.Tensor(tfnp).to(self.device)
158 | if self.precision == 'half': ret = ret.half()
159 |
160 | return ret
161 |
162 | lr_list = [x]
163 | for tf in 'v', 'h', 't':
164 | lr_list.extend([_transform(t, tf) for t in lr_list])
165 |
166 | sr_list = [forward_function(aug) for aug in lr_list]
167 | for i in range(len(sr_list)):
168 | if i > 3:
169 | sr_list[i] = _transform(sr_list[i], 't')
170 | if i % 4 > 1:
171 | sr_list[i] = _transform(sr_list[i], 'h')
172 | if (i % 4) % 2 == 1:
173 | sr_list[i] = _transform(sr_list[i], 'v')
174 |
175 | output_cat = torch.cat(sr_list, dim=0)
176 | output = output_cat.mean(dim=0, keepdim=True)
177 |
178 | return output
179 |
180 |
--------------------------------------------------------------------------------
/model_TA/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args, ckp):
10 | super(Model, self).__init__()
11 | print('Making model...')
12 | self.args = args
13 | self.scale = args.scale
14 | self.idx_scale = 0
15 | self.self_ensemble = args.self_ensemble
16 | self.chop = args.chop
17 | self.precision = args.precision
18 | self.cpu = args.cpu
19 | self.device = torch.device('cpu' if args.cpu else 'cuda')
20 | self.n_GPUs = args.n_GPUs
21 | self.save_models = args.save_models
22 | self.save = args.save
23 |
24 | module = import_module('model_TA.'+args.model)
25 | self.model = module.make_model(args).to(self.device)
26 | if args.precision == 'half': self.model.half()
27 |
28 | if not args.cpu and args.n_GPUs > 1:
29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs))
30 |
31 | self.load(
32 | ckp.dir,
33 | pre_train=args.pre_train_TA,
34 | resume=args.resume,
35 | cpu=args.cpu
36 | )
37 |
38 | def forward(self, x, deg_repre):
39 | if self.self_ensemble and not self.training:
40 | if self.chop:
41 | forward_function = self.forward_chop
42 | else:
43 | forward_function = self.model.forward
44 |
45 | return self.forward_x8(x, forward_function)
46 | elif self.chop and not self.training:
47 | return self.forward_chop(x)
48 | else:
49 | return self.model(x, deg_repre)
50 |
51 | def get_model(self):
52 | if self.n_GPUs <= 1 or self.cpu:
53 | return self.model
54 | else:
55 | return self.model.module
56 |
57 | def state_dict(self, **kwargs):
58 | target = self.get_model()
59 | return target.state_dict(**kwargs)
60 |
61 | def save(self, apath, epoch, is_best=False):
62 | target = self.get_model()
63 | torch.save(
64 | target.state_dict(),
65 | os.path.join(apath, 'model', 'model_latest.pt')
66 | )
67 | if is_best:
68 | torch.save(
69 | target.state_dict(),
70 | os.path.join(apath, 'model', 'model_best.pt')
71 | )
72 |
73 | if self.save_models:
74 | torch.save(
75 | target.state_dict(),
76 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
77 | )
78 |
79 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
80 | if cpu:
81 | kwargs = {'map_location': lambda storage, loc: storage}
82 | else:
83 | kwargs = {}
84 |
85 | if resume == -1:
86 | self.get_model().load_state_dict(
87 | torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs),
88 | strict=True
89 | )
90 |
91 | elif resume == 0:
92 | if pre_train != '.':
93 | self.get_model().load_state_dict(
94 | torch.load(pre_train, **kwargs),
95 | strict=True
96 | )
97 |
98 | elif resume > 0:
99 | self.get_model().load_state_dict(
100 | torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs),
101 | strict=False
102 | )
103 |
104 | def forward_chop(self, x, shave=10, min_size=160000):
105 | scale = self.scale[self.idx_scale]
106 | n_GPUs = min(self.n_GPUs, 4)
107 | b, c, h, w = x.size()
108 | h_half, w_half = h // 2, w // 2
109 | h_size, w_size = h_half + shave, w_half + shave
110 | lr_list = [
111 | x[:, :, 0:h_size, 0:w_size],
112 | x[:, :, 0:h_size, (w - w_size):w],
113 | x[:, :, (h - h_size):h, 0:w_size],
114 | x[:, :, (h - h_size):h, (w - w_size):w]]
115 |
116 | if w_size * h_size < min_size:
117 | sr_list = []
118 | for i in range(0, 4, n_GPUs):
119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
120 | sr_batch = self.model(lr_batch)
121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
122 | else:
123 | sr_list = [
124 | self.forward_chop(patch, shave=shave, min_size=min_size) \
125 | for patch in lr_list
126 | ]
127 |
128 | h, w = scale * h, scale * w
129 | h_half, w_half = scale * h_half, scale * w_half
130 | h_size, w_size = scale * h_size, scale * w_size
131 | shave *= scale
132 |
133 | output = x.new(b, c, h, w)
134 | output[:, :, 0:h_half, 0:w_half] \
135 | = sr_list[0][:, :, 0:h_half, 0:w_half]
136 | output[:, :, 0:h_half, w_half:w] \
137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
138 | output[:, :, h_half:h, 0:w_half] \
139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
140 | output[:, :, h_half:h, w_half:w] \
141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
142 |
143 | return output
144 |
145 | def forward_x8(self, x, forward_function):
146 | def _transform(v, op):
147 | if self.precision != 'single': v = v.float()
148 |
149 | v2np = v.data.cpu().numpy()
150 | if op == 'v':
151 | tfnp = v2np[:, :, :, ::-1].copy()
152 | elif op == 'h':
153 | tfnp = v2np[:, :, ::-1, :].copy()
154 | elif op == 't':
155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
156 |
157 | ret = torch.Tensor(tfnp).to(self.device)
158 | if self.precision == 'half': ret = ret.half()
159 |
160 | return ret
161 |
162 | lr_list = [x]
163 | for tf in 'v', 'h', 't':
164 | lr_list.extend([_transform(t, tf) for t in lr_list])
165 |
166 | sr_list = [forward_function(aug) for aug in lr_list]
167 | for i in range(len(sr_list)):
168 | if i > 3:
169 | sr_list[i] = _transform(sr_list[i], 't')
170 | if i % 4 > 1:
171 | sr_list[i] = _transform(sr_list[i], 'h')
172 | if (i % 4) % 2 == 1:
173 | sr_list[i] = _transform(sr_list[i], 'v')
174 |
175 | output_cat = torch.cat(sr_list, dim=0)
176 | output = output_cat.mean(dim=0, keepdim=True)
177 |
178 | return output
179 |
180 |
--------------------------------------------------------------------------------
/model_meta_stage3/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args, ckp):
10 | super(Model, self).__init__()
11 | print('Making model...')
12 | self.args = args
13 | self.scale = args.scale
14 | self.idx_scale = 0
15 | self.self_ensemble = args.self_ensemble
16 | self.chop = args.chop
17 | self.precision = args.precision
18 | self.cpu = args.cpu
19 | self.device = torch.device('cpu' if args.cpu else 'cuda')
20 | self.n_GPUs = args.n_GPUs
21 | self.save_models = args.save_models
22 | self.save = args.save
23 |
24 | module = import_module('model_meta_stage3.'+args.model)
25 | self.model = module.make_model(args).to(self.device)
26 | if args.precision == 'half': self.model.half()
27 |
28 | if not args.cpu and args.n_GPUs > 1:
29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs))
30 |
31 | self.load(
32 | ckp.dir,
33 | pre_train=args.pre_train_meta,
34 | resume=args.resume,
35 | cpu=args.cpu
36 | )
37 |
38 | def forward(self, lr,lr_bic):
39 | if self.self_ensemble and not self.training:
40 | if self.chop:
41 | forward_function = self.forward_chop
42 | else:
43 | forward_function = self.model.forward
44 |
45 | return self.forward_x8(lr, forward_function)
46 | elif self.chop and not self.training:
47 | return self.forward_chop(lr)
48 | else:
49 | return self.model(lr,lr_bic)
50 |
51 | def get_model(self):
52 | if self.n_GPUs <= 1 or self.cpu:
53 | return self.model
54 | else:
55 | return self.model.module
56 |
57 | def state_dict(self, **kwargs):
58 | target = self.get_model()
59 | return target.state_dict(**kwargs)
60 |
61 | def save(self, apath, epoch, is_best=False):
62 | target = self.get_model()
63 | torch.save(
64 | target.state_dict(),
65 | os.path.join(apath, 'model', 'model_latest.pt')
66 | )
67 | if is_best:
68 | torch.save(
69 | target.state_dict(),
70 | os.path.join(apath, 'model', 'model_best.pt')
71 | )
72 |
73 | if self.save_models:
74 | torch.save(
75 | target.state_dict(),
76 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
77 | )
78 |
79 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
80 | if cpu:
81 | kwargs = {'map_location': lambda storage, loc: storage}
82 | else:
83 | kwargs = {}
84 |
85 | if resume == -1:
86 | self.get_model().load_state_dict(
87 | torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs),
88 | strict=True
89 | )
90 |
91 | elif resume == 0:
92 | if pre_train != '.':
93 | self.get_model().load_state_dict(
94 | torch.load(pre_train, **kwargs),
95 | strict=True
96 | )
97 |
98 | elif resume > 0:
99 | self.get_model().load_state_dict(
100 | torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs),
101 | strict=False
102 | )
103 |
104 | def forward_chop(self, x, shave=10, min_size=160000):
105 | scale = self.scale[self.idx_scale]
106 | n_GPUs = min(self.n_GPUs, 4)
107 | b, c, h, w = x.size()
108 | h_half, w_half = h // 2, w // 2
109 | h_size, w_size = h_half + shave, w_half + shave
110 | lr_list = [
111 | x[:, :, 0:h_size, 0:w_size],
112 | x[:, :, 0:h_size, (w - w_size):w],
113 | x[:, :, (h - h_size):h, 0:w_size],
114 | x[:, :, (h - h_size):h, (w - w_size):w]]
115 |
116 | if w_size * h_size < min_size:
117 | sr_list = []
118 | for i in range(0, 4, n_GPUs):
119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
120 | sr_batch = self.model(lr_batch)
121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
122 | else:
123 | sr_list = [
124 | self.forward_chop(patch, shave=shave, min_size=min_size) \
125 | for patch in lr_list
126 | ]
127 |
128 | h, w = scale * h, scale * w
129 | h_half, w_half = scale * h_half, scale * w_half
130 | h_size, w_size = scale * h_size, scale * w_size
131 | shave *= scale
132 |
133 | output = x.new(b, c, h, w)
134 | output[:, :, 0:h_half, 0:w_half] \
135 | = sr_list[0][:, :, 0:h_half, 0:w_half]
136 | output[:, :, 0:h_half, w_half:w] \
137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
138 | output[:, :, h_half:h, 0:w_half] \
139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
140 | output[:, :, h_half:h, w_half:w] \
141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
142 |
143 | return output
144 |
145 | def forward_x8(self, x, forward_function):
146 | def _transform(v, op):
147 | if self.precision != 'single': v = v.float()
148 |
149 | v2np = v.data.cpu().numpy()
150 | if op == 'v':
151 | tfnp = v2np[:, :, :, ::-1].copy()
152 | elif op == 'h':
153 | tfnp = v2np[:, :, ::-1, :].copy()
154 | elif op == 't':
155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
156 |
157 | ret = torch.Tensor(tfnp).to(self.device)
158 | if self.precision == 'half': ret = ret.half()
159 |
160 | return ret
161 |
162 | lr_list = [x]
163 | for tf in 'v', 'h', 't':
164 | lr_list.extend([_transform(t, tf) for t in lr_list])
165 |
166 | sr_list = [forward_function(aug) for aug in lr_list]
167 | for i in range(len(sr_list)):
168 | if i > 3:
169 | sr_list[i] = _transform(sr_list[i], 't')
170 | if i % 4 > 1:
171 | sr_list[i] = _transform(sr_list[i], 'h')
172 | if (i % 4) % 2 == 1:
173 | sr_list[i] = _transform(sr_list[i], 'v')
174 |
175 | output_cat = torch.cat(sr_list, dim=0)
176 | output = output_cat.mean(dim=0, keepdim=True)
177 |
178 | return output
179 |
180 |
--------------------------------------------------------------------------------
/model_meta/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args, ckp):
10 | super(Model, self).__init__()
11 | print('Making model...')
12 | self.args = args
13 | self.scale = args.scale
14 | self.idx_scale = 0
15 | self.self_ensemble = args.self_ensemble
16 | self.chop = args.chop
17 | self.precision = args.precision
18 | self.cpu = args.cpu
19 | self.device = torch.device('cpu' if args.cpu else 'cuda')
20 | self.n_GPUs = args.n_GPUs
21 | self.save_models = args.save_models
22 | self.save = args.save
23 |
24 | module = import_module('model_meta.'+args.model)
25 | self.model = module.make_model(args).to(self.device)
26 | if args.precision == 'half': self.model.half()
27 |
28 | if not args.cpu and args.n_GPUs > 1:
29 | self.model = nn.DataParallel(self.model, range(args.n_GPUs))
30 |
31 | self.load(
32 | ckp.dir,
33 | pre_train=args.pre_train,
34 | resume=args.resume,
35 | cpu=args.cpu
36 | )
37 |
38 | def forward(self, lr,lr_bic, weights):
39 | if self.self_ensemble and not self.training:
40 | if self.chop:
41 | forward_function = self.forward_chop
42 | else:
43 | forward_function = self.model.forward
44 |
45 | return self.forward_x8(lr, forward_function)
46 | elif self.chop and not self.training:
47 | return self.forward_chop(lr)
48 | else:
49 | return self.model(lr,lr_bic, weights)
50 |
51 | def get_model(self):
52 | if self.n_GPUs <= 1 or self.cpu:
53 | return self.model
54 | else:
55 | return self.model.module
56 |
57 | def state_dict(self, **kwargs):
58 | target = self.get_model()
59 | return target.state_dict(**kwargs)
60 |
61 | def save(self, apath, epoch, is_best=False):
62 | target = self.get_model()
63 | torch.save(
64 | target.state_dict(),
65 | os.path.join(apath, 'model', 'model_meta_latest.pt')
66 | )
67 | if is_best:
68 | torch.save(
69 | target.state_dict(),
70 | os.path.join(apath, 'model', 'model_meta_best.pt')
71 | )
72 |
73 | if self.save_models:
74 | torch.save(
75 | target.state_dict(),
76 | os.path.join(apath, 'model', 'model_meta_{}.pt'.format(epoch))
77 | )
78 |
79 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
80 | if cpu:
81 | kwargs = {'map_location': lambda storage, loc: storage}
82 | else:
83 | kwargs = {}
84 |
85 | if resume == -1:
86 | self.get_model().load_state_dict(
87 | torch.load(os.path.join(apath, 'model', 'model_meta_latest.pt'), **kwargs),
88 | strict=True
89 | )
90 |
91 | elif resume == 0:
92 | if pre_train != '.':
93 | self.get_model().load_state_dict(
94 | torch.load(pre_train, **kwargs),
95 | strict=True
96 | )
97 |
98 | elif resume > 0:
99 | self.get_model().load_state_dict(
100 | torch.load(os.path.join(apath, 'model', 'model_meta_{}.pt'.format(resume)), **kwargs),
101 | strict=False
102 | )
103 |
104 | def forward_chop(self, x, shave=10, min_size=160000):
105 | scale = self.scale[self.idx_scale]
106 | n_GPUs = min(self.n_GPUs, 4)
107 | b, c, h, w = x.size()
108 | h_half, w_half = h // 2, w // 2
109 | h_size, w_size = h_half + shave, w_half + shave
110 | lr_list = [
111 | x[:, :, 0:h_size, 0:w_size],
112 | x[:, :, 0:h_size, (w - w_size):w],
113 | x[:, :, (h - h_size):h, 0:w_size],
114 | x[:, :, (h - h_size):h, (w - w_size):w]]
115 |
116 | if w_size * h_size < min_size:
117 | sr_list = []
118 | for i in range(0, 4, n_GPUs):
119 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
120 | sr_batch = self.model(lr_batch)
121 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
122 | else:
123 | sr_list = [
124 | self.forward_chop(patch, shave=shave, min_size=min_size) \
125 | for patch in lr_list
126 | ]
127 |
128 | h, w = scale * h, scale * w
129 | h_half, w_half = scale * h_half, scale * w_half
130 | h_size, w_size = scale * h_size, scale * w_size
131 | shave *= scale
132 |
133 | output = x.new(b, c, h, w)
134 | output[:, :, 0:h_half, 0:w_half] \
135 | = sr_list[0][:, :, 0:h_half, 0:w_half]
136 | output[:, :, 0:h_half, w_half:w] \
137 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
138 | output[:, :, h_half:h, 0:w_half] \
139 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
140 | output[:, :, h_half:h, w_half:w] \
141 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
142 |
143 | return output
144 |
145 | def forward_x8(self, x, forward_function):
146 | def _transform(v, op):
147 | if self.precision != 'single': v = v.float()
148 |
149 | v2np = v.data.cpu().numpy()
150 | if op == 'v':
151 | tfnp = v2np[:, :, :, ::-1].copy()
152 | elif op == 'h':
153 | tfnp = v2np[:, :, ::-1, :].copy()
154 | elif op == 't':
155 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
156 |
157 | ret = torch.Tensor(tfnp).to(self.device)
158 | if self.precision == 'half': ret = ret.half()
159 |
160 | return ret
161 |
162 | lr_list = [x]
163 | for tf in 'v', 'h', 't':
164 | lr_list.extend([_transform(t, tf) for t in lr_list])
165 |
166 | sr_list = [forward_function(aug) for aug in lr_list]
167 | for i in range(len(sr_list)):
168 | if i > 3:
169 | sr_list[i] = _transform(sr_list[i], 't')
170 | if i % 4 > 1:
171 | sr_list[i] = _transform(sr_list[i], 'h')
172 | if (i % 4) % 2 == 1:
173 | sr_list[i] = _transform(sr_list[i], 'v')
174 |
175 | output_cat = torch.cat(sr_list, dim=0)
176 | output = output_cat.mean(dim=0, keepdim=True)
177 |
178 | return output
179 |
180 |
--------------------------------------------------------------------------------
/trainer_stage1.py:
--------------------------------------------------------------------------------
1 | import os
2 | import utility
3 | import torch
4 | from decimal import Decimal
5 | import torch.nn.functional as F
6 | from utils import util
7 | import numpy as np
8 |
9 |
10 | class Trainer():
11 | def __init__(self, args, loader, my_model, my_loss, ckp):
12 | self.args = args
13 | self.scale = args.scale
14 | self.test_res_psnr = []
15 | self.test_res_ssim = []
16 | self.ckp = ckp
17 | self.loader_train = loader.loader_train
18 | self.loader_test = loader.loader_test
19 | self.model = my_model
20 | self.loss = my_loss
21 | self.optimizer = utility.make_optimizer(args, self.model)
22 | self.scheduler = utility.make_scheduler(args, self.optimizer)
23 |
24 | if self.args.load != '.':
25 | self.optimizer.load_state_dict(
26 | torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
27 | )
28 | for _ in range(len(ckp.log)): self.scheduler.step()
29 |
30 | def train(self):
31 | self.scheduler.step()
32 | self.loss.step()
33 | epoch = self.scheduler.last_epoch + 1
34 |
35 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr))
36 | for param_group in self.optimizer.param_groups:
37 | param_group['lr'] = lr
38 |
39 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)))
40 | self.loss.start_log()
41 | self.model.train()
42 |
43 | degrade = util.BicubicPreprocessing(
44 | self.scale[0],
45 | rgb_range = self.args.rgb_range
46 | )
47 |
48 | timer = utility.timer()
49 | losses_sr = utility.AverageMeter()
50 |
51 | for batch, (hr, _,) in enumerate(self.loader_train):
52 | hr = hr.cuda() # b, c, h, w
53 | lr,lr_bic = degrade(hr) # b, c, h, w
54 | #b_kernels, noise_level = degradation
55 | self.optimizer.zero_grad()
56 |
57 | timer.tic()
58 | # forward
59 | ## train the whole network
60 | sr = self.model(lr,lr_bic)
61 | loss_SR = self.loss(sr, hr)
62 | loss = loss_SR
63 | losses_sr.update(loss_SR.item())
64 |
65 | # backward
66 | loss.backward()
67 | self.optimizer.step()
68 | timer.hold()
69 |
70 | if (batch + 1) % self.args.print_every == 0:
71 | self.ckp.write_log(
72 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t'
73 | 'Loss [SR loss:{:.3f}]\t'
74 | 'Time [{:.1f}s]'.format(
75 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
76 | losses_sr.avg,
77 | timer.release(),
78 | ))
79 |
80 | self.loss.end_log(len(self.loader_train))
81 |
82 | # save model
83 | if epoch > self.args.st_save_epoch or epoch%30==0:
84 | target = self.model.get_model()
85 | model_dict = target.state_dict()
86 | torch.save(
87 | model_dict,
88 | os.path.join(self.ckp.dir, 'model', 'model_{}.pt'.format(epoch))
89 | )
90 |
91 | target = self.model.get_model()
92 | model_dict = target.state_dict()
93 | torch.save(
94 | model_dict,
95 | os.path.join(self.ckp.dir, 'model', 'model_last.pt')
96 | )
97 |
98 | def test(self):
99 | self.ckp.write_log('\nEvaluation:')
100 | self.ckp.add_log(torch.zeros(1, len(self.scale)))
101 | self.model.eval()
102 |
103 | timer_test = utility.timer()
104 | degrade = util.BicubicPreprocessing(
105 | self.scale[0],
106 | rgb_range=self.args.rgb_range
107 | )
108 |
109 | with torch.no_grad():
110 | for idx_scale, d in enumerate(self.loader_test):
111 | for idx_scale, scale in enumerate(self.scale):
112 | d.dataset.set_scale(idx_scale)
113 | eval_psnr = 0
114 | eval_ssim = 0
115 | for idx_img, (hr, filename) in enumerate(d):
116 | hr = hr.cuda() # b, c, h, w
117 | hr = self.crop_border(hr, scale)
118 | lr,lr_bic = degrade(hr) # b, c, h, w
119 |
120 | # inference
121 | timer_test.tic()
122 | sr = self.model(lr,lr_bic)
123 | timer_test.hold()
124 |
125 | sr = utility.quantize(sr, self.args.rgb_range)
126 | hr = utility.quantize(hr, self.args.rgb_range)
127 |
128 | # metrics
129 | eval_psnr += utility.calc_psnr(
130 | sr, hr, scale, self.args.rgb_range,
131 | benchmark=True
132 | )
133 | eval_ssim += utility.calc_ssim(
134 | (sr * 255).round().clamp(0, 255), (hr * 255).round().clamp(0, 255), scale,
135 | benchmark=True
136 | )
137 |
138 | # save results
139 | if self.args.save_results:
140 | save_list = [sr]
141 | filename = filename[0]
142 | self.ckp.save_results(filename, save_list, scale)
143 |
144 | if len(self.test_res_psnr) > 10:
145 | self.test_res_psnr.pop(0)
146 | self.test_res_ssim.pop(0)
147 | self.test_res_psnr.append(eval_psnr / len(self.loader_test))
148 | self.test_res_ssim.append(eval_ssim / len(self.loader_test))
149 |
150 | self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test)
151 | self.ckp.write_log(
152 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format(
153 | self.args.resume,
154 | self.args.data_test,
155 | scale,
156 | eval_psnr / len(self.loader_test),
157 | eval_ssim / len(self.loader_test),
158 | np.mean(self.test_res_psnr),
159 | np.mean(self.test_res_ssim)
160 | ))
161 |
162 | def crop_border(self, img_hr, scale):
163 | b, c, h, w = img_hr.size()
164 |
165 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)]
166 |
167 | return img_hr
168 |
169 | def terminate(self):
170 | if self.args.test_only:
171 | self.test()
172 | return True
173 | else:
174 | epoch = self.scheduler.last_epoch + 1
175 | return epoch >= self.args.epochs_sr
176 |
177 |
--------------------------------------------------------------------------------
/model_TA/blindsr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import model.common as common
4 | import torch.nn.functional as F
5 |
6 |
7 | def make_model(args):
8 | return BlindSR(args)
9 |
10 | class RDA_conv(nn.Module):
11 | def __init__(self, channels_in, channels_out, kernel_size, reduction):
12 | super(RDA_conv, self).__init__()
13 | self.channels_out = channels_out
14 | self.channels_in = channels_in
15 | self.kernel_size = kernel_size
16 |
17 | self.kernel = nn.Sequential(
18 | nn.Linear(64, 64, bias=False),
19 | nn.LeakyReLU(0.1, True),
20 | nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False)
21 | )
22 | self.conv = common.default_conv(channels_in, 1, 1)
23 | self.relu = nn.LeakyReLU(0.1, True)
24 | self.sig = nn.Sigmoid()
25 | def forward(self, x):
26 | '''
27 | :param x[0]: feature map: B * C * H * W
28 | :param x[1]: degradation representation: B * C
29 | '''
30 | b, c, h, w = x[0].size()
31 |
32 | # branch 1
33 | kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size)
34 | out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2))
35 | out = out.view(b, -1, h, w)
36 | M = self.sig(self.conv(x[0]))
37 | # branch 2
38 | out = out*M + x[0]
39 |
40 | return out
41 |
42 |
43 | class RDAB(nn.Module):
44 | def __init__(self, conv, n_feat, kernel_size, reduction):
45 | super(RDAB, self).__init__()
46 |
47 | self.da_conv1 = RDA_conv(n_feat, n_feat, kernel_size, reduction)
48 | self.da_conv2 = RDA_conv(n_feat, n_feat, kernel_size, reduction)
49 | self.conv1 = conv(n_feat, n_feat, kernel_size)
50 | self.conv2 = conv(n_feat, n_feat, kernel_size)
51 |
52 | self.relu = nn.LeakyReLU(0.1, True)
53 |
54 | def forward(self, x):
55 | '''
56 | :param x[0]: feature map: B * C * H * W
57 | :param x[1]: degradation representation: B * C
58 | '''
59 |
60 | out = self.relu(self.da_conv1(x))
61 | out = self.relu(self.conv1(out))
62 | out = self.relu(self.da_conv2([out, x[1]]))
63 | out = self.conv2(out) + x[0]
64 |
65 | return out
66 |
67 |
68 | class RDAG(nn.Module):
69 | def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks):
70 | super(RDAG, self).__init__()
71 | self.n_blocks = n_blocks
72 | modules_body = [
73 | RDAB(conv, n_feat, kernel_size, reduction) \
74 | for _ in range(n_blocks)
75 | ]
76 | #modules_body.append(conv(n_feat, n_feat, kernel_size))
77 |
78 | self.body = nn.Sequential(*modules_body)
79 |
80 | def forward(self, x):
81 | '''
82 | :param x[0]: feature map: B * C * H * W
83 | :param x[1]: degradation representation: B * C
84 | '''
85 | res = x[0]
86 | for i in range(self.n_blocks):
87 | res = self.body[i]([res, x[1]])
88 |
89 |
90 | return res
91 |
92 |
93 | class RDAN(nn.Module):
94 | def __init__(self, args, conv=common.default_conv):
95 | super(RDAN, self).__init__()
96 | n_blocks = 27
97 | n_feats = 64
98 | kernel_size = 3
99 | reduction = 8
100 | scale = int(args.scale[0])
101 |
102 | # RGB mean for DIV2K
103 | rgb_mean = (0.4488, 0.4371, 0.4040)
104 | rgb_std = (1.0, 1.0, 1.0)
105 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
106 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
107 |
108 | # head module
109 | modules_head = [conv(3, n_feats, kernel_size)]
110 | self.head = nn.Sequential(*modules_head)
111 |
112 | # compress
113 | self.compress = nn.Sequential(
114 | nn.Linear(256, 64, bias=False),
115 | nn.LeakyReLU(0.1, True)
116 | )
117 |
118 | # body
119 | modules_body = [
120 | RDAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks)
121 | ]
122 | modules_body.append(conv(n_feats, n_feats, kernel_size))
123 | self.body = nn.Sequential(*modules_body)
124 |
125 | # tail
126 | modules_tail = [common.Upsampler(conv, scale, n_feats, act=False),
127 | conv(n_feats, 3, kernel_size)]
128 | self.tail = nn.Sequential(*modules_tail)
129 |
130 | def forward(self, x, k_v):
131 | k_v = self.compress(k_v)
132 |
133 | # sub mean
134 | x = self.sub_mean(x)
135 |
136 | # head
137 | x = self.head(x)
138 |
139 | # body
140 | res = x
141 | res = self.body[0]([res, k_v])
142 | res = self.body[-1](res)
143 | res = res + x
144 |
145 | # tail
146 | x = self.tail(res)
147 |
148 | # add mean
149 | x = self.add_mean(x)
150 |
151 | return x
152 |
153 |
154 | class DEN(nn.Module):
155 | def __init__(self,args):
156 | super(DEN, self).__init__()
157 | n_feats = args.n_feats
158 | self.E = nn.Sequential(
159 | nn.Conv2d(n_feats, 64, kernel_size=3, padding=1),
160 | nn.LeakyReLU(0.1, True),
161 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
162 | nn.LeakyReLU(0.1, True),
163 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
164 | nn.LeakyReLU(0.1, True),
165 | nn.Conv2d(128, 128, kernel_size=3, padding=1),
166 | nn.LeakyReLU(0.1, True),
167 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
168 | nn.LeakyReLU(0.1, True),
169 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
170 | nn.LeakyReLU(0.1, True),
171 | nn.AdaptiveAvgPool2d(1),
172 | )
173 | self.mlp = nn.Sequential(
174 | nn.Linear(256, 256),
175 | nn.LeakyReLU(0.1, True),
176 | nn.Linear(256, 256),
177 | )
178 |
179 | def forward(self, x):
180 | fea = self.E(x).squeeze(-1).squeeze(-1)
181 | T_fea = []
182 | for i in range(len(self.mlp)):
183 | fea = self.mlp[i](fea)
184 | if i==2:
185 | T_fea.append(fea)
186 |
187 | return fea,T_fea
188 |
189 |
190 | class BlindSR(nn.Module):
191 | def __init__(self, args):
192 | super(BlindSR, self).__init__()
193 |
194 | # Generator
195 | self.G = RDAN(args)
196 |
197 | self.E = DEN(args)
198 |
199 |
200 | def forward(self, x, deg_repre):
201 | if self.training:
202 |
203 | # degradation-aware represenetion learning
204 | deg_repre, T_fea = self.E(deg_repre)
205 |
206 | # degradation-aware SR
207 | sr = self.G(x, deg_repre)
208 |
209 | return sr, T_fea
210 | else:
211 | # degradation-aware represenetion learning
212 | deg_repre, _ = self.E(deg_repre)
213 |
214 | # degradation-aware SR
215 | sr = self.G(x, deg_repre)
216 |
217 | return sr
218 |
--------------------------------------------------------------------------------
/model_ST/blindsr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import model.common as common
4 | import torch.nn.functional as F
5 |
6 |
7 | import torch
8 | from torch import nn
9 | import model_ST.common as common
10 | import torch.nn.functional as F
11 |
12 |
13 | def make_model(args):
14 | return BlindSR(args)
15 |
16 | class RDA_conv(nn.Module):
17 | def __init__(self, channels_in, channels_out, kernel_size, reduction):
18 | super(RDA_conv, self).__init__()
19 | self.channels_out = channels_out
20 | self.channels_in = channels_in
21 | self.kernel_size = kernel_size
22 |
23 | self.kernel = nn.Sequential(
24 | nn.Linear(64, 64, bias=False),
25 | nn.LeakyReLU(0.1, True),
26 | nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False)
27 | )
28 | self.conv = common.default_conv(channels_in, 1, 1)
29 | self.relu = nn.LeakyReLU(0.1, True)
30 | self.sig = nn.Sigmoid()
31 | def forward(self, x):
32 | '''
33 | :param x[0]: feature map: B * C * H * W
34 | :param x[1]: degradation representation: B * C
35 | '''
36 | b, c, h, w = x[0].size()
37 |
38 | # branch 1
39 | kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size)
40 | out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2))
41 | out = out.view(b, -1, h, w)
42 | M = self.sig(self.conv(x[0]))
43 | # branch 2
44 | out = out*M + x[0]
45 |
46 | return out
47 |
48 |
49 | class RDAB(nn.Module):
50 | def __init__(self, conv, n_feat, kernel_size, reduction):
51 | super(RDAB, self).__init__()
52 |
53 | self.da_conv1 = RDA_conv(n_feat, n_feat, kernel_size, reduction)
54 | self.da_conv2 = RDA_conv(n_feat, n_feat, kernel_size, reduction)
55 | self.conv1 = conv(n_feat, n_feat, kernel_size)
56 | self.conv2 = conv(n_feat, n_feat, kernel_size)
57 |
58 | self.relu = nn.LeakyReLU(0.1, True)
59 |
60 | def forward(self, x):
61 | '''
62 | :param x[0]: feature map: B * C * H * W
63 | :param x[1]: degradation representation: B * C
64 | '''
65 |
66 | out = self.relu(self.da_conv1(x))
67 | out = self.relu(self.conv1(out))
68 | out = self.relu(self.da_conv2([out, x[1]]))
69 | out = self.conv2(out) + x[0]
70 |
71 | return out
72 |
73 |
74 | class RDAG(nn.Module):
75 | def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks):
76 | super(RDAG, self).__init__()
77 | self.n_blocks = n_blocks
78 | modules_body = [
79 | RDAB(conv, n_feat, kernel_size, reduction) \
80 | for _ in range(n_blocks)
81 | ]
82 | #modules_body.append(conv(n_feat, n_feat, kernel_size))
83 |
84 | self.body = nn.Sequential(*modules_body)
85 |
86 | def forward(self, x):
87 | '''
88 | :param x[0]: feature map: B * C * H * W
89 | :param x[1]: degradation representation: B * C
90 | '''
91 | res = x[0]
92 | for i in range(self.n_blocks):
93 | res = self.body[i]([res, x[1]])
94 |
95 |
96 | return res
97 |
98 |
99 | class RDAN(nn.Module):
100 | def __init__(self, args, conv=common.default_conv):
101 | super(RDAN, self).__init__()
102 | n_blocks = 27
103 | n_feats = 64
104 | kernel_size = 3
105 | reduction = 8
106 | scale = int(args.scale[0])
107 |
108 | # RGB mean for DIV2K
109 | rgb_mean = (0.4488, 0.4371, 0.4040)
110 | rgb_std = (1.0, 1.0, 1.0)
111 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
112 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
113 |
114 | # head module
115 | modules_head = [conv(3, n_feats, kernel_size)]
116 | self.head = nn.Sequential(*modules_head)
117 |
118 | # compress
119 | self.compress = nn.Sequential(
120 | nn.Linear(256, 64, bias=False),
121 | nn.LeakyReLU(0.1, True)
122 | )
123 |
124 | # body
125 | modules_body = [
126 | RDAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks)
127 | ]
128 | modules_body.append(conv(n_feats, n_feats, kernel_size))
129 | self.body = nn.Sequential(*modules_body)
130 |
131 | # tail
132 | modules_tail = [common.Upsampler(conv, scale, n_feats, act=False),
133 | conv(n_feats, 3, kernel_size)]
134 | self.tail = nn.Sequential(*modules_tail)
135 |
136 | def forward(self, x, k_v):
137 | k_v = self.compress(k_v)
138 |
139 | # sub mean
140 | x = self.sub_mean(x)
141 |
142 | # head
143 | x = self.head(x)
144 |
145 | # body
146 | res = x
147 | res = self.body[0]([res, k_v])
148 | res = self.body[-1](res)
149 | res = res + x
150 |
151 | # tail
152 | x = self.tail(res)
153 |
154 | # add mean
155 | x = self.add_mean(x)
156 |
157 | return x
158 |
159 |
160 |
161 | class DEN(nn.Module):
162 | def __init__(self,args):
163 | super(DEN, self).__init__()
164 | n_feats = args.n_feats
165 | self.E = nn.Sequential(
166 | nn.Conv2d(3, 64, kernel_size=3, padding=1),
167 | nn.LeakyReLU(0.1, True),
168 | nn.Conv2d(64, 64, kernel_size=3, padding=1),
169 | nn.LeakyReLU(0.1, True),
170 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
171 | nn.LeakyReLU(0.1, True),
172 | nn.Conv2d(128, 128, kernel_size=3, padding=1),
173 | nn.LeakyReLU(0.1, True),
174 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
175 | nn.LeakyReLU(0.1, True),
176 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
177 | nn.LeakyReLU(0.1, True),
178 | nn.AdaptiveAvgPool2d(1),
179 | )
180 | self.mlp = nn.Sequential(
181 | nn.Linear(256, 256),
182 | nn.LeakyReLU(0.1, True),
183 | nn.Linear(256, 256),
184 | )
185 |
186 | def forward(self, x):
187 | fea = self.E(x).squeeze(-1).squeeze(-1)
188 | S_fea = []
189 | for i in range(len(self.mlp)):
190 | fea = self.mlp[i](fea)
191 | if i==2:
192 | S_fea.append(fea)
193 |
194 | return fea,S_fea
195 |
196 |
197 | class BlindSR(nn.Module):
198 | def __init__(self, args):
199 | super(BlindSR, self).__init__()
200 |
201 | # Generator
202 | self.G = RDAN(args)
203 |
204 | self.E_st = DEN(args)
205 |
206 |
207 | def forward(self, x):
208 | if self.training:
209 |
210 | # degradation-aware represenetion learning
211 | deg_repre, S_fea = self.E_st(x)
212 |
213 | # degradation-aware SR
214 | sr = self.G(x, deg_repre)
215 |
216 | return sr, S_fea
217 | else:
218 | # degradation-aware represenetion learning
219 | deg_repre, _ = self.E_st(x)
220 |
221 | # degradation-aware SR
222 | sr = self.G(x, deg_repre)
223 |
224 | return sr
225 |
--------------------------------------------------------------------------------
/utility2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import datetime
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import scipy.misc as misc
8 | import cv2
9 | import torch
10 | import torch.optim as optim
11 | import torch.optim.lr_scheduler as lrs
12 |
13 |
14 | class AverageMeter(object):
15 | """Computes and stores the average and current value"""
16 | def __init__(self):
17 | self.reset()
18 |
19 | def reset(self):
20 | self.val = 0
21 | self.avg = 0
22 | self.sum = 0
23 | self.count = 0
24 |
25 | def update(self, val, n=1):
26 | self.val = val
27 | self.sum += val * n
28 | self.count += n
29 | self.avg = self.sum / self.count
30 |
31 |
32 | class timer():
33 | def __init__(self):
34 | self.acc = 0
35 | self.tic()
36 |
37 | def tic(self):
38 | self.t0 = time.time()
39 |
40 | def toc(self):
41 | return time.time() - self.t0
42 |
43 | def hold(self):
44 | self.acc += self.toc()
45 |
46 | def release(self):
47 | ret = self.acc
48 | self.acc = 0
49 |
50 | return ret
51 |
52 | def reset(self):
53 | self.acc = 0
54 |
55 |
56 | class checkpoint():
57 | def __init__(self, args):
58 | self.args = args
59 | self.ok = True
60 | self.log = torch.Tensor()
61 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
62 |
63 | if args.blur_type == 'iso_gaussian':
64 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_iso'
65 | elif args.blur_type == 'aniso_gaussian':
66 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_aniso'
67 |
68 | def _make_dir(path):
69 | if not os.path.exists(path): os.makedirs(path)
70 |
71 | _make_dir(self.dir)
72 | _make_dir(self.dir + '/model')
73 | _make_dir(self.dir + '/results')
74 |
75 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
76 | self.log_file = open(self.dir + '/log.txt', open_type)
77 | with open(self.dir + '/config.txt', open_type) as f:
78 | f.write(now + '\n\n')
79 | for arg in vars(args):
80 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
81 | f.write('\n')
82 |
83 | def save(self, trainer, epoch, is_best=False):
84 | trainer.model.save(self.dir, epoch, is_best=is_best)
85 | trainer.loss.save(self.dir)
86 | trainer.loss.plot_loss(self.dir, epoch)
87 |
88 | self.plot_psnr(epoch)
89 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
90 | torch.save(
91 | trainer.optimizer.state_dict(),
92 | os.path.join(self.dir, 'optimizer.pt')
93 | )
94 |
95 | def add_log(self, log):
96 | self.log = torch.cat([self.log, log])
97 |
98 | def write_log(self, log, refresh=False):
99 | print(log)
100 | self.log_file.write(log + '\n')
101 | if refresh:
102 | self.log_file.close()
103 | self.log_file = open(self.dir + '/log.txt', 'a')
104 |
105 | def done(self):
106 | self.log_file.close()
107 |
108 | def plot_psnr(self, epoch):
109 | axis = np.linspace(1, epoch, epoch)
110 | label = 'SR on {}'.format(self.args.data_test)
111 | fig = plt.figure()
112 | plt.title(label)
113 | for idx_scale, scale in enumerate(self.args.scale):
114 | plt.plot(
115 | axis,
116 | self.log[:, idx_scale].numpy(),
117 | label='Scale {}'.format(scale)
118 | )
119 | plt.legend()
120 | plt.xlabel('Epochs')
121 | plt.ylabel('PSNR')
122 | plt.grid(True)
123 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
124 | plt.close(fig)
125 |
126 | def save_results(self, filename, save_list, scale):
127 | filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
128 |
129 | normalized = save_list[0][0].data.mul(255 / self.args.rgb_range)
130 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
131 | misc.imsave('{}{}.png'.format(filename, 'SR'), ndarr)
132 |
133 |
134 | def quantize(img, rgb_range):
135 | pixel_range = 255 / rgb_range
136 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
137 |
138 |
139 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
140 | diff = (sr - hr).data.div(rgb_range)
141 | if benchmark:
142 | shave = scale
143 | if diff.size(1) > 1:
144 | convert = diff.new(1, 3, 1, 1)
145 | convert[0, 0, 0, 0] = 65.738
146 | convert[0, 1, 0, 0] = 129.057
147 | convert[0, 2, 0, 0] = 25.064
148 | diff.mul_(convert).div_(256)
149 | diff = diff.sum(dim=1, keepdim=True)
150 | else:
151 | shave = scale + 6
152 | import math
153 | shave = math.ceil(shave)
154 | valid = diff[:, :, shave:-shave, shave:-shave]
155 | mse = valid.pow(2).mean()
156 |
157 | return -10 * math.log10(mse)
158 |
159 |
160 | def calc_ssim(img1, img2, scale=2, benchmark=False):
161 | '''calculate SSIM
162 | the same outputs as MATLAB's
163 | img1, img2: [0, 255]
164 | '''
165 | if benchmark:
166 | border = math.ceil(scale)
167 | else:
168 | border = math.ceil(scale) + 6
169 |
170 | img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy()
171 | img1 = np.transpose(img1, (1, 2, 0))
172 | img2 = img2.data.squeeze().cpu().numpy()
173 | img2 = np.transpose(img2, (1, 2, 0))
174 |
175 | img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0
176 | img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0
177 | if not img1.shape == img2.shape:
178 | raise ValueError('Input images must have the same dimensions.')
179 | h, w = img1.shape[:2]
180 | img1_y = img1_y[border:h - border, border:w - border]
181 | img2_y = img2_y[border:h - border, border:w - border]
182 |
183 | if img1_y.ndim == 2:
184 | return ssim(img1_y, img2_y)
185 | elif img1.ndim == 3:
186 | if img1.shape[2] == 3:
187 | ssims = []
188 | for i in range(3):
189 | ssims.append(ssim(img1, img2))
190 | return np.array(ssims).mean()
191 | elif img1.shape[2] == 1:
192 | return ssim(np.squeeze(img1), np.squeeze(img2))
193 | else:
194 | raise ValueError('Wrong input image dimensions.')
195 |
196 |
197 | def ssim(img1, img2):
198 | C1 = (0.01 * 255) ** 2
199 | C2 = (0.03 * 255) ** 2
200 |
201 | img1 = img1.astype(np.float64)
202 | img2 = img2.astype(np.float64)
203 | kernel = cv2.getGaussianKernel(11, 1.5)
204 | window = np.outer(kernel, kernel.transpose())
205 |
206 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
207 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
208 | mu1_sq = mu1 ** 2
209 | mu2_sq = mu2 ** 2
210 | mu1_mu2 = mu1 * mu2
211 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
212 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
213 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
214 |
215 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
216 | (sigma1_sq + sigma2_sq + C2))
217 | return ssim_map.mean()
218 |
219 |
220 | def make_optimizer(args, my_model):
221 | trainable = filter(lambda x: x.requires_grad, my_model.parameters())
222 |
223 | if args.optimizer == 'SGD':
224 | optimizer_function = optim.SGD
225 | kwargs = {'momentum': args.momentum}
226 | elif args.optimizer == 'ADAM':
227 | optimizer_function = optim.Adam
228 | kwargs = {
229 | 'betas': (args.beta1, args.beta2),
230 | 'eps': args.epsilon
231 | }
232 | elif args.optimizer == 'RMSprop':
233 | optimizer_function = optim.RMSprop
234 | kwargs = {'eps': args.epsilon}
235 |
236 | kwargs['weight_decay'] = args.weight_decay
237 |
238 | return optimizer_function(trainable, **kwargs)
239 |
240 |
241 | def make_scheduler(args, my_optimizer):
242 | if args.decay_type == 'step':
243 | scheduler = lrs.StepLR(
244 | my_optimizer,
245 | step_size=args.lr_decay_sr,
246 | gamma=args.gamma_sr,
247 | )
248 | elif args.decay_type.find('step') >= 0:
249 | milestones = args.decay_type.split('_')
250 | milestones.pop(0)
251 | milestones = list(map(lambda x: int(x), milestones))
252 | scheduler = lrs.MultiStepLR(
253 | my_optimizer,
254 | milestones=milestones,
255 | gamma=args.gamma
256 | )
257 |
258 | scheduler.step(args.start_epoch - 1)
259 |
260 | return scheduler
261 |
262 |
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import datetime
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import scipy.misc as misc
8 | import cv2
9 | import torch
10 | import torch.optim as optim
11 | import torch.optim.lr_scheduler as lrs
12 | import imageio
13 |
14 |
15 | class AverageMeter(object):
16 | """Computes and stores the average and current value"""
17 | def __init__(self):
18 | self.reset()
19 |
20 | def reset(self):
21 | self.val = 0
22 | self.avg = 0
23 | self.sum = 0
24 | self.count = 0
25 |
26 | def update(self, val, n=1):
27 | self.val = val
28 | self.sum += val * n
29 | self.count += n
30 | self.avg = self.sum / self.count
31 |
32 |
33 | class timer():
34 | def __init__(self):
35 | self.acc = 0
36 | self.tic()
37 |
38 | def tic(self):
39 | self.t0 = time.time()
40 |
41 | def toc(self):
42 | return time.time() - self.t0
43 |
44 | def hold(self):
45 | self.acc += self.toc()
46 |
47 | def release(self):
48 | ret = self.acc
49 | self.acc = 0
50 |
51 | return ret
52 |
53 | def reset(self):
54 | self.acc = 0
55 |
56 |
57 | class checkpoint():
58 | def __init__(self, args):
59 | self.args = args
60 | self.ok = True
61 | self.log = torch.Tensor()
62 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
63 |
64 | if args.blur_type == 'iso_gaussian':
65 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_iso'
66 | elif args.blur_type == 'aniso_gaussian':
67 | self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_aniso'
68 |
69 | def _make_dir(path):
70 | if not os.path.exists(path): os.makedirs(path)
71 |
72 | _make_dir(self.dir)
73 | _make_dir(self.dir + '/model')
74 | _make_dir(self.dir + '/optimzer')
75 | _make_dir(self.dir + '/results')
76 |
77 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
78 | self.log_file = open(self.dir + '/log.txt', open_type)
79 | with open(self.dir + '/config.txt', open_type) as f:
80 | f.write(now + '\n\n')
81 | for arg in vars(args):
82 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
83 | f.write('\n')
84 |
85 | def save(self, trainer, epoch, is_best=False):
86 | trainer.model.save(self.dir, epoch, is_best=is_best)
87 | trainer.loss.save(self.dir)
88 | trainer.loss.plot_loss(self.dir, epoch)
89 |
90 | self.plot_psnr(epoch)
91 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
92 | torch.save(
93 | trainer.optimizer.state_dict(),
94 | os.path.join(self.dir, 'optimizer.pt')
95 | )
96 |
97 | def add_log(self, log):
98 | self.log = torch.cat([self.log, log])
99 |
100 | def write_log(self, log, refresh=False):
101 | print(log)
102 | self.log_file.write(log + '\n')
103 | if refresh:
104 | self.log_file.close()
105 | self.log_file = open(self.dir + '/log.txt', 'a')
106 |
107 | def done(self):
108 | self.log_file.close()
109 |
110 | def plot_psnr(self, epoch):
111 | axis = np.linspace(1, epoch, epoch)
112 | label = 'SR on {}'.format(self.args.data_test)
113 | fig = plt.figure()
114 | plt.title(label)
115 | for idx_scale, scale in enumerate(self.args.scale):
116 | plt.plot(
117 | axis,
118 | self.log[:, idx_scale].numpy(),
119 | label='Scale {}'.format(scale)
120 | )
121 | plt.legend()
122 | plt.xlabel('Epochs')
123 | plt.ylabel('PSNR')
124 | plt.grid(True)
125 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
126 | plt.close(fig)
127 |
128 | def save_results(self, filename, save_list, scale):
129 | filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
130 |
131 | normalized = save_list[0][0].data.mul(255 / self.args.rgb_range)
132 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
133 | imageio.imsave('{}{}.png'.format(filename, 'SR'), ndarr)
134 |
135 |
136 | def quantize(img, rgb_range):
137 | pixel_range = 255 / rgb_range
138 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
139 |
140 |
141 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
142 | diff = (sr - hr).data.div(rgb_range)
143 | if benchmark:
144 | shave = scale
145 | if diff.size(1) > 1:
146 | convert = diff.new(1, 3, 1, 1)
147 | convert[0, 0, 0, 0] = 65.738
148 | convert[0, 1, 0, 0] = 129.057
149 | convert[0, 2, 0, 0] = 25.064
150 | diff.mul_(convert).div_(256)
151 | diff = diff.sum(dim=1, keepdim=True)
152 | else:
153 | shave = scale + 6
154 | import math
155 | shave = math.ceil(shave)
156 | valid = diff[:, :, shave:-shave, shave:-shave]
157 | mse = valid.pow(2).mean()
158 |
159 | return -10 * math.log10(mse)
160 |
161 |
162 | def calc_ssim(img1, img2, scale=2, benchmark=False):
163 | '''calculate SSIM
164 | the same outputs as MATLAB's
165 | img1, img2: [0, 255]
166 | '''
167 | if benchmark:
168 | border = math.ceil(scale)
169 | else:
170 | border = math.ceil(scale) + 6
171 |
172 | img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy()
173 | img1 = np.transpose(img1, (1, 2, 0))
174 | img2 = img2.data.squeeze().cpu().numpy()
175 | img2 = np.transpose(img2, (1, 2, 0))
176 |
177 | img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0
178 | img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0
179 | if not img1.shape == img2.shape:
180 | raise ValueError('Input images must have the same dimensions.')
181 | h, w = img1.shape[:2]
182 | img1_y = img1_y[border:h - border, border:w - border]
183 | img2_y = img2_y[border:h - border, border:w - border]
184 |
185 | if img1_y.ndim == 2:
186 | return ssim(img1_y, img2_y)
187 | elif img1.ndim == 3:
188 | if img1.shape[2] == 3:
189 | ssims = []
190 | for i in range(3):
191 | ssims.append(ssim(img1, img2))
192 | return np.array(ssims).mean()
193 | elif img1.shape[2] == 1:
194 | return ssim(np.squeeze(img1), np.squeeze(img2))
195 | else:
196 | raise ValueError('Wrong input image dimensions.')
197 |
198 |
199 | def ssim(img1, img2):
200 | C1 = (0.01 * 255) ** 2
201 | C2 = (0.03 * 255) ** 2
202 |
203 | img1 = img1.astype(np.float64)
204 | img2 = img2.astype(np.float64)
205 | kernel = cv2.getGaussianKernel(11, 1.5)
206 | window = np.outer(kernel, kernel.transpose())
207 |
208 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
209 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
210 | mu1_sq = mu1 ** 2
211 | mu2_sq = mu2 ** 2
212 | mu1_mu2 = mu1 * mu2
213 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
214 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
215 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
216 |
217 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
218 | (sigma1_sq + sigma2_sq + C2))
219 | return ssim_map.mean()
220 |
221 |
222 | def make_optimizer(args, my_model):
223 | trainable = filter(lambda x: x.requires_grad, my_model.parameters())
224 |
225 | if args.optimizer == 'SGD':
226 | optimizer_function = optim.SGD
227 | kwargs = {'momentum': args.momentum}
228 | elif args.optimizer == 'ADAM':
229 | optimizer_function = optim.Adam
230 | kwargs = {
231 | 'betas': (args.beta1, args.beta2),
232 | 'eps': args.epsilon
233 | }
234 | elif args.optimizer == 'RMSprop':
235 | optimizer_function = optim.RMSprop
236 | kwargs = {'eps': args.epsilon}
237 |
238 | kwargs['weight_decay'] = args.weight_decay
239 |
240 | return optimizer_function(trainable, **kwargs)
241 |
242 |
243 | def make_scheduler(args, my_optimizer):
244 | if args.decay_type == 'step':
245 | scheduler = lrs.StepLR(
246 | my_optimizer,
247 | step_size=args.lr_decay_sr,
248 | gamma=args.gamma_sr,
249 | )
250 | elif args.decay_type.find('step') >= 0:
251 | milestones = args.decay_type.split('_')
252 | milestones.pop(0)
253 | milestones = list(map(lambda x: int(x), milestones))
254 | scheduler = lrs.MultiStepLR(
255 | my_optimizer,
256 | milestones=milestones,
257 | gamma=args.gamma
258 | )
259 |
260 | scheduler.step(args.start_epoch - 1)
261 |
262 | return scheduler
263 |
264 |
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import template
3 |
4 | parser = argparse.ArgumentParser(description='EDSR and MDSR')
5 |
6 | parser.add_argument('--debug', action='store_true',
7 | help='Enables debug mode')
8 | parser.add_argument('--template', default='.',
9 | help='You can set various templates in option.py')
10 |
11 | # Hardware specifications
12 | parser.add_argument('--n_threads', type=int, default=4,
13 | help='number of threads for data loading')
14 | parser.add_argument('--cpu', type=bool, default=False,
15 | help='use cpu only')
16 | parser.add_argument('--n_GPUs', type=int, default=2,
17 | help='number of GPUs')
18 | parser.add_argument('--seed', type=int, default=1,
19 | help='random seed')
20 | parser.add_argument('--pre_train_meta', type=str, default= '.',
21 | help='pre-trained model directory')
22 | parser.add_argument('--pre_train_TA', type=str, default= '.',
23 | help='pre-trained model directory')
24 | parser.add_argument('--pre_train_ST', type=str, default= '.',
25 | help='pre-trained model directory')
26 | parser.add_argument('--is_stage3', action='store_true',
27 | help='set this option to test the model')
28 | parser.add_argument('--temperature', type=float, default=0.15,
29 | help='for konwledge distillation')
30 |
31 | # Meta-Learning
32 | parser.add_argument('--task_iter', type=int, default=20,
33 | help='each task iteration times')
34 | parser.add_argument('--test_iter', type=int, default=20,
35 | help='each task iteration times')
36 | parser.add_argument('--meta_batch_size', type=int, default=8,
37 | help='each task iteration times')
38 | parser.add_argument('--task_batch_size', type=int, default=16,
39 | help='each task iteration times')
40 | parser.add_argument('--lr_task', type=float, default=1e-3,
41 | help='learning rate to train the whole network')
42 |
43 |
44 | # Data specifications
45 | parser.add_argument('--dir_data', type=str, default='D:/LongguangWang/Data',
46 | help='dataset directory')
47 | parser.add_argument('--dir_demo', type=str, default='../test',
48 | help='demo image directory')
49 | parser.add_argument('--data_train', type=str, default='DF2K',
50 | help='train dataset name')
51 | parser.add_argument('--data_test', type=str, default='Set5',
52 | help='test dataset name')
53 | parser.add_argument('--data_range', type=str, default='1-3450/801-810',
54 | help='train/test data range')
55 | parser.add_argument('--ext', type=str, default='sep',
56 | help='dataset file extension')
57 | parser.add_argument('--scale', type=str, default='4',
58 | help='super resolution scale')
59 | parser.add_argument('--patch_size', type=int, default=36,
60 | help='output patch size')
61 | parser.add_argument('--rgb_range', type=int, default=1,
62 | help='maximum value of RGB')
63 | parser.add_argument('--n_colors', type=int, default=3,
64 | help='number of color channels to use')
65 | parser.add_argument('--chop', action='store_true',
66 | help='enable memory-efficient forward')
67 | parser.add_argument('--no_augment', action='store_true',
68 | help='do not use data augmentation')
69 |
70 | # Degradation specifications
71 | parser.add_argument('--blur_kernel', type=int, default=21,
72 | help='size of blur kernels')
73 | parser.add_argument('--blur_type', type=str, default='iso_gaussian',
74 | help='blur types (iso_gaussian | aniso_gaussian)')
75 | parser.add_argument('--mode', type=str, default='bicubic',
76 | help='downsampler (bicubic | s-fold)')
77 | parser.add_argument('--noise', type=float, default=0.0,
78 | help='noise level')
79 | ## isotropic Gaussian blur
80 | parser.add_argument('--sig_min', type=float, default=0.2,
81 | help='minimum sigma of isotropic Gaussian blurs')
82 | parser.add_argument('--sig_max', type=float, default=4.0,
83 | help='maximum sigma of isotropic Gaussian blurs')
84 | parser.add_argument('--sig', type=float, default=4.0,
85 | help='specific sigma of isotropic Gaussian blurs')
86 | ## anisotropic Gaussian blur
87 | parser.add_argument('--lambda_min', type=float, default=0.2,
88 | help='minimum value for the eigenvalue of anisotropic Gaussian blurs')
89 | parser.add_argument('--lambda_max', type=float, default=4.0,
90 | help='maximum value for the eigenvalue of anisotropic Gaussian blurs')
91 | parser.add_argument('--lambda_1', type=float, default=0.2,
92 | help='one eigenvalue of anisotropic Gaussian blurs')
93 | parser.add_argument('--lambda_2', type=float, default=4.0,
94 | help='another eigenvalue of anisotropic Gaussian blurs')
95 | parser.add_argument('--theta', type=float, default=0.0,
96 | help='rotation angle of anisotropic Gaussian blurs [0, 180]')
97 |
98 |
99 | # Model specifications
100 | parser.add_argument('--model', default='blindsr',
101 | help='model name')
102 | parser.add_argument('--pre_train', type=str, default= '.',
103 | help='pre-trained model directory')
104 | parser.add_argument('--extend', type=str, default='.',
105 | help='pre-trained model directory')
106 | parser.add_argument('--shift_mean', default=True,
107 | help='subtract pixel mean from the input')
108 | parser.add_argument('--dilation', action='store_true',
109 | help='use dilated convolution')
110 | parser.add_argument('--precision', type=str, default='single',
111 | choices=('single', 'half'),
112 | help='FP precision for test (single | half)')
113 | parser.add_argument('--n_resblocks', type=int, default=20,
114 | help='number of residual blocks')
115 | parser.add_argument('--n_feats', type=int, default=64,
116 | help='number of feature maps')
117 | parser.add_argument('--res_scale', type=float, default=1,
118 | help='residual scaling')
119 |
120 | # Training specifications
121 | parser.add_argument('--reset', action='store_true',
122 | help='reset the training')
123 | parser.add_argument('--test_every', type=int, default=1000,
124 | help='do test per every N batches')
125 | parser.add_argument('--epochs_encoder', type=int, default=100,
126 | help='number of epochs to train the degradation encoder')
127 | parser.add_argument('--epochs_sr', type=int, default=500,
128 | help='number of epochs to train the whole network')
129 | parser.add_argument('--st_save_epoch', type=int, default=550,
130 | help='number of epochs to save network')
131 | parser.add_argument('--batch_size', type=int, default=32,
132 | help='input batch size for training')
133 | parser.add_argument('--split_batch', type=int, default=1,
134 | help='split the batch into smaller chunks')
135 | parser.add_argument('--self_ensemble', action='store_true',
136 | help='use self-ensemble method for test')
137 | parser.add_argument('--test_only', action='store_true',
138 | help='set this option to test the model')
139 |
140 | # Optimization specifications
141 | parser.add_argument('--lr_encoder', type=float, default=1e-3,
142 | help='learning rate to train the degradation encoder')
143 | parser.add_argument('--lr_sr', type=float, default=1e-4,
144 | help='learning rate to train the whole network')
145 | parser.add_argument('--lr_decay_encoder', type=int, default=60,
146 | help='learning rate decay per N epochs')
147 | parser.add_argument('--lr_decay_sr', type=int, default=125,
148 | help='learning rate decay per N epochs')
149 | parser.add_argument('--decay_type', type=str, default='step',
150 | help='learning rate decay type')
151 | parser.add_argument('--gamma_encoder', type=float, default=0.1,
152 | help='learning rate decay factor for step decay')
153 | parser.add_argument('--gamma_sr', type=float, default=0.5,
154 | help='learning rate decay factor for step decay')
155 | parser.add_argument('--optimizer', default='ADAM',
156 | choices=('SGD', 'ADAM', 'RMSprop'),
157 | help='optimizer to use (SGD | ADAM | RMSprop)')
158 | parser.add_argument('--momentum', type=float, default=0.9,
159 | help='SGD momentum')
160 | parser.add_argument('--beta1', type=float, default=0.9,
161 | help='ADAM beta1')
162 | parser.add_argument('--beta2', type=float, default=0.999,
163 | help='ADAM beta2')
164 | parser.add_argument('--epsilon', type=float, default=1e-8,
165 | help='ADAM epsilon for numerical stability')
166 | parser.add_argument('--weight_decay', type=float, default=0,
167 | help='weight decay')
168 | parser.add_argument('--start_epoch', type=int, default=0,
169 | help='resume from the snapshot, and the start_epoch')
170 |
171 | # Loss specifications
172 | parser.add_argument('--loss', type=str, default='1*L1',
173 | help='loss function configuration')
174 | parser.add_argument('--skip_threshold', type=float, default='1e6',
175 | help='skipping batch that has large error')
176 |
177 | # Log specifications
178 | parser.add_argument('--save', type=str, default='blindsr',
179 | help='file name to save')
180 | parser.add_argument('--load', type=str, default='.',
181 | help='file name to load')
182 | parser.add_argument('--resume', type=int, default=0,
183 | help='resume from specific checkpoint')
184 | parser.add_argument('--save_models', action='store_true',
185 | help='save all intermediate models')
186 | parser.add_argument('--print_every', type=int, default=200,
187 | help='how many batches to wait before logging training status')
188 | parser.add_argument('--save_results', default=False,
189 | help='save output results')
190 |
191 | args = parser.parse_args()
192 | template.set_template(args)
193 |
194 | args.scale = list(map(lambda x: float(x), args.scale.split('+')))
195 |
196 | args.data_train = args.data_train.split('+')
197 | args.data_test = args.data_test.split('+')
198 |
199 |
200 | for arg in vars(args):
201 | if vars(args)[arg] == 'True':
202 | vars(args)[arg] = True
203 | elif vars(args)[arg] == 'False':
204 | vars(args)[arg] = False
--------------------------------------------------------------------------------
/trainer_stage3.py:
--------------------------------------------------------------------------------
1 | import os
2 | import utility2
3 | import torch
4 | from decimal import Decimal
5 | import torch.nn.functional as F
6 | from utils import util
7 | from utils import util2
8 | from collections import OrderedDict
9 | import random
10 | import numpy as np
11 | import torch.nn as nn
12 |
13 |
14 | class Trainer():
15 | def __init__(self, args, loader, model_meta_copy, model_TA, my_loss, ckp):
16 | self.test_res_psnr = []
17 | self.test_res_ssim = []
18 | self.args = args
19 | self.scale = args.scale
20 | self.loss1= nn.L1Loss()
21 | self.ckp = ckp
22 | self.loader_train = loader.loader_train
23 | self.loader_test = loader.loader_test
24 | self.model_meta_copy = model_meta_copy
25 |
26 | self.model_meta_copy_state = self.model_meta_copy.state_dict()
27 | self.model_TA = model_TA
28 | self.loss = my_loss
29 | self.meta_batch_size = args.meta_batch_size
30 | self.task_batch_size = args.task_batch_size
31 | self.task_iter = args.task_iter
32 | self.test_iter = args.test_iter
33 | self.optimizer = utility2.make_optimizer(args, self.model_TA)
34 | self.scheduler = utility2.make_scheduler(args, self.optimizer)
35 | self.lr_task = args.lr_task
36 | self.taskiter_weight=[1/self.task_iter]*self.task_iter
37 |
38 |
39 | if self.args.load != '.':
40 | self.optimizer.load_state_dict(
41 | torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
42 | )
43 | for _ in range(len(ckp.log)): self.scheduler.step()
44 |
45 | def train(self):
46 | self.scheduler.step()
47 | self.loss.step()
48 | epoch = self.scheduler.last_epoch + 1
49 |
50 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr))
51 | for param_group in self.optimizer.param_groups:
52 | param_group['lr'] = lr
53 |
54 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)))
55 | self.loss.start_log()
56 | self.model_TA.train()
57 |
58 | degrade = util2.SRMDPreprocessing(
59 | self.scale[0],
60 | kernel_size=self.args.blur_kernel,
61 | blur_type=self.args.blur_type,
62 | sig_min=self.args.sig_min,
63 | sig_max=self.args.sig_max,
64 | lambda_min=self.args.lambda_min,
65 | lambda_max=self.args.lambda_max,
66 | noise=self.args.noise
67 | )
68 |
69 | timer = utility2.timer()
70 |
71 | for batch, (hr, _) in enumerate(self.loader_train):
72 | hr = hr.cuda() # b, c, h, w
73 | timer.tic()
74 | loss_all = 0
75 |
76 | lr_blur, hr_blur = degrade(hr) # b, c, h, w
77 | self.model_meta_copy.get_model().load_state_dict(self.model_meta_copy_state)
78 |
79 | for iter in range(self.test_iter):
80 | if iter == 0:
81 | learning_rate = 1e-2
82 | elif iter < 5:
83 | learning_rate = 5e-3
84 | else:
85 | learning_rate = 1e-3
86 | sr,_ = self.model_meta_copy(lr_blur, hr_blur)
87 | loss = self.loss1(sr, hr)
88 | self.model_meta_copy.zero_grad()
89 | loss.backward()
90 |
91 | for param in self.model_meta_copy.parameters():
92 | if param.requires_grad:
93 | param.data.sub_(param.grad.data * learning_rate)
94 |
95 | _, deg_repre = self.model_meta_copy(lr_blur, hr_blur)
96 | sr,_ = self.model_TA(lr_blur, deg_repre.detach())
97 | loss_all += self.loss1(sr,hr)
98 | self.optimizer.zero_grad()
99 | loss_all.backward()
100 | self.optimizer.step()
101 |
102 | # Remove the hooks before next training phase
103 | timer.hold()
104 |
105 | if (batch + 1) % self.args.print_every == 0:
106 | self.ckp.write_log(
107 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t'
108 | 'Loss [SR loss:{:.3f}]\t'
109 | 'Time [{:.1f}s]'.format(
110 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
111 | loss_all.item(),
112 | timer.release(),
113 | ))
114 |
115 | self.loss.end_log(len(self.loader_train))
116 |
117 | # save model
118 | if epoch > self.args.st_save_epoch or (epoch %20 ==0):
119 | target = self.model_TA.get_model()
120 | model_dict = target.state_dict()
121 | torch.save(
122 | model_dict,
123 | os.path.join(self.ckp.dir, 'model', 'model_TA_{}.pt'.format(epoch))
124 | )
125 |
126 | optimzer_dict = self.optimizer.state_dict()
127 | torch.save(
128 | optimzer_dict,
129 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_TA_{}.pt'.format(epoch))
130 | )
131 |
132 | target = self.model_TA.get_model()
133 | model_dict = target.state_dict()
134 | torch.save(
135 | model_dict,
136 | os.path.join(self.ckp.dir, 'model', 'model_TA_last.pt')
137 | )
138 | optimzer_dict = self.optimizer.state_dict()
139 | torch.save(
140 | optimzer_dict,
141 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_TA_last.pt')
142 | )
143 |
144 |
145 | def test(self):
146 | self.ckp.write_log('\nEvaluation:')
147 | self.ckp.add_log(torch.zeros(1, len(self.scale)))
148 |
149 | timer_test = utility2.timer()
150 | t = 0
151 | self.model_TA.eval()
152 |
153 | degrade = util2.SRMDPreprocessing(
154 | self.scale[0],
155 | kernel_size=self.args.blur_kernel,
156 | blur_type=self.args.blur_type,
157 | sig=self.args.sig,
158 | lambda_1=self.args.lambda_1,
159 | lambda_2=self.args.lambda_2,
160 | theta=self.args.theta,
161 | noise=self.args.noise
162 | )
163 |
164 | for idx_scale, d in enumerate(self.loader_test):
165 | for idx_scale, scale in enumerate(self.scale):
166 | d.dataset.set_scale(idx_scale)
167 | eval_psnr = 0
168 | eval_ssim = 0
169 | iter_psnr = [0] * (self.test_iter + 1)
170 | iter_ssim = [0] * (self.test_iter + 1)
171 | for idx_img, (hr, filename) in enumerate(d):
172 | self.model_meta_copy.get_model().load_state_dict(self.model_meta_copy_state)
173 | hr = hr.cuda() # b, c, h, w
174 | hr = self.crop_border(hr, scale)
175 | # inference
176 | # timer_test.tic()
177 | lr_blur, hr_blur = degrade(hr, random=False)
178 |
179 |
180 |
181 | for iter in range(self.test_iter):
182 | if iter == 0:
183 | learning_rate = 1e-2
184 | elif iter < 5:
185 | learning_rate = 5e-3
186 | else:
187 | learning_rate = 1e-3
188 |
189 | sr,_ = self.model_meta_copy(lr_blur, hr_blur)
190 | loss = self.loss1(sr, hr)
191 | self.model_meta_copy.zero_grad()
192 | loss.backward()
193 |
194 | for param in self.model_meta_copy.parameters():
195 | if param.requires_grad:
196 | param.data.sub_(param.grad.data * learning_rate)
197 |
198 | _, deg_repre = self.model_meta_copy(lr_blur, hr_blur)
199 | #print(lr_blur.shape,deg_repre.shape)
200 | torch.cuda.synchronize()
201 | timer_test.tic()
202 | sr = self.model_TA(lr_blur, deg_repre.detach())
203 | torch.cuda.synchronize()
204 | timer_test.hold()
205 | t0 = timer_test.release()
206 |
207 | print("idx:", idx_img, ",time consuming:", t0)
208 | t += t0
209 |
210 | hr = utility2.quantize(hr, self.args.rgb_range)
211 | sr = utility2.quantize(sr, self.args.rgb_range)
212 |
213 | iter_psnr[-1] += utility2.calc_psnr(
214 | sr, hr, scale, self.args.rgb_range,
215 | benchmark=True
216 | )
217 | iter_ssim[-1] += utility2.calc_ssim(
218 | (sr * 255).round().clamp(0, 255), (hr * 255).round().clamp(0, 255), scale,
219 | benchmark=True
220 | )
221 |
222 | timer_test.hold()
223 |
224 | # save results
225 | if self.args.save_results:
226 | save_list = [sr]
227 | filename = filename[0]
228 | self.ckp.save_results(filename, save_list, scale)
229 |
230 | # for t in range(self.test_iter+1):
231 | # print("iter:",t,",PSNR:",iter_psnr[t]/ len(self.loader_test),",SSIM:",iter_ssim[t]/ len(self.loader_test))
232 |
233 | eval_psnr = iter_psnr[-1]
234 | eval_ssim = iter_ssim[-1]
235 |
236 | if len(self.test_res_psnr)>10:
237 | self.test_res_psnr.pop(0)
238 | self.test_res_ssim.pop(0)
239 | self.test_res_psnr.append(eval_psnr / len(self.loader_test))
240 | self.test_res_ssim.append(eval_ssim / len(self.loader_test))
241 | print("All time consuming:", t / len(self.loader_test))
242 |
243 | self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test)
244 | self.ckp.write_log(
245 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format(
246 | self.args.resume,
247 | self.args.data_test,
248 | scale,
249 | eval_psnr / len(self.loader_test),
250 | eval_ssim / len(self.loader_test),
251 | np.mean(self.test_res_psnr),
252 | np.mean(self.test_res_ssim)
253 | ))
254 |
255 | def crop_border(self, img_hr, scale):
256 | b, c, h, w = img_hr.size()
257 |
258 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)]
259 |
260 | return img_hr
261 |
262 | def get_patch(self, img, patch_size=48, scale=4):
263 | tb, tc, th, tw = img.shape ## HR image
264 | tp = round(scale * patch_size)
265 | tx = random.randrange(0, (tw - tp))
266 | ty = random.randrange(0, (th - tp))
267 |
268 | return img[:,:,ty:ty + tp, tx:tx + tp]
269 |
270 | def crop(self, img_hr):
271 | # b, c, h, w = img_hr.size()
272 | tp_hr = []
273 | for i in range(self.task_batch_size):
274 | tp_hr.append(self.get_patch(img_hr,self.args.patch_size,self.scale[0]))
275 | tp_hr = torch.cat(tp_hr,dim=0)
276 | return tp_hr
277 |
278 | def terminate(self):
279 | if self.args.test_only:
280 | self.test()
281 | return True
282 | else:
283 | epoch = self.scheduler.last_epoch + 1
284 | return epoch >= self.args.epochs_sr
--------------------------------------------------------------------------------
/trainer_stage4.py:
--------------------------------------------------------------------------------
1 | import os
2 | import utility
3 | import torch
4 | from decimal import Decimal
5 | import torch.nn.functional as F
6 | from utils import util2
7 | import numpy as np
8 | import torch.nn as nn
9 |
10 |
11 | class Trainer():
12 | def __init__(self, args, loader, model_ST, model_TA,model_meta_copy, my_loss, ckp):
13 | self.is_first =True
14 | self.args = args
15 | self.scale = args.scale
16 | self.test_res_psnr = []
17 | self.test_res_ssim = []
18 | self.ckp = ckp
19 | self.loader_train = loader.loader_train
20 | self.loader_test = loader.loader_test
21 | self.model_ST = model_ST
22 | self.model_TA = model_TA
23 | self.model_meta_copy = model_meta_copy
24 | for k,v in self.model_meta_copy.named_parameters():
25 | if "tail" in k:
26 | v.requires_grad=False
27 | self.model_meta_copy_state = self.model_meta_copy.state_dict()
28 | self.model_Est = torch.nn.DataParallel(self.model_ST.get_model().E_st, range(self.args.n_GPUs))
29 | self.model_Eta = torch.nn.DataParallel(self.model_TA.get_model().E, range(self.args.n_GPUs))
30 | self.loss1 = nn.L1Loss()
31 | self.loss = my_loss
32 | self.optimizer = utility.make_optimizer(args, self.model_ST)
33 | self.scheduler = utility.make_scheduler(args, self.optimizer)
34 | self.task_iter = args.task_iter
35 | self.test_iter = args.test_iter
36 | self.lr_task = args.lr_task
37 | self.temperature = args.temperature
38 |
39 | if self.args.load != '.':
40 | self.optimizer.load_state_dict(
41 | torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
42 | )
43 | for _ in range(len(ckp.log)): self.scheduler.step()
44 |
45 | def train(self):
46 | self.scheduler.step()
47 | self.loss.step()
48 | epoch = self.scheduler.last_epoch + 1
49 |
50 | if epoch <= self.args.epochs_encoder:
51 | lr = self.args.lr_encoder * (self.args.gamma_encoder ** (epoch // self.args.lr_decay_encoder))
52 | for param_group in self.optimizer.param_groups:
53 | param_group['lr'] = lr
54 | else:
55 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr))
56 | for param_group in self.optimizer.param_groups:
57 | param_group['lr'] = lr
58 |
59 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)))
60 | self.loss.start_log()
61 | self.model_ST.train()
62 |
63 | degrade = util2.SRMDPreprocessing(
64 | self.scale[0],
65 | kernel_size=self.args.blur_kernel,
66 | blur_type=self.args.blur_type,
67 | sig_min=self.args.sig_min,
68 | sig_max=self.args.sig_max,
69 | lambda_min=self.args.lambda_min,
70 | lambda_max=self.args.lambda_max,
71 | noise=self.args.noise
72 | )
73 |
74 | timer = utility.timer()
75 | losses_sr, losses_distill_distribution, losses_distill_abs = utility.AverageMeter(),utility.AverageMeter(),utility.AverageMeter()
76 |
77 |
78 | for batch, (hr, _) in enumerate(self.loader_train):
79 | hr = hr.cuda() # b, c, h, w
80 | # b,c,h,w = hr.shape
81 | timer.tic()
82 | # loss_all = 0
83 | lr_blur, hr_blur = degrade(hr) # b, c, h, w
84 | self.model_meta_copy.get_model().load_state_dict(self.model_meta_copy_state)
85 |
86 | for iter in range(self.test_iter):
87 | if iter == 0:
88 | learning_rate = 1e-2
89 | elif iter < 5:
90 | learning_rate = 5e-3
91 | else:
92 | learning_rate = 1e-3
93 | sr, _ = self.model_meta_copy(lr_blur, hr_blur)
94 | loss = self.loss1(sr, hr)
95 | self.model_meta_copy.zero_grad()
96 | loss.backward()
97 |
98 | for param in self.model_meta_copy.parameters():
99 | if param.requires_grad:
100 | param.data.sub_(param.grad.data * learning_rate)
101 |
102 | _, deg_repre = self.model_meta_copy(lr_blur, hr_blur)
103 | _, T_fea = self.model_Eta(deg_repre)
104 |
105 | loss_distill_dis = 0
106 | loss_distill_abs = 0
107 |
108 | if epoch <= self.args.epochs_encoder:
109 | if self.is_first:
110 | _, S_fea = self.model_ST(lr_blur)
111 | self.is_first = False
112 | else:
113 | _, S_fea = self.model_Est(lr_blur)
114 | for i in range(len(T_fea)):
115 | student_distance = F.log_softmax(S_fea[i] / self.temperature, dim=1)
116 | teacher_distance = F.softmax(T_fea[i]/ self.temperature, dim=1)
117 | loss_distill_dis += F.kl_div(
118 | student_distance, teacher_distance, reduction='batchmean')
119 | loss_distill_abs += nn.L1Loss()(S_fea[i], T_fea[i])
120 | losses_distill_distribution.update(loss_distill_dis.item())
121 | losses_distill_abs.update(loss_distill_abs.item())
122 | loss = loss_distill_dis + 0.1*loss_distill_abs
123 | else:
124 | sr, S_fea = self.model_ST(lr_blur)
125 | loss_SR = self.loss(sr, hr)
126 | for i in range(len(T_fea)):
127 | student_distance = F.log_softmax(S_fea[i] / self.temperature, dim=1)
128 | teacher_distance = F.softmax(T_fea[i] / self.temperature, dim=1)
129 | loss_distill_dis += F.kl_div(
130 | student_distance, teacher_distance, reduction='batchmean')
131 | loss_distill_abs += nn.L1Loss()(S_fea[i], T_fea[i])
132 | losses_distill_distribution.update(loss_distill_dis.item())
133 | losses_distill_abs.update(loss_distill_abs.item())
134 | loss = loss_SR + loss_distill_dis + 0.1 * loss_distill_abs
135 | losses_sr.update(loss_SR.item())
136 |
137 | # backward
138 | self.optimizer.zero_grad()
139 | loss.backward()
140 | self.optimizer.step()
141 | timer.hold()
142 |
143 | if epoch <= self.args.epochs_encoder:
144 | if (batch + 1) % self.args.print_every == 0:
145 | self.ckp.write_log(
146 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t'
147 | 'Loss [ distill_dis loss:{:.3f}, distill_abs loss:{:.3f}]\t'
148 | 'Time [{:.1f}s]'.format(
149 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
150 | losses_distill_distribution.avg, losses_distill_abs.avg,
151 | timer.release(),
152 | ))
153 | else:
154 | if (batch + 1) % self.args.print_every == 0:
155 | self.ckp.write_log(
156 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t'
157 | 'Loss [SR loss:{:.3f}, distill_dis loss:{:.3f}, distill_abs loss:{:.3f}]\t'
158 | 'Time [{:.1f}s]'.format(
159 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
160 | losses_sr.avg, losses_distill_distribution.avg, losses_distill_abs.avg,
161 | timer.release(),
162 | ))
163 |
164 | self.loss.end_log(len(self.loader_train))
165 |
166 | # save model
167 | if epoch > self.args.st_save_epoch or epoch%30==0:
168 | target = self.model_ST.get_model()
169 | model_dict = target.state_dict()
170 | torch.save(
171 | model_dict,
172 | os.path.join(self.ckp.dir, 'model', 'model_ST_{}.pt'.format(epoch))
173 | )
174 |
175 | target = self.model_ST.get_model()
176 | model_dict = target.state_dict()
177 | torch.save(
178 | model_dict,
179 | os.path.join(self.ckp.dir, 'model', 'model_ST_last.pt')
180 | )
181 |
182 | def test(self):
183 | self.ckp.write_log('\nEvaluation:')
184 | self.ckp.add_log(torch.zeros(1, len(self.scale)))
185 | self.model_ST.eval()
186 | t = 0
187 | timer_test = utility.timer()
188 |
189 | degrade = util2.SRMDPreprocessing(
190 | self.scale[0],
191 | kernel_size=self.args.blur_kernel,
192 | blur_type=self.args.blur_type,
193 | sig=self.args.sig,
194 | lambda_1=self.args.lambda_1,
195 | lambda_2=self.args.lambda_2,
196 | theta=self.args.theta,
197 | noise=self.args.noise
198 | )
199 |
200 | with torch.no_grad():
201 | for idx_scale, d in enumerate(self.loader_test):
202 | for idx_scale, scale in enumerate(self.scale):
203 | d.dataset.set_scale(idx_scale)
204 | eval_psnr = 0
205 | eval_ssim = 0
206 | for idx_img, (hr, filename) in enumerate(d):
207 | hr = hr.cuda() # b, c, h, w
208 | hr = self.crop_border(hr, scale)
209 | lr_blur, hr_blur = degrade(hr, random=False) # b, c, h, w
210 |
211 | # inference
212 | torch.cuda.synchronize()
213 | timer_test.tic()
214 | sr = self.model_ST(lr_blur)
215 | torch.cuda.synchronize()
216 | timer_test.hold()
217 | t0 = timer_test.release()
218 | print("idx:", idx_img, ",time consuming:", t0)
219 | t += t0
220 |
221 | sr = utility.quantize(sr, self.args.rgb_range)
222 | hr = utility.quantize(hr, self.args.rgb_range)
223 |
224 | # metrics
225 | eval_psnr += utility.calc_psnr(
226 | sr, hr, scale, self.args.rgb_range,
227 | benchmark=True
228 | )
229 | eval_ssim += utility.calc_ssim(
230 | (sr * 255).round().clamp(0, 255), (hr * 255).round().clamp(0, 255), scale,
231 | benchmark=True
232 | )
233 |
234 | # save results
235 | if self.args.save_results:
236 | save_list = [sr]
237 | filename = filename[0]
238 | self.ckp.save_results(filename, save_list, scale)
239 |
240 | if len(self.test_res_psnr) > 10:
241 | self.test_res_psnr.pop(0)
242 | self.test_res_ssim.pop(0)
243 | self.test_res_psnr.append(eval_psnr / len(d))
244 | self.test_res_ssim.append(eval_ssim / len(d))
245 | print("All time consuming:", t / len(d))
246 |
247 | self.ckp.log[-1, idx_scale] = eval_psnr / len(d)
248 | self.ckp.write_log(
249 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format(
250 | self.args.resume,
251 | self.args.data_test,
252 | scale,
253 | eval_psnr / len(d),
254 | eval_ssim / len(d),
255 | np.mean(self.test_res_psnr),
256 | np.mean(self.test_res_ssim)
257 | ))
258 |
259 | def crop_border(self, img_hr, scale):
260 | b, c, h, w = img_hr.size()
261 |
262 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)]
263 |
264 | return img_hr
265 |
266 | def terminate(self):
267 | if self.args.test_only:
268 | self.test()
269 | return True
270 | else:
271 | epoch = self.scheduler.last_epoch + 1
272 | return epoch >= self.args.epochs_sr
273 |
274 |
--------------------------------------------------------------------------------
/utils/util2.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import utility
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def cal_sigma(sig_x, sig_y, radians):
10 | sig_x = sig_x.view(-1, 1, 1)
11 | sig_y = sig_y.view(-1, 1, 1)
12 | radians = radians.view(-1, 1, 1)
13 |
14 | D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1)
15 | U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2),
16 | torch.cat([radians.sin(), radians.cos()], 2)], 1)
17 | sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2)))
18 |
19 | return sigma
20 |
21 |
22 | def anisotropic_gaussian_kernel(batch, kernel_size, covar):
23 | ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2
24 |
25 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
26 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
27 | xy = torch.stack([xx, yy], -1).view(batch, -1, 2)
28 |
29 | inverse_sigma = torch.inverse(covar)
30 | kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size)
31 |
32 | return kernel / kernel.sum([1, 2], keepdim=True)
33 |
34 |
35 | def isotropic_gaussian_kernel(batch, kernel_size, sigma):
36 | ax = torch.arange(kernel_size).float().cuda() - kernel_size//2
37 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
38 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
39 | kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2))
40 |
41 | return kernel / kernel.sum([1,2], keepdim=True)
42 |
43 |
44 | def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0):
45 | theta = torch.rand(batch).cuda() / 180 * math.pi
46 | lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
47 | lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
48 |
49 | covar = cal_sigma(lambda_1, lambda_2, theta)
50 | kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar)
51 | return kernel
52 |
53 |
54 | def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0):
55 | theta = torch.ones(1).cuda() * theta / 180 * math.pi
56 | lambda_1 = torch.ones(1).cuda() * lambda_1
57 | lambda_2 = torch.ones(1).cuda() * lambda_2
58 |
59 | covar = cal_sigma(lambda_1, lambda_2, theta)
60 | kernel = anisotropic_gaussian_kernel(1, kernel_size, covar)
61 | return kernel
62 |
63 |
64 | def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0):
65 | x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min
66 | k = isotropic_gaussian_kernel(batch, kernel_size, x)
67 | return k
68 |
69 |
70 | def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0):
71 | x = torch.ones(1).cuda() * sig
72 | k = isotropic_gaussian_kernel(1, kernel_size, x)
73 | return k
74 |
75 |
76 | def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0):
77 | if blur_type == 'iso_gaussian':
78 | return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max)
79 | elif blur_type == 'aniso_gaussian':
80 | return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max)
81 |
82 |
83 | def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0):
84 | if blur_type == 'iso_gaussian':
85 | return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig)
86 | elif blur_type == 'aniso_gaussian':
87 | return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta)
88 |
89 |
90 | # implementation of matlab bicubic interpolation in pytorch
91 | class bicubic(nn.Module):
92 | def __init__(self):
93 | super(bicubic, self).__init__()
94 |
95 | def cubic(self, x):
96 | absx = torch.abs(x)
97 | absx2 = torch.abs(x) * torch.abs(x)
98 | absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x)
99 |
100 | condition1 = (absx <= 1).to(torch.float32)
101 | condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32)
102 |
103 | f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2
104 | return f
105 |
106 | def contribute(self, in_size, out_size, scale):
107 | kernel_width = 4
108 | if scale < 1:
109 | kernel_width = 4 / scale
110 | x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda()
111 | x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda()
112 |
113 | u0 = x0 / scale + 0.5 * (1 - 1 / scale)
114 | u1 = x1 / scale + 0.5 * (1 - 1 / scale)
115 |
116 | left0 = torch.floor(u0 - kernel_width / 2)
117 | left1 = torch.floor(u1 - kernel_width / 2)
118 |
119 | P = np.ceil(kernel_width) + 2
120 |
121 | indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
122 | indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
123 |
124 | mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0)
125 | mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0)
126 |
127 | if scale < 1:
128 | weight0 = scale * self.cubic(mid0 * scale)
129 | weight1 = scale * self.cubic(mid1 * scale)
130 | else:
131 | weight0 = self.cubic(mid0)
132 | weight1 = self.cubic(mid1)
133 |
134 | weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2))
135 | weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2))
136 |
137 | indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0)
138 | indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0)
139 |
140 | kill0 = torch.eq(weight0, 0)[0][0]
141 | kill1 = torch.eq(weight1, 0)[0][0]
142 |
143 | weight0 = weight0[:, :, kill0 == 0]
144 | weight1 = weight1[:, :, kill1 == 0]
145 |
146 | indice0 = indice0[:, :, kill0 == 0]
147 | indice1 = indice1[:, :, kill1 == 0]
148 |
149 | return weight0, weight1, indice0, indice1
150 |
151 | def forward(self, input, scale=1/4):
152 | b, c, h, w = input.shape
153 |
154 | weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
155 | weight0 = weight0[0]
156 | weight1 = weight1[0]
157 |
158 | indice0 = indice0[0].long()
159 | indice1 = indice1[0].long()
160 |
161 | out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4))
162 | out = (torch.sum(out, dim=3))
163 | A = out.permute(0, 1, 3, 2)
164 |
165 | out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4))
166 | out = out.sum(3).permute(0, 1, 3, 2)
167 |
168 | return out
169 |
170 |
171 | class Gaussin_Kernel(object):
172 | def __init__(self, kernel_size=21, blur_type='iso_gaussian',
173 | sig=2.6, sig_min=0.2, sig_max=4.0,
174 | lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0):
175 | self.kernel_size = kernel_size
176 | self.blur_type = blur_type
177 |
178 | self.sig = sig
179 | self.sig_min = sig_min
180 | self.sig_max = sig_max
181 |
182 | self.lambda_1 = lambda_1
183 | self.lambda_2 = lambda_2
184 | self.theta = theta
185 | self.lambda_min = lambda_min
186 | self.lambda_max = lambda_max
187 |
188 | def __call__(self, batch, random):
189 | # random kernel
190 | if random == True:
191 | return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type,
192 | sig_min=self.sig_min, sig_max=self.sig_max,
193 | lambda_min=self.lambda_min, lambda_max=self.lambda_max)
194 |
195 | # stable kernel
196 | else:
197 | return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type,
198 | sig=self.sig,
199 | lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta)
200 |
201 | class BatchBlur(nn.Module):
202 | def __init__(self, kernel_size=21):
203 | super(BatchBlur, self).__init__()
204 | self.kernel_size = kernel_size
205 | if kernel_size % 2 == 1:
206 | self.pad = nn.ReflectionPad2d(kernel_size//2)
207 | else:
208 | self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1))
209 |
210 | def forward(self, input, kernel):
211 | B, C, H, W = input.size()
212 | input_pad = self.pad(input)
213 | H_p, W_p = input_pad.size()[-2:]
214 |
215 | if len(kernel.size()) == 2:
216 | input_CBHW = input_pad.view((C * B, 1, H_p, W_p))
217 | kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size))
218 |
219 | return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W))
220 | else:
221 | input_CBHW = input_pad.view((1, C * B, H_p, W_p))
222 | kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size))
223 | kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size))
224 |
225 | return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W))
226 |
227 |
228 | class SRMDPreprocessing(object):
229 | def __init__(self,
230 | scale,
231 | mode='bicubic',
232 | kernel_size=21,
233 | blur_type='iso_gaussian',
234 | sig=2.6,
235 | sig_min=0.2,
236 | sig_max=4.0,
237 | lambda_1=0.2,
238 | lambda_2=4.0,
239 | theta=0,
240 | lambda_min=0.2,
241 | lambda_max=4.0,
242 | noise=0.0,
243 | rgb_range=1
244 | ):
245 | '''
246 | # sig, sig_min and sig_max are used for isotropic Gaussian blurs
247 | During training phase (random=True):
248 | the width of the blur kernel is randomly selected from [sig_min, sig_max]
249 | During test phase (random=False):
250 | the width of the blur kernel is set to sig
251 |
252 | # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs
253 | During training phase (random=True):
254 | the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max]
255 | the angle value is randomly selected from [0, pi]
256 | During test phase (random=False):
257 | the eigenvalues of the covariance are set to lambda_1 and lambda_2
258 | the angle value is set to theta
259 | '''
260 | self.kernel_size = kernel_size
261 | self.scale = scale
262 | self.mode = mode
263 | self.noise = noise / (255 / rgb_range)
264 | self.rgb_range = rgb_range
265 |
266 | self.gen_kernel = Gaussin_Kernel(
267 | kernel_size=kernel_size, blur_type=blur_type,
268 | sig=sig, sig_min=sig_min, sig_max=sig_max,
269 | lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max
270 | )
271 | self.blur = BatchBlur(kernel_size=kernel_size)
272 | self.bicubic = bicubic()
273 |
274 | def __call__(self, hr_tensor, random=True):
275 | with torch.no_grad():
276 | # only downsampling
277 | if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0:
278 | B, C, H, W = hr_tensor.size()
279 | hr_blured = hr_tensor.view(-1, C, H, W)
280 | b_kernels = None
281 |
282 | # gaussian blur + downsampling
283 | else:
284 | B, C, H, W = hr_tensor.size()
285 | b_kernels = self.gen_kernel(1, random) # B degradations
286 | b_kernels = b_kernels.expand(B,-1,-1)
287 |
288 | # blur
289 | hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels)
290 | hr_blured = hr_blured.view(-1, C, H, W) # B, C, H, W
291 |
292 | # downsampling
293 | if self.mode == 'bicubic':
294 | lr_blured = self.bicubic(hr_blured, scale=1/self.scale)
295 | elif self.mode == 's-fold':
296 | lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0]
297 |
298 |
299 | # add noise
300 | if self.noise > 0:
301 | _, C, H_lr, W_lr = lr_blured.size()
302 | noise_level = torch.rand(1, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise
303 | noise = torch.randn_like(lr_blured).view(-1, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr)
304 | lr_blured.add_(noise)
305 |
306 | hr_blured = self.bicubic(lr_blured, scale= self.scale)
307 |
308 |
309 | lr_blured = utility.quantize(lr_blured, self.rgb_range)
310 | hr_blured = utility.quantize(hr_blured, self.rgb_range)
311 |
312 |
313 |
314 | return lr_blured.view(B, C, H//int(self.scale), W//int(self.scale)), hr_blured.view(B, C, H, W)
315 |
316 |
--------------------------------------------------------------------------------
/trainer_stage2.py:
--------------------------------------------------------------------------------
1 | import os
2 | import utility2
3 | import torch
4 | from decimal import Decimal
5 | import torch.nn.functional as F
6 | from utils import util
7 | from utils import util2
8 | from collections import OrderedDict
9 | import random
10 | import numpy as np
11 | import torch.nn as nn
12 |
13 |
14 | class Trainer():
15 | def __init__(self, args, loader, model_meta, model_meta_copy, my_loss, ckp):
16 | self.test_res_psnr = []
17 | self.test_res_ssim = []
18 | self.args = args
19 | self.scale = args.scale
20 | self.loss1= nn.L1Loss()
21 | self.ckp = ckp
22 | self.loader_train = loader.loader_train
23 | self.loader_test = loader.loader_test
24 | self.model_meta = model_meta
25 | self.model_meta_copy = model_meta_copy
26 |
27 | self.loss = my_loss
28 | self.meta_batch_size = args.meta_batch_size
29 | self.task_batch_size = args.task_batch_size
30 | self.task_iter = args.task_iter
31 | self.test_iter = args.test_iter
32 | self.optimizer = utility2.make_optimizer(args, self.model_meta)
33 | self.scheduler = utility2.make_scheduler(args, self.optimizer)
34 | self.lr_task = args.lr_task
35 | self.taskiter_weight=[1/self.task_iter]*self.task_iter
36 |
37 | if self.args.load != '.':
38 | self.optimizer.load_state_dict(
39 | torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
40 | )
41 | for _ in range(len(ckp.log)): self.scheduler.step()
42 |
43 | def train(self):
44 | self.scheduler.step()
45 | self.loss.step()
46 | epoch = self.scheduler.last_epoch + 1
47 |
48 | lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr))
49 | for param_group in self.optimizer.param_groups:
50 | param_group['lr'] = lr
51 |
52 | self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)))
53 | self.loss.start_log()
54 | self.model_meta.train()
55 |
56 | degrade = util2.SRMDPreprocessing(
57 | self.scale[0],
58 | kernel_size=self.args.blur_kernel,
59 | blur_type=self.args.blur_type,
60 | sig_min=self.args.sig_min,
61 | sig_max=self.args.sig_max,
62 | lambda_min=self.args.lambda_min,
63 | lambda_max=self.args.lambda_max,
64 | noise=self.args.noise
65 | )
66 |
67 | timer = utility2.timer()
68 | losses_sr = utility2.AverageMeter()
69 |
70 | for batch, (hr, _,) in enumerate(self.loader_train):
71 | hr = hr.cuda() # b, c, h, w
72 | b,c,h,w =hr.shape
73 | query_label = hr[:b//2,:,:,:]
74 | support_label = hr[b//2:,:,:,:]
75 | weights = OrderedDict(
76 | (name, param) for (name, param) in self.model_meta.model.named_parameters() if param.requires_grad )
77 |
78 | meta_grads = {name: 0 for (name, param) in weights.items()}
79 | timer.tic()
80 | loss_train = [0]*self.task_iter
81 | for i in range(self.meta_batch_size):
82 | lr_blur, hr_blur = degrade(hr) # b, c, h, w
83 | query_lr = lr_blur[:b // 2, :, :, :]
84 | support_lr = lr_blur[b // 2:, :, :, :]
85 | query_hr = hr_blur[:b//2,:,:,:]
86 | support_hr = hr_blur[b // 2:, :, :, :]
87 | sr = self.model_meta(query_lr,query_hr, weights)
88 |
89 | loss = self.loss1(sr,query_label)
90 | loss_train[0] += loss.item()/self.meta_batch_size
91 | grads = torch.autograd.grad(loss, weights.values())
92 | fast_weights = OrderedDict(
93 | (name, param - self.lr_task * grad) for ((name, param), grad) in zip(weights.items(), grads))
94 |
95 |
96 | for j in range(1,self.task_iter):
97 | sr = self.model_meta(query_lr,query_hr, fast_weights)
98 | loss = self.loss1(sr,query_label)
99 | loss_train[j] += loss.item() / self.meta_batch_size
100 | grads = torch.autograd.grad(loss, fast_weights.values())
101 | fast_weights = OrderedDict(
102 | (name, param - self.lr_task * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))
103 | #***************support***********************
104 | sr = self.model_meta(support_lr,support_hr, fast_weights)
105 | loss = self.loss1(sr, support_label)
106 | grads = torch.autograd.grad(loss, weights.values())
107 | for ((name, _), g) in zip(meta_grads.items(), grads):
108 | meta_grads[name] = meta_grads[name] + g/self.meta_batch_size
109 |
110 |
111 | hooks = []
112 | for (k, v) in self.model_meta.model.named_parameters():
113 | def get_closure():
114 | key = k
115 | def replace_grad(grad):
116 | return meta_grads[key]
117 | return replace_grad
118 | if v.requires_grad:
119 | hooks.append(v.register_hook(get_closure()))
120 |
121 | sr = self.model_meta(query_lr,query_hr, weights)
122 | loss = self.loss(sr, query_label)
123 | self.optimizer.zero_grad()
124 | loss.backward()
125 | self.optimizer.step()
126 |
127 | # Remove the hooks before next training phase
128 | for h in hooks:
129 | h.remove()
130 |
131 | timer.hold()
132 |
133 | if (batch + 1) % self.args.print_every == 0:
134 | self.ckp.write_log(
135 | 'Epoch: [{:04d}][{:04d}/{:04d}]\t'
136 | 'Loss [SR0 loss:{:.3f},SR1 loss:{:.3f},SR2 loss:{:.3f},SR3 loss:{:.3f},SR4 loss:{:.3f}]\t'
137 | 'Time [{:.1f}s]'.format(
138 | epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
139 | loss_train[0],loss_train[1],loss_train[2],loss_train[3],loss_train[4],
140 | timer.release(),
141 | ))
142 |
143 | self.loss.end_log(len(self.loader_train))
144 |
145 | # save model
146 | if epoch > self.args.st_save_epoch or (epoch %10 ==0):
147 | target = self.model_meta.get_model()
148 | model_dict = target.state_dict()
149 | torch.save(
150 | model_dict,
151 | os.path.join(self.ckp.dir, 'model', 'model_meta_{}.pt'.format(epoch))
152 | )
153 |
154 | optimzer_dict = self.optimizer.state_dict()
155 | torch.save(
156 | optimzer_dict,
157 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_meta_{}.pt'.format(epoch))
158 | )
159 |
160 | target = self.model_meta.get_model()
161 | model_dict = target.state_dict()
162 | torch.save(
163 | model_dict,
164 | os.path.join(self.ckp.dir, 'model', 'model_meta_last.pt')
165 | )
166 | optimzer_dict = self.optimizer.state_dict()
167 | torch.save(
168 | optimzer_dict,
169 | os.path.join(self.ckp.dir, 'optimzer', 'optimzer_meta_last.pt')
170 | )
171 |
172 |
173 | def test(self):
174 | self.ckp.write_log('\nEvaluation:')
175 | self.ckp.add_log(torch.zeros(1, len(self.scale)))
176 |
177 | timer_test = utility2.timer()
178 | self.model_meta.eval()
179 | degrade = util2.SRMDPreprocessing(
180 | self.scale[0],
181 | kernel_size=self.args.blur_kernel,
182 | blur_type=self.args.blur_type,
183 | sig=self.args.sig,
184 | lambda_1=self.args.lambda_1,
185 | lambda_2=self.args.lambda_2,
186 | theta=self.args.theta,
187 | noise=self.args.noise
188 | )
189 |
190 | for idx_scale, d in enumerate(self.loader_test):
191 | for idx_scale, scale in enumerate(self.scale):
192 | d.dataset.set_scale(idx_scale)
193 | eval_psnr = 0
194 | eval_ssim = 0
195 | iter_psnr = [0]*(self.test_iter+1)
196 | iter_ssim = [0]*(self.test_iter+1)
197 | for idx_img, (hr, filename) in enumerate(d):
198 | self.model_meta_copy.model.load_state_dict(self.model_meta.state_dict())
199 |
200 | hr = hr.cuda() # b, c, h, w
201 | hr = self.crop_border(hr, scale)
202 | # inference
203 | timer_test.tic()
204 | hr_batch = hr
205 | hr = utility2.quantize(hr, self.args.rgb_range)
206 | lr_blur, hr_blur = degrade(hr_batch, random=False)
207 |
208 | for iter in range(self.test_iter):
209 | if iter == 0:
210 | learning_rate = 1e-2
211 | elif iter < 5:
212 | learning_rate = 5e-3
213 | else:
214 | learning_rate = 1e-3
215 |
216 | sr = self.model_meta_copy(lr_blur,hr_blur)
217 |
218 | loss = self.loss1(sr, hr_batch)
219 |
220 | sr = utility2.quantize(sr, self.args.rgb_range)
221 |
222 | iter_psnr[iter] += utility2.calc_psnr(
223 | sr, hr, scale, self.args.rgb_range,
224 | benchmark=True
225 | )
226 | iter_ssim[iter] += utility2.calc_ssim(
227 | (sr*255).round().clamp(0,255), (hr*255).round().clamp(0,255),scale,
228 | benchmark=True
229 | )
230 |
231 | self.model_meta_copy.zero_grad()
232 | loss.backward()
233 |
234 | for param in self.model_meta_copy.parameters():
235 | if param.requires_grad:
236 | param.data.sub_(param.grad.data * learning_rate)
237 |
238 | sr = self.model_meta_copy(lr_blur,hr_blur)
239 |
240 |
241 | timer_test.hold()
242 |
243 | sr = utility2.quantize(sr, self.args.rgb_range)
244 |
245 |
246 | # metrics
247 | iter_psnr[-1] += utility2.calc_psnr(
248 | sr, hr, scale, self.args.rgb_range,
249 | benchmark=True
250 | )
251 | iter_ssim[-1] += utility2.calc_ssim(
252 | (sr*255).round().clamp(0,255), (hr*255).round().clamp(0,255),scale,
253 | benchmark=True
254 | )
255 |
256 | # save results
257 | if self.args.save_results:
258 | save_list = [sr]
259 | filename = filename[0]
260 | self.ckp.save_results(filename, save_list, scale)
261 |
262 |
263 | for t in range(self.test_iter+1):
264 | print("iter:",t,",PSNR:",iter_psnr[t]/ len(self.loader_test),",SSIM:",iter_ssim[t]/ len(self.loader_test))
265 |
266 | eval_psnr = iter_psnr[-1]
267 | eval_ssim = iter_ssim[-1]
268 |
269 | if len(self.test_res_psnr)>10:
270 | self.test_res_psnr.pop(0)
271 | self.test_res_ssim.pop(0)
272 | self.test_res_psnr.append(eval_psnr / len(self.loader_test))
273 | self.test_res_ssim.append(eval_ssim / len(self.loader_test))
274 |
275 | self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test)
276 |
277 | self.ckp.write_log(
278 | '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f} mean_PSNR: {:.3f} mean_SSIM: {:.4f}'.format(
279 | self.args.resume,
280 | self.args.data_test,
281 | scale,
282 | eval_psnr / len(self.loader_test),
283 | eval_ssim / len(self.loader_test),
284 | np.mean(self.test_res_psnr),
285 | np.mean(self.test_res_ssim)
286 | ))
287 |
288 | def crop_border(self, img_hr, scale):
289 | b, c, h, w = img_hr.size()
290 |
291 | img_hr = img_hr[:, :, :int(h//scale*scale), :int(w//scale*scale)]
292 |
293 | return img_hr
294 |
295 | def get_patch(self, img, patch_size=48, scale=4):
296 | tb, tc, th, tw = img.shape ## HR image
297 | tp = round(scale * patch_size)
298 | tx = random.randrange(0, (tw - tp))
299 | ty = random.randrange(0, (th - tp))
300 |
301 | return img[:,:,ty:ty + tp, tx:tx + tp]
302 |
303 | def crop(self, img_hr):
304 | tp_hr = []
305 | for i in range(self.task_batch_size):
306 | tp_hr.append(self.get_patch(img_hr,self.args.patch_size,self.scale[0]))
307 | tp_hr = torch.cat(tp_hr,dim=0)
308 | return tp_hr
309 |
310 | def terminate(self):
311 | if self.args.test_only:
312 | self.test()
313 | return True
314 | else:
315 | epoch = self.scheduler.last_epoch + 1
316 | return epoch >= self.args.epochs_sr
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import utility
7 |
8 |
9 | def cal_sigma(sig_x, sig_y, radians):
10 | sig_x = sig_x.view(-1, 1, 1)
11 | sig_y = sig_y.view(-1, 1, 1)
12 | radians = radians.view(-1, 1, 1)
13 |
14 | D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1)
15 | U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2),
16 | torch.cat([radians.sin(), radians.cos()], 2)], 1)
17 | sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2)))
18 |
19 | return sigma
20 |
21 |
22 | def anisotropic_gaussian_kernel(batch, kernel_size, covar):
23 | ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2
24 |
25 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
26 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
27 | xy = torch.stack([xx, yy], -1).view(batch, -1, 2)
28 |
29 | inverse_sigma = torch.inverse(covar)
30 | kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size)
31 |
32 | return kernel / kernel.sum([1, 2], keepdim=True)
33 |
34 |
35 | def isotropic_gaussian_kernel(batch, kernel_size, sigma):
36 | ax = torch.arange(kernel_size).float().cuda() - kernel_size//2
37 | xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
38 | yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
39 | kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2))
40 |
41 | return kernel / kernel.sum([1,2], keepdim=True)
42 |
43 |
44 | def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0):
45 | theta = torch.rand(batch).cuda() / 180 * math.pi
46 | lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
47 | lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
48 |
49 | covar = cal_sigma(lambda_1, lambda_2, theta)
50 | kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar)
51 | return kernel
52 |
53 |
54 | def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0):
55 | theta = torch.ones(1).cuda() * theta / 180 * math.pi
56 | lambda_1 = torch.ones(1).cuda() * lambda_1
57 | lambda_2 = torch.ones(1).cuda() * lambda_2
58 |
59 | covar = cal_sigma(lambda_1, lambda_2, theta)
60 | kernel = anisotropic_gaussian_kernel(1, kernel_size, covar)
61 | return kernel
62 |
63 |
64 | def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0):
65 | x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min
66 | k = isotropic_gaussian_kernel(batch, kernel_size, x)
67 | return k
68 |
69 |
70 | def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0):
71 | x = torch.ones(1).cuda() * sig
72 | k = isotropic_gaussian_kernel(1, kernel_size, x)
73 | return k
74 |
75 |
76 | def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0):
77 | if blur_type == 'iso_gaussian':
78 | return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max)
79 | elif blur_type == 'aniso_gaussian':
80 | return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max)
81 |
82 |
83 | def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0):
84 | if blur_type == 'iso_gaussian':
85 | return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig)
86 | elif blur_type == 'aniso_gaussian':
87 | return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta)
88 |
89 |
90 | # implementation of matlab bicubic interpolation in pytorch
91 | class bicubic(nn.Module):
92 | def __init__(self):
93 | super(bicubic, self).__init__()
94 |
95 | def cubic(self, x):
96 | absx = torch.abs(x)
97 | absx2 = torch.abs(x) * torch.abs(x)
98 | absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x)
99 |
100 | condition1 = (absx <= 1).to(torch.float32)
101 | condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32)
102 |
103 | f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2
104 | return f
105 |
106 | def contribute(self, in_size, out_size, scale):
107 | kernel_width = 4
108 | if scale < 1:
109 | kernel_width = 4 / scale
110 | x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda()
111 | x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda()
112 |
113 | u0 = x0 / scale + 0.5 * (1 - 1 / scale)
114 | u1 = x1 / scale + 0.5 * (1 - 1 / scale)
115 |
116 | left0 = torch.floor(u0 - kernel_width / 2)
117 | left1 = torch.floor(u1 - kernel_width / 2)
118 |
119 | P = np.ceil(kernel_width) + 2
120 |
121 | indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
122 | indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
123 |
124 | mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0)
125 | mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0)
126 |
127 | if scale < 1:
128 | weight0 = scale * self.cubic(mid0 * scale)
129 | weight1 = scale * self.cubic(mid1 * scale)
130 | else:
131 | weight0 = self.cubic(mid0)
132 | weight1 = self.cubic(mid1)
133 |
134 | weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2))
135 | weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2))
136 |
137 | indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0)
138 | indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0)
139 |
140 | kill0 = torch.eq(weight0, 0)[0][0]
141 | kill1 = torch.eq(weight1, 0)[0][0]
142 |
143 | weight0 = weight0[:, :, kill0 == 0]
144 | weight1 = weight1[:, :, kill1 == 0]
145 |
146 | indice0 = indice0[:, :, kill0 == 0]
147 | indice1 = indice1[:, :, kill1 == 0]
148 |
149 | return weight0, weight1, indice0, indice1
150 |
151 | def forward(self, input, scale=1/4):
152 | b, c, h, w = input.shape
153 |
154 | weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
155 | weight0 = weight0[0]
156 | weight1 = weight1[0]
157 |
158 | indice0 = indice0[0].long()
159 | indice1 = indice1[0].long()
160 |
161 | out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4))
162 | out = (torch.sum(out, dim=3))
163 | A = out.permute(0, 1, 3, 2)
164 |
165 | out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4))
166 | out = out.sum(3).permute(0, 1, 3, 2)
167 |
168 | return out
169 |
170 |
171 | class Gaussin_Kernel(object):
172 | def __init__(self, kernel_size=21, blur_type='iso_gaussian',
173 | sig=2.6, sig_min=0.2, sig_max=4.0,
174 | lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0):
175 | self.kernel_size = kernel_size
176 | self.blur_type = blur_type
177 |
178 | self.sig = sig
179 | self.sig_min = sig_min
180 | self.sig_max = sig_max
181 |
182 | self.lambda_1 = lambda_1
183 | self.lambda_2 = lambda_2
184 | self.theta = theta
185 | self.lambda_min = lambda_min
186 | self.lambda_max = lambda_max
187 |
188 | def __call__(self, batch, random):
189 | # random kernel
190 | if random == True:
191 | return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type,
192 | sig_min=self.sig_min, sig_max=self.sig_max,
193 | lambda_min=self.lambda_min, lambda_max=self.lambda_max)
194 |
195 | # stable kernel
196 | else:
197 | return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type,
198 | sig=self.sig,
199 | lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta)
200 |
201 | class BatchBlur(nn.Module):
202 | def __init__(self, kernel_size=21):
203 | super(BatchBlur, self).__init__()
204 | self.kernel_size = kernel_size
205 | if kernel_size % 2 == 1:
206 | self.pad = nn.ReflectionPad2d(kernel_size//2)
207 | else:
208 | self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1))
209 |
210 | def forward(self, input, kernel):
211 | B, C, H, W = input.size()
212 | input_pad = self.pad(input)
213 | H_p, W_p = input_pad.size()[-2:]
214 |
215 | if len(kernel.size()) == 2:
216 | input_CBHW = input_pad.view((C * B, 1, H_p, W_p))
217 | kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size))
218 |
219 | return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W))
220 | else:
221 | input_CBHW = input_pad.view((1, C * B, H_p, W_p))
222 | kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size))
223 | kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size))
224 |
225 | return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W))
226 |
227 |
228 | class SRMDPreprocessing(object):
229 | def __init__(self,
230 | scale,
231 | mode='bicubic',
232 | kernel_size=21,
233 | blur_type='iso_gaussian',
234 | sig=2.6,
235 | sig_min=0.2,
236 | sig_max=4.0,
237 | lambda_1=0.2,
238 | lambda_2=4.0,
239 | theta=0,
240 | lambda_min=0.2,
241 | lambda_max=4.0,
242 | noise=0.0,
243 | rgb_range=1
244 | ):
245 | '''
246 | # sig, sig_min and sig_max are used for isotropic Gaussian blurs
247 | During training phase (random=True):
248 | the width of the blur kernel is randomly selected from [sig_min, sig_max]
249 | During test phase (random=False):
250 | the width of the blur kernel is set to sig
251 |
252 | # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs
253 | During training phase (random=True):
254 | the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max]
255 | the angle value is randomly selected from [0, pi]
256 | During test phase (random=False):
257 | the eigenvalues of the covariance are set to lambda_1 and lambda_2
258 | the angle value is set to theta
259 | '''
260 | self.kernel_size = kernel_size
261 | self.scale = scale
262 | self.mode = mode
263 | self.noise = noise / (255/rgb_range)
264 | self.rgb_range=rgb_range
265 |
266 | self.gen_kernel = Gaussin_Kernel(
267 | kernel_size=kernel_size, blur_type=blur_type,
268 | sig=sig, sig_min=sig_min, sig_max=sig_max,
269 | lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max
270 | )
271 | self.blur = BatchBlur(kernel_size=kernel_size)
272 | self.bicubic = bicubic()
273 |
274 | def __call__(self, hr_tensor, random=True):
275 | with torch.no_grad():
276 | # only downsampling
277 | if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0:
278 | B, C, H, W = hr_tensor.size()
279 | hr_blured = hr_tensor.view(-1, C, H, W)
280 | b_kernels = None
281 |
282 | # gaussian blur + downsampling
283 | else:
284 | B, C, H, W = hr_tensor.size()
285 | b_kernels = self.gen_kernel(B, random) # B degradations
286 |
287 | # blur
288 | hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels)
289 | hr_blured = hr_blured.view(-1, C, H, W) # B, C, H, W
290 |
291 | # downsampling
292 | if self.mode == 'bicubic':
293 | lr_blured = self.bicubic(hr_blured, scale=1/self.scale)
294 | elif self.mode == 's-fold':
295 | lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0]
296 |
297 |
298 | # add noise
299 | noise_level = None
300 | if self.noise > 0:
301 | _, C, H_lr, W_lr = lr_blured.size()
302 | noise_level = torch.rand(B, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise
303 | noise = torch.randn_like(lr_blured).view(-1, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr)
304 | lr_blured.add_(noise)
305 |
306 | lr_blured = utility.quantize(lr_blured, self.rgb_range)
307 |
308 | if isinstance(noise_level, float):
309 | noise_level = torch.Tensor([noise_level]).view(1,1,1,1).to(lr_blured.device)
310 | return lr_blured.view(B, C, H//int(self.scale), W//int(self.scale)), [b_kernels,noise_level]
311 |
312 |
313 | class BicubicPreprocessing(object):
314 | def __init__(self,
315 | scale,
316 | rgb_range=1
317 | ):
318 | '''
319 | # sig, sig_min and sig_max are used for isotropic Gaussian blurs
320 | During training phase (random=True):
321 | the width of the blur kernel is randomly selected from [sig_min, sig_max]
322 | During test phase (random=False):
323 | the width of the blur kernel is set to sig
324 |
325 | # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs
326 | During training phase (random=True):
327 | the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max]
328 | the angle value is randomly selected from [0, pi]
329 | During test phase (random=False):
330 | the eigenvalues of the covariance are set to lambda_1 and lambda_2
331 | the angle value is set to theta
332 | '''
333 | self.scale = scale
334 | self.rgb_range=rgb_range
335 | self.bicubic = bicubic()
336 |
337 | def __call__(self, hr_tensor, random=True):
338 | with torch.no_grad():
339 | B, C, H, W = hr_tensor.size()
340 |
341 | lr = self.bicubic(hr_tensor, scale=1/self.scale)
342 | lr_bic = self.bicubic(lr, scale=self.scale)
343 |
344 |
345 | lr = utility.quantize(lr, self.rgb_range)
346 | lr_bic = utility.quantize(lr_bic, self.rgb_range)
347 |
348 | return lr.view(B, C, H//int(self.scale), W//int(self.scale)),lr_bic.view(B, C, H, W)
349 |
--------------------------------------------------------------------------------