├── README.md ├── code ├── __init__.py ├── main.py ├── template.py ├── data │ ├── demo.py │ ├── benchmark.py │ ├── __init__.py │ ├── ntire_val.py │ ├── div2k.py │ ├── ntire.py │ ├── common.py │ └── srdata.py ├── loss │ ├── vgg.py │ ├── discriminator.py │ ├── adversarial.py │ └── __init__.py ├── model │ ├── mdsr.py │ ├── edsr.py │ ├── ddbpn.py │ ├── common.py │ ├── rcan.py │ ├── __init__.py │ └── frn_updown.py ├── trainer.py ├── utility.py ├── dataloader.py └── option.py └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | import os 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | if checkpoint.ok: 14 | loader = data.Data(args) 15 | model = model.Model(args, checkpoint) 16 | torch.cuda.empty_cache() 17 | # print(model) 18 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 19 | t = Trainer(args, loader, model, loss, checkpoint) 20 | while not t.terminate(): 21 | t.train() 22 | t.test() 23 | 24 | checkpoint.done() 25 | 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # PyBuilder 58 | target/ 59 | 60 | # PyTorch 61 | *.pt 62 | *.pdf 63 | *.png 64 | *.txt 65 | *.swp 66 | .vscode 67 | -------------------------------------------------------------------------------- /code/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.lr_decay = 100 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.lr_decay = 500 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.lr_decay = 150 39 | 40 | -------------------------------------------------------------------------------- /code/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import scipy.misc as misc 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, train=False): 13 | self.args = args 14 | self.name = 'Demo' 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = False 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.split(self.filelist[idx])[-1] 28 | filename, _ = os.path.splitext(filename) 29 | lr = misc.imread(self.filelist[idx]) 30 | lr = common.set_channel([lr], self.args.n_colors)[0] 31 | 32 | return common.np2Tensor([lr], self.args.rgb_range)[0], -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /code/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | from torch.autograd import Variable 8 | 9 | class VGG(nn.Module): 10 | def __init__(self, conv_index, rgb_range=1): 11 | super(VGG, self).__init__() 12 | vgg_features = models.vgg19(pretrained=True).features 13 | modules = [m for m in vgg_features] 14 | if conv_index == '22': 15 | self.vgg = nn.Sequential(*modules[:8]) 16 | elif conv_index == '54': 17 | self.vgg = nn.Sequential(*modules[:35]) 18 | 19 | vgg_mean = (0.485, 0.456, 0.406) 20 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 21 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 22 | self.vgg.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /code/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | import scipy.misc as misc 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Benchmark(srdata.SRData): 13 | def __init__(self, args, train=True): 14 | super(Benchmark, self).__init__(args, train, benchmark=True) 15 | 16 | def _scan(self): 17 | list_hr = [] 18 | list_lr = [[] for _ in self.scale] 19 | for entry in os.scandir(self.dir_hr): 20 | filename = os.path.splitext(entry.name)[0] 21 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 22 | for si, s in enumerate(self.scale): 23 | list_lr[si].append(os.path.join( 24 | self.dir_lr, 25 | 'X{}/{}x{}{}'.format(s, filename, s, self.ext) 26 | )) 27 | 28 | list_hr.sort() 29 | for l in list_lr: 30 | l.sort() 31 | 32 | return list_hr, list_lr 33 | 34 | def _set_filesystem(self, dir_data): 35 | self.apath = os.path.join(dir_data, 'benchmark', self.args.data_test) 36 | self.dir_hr = os.path.join(self.apath, 'HR') 37 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 38 | self.ext = '.png' 39 | -------------------------------------------------------------------------------- /code/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | def __init__(self, args, gan_type='GAN'): 7 | super(Discriminator, self).__init__() 8 | 9 | in_channels = 3 10 | out_channels = 64 11 | depth = 7 12 | #bn = not gan_type == 'WGAN_GP' 13 | bn = True 14 | act = nn.LeakyReLU(negative_slope=0.2, inplace=True) 15 | 16 | m_features = [ 17 | common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act) 18 | ] 19 | for i in range(depth): 20 | in_channels = out_channels 21 | if i % 2 == 1: 22 | stride = 1 23 | out_channels *= 2 24 | else: 25 | stride = 2 26 | m_features.append(common.BasicBlock( 27 | in_channels, out_channels, 3, stride=stride, bn=bn, act=act 28 | )) 29 | 30 | self.features = nn.Sequential(*m_features) 31 | 32 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 33 | m_classifier = [ 34 | nn.Linear(out_channels * patch_size**2, 1024), 35 | act, 36 | nn.Linear(1024, 1) 37 | ] 38 | self.classifier = nn.Sequential(*m_classifier) 39 | 40 | def forward(self, x): 41 | features = self.features(x) 42 | output = self.classifier(features.view(features.size(0), -1)) 43 | 44 | return output 45 | 46 | -------------------------------------------------------------------------------- /code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | from dataloader import MSDataLoader 4 | from torch.utils.data.dataloader import default_collate 5 | 6 | class Data: 7 | def __init__(self, args): 8 | kwargs = {} 9 | if not args.cpu: 10 | kwargs['collate_fn'] = default_collate 11 | kwargs['pin_memory'] = True 12 | else: 13 | kwargs['collate_fn'] = default_collate 14 | kwargs['pin_memory'] = False 15 | 16 | self.loader_train = None 17 | if not args.test_only: 18 | module_train = import_module('data.' + args.data_train.lower()) 19 | trainset = getattr(module_train, args.data_train)(args) 20 | self.loader_train = MSDataLoader( 21 | args, 22 | trainset, 23 | batch_size=args.batch_size, 24 | shuffle=True, 25 | **kwargs 26 | ) 27 | 28 | if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']: 29 | if not args.benchmark_noise: 30 | module_test = import_module('data.benchmark') 31 | testset = getattr(module_test, 'Benchmark')(args, train=False) 32 | else: 33 | module_test = import_module('data.benchmark_noise') 34 | testset = getattr(module_test, 'BenchmarkNoise')( 35 | args, 36 | train=False 37 | ) 38 | 39 | else: 40 | module_test = import_module('data.' + args.data_test.lower()) 41 | testset = getattr(module_test, args.data_test)(args, train=False) 42 | 43 | self.loader_test = MSDataLoader( 44 | args, 45 | testset, 46 | batch_size=1, 47 | shuffle=False, 48 | **kwargs 49 | ) 50 | -------------------------------------------------------------------------------- /code/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return MDSR(args) 7 | 8 | class MDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(MDSR, self).__init__() 11 | n_resblocks = args.n_resblocks 12 | n_feats = args.n_feats 13 | kernel_size = 3 14 | self.scale_idx = 0 15 | 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | self.pre_process = nn.ModuleList([ 25 | nn.Sequential( 26 | common.ResBlock(conv, n_feats, 5, act=act), 27 | common.ResBlock(conv, n_feats, 5, act=act) 28 | ) for _ in args.scale 29 | ]) 30 | 31 | m_body = [ 32 | common.ResBlock( 33 | conv, n_feats, kernel_size, act=act 34 | ) for _ in range(n_resblocks) 35 | ] 36 | m_body.append(conv(n_feats, n_feats, kernel_size)) 37 | 38 | self.upsample = nn.ModuleList([ 39 | common.Upsampler( 40 | conv, s, n_feats, act=False 41 | ) for s in args.scale 42 | ]) 43 | 44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 45 | 46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 47 | 48 | self.head = nn.Sequential(*m_head) 49 | self.body = nn.Sequential(*m_body) 50 | self.tail = nn.Sequential(*m_tail) 51 | 52 | def forward(self, x): 53 | x = self.sub_mean(x) 54 | x = self.head(x) 55 | x = self.pre_process[self.scale_idx](x) 56 | 57 | res = self.body(x) 58 | res += x 59 | 60 | x = self.upsample[self.scale_idx](res) 61 | x = self.tail(x) 62 | x = self.add_mean(x) 63 | 64 | return x 65 | 66 | def set_scale(self, scale_idx): 67 | self.scale_idx = scale_idx 68 | 69 | -------------------------------------------------------------------------------- /code/data/ntire_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | import scipy.misc as misc 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class NTIRE_VAL(srdata.SRData): 13 | def __init__(self, args, train=True): 14 | super(NTIRE_VAL, self).__init__(args, train) 15 | self.repeat = args.test_every / (args.n_train / args.batch_size) 16 | 17 | def _scan(self): 18 | list_hr = [] 19 | list_lr = [[] for _ in self.scale] 20 | if self.train: 21 | NotImplementedError 22 | else: 23 | idx_begin = 0 24 | idx_end = self.args.n_val 25 | 26 | for j in [1,2]: 27 | for i in range(idx_begin + 1, idx_end + 1): 28 | filename = 'cam{}_0{}'.format(j, i) 29 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 30 | 31 | list_lr[0].append(os.path.join( 32 | self.dir_lr, 33 | '{}{}'.format(filename, self.ext) 34 | )) 35 | 36 | return list_hr, list_lr 37 | 38 | def _set_filesystem(self, dir_data): 39 | self.apath = dir_data + '/ntire2019' 40 | self.dir_hr = os.path.join(self.apath, 'val') 41 | self.dir_lr = os.path.join(self.apath, 'val') 42 | self.ext = '.png' 43 | 44 | def _name_hrbin(self): 45 | return os.path.join( 46 | self.apath, 47 | 'bin', 48 | '{}_bin_HR.npy'.format(self.split) 49 | ) 50 | 51 | def _name_lrbin(self, scale): 52 | return os.path.join( 53 | self.apath, 54 | 'bin', 55 | '{}_bin_LR_X{}.npy'.format(self.split, scale) 56 | ) 57 | 58 | def __len__(self): 59 | if self.train: 60 | return int(len(self.images_hr) * self.repeat) 61 | else: 62 | return len(self.images_hr) 63 | 64 | def _get_index(self, idx): 65 | if self.train: 66 | return idx % len(self.images_hr) 67 | else: 68 | return idx 69 | 70 | -------------------------------------------------------------------------------- /code/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | import scipy.misc as misc 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class DIV2K(srdata.SRData): 13 | def __init__(self, args, train=True): 14 | super(DIV2K, self).__init__(args, train) 15 | self.repeat = args.test_every // (args.n_train // args.batch_size) 16 | 17 | def _scan(self): 18 | list_hr = [] 19 | list_lr = [[] for _ in self.scale] 20 | if self.train: 21 | idx_begin = 0 22 | idx_end = self.args.n_train 23 | else: 24 | idx_begin = self.args.n_train 25 | idx_end = self.args.offset_val + self.args.n_val 26 | 27 | for j in range(1,3): 28 | for i in range(idx_begin + 1, idx_end + 1): 29 | filename = 'cam{}_0{}'.format(j, i) 30 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 31 | 32 | list_lr[0].append(os.path.join( 33 | self.dir_lr, 34 | '{}{}'.format(filename, self.ext) 35 | )) 36 | 37 | return list_hr, list_lr 38 | 39 | def _set_filesystem(self, dir_data): 40 | self.apath = dir_data + '/ntire2019' 41 | self.dir_hr = os.path.join(self.apath, 'HR') 42 | self.dir_lr = os.path.join(self.apath, 'LR') 43 | self.ext = '.png' 44 | 45 | def _name_hrbin(self): 46 | return os.path.join( 47 | self.apath, 48 | 'bin', 49 | '{}_bin_HR.npy'.format(self.split) 50 | ) 51 | 52 | def _name_lrbin(self, scale): 53 | return os.path.join( 54 | self.apath, 55 | 'bin', 56 | '{}_bin_LR_X{}.npy'.format(self.split, scale) 57 | ) 58 | 59 | def __len__(self): 60 | if self.train: 61 | return len(self.images_hr) * self.repeat 62 | else: 63 | return len(self.images_hr) 64 | 65 | def _get_index(self, idx): 66 | if self.train: 67 | return idx % len(self.images_hr) 68 | else: 69 | return idx 70 | 71 | -------------------------------------------------------------------------------- /code/data/ntire.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | import scipy.misc as misc 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class NTIRE(srdata.SRData): 13 | def __init__(self, args, train=True): 14 | super(NTIRE, self).__init__(args, train) 15 | self.repeat = args.test_every / (args.n_train / args.batch_size) 16 | 17 | def _scan(self): 18 | list_hr = [] 19 | list_lr = [[] for _ in self.scale] 20 | if self.train: 21 | idx_begin = 0 22 | idx_end = self.args.n_train 23 | else: 24 | idx_begin = self.args.offset_val 25 | idx_end = self.args.offset_val + self.args.n_val 26 | 27 | for j in [1,2]: 28 | for i in range(idx_begin + 1, idx_end + 1): 29 | filename = 'cam{}_0{}'.format(j, i) 30 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 31 | 32 | list_lr[0].append(os.path.join( 33 | self.dir_lr, 34 | '{}{}'.format(filename, self.ext) 35 | )) 36 | 37 | return list_hr, list_lr 38 | 39 | def _set_filesystem(self, dir_data): 40 | self.apath = dir_data + '/ntire2019' 41 | self.dir_hr = os.path.join(self.apath, 'HR') 42 | self.dir_lr = os.path.join(self.apath, 'LR') 43 | self.ext = '.png' 44 | 45 | def _name_hrbin(self): 46 | return os.path.join( 47 | self.apath, 48 | 'bin', 49 | '{}_bin_HR.npy'.format(self.split) 50 | ) 51 | 52 | def _name_lrbin(self, scale): 53 | return os.path.join( 54 | self.apath, 55 | 'bin', 56 | '{}_bin_LR_X{}.npy'.format(self.split, scale) 57 | ) 58 | 59 | def __len__(self): 60 | if self.train: 61 | return int(len(self.images_hr) * self.repeat) 62 | else: 63 | return len(self.images_hr) 64 | 65 | def _get_index(self, idx): 66 | if self.train: 67 | return idx % len(self.images_hr) 68 | else: 69 | return idx 70 | 71 | -------------------------------------------------------------------------------- /code/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.io as sio 5 | import skimage.color as sc 6 | import skimage.transform as st 7 | 8 | import torch 9 | from torchvision import transforms 10 | 11 | def get_patch(img_in, img_tar, patch_size, scale, multi_scale=False): 12 | ih, iw = img_in.shape[:2] 13 | 14 | p = scale if multi_scale else 1 15 | tp = p * patch_size 16 | ip = tp // scale 17 | 18 | ix = random.randrange(0, iw - ip + 1) 19 | iy = random.randrange(0, ih - ip + 1) 20 | tx, ty = scale * ix, scale * iy 21 | 22 | img_in = img_in[iy:iy + ip, ix:ix + ip, :] 23 | img_tar = img_tar[ty:ty + tp, tx:tx + tp, :] 24 | 25 | return img_in, img_tar 26 | 27 | def set_channel(l, n_channel): 28 | def _set_channel(img): 29 | if img.ndim == 2: 30 | img = np.expand_dims(img, axis=2) 31 | 32 | c = img.shape[2] 33 | if n_channel == 1 and c == 3: 34 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 35 | elif n_channel == 3 and c == 1: 36 | img = np.concatenate([img] * n_channel, 2) 37 | 38 | return img 39 | 40 | return [_set_channel(_l) for _l in l] 41 | 42 | def np2Tensor(l, rgb_range): 43 | def _np2Tensor(img): 44 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 45 | tensor = torch.from_numpy(np_transpose).float() 46 | tensor.mul_(rgb_range / 255) 47 | 48 | return tensor 49 | 50 | return [_np2Tensor(_l) for _l in l] 51 | 52 | def add_noise(x, noise='.'): 53 | if noise is not '.': 54 | noise_type = noise[0] 55 | noise_value = int(noise[1:]) 56 | if noise_type == 'G': 57 | noises = np.random.normal(scale=noise_value, size=x.shape) 58 | noises = noises.round() 59 | elif noise_type == 'S': 60 | noises = np.random.poisson(x * noise_value) / noise_value 61 | noises = noises - noises.mean(axis=0).mean(axis=0) 62 | 63 | x_noise = x.astype(np.int16) + noises.astype(np.int16) 64 | x_noise = x_noise.clip(0, 255).astype(np.uint8) 65 | return x_noise 66 | else: 67 | return x 68 | 69 | def augment(l, hflip=True, rot=True): 70 | hflip = hflip and random.random() < 0.5 71 | vflip = rot and random.random() < 0.5 72 | rot90 = rot and random.random() < 0.5 73 | 74 | def _augment(img): 75 | if hflip: img = img[:, ::-1, :] 76 | if vflip: img = img[::-1, :, :] 77 | if rot90: img = img.transpose(1, 0, 2) 78 | 79 | return img 80 | 81 | return [_augment(_l) for _l in l] 82 | -------------------------------------------------------------------------------- /code/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return EDSR(args) 7 | 8 | class EDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(EDSR, self).__init__() 11 | 12 | n_resblock = args.n_resblocks 13 | n_feats = args.n_feats 14 | kernel_size = 3 15 | scale = args.scale[0] 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | # define head module 23 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 24 | 25 | # define body module 26 | m_body = [ 27 | common.ResBlock( 28 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 29 | ) for _ in range(n_resblock) 30 | ] 31 | m_body.append(conv(n_feats, n_feats, kernel_size)) 32 | 33 | # define tail module 34 | m_tail = [ 35 | common.Upsampler(conv, scale, n_feats, act=False), 36 | conv(n_feats, args.n_colors, kernel_size) 37 | ] 38 | 39 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 40 | 41 | self.head = nn.Sequential(*m_head) 42 | self.body = nn.Sequential(*m_body) 43 | self.tail = nn.Sequential(*m_tail) 44 | 45 | def forward(self, x): 46 | x = self.sub_mean(x) 47 | x = self.head(x) 48 | 49 | res = self.body(x) 50 | res += x 51 | 52 | x = self.tail(res) 53 | x = self.add_mean(x) 54 | 55 | return x 56 | 57 | def load_state_dict(self, state_dict, strict=True): 58 | own_state = self.state_dict() 59 | for name, param in state_dict.items(): 60 | if name in own_state: 61 | if isinstance(param, nn.Parameter): 62 | param = param.data 63 | try: 64 | own_state[name].copy_(param) 65 | except Exception: 66 | if name.find('tail') == -1: 67 | raise RuntimeError('While copying the parameter named {}, ' 68 | 'whose dimensions in the model are {} and ' 69 | 'whose dimensions in the checkpoint are {}.' 70 | .format(name, own_state[name].size(), param.size())) 71 | elif strict: 72 | if name.find('tail') == -1: 73 | raise KeyError('unexpected key "{}" in state_dict' 74 | .format(name)) 75 | 76 | -------------------------------------------------------------------------------- /code/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from model import common 3 | from loss import discriminator 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | 11 | class Adversarial(nn.Module): 12 | def __init__(self, args, gan_type): 13 | super(Adversarial, self).__init__() 14 | self.gan_type = gan_type 15 | self.gan_k = args.gan_k 16 | self.discriminator = discriminator.Discriminator(args, gan_type) 17 | if gan_type != 'WGAN_GP': 18 | self.optimizer = utility.make_optimizer(args, self.discriminator) 19 | else: 20 | self.optimizer = optim.Adam( 21 | self.discriminator.parameters(), 22 | betas=(0, 0.9), eps=1e-8, lr=1e-5 23 | ) 24 | self.scheduler = utility.make_scheduler(args, self.optimizer) 25 | 26 | def forward(self, fake, real): 27 | fake_detach = fake.detach() 28 | 29 | self.loss = 0 30 | for _ in range(self.gan_k): 31 | self.optimizer.zero_grad() 32 | d_fake = self.discriminator(fake_detach) 33 | d_real = self.discriminator(real) 34 | if self.gan_type == 'GAN': 35 | label_fake = torch.zeros_like(d_fake) 36 | label_real = torch.ones_like(d_real) 37 | loss_d \ 38 | = F.binary_cross_entropy_with_logits(d_fake, label_fake) \ 39 | + F.binary_cross_entropy_with_logits(d_real, label_real) 40 | elif self.gan_type.find('WGAN') >= 0: 41 | loss_d = (d_fake - d_real).mean() 42 | if self.gan_type.find('GP') >= 0: 43 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 44 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 45 | hat.requires_grad = True 46 | d_hat = self.discriminator(hat) 47 | gradients = torch.autograd.grad( 48 | outputs=d_hat.sum(), inputs=hat, 49 | retain_graph=True, create_graph=True, only_inputs=True 50 | )[0] 51 | gradients = gradients.view(gradients.size(0), -1) 52 | gradient_norm = gradients.norm(2, dim=1) 53 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 54 | loss_d += gradient_penalty 55 | 56 | # Discriminator update 57 | self.loss += loss_d.item() 58 | loss_d.backward() 59 | self.optimizer.step() 60 | 61 | if self.gan_type == 'WGAN': 62 | for p in self.discriminator.parameters(): 63 | p.data.clamp_(-1, 1) 64 | 65 | self.loss /= self.gan_k 66 | 67 | d_fake_for_g = self.discriminator(fake) 68 | if self.gan_type == 'GAN': 69 | loss_g = F.binary_cross_entropy_with_logits( 70 | d_fake_for_g, label_real 71 | ) 72 | elif self.gan_type.find('WGAN') >= 0: 73 | loss_g = -d_fake_for_g.mean() 74 | 75 | # Generator loss 76 | return loss_g 77 | 78 | def state_dict(self, *args, **kwargs): 79 | state_discriminator = self.discriminator.state_dict(*args, **kwargs) 80 | state_optimizer = self.optimizer.state_dict() 81 | 82 | return dict(**state_discriminator, **state_optimizer) 83 | 84 | # Some references 85 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 86 | # OR 87 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 88 | -------------------------------------------------------------------------------- /code/model/ddbpn.py: -------------------------------------------------------------------------------- 1 | # Deep Back-Projection Networks For Super-Resolution 2 | # https://arxiv.org/abs/1803.02735 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return DDBPN(args) 12 | 13 | def projection_conv(in_channels, out_channels, scale, up=True): 14 | kernel_size, stride, padding = { 15 | 2: (6, 2, 2), 16 | 4: (8, 4, 2), 17 | 8: (12, 8, 2) 18 | }[scale] 19 | if up: 20 | conv_f = nn.ConvTranspose2d 21 | else: 22 | conv_f = nn.Conv2d 23 | 24 | return conv_f( 25 | in_channels, out_channels, kernel_size, 26 | stride=stride, padding=padding 27 | ) 28 | 29 | class DenseProjection(nn.Module): 30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 31 | super(DenseProjection, self).__init__() 32 | if bottleneck: 33 | self.bottleneck = nn.Sequential(*[ 34 | nn.Conv2d(in_channels, nr, 1), 35 | nn.PReLU(nr) 36 | ]) 37 | inter_channels = nr 38 | else: 39 | self.bottleneck = None 40 | inter_channels = in_channels 41 | 42 | self.conv_1 = nn.Sequential(*[ 43 | projection_conv(inter_channels, nr, scale, up), 44 | nn.PReLU(nr) 45 | ]) 46 | self.conv_2 = nn.Sequential(*[ 47 | projection_conv(nr, inter_channels, scale, not up), 48 | nn.PReLU(inter_channels) 49 | ]) 50 | self.conv_3 = nn.Sequential(*[ 51 | projection_conv(inter_channels, nr, scale, up), 52 | nn.PReLU(nr) 53 | ]) 54 | 55 | def forward(self, x): 56 | if self.bottleneck is not None: 57 | x = self.bottleneck(x) 58 | 59 | a_0 = self.conv_1(x) 60 | b_0 = self.conv_2(a_0) 61 | e = b_0.sub(x) 62 | a_1 = self.conv_3(e) 63 | 64 | out = a_0.add(a_1) 65 | 66 | return out 67 | 68 | class DDBPN(nn.Module): 69 | def __init__(self, args): 70 | super(DDBPN, self).__init__() 71 | scale = args.scale[0] 72 | 73 | n0 = 128 74 | nr = 32 75 | self.depth = 6 76 | 77 | rgb_mean = (0.4488, 0.4371, 0.4040) 78 | rgb_std = (1.0, 1.0, 1.0) 79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 80 | initial = [ 81 | nn.Conv2d(args.n_colors, n0, 3, padding=1), 82 | nn.PReLU(n0), 83 | nn.Conv2d(n0, nr, 1), 84 | nn.PReLU(nr) 85 | ] 86 | self.initial = nn.Sequential(*initial) 87 | 88 | self.upmodules = nn.ModuleList() 89 | self.downmodules = nn.ModuleList() 90 | channels = nr 91 | for i in range(self.depth): 92 | self.upmodules.append( 93 | DenseProjection(channels, nr, scale, True, i > 1) 94 | ) 95 | if i != 0: 96 | channels += nr 97 | 98 | channels = nr 99 | for i in range(self.depth - 1): 100 | self.downmodules.append( 101 | DenseProjection(channels, nr, scale, False, i != 0) 102 | ) 103 | channels += nr 104 | 105 | reconstruction = [ 106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 107 | ] 108 | self.reconstruction = nn.Sequential(*reconstruction) 109 | 110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 111 | 112 | def forward(self, x): 113 | x = self.sub_mean(x) 114 | x = self.initial(x) 115 | 116 | h_list = [] 117 | l_list = [] 118 | for i in range(self.depth - 1): 119 | if i == 0: 120 | l = x 121 | else: 122 | l = torch.cat(l_list, dim=1) 123 | h_list.append(self.upmodules[i](l)) 124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 125 | 126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 127 | out = self.reconstruction(torch.cat(h_list, dim=1)) 128 | out = self.add_mean(out) 129 | 130 | return out 131 | 132 | -------------------------------------------------------------------------------- /code/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.autograd import Variable 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size//2), bias=bias) 13 | class invPixelShuffle(nn.Module): 14 | 15 | def __init__(self, ratio=2): 16 | super(invPixelShuffle, self).__init__() 17 | self.ratio = ratio 18 | def forward(self, tensor): 19 | ratio = self.ratio 20 | b = tensor.size(0) 21 | ch = tensor.size(1) 22 | y = tensor.size(2) 23 | x = tensor.size(3) 24 | assert x % ratio ==0 and y % ratio ==0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio) 25 | 26 | return tensor.view(b, ch, y//ratio, ratio, x//ratio, ratio).permute(0,1,3,5,2,4).contiguous().view(b, -1, y//ratio, x//ratio) 27 | 28 | class MeanShift(nn.Conv2d): 29 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 30 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 31 | std = torch.Tensor(rgb_std) 32 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 33 | self.weight.data.div_(std.view(3, 1, 1, 1)) 34 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 35 | self.bias.data.div_(std) 36 | self.requires_grad = False 37 | 38 | class BasicBlock(nn.Sequential): 39 | def __init__( 40 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 41 | bn=True, act=nn.ReLU(True)): 42 | 43 | m = [nn.Conv2d( 44 | in_channels, out_channels, kernel_size, 45 | padding=(kernel_size//2), stride=stride, bias=bias) 46 | ] 47 | if bn: m.append(nn.BatchNorm2d(out_channels)) 48 | if act is not None: m.append(act) 49 | super(BasicBlock, self).__init__(*m) 50 | 51 | class ResBlock(nn.Module): 52 | def __init__( 53 | self, conv, n_feat, kernel_size, 54 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 55 | 56 | super(ResBlock, self).__init__() 57 | m = [] 58 | for i in range(2): 59 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 60 | if bn: m.append(nn.BatchNorm2d(n_feat)) 61 | if i == 0: m.append(act) 62 | 63 | self.body = nn.Sequential(*m) 64 | self.res_scale = res_scale 65 | 66 | def forward(self, x): 67 | res = self.body(x).mul(self.res_scale) 68 | res += x 69 | 70 | return res 71 | 72 | class Upsampler(nn.Sequential): 73 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 74 | 75 | m = [] 76 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 77 | for _ in range(int(math.log(scale, 2))): 78 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 79 | m.append(nn.PixelShuffle(2)) 80 | if bn: m.append(nn.BatchNorm2d(n_feat)) 81 | if act: m.append(act()) 82 | elif scale == 3: 83 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 84 | m.append(nn.PixelShuffle(3)) 85 | if bn: m.append(nn.BatchNorm2d(n_feat)) 86 | if act: m.append(act()) 87 | else: 88 | raise NotImplementedError 89 | 90 | super(Upsampler, self).__init__(*m) 91 | class invUpsampler(nn.Sequential): 92 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 93 | 94 | m = [] 95 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 96 | for _ in range(int(math.log(scale, 2))): 97 | m.append(invPixelShuffle(2)) 98 | m.append(conv(n_feat*4, n_feat, 3, bias)) 99 | if bn: m.append(nn.BatchNorm2d(n_feat)) 100 | if act: m.append(act()) 101 | elif scale == 3: 102 | m.append(invPixelShuffle(3)) 103 | m.append(conv(n_feat*9, n_feat, 3, bias)) 104 | if bn: m.append(nn.BatchNorm2d(n_feat)) 105 | if act: m.append(act()) 106 | else: 107 | raise NotImplementedError 108 | 109 | super(invUpsampler, self).__init__(*m) 110 | -------------------------------------------------------------------------------- /code/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import scipy.misc as misc 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class SRData(data.Dataset): 12 | def __init__(self, args, train=True, benchmark=False): 13 | self.args = args 14 | self.train = train 15 | self.split = 'train' if train else 'test' 16 | self.benchmark = benchmark 17 | self.scale = args.scale 18 | self.idx_scale = 0 19 | 20 | self._set_filesystem(args.dir_data) 21 | 22 | def _load_bin(): 23 | self.images_hr = np.load(self._name_hrbin()) 24 | self.images_lr = [ 25 | np.load(self._name_lrbin(s)) for s in self.scale 26 | ] 27 | 28 | if args.ext == 'img' or benchmark: 29 | self.images_hr, self.images_lr = self._scan() 30 | elif args.ext.find('sep') >= 0: 31 | self.images_hr, self.images_lr = self._scan() 32 | if args.ext.find('reset') >= 0: 33 | print('Preparing seperated binary files') 34 | for v in self.images_hr: 35 | hr = misc.imread(v) 36 | name_sep = v.replace(self.ext, '.npy') 37 | np.save(name_sep, hr) 38 | for si, s in enumerate(self.scale): 39 | for v in self.images_lr[si]: 40 | lr = misc.imread(v) 41 | name_sep = v.replace(self.ext, '.npy') 42 | np.save(name_sep, lr) 43 | 44 | self.images_hr = [ 45 | v.replace(self.ext, '.npy') for v in self.images_hr 46 | ] 47 | self.images_lr = [ 48 | [v.replace(self.ext, '.npy') for v in self.images_lr[i]] 49 | for i in range(len(self.scale)) 50 | ] 51 | 52 | elif args.ext.find('bin') >= 0: 53 | try: 54 | if args.ext.find('reset') >= 0: 55 | raise IOError 56 | print('Loading a binary file') 57 | _load_bin() 58 | except: 59 | print('Preparing a binary file') 60 | bin_path = os.path.join(self.apath, 'bin') 61 | if not os.path.isdir(bin_path): 62 | os.mkdir(bin_path) 63 | 64 | list_hr, list_lr = self._scan() 65 | hr = [misc.imread(f) for f in list_hr] 66 | np.save(self._name_hrbin(), hr) 67 | del hr 68 | for si, s in enumerate(self.scale): 69 | lr_scale = [misc.imread(f) for f in list_lr[si]] 70 | np.save(self._name_lrbin(s), lr_scale) 71 | del lr_scale 72 | _load_bin() 73 | else: 74 | print('Please define data type') 75 | 76 | def _scan(self): 77 | raise NotImplementedError 78 | 79 | def _set_filesystem(self, dir_data): 80 | raise NotImplementedError 81 | 82 | def _name_hrbin(self): 83 | raise NotImplementedError 84 | 85 | def _name_lrbin(self, scale): 86 | raise NotImplementedError 87 | 88 | def __getitem__(self, idx): 89 | lr, hr, filename = self._load_file(idx) 90 | lr, hr = self._get_patch(lr, hr) 91 | lr, hr = common.set_channel([lr, hr], self.args.n_colors) 92 | lr_tensor, hr_tensor = common.np2Tensor([lr, hr], self.args.rgb_range) 93 | return lr_tensor, hr_tensor, filename 94 | 95 | def __len__(self): 96 | return len(self.images_hr) 97 | 98 | def _get_index(self, idx): 99 | return idx 100 | 101 | def _load_file(self, idx): 102 | idx = self._get_index(idx) 103 | lr = self.images_lr[self.idx_scale][idx] 104 | hr = self.images_hr[idx] 105 | if self.args.ext == 'img' or self.benchmark: 106 | filename = hr 107 | lr = misc.imread(lr) 108 | hr = misc.imread(hr) 109 | elif self.args.ext.find('sep') >= 0: 110 | filename = hr 111 | lr = np.load(lr) 112 | hr = np.load(hr) 113 | else: 114 | filename = str(idx + 1) 115 | 116 | filename = os.path.splitext(os.path.split(filename)[-1])[0] 117 | 118 | return lr, hr, filename 119 | 120 | def _get_patch(self, lr, hr): 121 | patch_size = self.args.patch_size 122 | scale = 1#self.scale[self.idx_scale] 123 | multi_scale = len(self.scale) > 1 124 | if self.train: 125 | lr, hr = common.get_patch( 126 | lr, hr, patch_size, scale, multi_scale=multi_scale 127 | ) 128 | lr, hr = common.augment([lr, hr]) 129 | lr = common.add_noise(lr, self.args.noise) 130 | else: 131 | ih, iw = lr.shape[0:2] 132 | hr = hr[0:ih*scale, 0:iw*scale] 133 | 134 | return lr, hr 135 | 136 | def set_scale(self, idx_scale): 137 | self.idx_scale = idx_scale 138 | 139 | -------------------------------------------------------------------------------- /code/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class Loss(nn.modules.loss._Loss): 15 | def __init__(self, args, ckp): 16 | super(Loss, self).__init__() 17 | print('Preparing loss function:') 18 | 19 | self.n_GPUs = args.n_GPUs 20 | self.loss = [] 21 | self.loss_module = nn.ModuleList() 22 | for loss in args.loss.split('+'): 23 | weight, loss_type = loss.split('*') 24 | if loss_type == 'MSE': 25 | loss_function = nn.MSELoss() 26 | elif loss_type == 'L1': 27 | loss_function = nn.L1Loss() 28 | elif loss_type.find('VGG') >= 0: 29 | module = import_module('loss.vgg') 30 | loss_function = getattr(module, 'VGG')( 31 | loss_type[3:], 32 | rgb_range=args.rgb_range 33 | ) 34 | elif loss_type.find('GAN') >= 0: 35 | module = import_module('loss.adversarial') 36 | loss_function = getattr(module, 'Adversarial')( 37 | args, 38 | loss_type 39 | ) 40 | 41 | self.loss.append({ 42 | 'type': loss_type, 43 | 'weight': float(weight), 44 | 'function': loss_function} 45 | ) 46 | if loss_type.find('GAN') >= 0: 47 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 48 | 49 | if len(self.loss) > 1: 50 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 51 | 52 | for l in self.loss: 53 | if l['function'] is not None: 54 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 55 | self.loss_module.append(l['function']) 56 | 57 | self.log = torch.Tensor() 58 | 59 | device = torch.device('cpu' if args.cpu else 'cuda') 60 | self.loss_module.to(device) 61 | if args.precision == 'half': self.loss_module.half() 62 | if not args.cpu and args.n_GPUs > 1: 63 | self.loss_module = nn.DataParallel( 64 | self.loss_module, range(args.n_GPUs) 65 | ) 66 | 67 | if args.load != '.': self.load(ckp.dir, cpu=args.cpu) 68 | 69 | def forward(self, sr, hr): 70 | losses = [] 71 | for i, l in enumerate(self.loss): 72 | if l['function'] is not None: 73 | loss = l['function'](sr, hr) 74 | effective_loss = l['weight'] * loss 75 | losses.append(effective_loss) 76 | self.log[-1, i] += effective_loss.item() 77 | elif l['type'] == 'DIS': 78 | self.log[-1, i] += self.loss[i - 1]['function'].loss 79 | 80 | loss_sum = sum(losses) 81 | if len(self.loss) > 1: 82 | self.log[-1, -1] += loss_sum.item() 83 | 84 | return loss_sum 85 | 86 | def step(self): 87 | for l in self.get_loss_module(): 88 | if hasattr(l, 'scheduler'): 89 | l.scheduler.step() 90 | 91 | def start_log(self): 92 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 93 | 94 | def end_log(self, n_batches): 95 | self.log[-1].div_(n_batches) 96 | 97 | def display_loss(self, batch): 98 | n_samples = batch + 1 99 | log = [] 100 | for l, c in zip(self.loss, self.log[-1]): 101 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 102 | 103 | return ''.join(log) 104 | 105 | def plot_loss(self, apath, epoch): 106 | axis = np.linspace(1, epoch, epoch) 107 | for i, l in enumerate(self.loss): 108 | label = '{} Loss'.format(l['type']) 109 | fig = plt.figure() 110 | plt.title(label) 111 | plt.plot(axis, self.log[:, i].numpy(), label=label) 112 | plt.legend() 113 | plt.xlabel('Epochs') 114 | plt.ylabel('Loss') 115 | plt.grid(True) 116 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) 117 | plt.close(fig) 118 | 119 | def get_loss_module(self): 120 | if self.n_GPUs == 1: 121 | return self.loss_module 122 | else: 123 | return self.loss_module.module 124 | 125 | def save(self, apath): 126 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 127 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 128 | 129 | def load(self, apath, cpu=False): 130 | if cpu: 131 | kwargs = {'map_location': lambda storage, loc: storage} 132 | else: 133 | kwargs = {} 134 | 135 | self.load_state_dict(torch.load( 136 | os.path.join(apath, 'loss.pt'), 137 | **kwargs 138 | )) 139 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 140 | for l in self.loss_module: 141 | if hasattr(l, 'scheduler'): 142 | for _ in range(len(self.log)): l.scheduler.step() 143 | 144 | -------------------------------------------------------------------------------- /code/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | 5 | import utility 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | from tqdm import tqdm 10 | 11 | class Trainer(): 12 | def __init__(self, args, loader, my_model, my_loss, ckp): 13 | self.args = args 14 | self.scale = args.scale 15 | 16 | self.ckp = ckp 17 | self.loader_train = loader.loader_train 18 | self.loader_test = loader.loader_test 19 | self.model = my_model 20 | self.loss = my_loss 21 | self.optimizer = utility.make_optimizer(args, self.model) 22 | self.scheduler = utility.make_scheduler(args, self.optimizer) 23 | 24 | if self.args.load != '.': 25 | self.optimizer.load_state_dict( 26 | torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 27 | ) 28 | for _ in range(len(ckp.log)): self.scheduler.step() 29 | 30 | self.error_last = 1e22 31 | 32 | def train(self): 33 | torch.cuda.empty_cache() 34 | self.scheduler.step() 35 | self.loss.step() 36 | epoch = self.scheduler.last_epoch + 1 37 | lr = self.scheduler.get_lr()[0] 38 | 39 | self.ckp.write_log( 40 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 41 | ) 42 | self.loss.start_log() 43 | self.model.train() 44 | 45 | timer_data, timer_model = utility.timer(), utility.timer() 46 | for batch, (lr, hr, _, idx_scale) in tqdm(enumerate(self.loader_train)): 47 | lr, hr = self.prepare([lr, hr]) 48 | timer_data.hold() 49 | timer_model.tic() 50 | 51 | self.optimizer.zero_grad() 52 | sr = self.model(lr, idx_scale) 53 | loss = self.loss(sr, hr) 54 | 55 | sr_ = self.model.model.forward_(hr) 56 | loss += self.loss(sr_, hr)#self.loss(sr_, hr) 57 | if loss.item() < self.args.skip_threshold * self.error_last: 58 | loss.backward() 59 | self.optimizer.step() 60 | else: 61 | print('Skip this batch {}! (Loss: {})'.format( 62 | batch + 1, loss.item() 63 | )) 64 | 65 | timer_model.hold() 66 | 67 | if (batch + 1) % self.args.print_every == 0: 68 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 69 | (batch + 1) * self.args.batch_size, 70 | len(self.loader_train.dataset), 71 | self.loss.display_loss(batch), 72 | timer_model.release(), 73 | timer_data.release())) 74 | 75 | timer_data.tic() 76 | 77 | self.loss.end_log(len(self.loader_train)) 78 | self.error_last = self.loss.log[-1, -1] 79 | 80 | def test(self): 81 | epoch = self.scheduler.last_epoch + 1 82 | self.ckp.write_log('\nEvaluation:') 83 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 84 | self.model.eval() 85 | 86 | timer_test = utility.timer() 87 | with torch.no_grad(): 88 | for idx_scale, scale in enumerate(self.scale): 89 | eval_acc = 0 90 | self.loader_test.dataset.set_scale(idx_scale) 91 | tqdm_test = tqdm(self.loader_test, ncols=80) 92 | for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test): 93 | filename = filename[0] 94 | no_eval = (hr.nelement() == 1) 95 | if not no_eval: 96 | lr, hr = self.prepare([lr, hr]) 97 | else: 98 | lr = self.prepare([lr])[0] 99 | scale_ = scale *self.args.px 100 | 101 | x, y = (lr.size(-2)//scale_)*scale_, (lr.size(-1)//scale_)*scale_ 102 | sr = self.model(lr[:,:,:x,:y], idx_scale) 103 | sr = utility.quantize(sr, self.args.rgb_range) 104 | sr = torch.cat([sr, lr[:,:,x:,:y]], -2) 105 | sr = torch.cat([sr, lr[:,:,:,y:]], -1) 106 | save_list = [sr] 107 | if not no_eval: 108 | eval_acc += utility.calc_psnr( 109 | sr, hr, scale, self.args.rgb_range, 110 | benchmark=self.loader_test.dataset.benchmark 111 | ) 112 | save_list.extend([lr, hr]) 113 | 114 | if self.args.save_results: 115 | self.ckp.save_results(filename, save_list, scale) 116 | 117 | self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test) 118 | best = self.ckp.log.max(0) 119 | self.ckp.write_log( 120 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 121 | self.args.data_test, 122 | scale, 123 | self.ckp.log[-1, idx_scale], 124 | best[0][idx_scale], 125 | best[1][idx_scale] + 1 126 | ) 127 | ) 128 | 129 | self.ckp.write_log( 130 | 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 131 | ) 132 | if not self.args.test_only: 133 | self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) 134 | 135 | def prepare(self, l, volatile=False): 136 | device = torch.device('cpu' if self.args.cpu else 'cuda') 137 | def _prepare(tensor): 138 | if self.args.precision == 'half': tensor = tensor.half() 139 | return tensor.to(device) 140 | 141 | return [_prepare(_l) for _l in l] 142 | 143 | def terminate(self): 144 | if self.args.test_only: 145 | self.test() 146 | return True 147 | else: 148 | epoch = self.scheduler.last_epoch + 1 149 | return epoch >= self.args.epochs 150 | 151 | -------------------------------------------------------------------------------- /code/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from functools import reduce 6 | 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | 11 | import numpy as np 12 | import scipy.misc as misc 13 | 14 | import torch 15 | import torch.optim as optim 16 | import torch.optim.lr_scheduler as lrs 17 | 18 | class timer(): 19 | def __init__(self): 20 | self.acc = 0 21 | self.tic() 22 | 23 | def tic(self): 24 | self.t0 = time.time() 25 | 26 | def toc(self): 27 | return time.time() - self.t0 28 | 29 | def hold(self): 30 | self.acc += self.toc() 31 | 32 | def release(self): 33 | ret = self.acc 34 | self.acc = 0 35 | 36 | return ret 37 | 38 | def reset(self): 39 | self.acc = 0 40 | 41 | class checkpoint(): 42 | def __init__(self, args): 43 | self.args = args 44 | self.ok = True 45 | self.log = torch.Tensor() 46 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 47 | 48 | if args.load == '.': 49 | if args.save == '.': args.save = now 50 | self.dir = '../experiment/' + args.save 51 | else: 52 | self.dir = '../experiment/' + args.load 53 | if not os.path.exists(self.dir): 54 | args.load = '.' 55 | else: 56 | self.log = torch.load(self.dir + '/psnr_log.pt') 57 | print('Continue from epoch {}...'.format(len(self.log))) 58 | 59 | if args.reset: 60 | os.system('rm -rf ' + self.dir) 61 | args.load = '.' 62 | 63 | def _make_dir(path): 64 | if not os.path.exists(path): os.makedirs(path) 65 | 66 | _make_dir(self.dir) 67 | _make_dir(self.dir + '/model') 68 | _make_dir(self.dir + '/results') 69 | 70 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' 71 | self.log_file = open(self.dir + '/log.txt', open_type) 72 | with open(self.dir + '/config.txt', open_type) as f: 73 | f.write(now + '\n\n') 74 | for arg in vars(args): 75 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 76 | f.write('\n') 77 | 78 | def save(self, trainer, epoch, is_best=False): 79 | trainer.model.save(self.dir, epoch, is_best=is_best) 80 | trainer.loss.save(self.dir) 81 | trainer.loss.plot_loss(self.dir, epoch) 82 | 83 | self.plot_psnr(epoch) 84 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) 85 | torch.save( 86 | trainer.optimizer.state_dict(), 87 | os.path.join(self.dir, 'optimizer.pt') 88 | ) 89 | 90 | def add_log(self, log): 91 | self.log = torch.cat([self.log, log]) 92 | 93 | def write_log(self, log, refresh=False): 94 | print(log) 95 | self.log_file.write(log + '\n') 96 | if refresh: 97 | self.log_file.close() 98 | self.log_file = open(self.dir + '/log.txt', 'a') 99 | 100 | def done(self): 101 | self.log_file.close() 102 | 103 | def plot_psnr(self, epoch): 104 | axis = np.linspace(1, epoch, epoch) 105 | label = 'SR on {}'.format(self.args.data_test) 106 | fig = plt.figure() 107 | plt.title(label) 108 | for idx_scale, scale in enumerate(self.args.scale): 109 | plt.plot( 110 | axis, 111 | self.log[:, idx_scale].numpy(), 112 | label='Scale {}'.format(scale) 113 | ) 114 | plt.legend() 115 | plt.xlabel('Epochs') 116 | plt.ylabel('PSNR') 117 | plt.grid(True) 118 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) 119 | plt.close(fig) 120 | 121 | def save_results(self, filename, save_list, scale): 122 | filename = '{}/results/{}'.format(self.dir, filename)#, scale) 123 | postfix = ('', 'LR', 'HR') 124 | for v, p in zip(save_list, postfix): 125 | normalized = v[0].data.mul(255 / self.args.rgb_range) 126 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 127 | misc.imsave('{}{}.png'.format(filename, p), ndarr) 128 | 129 | def quantize(img, rgb_range): 130 | pixel_range = 255 / rgb_range 131 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 132 | 133 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): 134 | diff = (sr - hr).data.div(rgb_range) 135 | shave = scale 136 | valid = diff[:, :, shave:-shave, shave:-shave] 137 | mse = valid.pow(2).mean() 138 | 139 | return -10 * math.log10(mse) 140 | 141 | def make_optimizer(args, my_model): 142 | trainable = filter(lambda x: x.requires_grad, my_model.parameters()) 143 | 144 | if args.optimizer == 'SGD': 145 | optimizer_function = optim.SGD 146 | kwargs = {'momentum': args.momentum} 147 | elif args.optimizer == 'ADAM': 148 | optimizer_function = optim.Adam 149 | kwargs = { 150 | 'betas': (args.beta1, args.beta2), 151 | 'eps': args.epsilon 152 | } 153 | elif args.optimizer == 'RMSprop': 154 | optimizer_function = optim.RMSprop 155 | kwargs = {'eps': args.epsilon} 156 | 157 | kwargs['lr'] = args.lr 158 | kwargs['weight_decay'] = args.weight_decay 159 | 160 | return optimizer_function(trainable, **kwargs) 161 | 162 | def make_scheduler(args, my_optimizer): 163 | if args.decay_type == 'step': 164 | scheduler = lrs.StepLR( 165 | my_optimizer, 166 | step_size=args.lr_decay, 167 | gamma=args.gamma 168 | ) 169 | elif args.decay_type.find('step') >= 0: 170 | milestones = args.decay_type.split('_') 171 | milestones.pop(0) 172 | milestones = list(map(lambda x: int(x), milestones)) 173 | scheduler = lrs.MultiStepLR( 174 | my_optimizer, 175 | milestones=milestones, 176 | gamma=args.gamma 177 | ) 178 | 179 | return scheduler 180 | 181 | -------------------------------------------------------------------------------- /code/model/rcan.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | import torch 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return RCAN(args) 7 | 8 | ## Channel Attention (CA) Layer 9 | class CALayer(nn.Module): 10 | def __init__(self, channel, reduction=16): 11 | super(CALayer, self).__init__() 12 | # global average pooling: feature --> point 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.max_pool = nn.AdaptiveMaxPool2d(1) 15 | # feature channel downscale and upscale --> channel weight 16 | self.conv_du = nn.Sequential( 17 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 20 | ) 21 | 22 | self.sigmoid = nn.Sigmoid() 23 | 24 | def forward(self, x): 25 | y = self.avg_pool(x) 26 | # z = self.max_pool(x) 27 | y = self.conv_du(y) 28 | # z = self.conv_du(z) 29 | return x * self.sigmoid(y)# + z) 30 | 31 | class SALayer(nn.Module): 32 | def __init__(self, channel, reduction=16): 33 | super(SALayer, self).__init__() 34 | # global average pooling: feature --> point 35 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 36 | self.max_pool = nn.AdaptiveMaxPool2d(1) 37 | # feature channel downscale and upscale --> channel weight 38 | self.conv_du = nn.Sequential( 39 | nn.Conv2d(2, 1, 7, padding=3, bias=True), 40 | ) 41 | 42 | self.sigmoid = nn.Sigmoid() 43 | 44 | def forward(self, x): 45 | y = torch.mean(x,1, keepdim=True) 46 | z, _ = torch.max(x,1, keepdim=True) 47 | q = torch.cat((y,z), 1) 48 | q = self.conv_du(q) 49 | 50 | return x * self.sigmoid(q) 51 | ## Residual Channel Attention Block (RCAB) 52 | class RCAB(nn.Module): 53 | def __init__( 54 | self, conv, n_feat, kernel_size, reduction, 55 | bias=True, bn=False, act=nn.LeakyReLU(0.2, True), res_scale=1): 56 | 57 | super(RCAB, self).__init__() 58 | modules_body = [] 59 | for i in range(2): 60 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 61 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 62 | modules_body.append(act) 63 | modules_body.append(SALayer(n_feat, reduction)) 64 | self.body = nn.Sequential(*modules_body) 65 | self.res_scale = res_scale 66 | 67 | def forward(self, x): 68 | res = self.body(x) 69 | #res = self.body(x).mul(self.res_scale) 70 | res += x 71 | return res 72 | 73 | ## Residual Group (RG) 74 | class ResidualGroup(nn.Module): 75 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 76 | super(ResidualGroup, self).__init__() 77 | modules_body = [] 78 | modules_body = [ 79 | RCAB( 80 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 81 | for _ in range(n_resblocks)] 82 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 83 | self.body = nn.Sequential(*modules_body) 84 | 85 | def forward(self, x): 86 | res = self.body(x) 87 | res += x 88 | return res 89 | 90 | ## Residual Channel Attention Network (RCAN) 91 | class RCAN(nn.Module): 92 | def __init__(self, args, conv=common.default_conv): 93 | super(RCAN, self).__init__() 94 | 95 | n_resgroups = args.n_resgroups 96 | n_resblocks = args.n_resblocks 97 | n_feats = args.n_feats 98 | kernel_size = 3 99 | reduction = args.reduction 100 | scale = args.scale[0] 101 | act = nn.ReLU(True) 102 | 103 | # RGB mean for DIV2K 104 | rgb_mean = (0.4488, 0.4371, 0.4040) 105 | rgb_std = (1.0, 1.0, 1.0) 106 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 107 | 108 | # define head module 109 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 110 | 111 | # define body module 112 | modules_body = [ 113 | ResidualGroup( 114 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 115 | for _ in range(n_resgroups)] 116 | 117 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 118 | 119 | # define tail module 120 | modules_tail = [ 121 | # common.Upsampler(conv, scale, n_feats, act=False), 122 | conv(n_feats, args.n_colors, kernel_size)] 123 | 124 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 125 | 126 | self.head = nn.Sequential(*modules_head) 127 | self.body = nn.Sequential(*modules_body) 128 | self.tail = nn.Sequential(*modules_tail) 129 | 130 | def forward(self, x): 131 | x = self.sub_mean(x) 132 | x = self.head(x) 133 | 134 | res = self.body(x) 135 | res += x 136 | 137 | x = self.tail(res) 138 | x = self.add_mean(x) 139 | 140 | return x 141 | 142 | def load_state_dict(self, state_dict, strict=False): 143 | own_state = self.state_dict() 144 | for name, param in state_dict.items(): 145 | if name in own_state: 146 | if isinstance(param, nn.Parameter): 147 | param = param.data 148 | try: 149 | own_state[name].copy_(param) 150 | except Exception: 151 | if name.find('tail') >= 0: 152 | print('Replace pre-trained upsampler to new one...') 153 | else: 154 | raise RuntimeError('While copying the parameter named {}, ' 155 | 'whose dimensions in the model are {} and ' 156 | 'whose dimensions in the checkpoint are {}.' 157 | .format(name, own_state[name].size(), param.size())) 158 | elif strict: 159 | if name.find('tail') == -1: 160 | raise KeyError('unexpected key "{}" in state_dict' 161 | .format(name)) 162 | 163 | if strict: 164 | missing = set(own_state.keys()) - set(state_dict.keys()) 165 | if len(missing) > 0: 166 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 167 | -------------------------------------------------------------------------------- /code/dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import threading 3 | import queue 4 | import random 5 | import collections 6 | 7 | import torch 8 | import torch.multiprocessing as multiprocessing 9 | 10 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 11 | _remove_worker_pids, _error_if_any_worker_fails 12 | from torch.utils.data.dataloader import DataLoader 13 | from torch.utils.data.dataloader import _DataLoaderIter 14 | from torch.utils.data.dataloader import ManagerWatchdog 15 | from torch.utils.data.dataloader import _pin_memory_loop 16 | from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL 17 | 18 | from torch.utils.data.dataloader import ExceptionWrapper 19 | from torch.utils.data.dataloader import _use_shared_memory 20 | from torch.utils.data.dataloader import numpy_type_map 21 | from torch.utils.data.dataloader import default_collate 22 | from torch.utils.data.dataloader import pin_memory_batch 23 | from torch.utils.data.dataloader import _SIGCHLD_handler_set 24 | from torch.utils.data.dataloader import _set_SIGCHLD_handler 25 | 26 | if sys.version_info[0] == 2: 27 | import Queue as queue 28 | else: 29 | import queue 30 | 31 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 32 | try: 33 | global _use_shared_memory 34 | _use_shared_memory = True 35 | _set_worker_signal_handlers() 36 | 37 | torch.set_num_threads(1) 38 | random.seed(seed) 39 | torch.manual_seed(seed) 40 | data_queue.cancel_join_thread() 41 | 42 | if init_fn is not None: 43 | init_fn(worker_id) 44 | 45 | watchdog = ManagerWatchdog() 46 | 47 | while watchdog.is_alive(): 48 | try: 49 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 50 | except queue.Empty: 51 | continue 52 | 53 | if r is None: 54 | assert done_event.is_set() 55 | return 56 | elif done_event.is_set(): 57 | continue 58 | idx, batch_indices = r 59 | try: 60 | idx_scale = 0 61 | if len(scale) > 1 and dataset.train: 62 | idx_scale = random.randrange(0, len(scale)) 63 | dataset.set_scale(idx_scale) 64 | 65 | samples = collate_fn([dataset[i] for i in batch_indices]) 66 | samples.append(idx_scale) 67 | except Exception: 68 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 69 | else: 70 | data_queue.put((idx, samples)) 71 | except KeyboardInterrupt: 72 | pass 73 | 74 | class _MSDataLoaderIter(_DataLoaderIter): 75 | def __init__(self, loader): 76 | self.dataset = loader.dataset 77 | self.scale = loader.scale 78 | self.collate_fn = loader.collate_fn 79 | self.batch_sampler = loader.batch_sampler 80 | self.num_workers = loader.num_workers 81 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 82 | self.timeout = loader.timeout 83 | 84 | self.sample_iter = iter(self.batch_sampler) 85 | 86 | base_seed = torch.LongTensor(1).random_().item() 87 | 88 | if self.num_workers > 0: 89 | self.worker_init_fn = loader.worker_init_fn 90 | self.worker_queue_idx = 0 91 | self.worker_result_queue = multiprocessing.Queue() 92 | self.batches_outstanding = 0 93 | self.worker_pids_set = False 94 | self.shutdown = False 95 | self.send_idx = 0 96 | self.rcvd_idx = 0 97 | self.reorder_dict = {} 98 | self.done_event = multiprocessing.Event() 99 | 100 | base_seed = torch.LongTensor(1).random_()[0] 101 | 102 | self.index_queues = [] 103 | self.workers = [] 104 | for i in range(self.num_workers): 105 | index_queue = multiprocessing.Queue() 106 | index_queue.cancel_join_thread() 107 | w = multiprocessing.Process( 108 | target=_ms_loop, 109 | args=( 110 | self.dataset, 111 | index_queue, 112 | self.worker_result_queue, 113 | self.done_event, 114 | self.collate_fn, 115 | self.scale, 116 | base_seed + i, 117 | self.worker_init_fn, 118 | i 119 | ) 120 | ) 121 | w.start() 122 | self.index_queues.append(index_queue) 123 | self.workers.append(w) 124 | 125 | if self.pin_memory: 126 | self.data_queue = queue.Queue() 127 | pin_memory_thread = threading.Thread( 128 | target=_pin_memory_loop, 129 | args=( 130 | self.worker_result_queue, 131 | self.data_queue, 132 | torch.cuda.current_device(), 133 | self.done_event 134 | ) 135 | ) 136 | pin_memory_thread.daemon = True 137 | pin_memory_thread.start() 138 | self.pin_memory_thread = pin_memory_thread 139 | else: 140 | self.data_queue = self.worker_result_queue 141 | 142 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 143 | _set_SIGCHLD_handler() 144 | self.worker_pids_set = True 145 | 146 | for _ in range(2 * self.num_workers): 147 | self._put_indices() 148 | 149 | class MSDataLoader(DataLoader): 150 | def __init__( 151 | self, args, dataset, batch_size=1, shuffle=False, 152 | sampler=None, batch_sampler=None, 153 | collate_fn=default_collate, pin_memory=False, drop_last=False, 154 | timeout=0, worker_init_fn=None): 155 | 156 | super(MSDataLoader, self).__init__( 157 | dataset, 158 | batch_size=batch_size, 159 | shuffle=shuffle, 160 | sampler=sampler, 161 | batch_sampler=batch_sampler, 162 | num_workers=args.n_threads, 163 | collate_fn=collate_fn, 164 | pin_memory=pin_memory, 165 | drop_last=drop_last, 166 | timeout=timeout, 167 | worker_init_fn=worker_init_fn 168 | ) 169 | 170 | self.scale = args.scale 171 | 172 | def __iter__(self): 173 | return _MSDataLoaderIter(self) 174 | -------------------------------------------------------------------------------- /code/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | 23 | module = import_module('model.' + args.model.lower()) 24 | self.model = module.make_model(args).to(self.device) 25 | if args.precision == 'half': self.model.half() 26 | 27 | if not args.cpu and args.n_GPUs > 1: 28 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 29 | 30 | self.load( 31 | ckp.dir, 32 | pre_train=args.pre_train, 33 | resume=args.resume, 34 | cpu=args.cpu 35 | ) 36 | if args.print_model: print(self.model) 37 | 38 | def forward(self, x, idx_scale): 39 | self.idx_scale = idx_scale 40 | target = self.get_model() 41 | if hasattr(target, 'set_scale'): 42 | target.set_scale(idx_scale) 43 | 44 | if self.self_ensemble and not self.training: 45 | if self.chop: 46 | forward_function = self.forward_chop 47 | else: 48 | forward_function = self.model.forward 49 | 50 | return self.forward_x8(x, forward_function) 51 | elif self.chop and not self.training: 52 | return self.forward_chop(x) 53 | else: 54 | return self.model(x) 55 | 56 | def get_model(self): 57 | if self.n_GPUs == 1: 58 | return self.model 59 | else: 60 | return self.model.module 61 | 62 | def state_dict(self, **kwargs): 63 | target = self.get_model() 64 | return target.state_dict(**kwargs) 65 | 66 | def save(self, apath, epoch, is_best=False): 67 | target = self.get_model() 68 | torch.save( 69 | target.state_dict(), 70 | os.path.join(apath, 'model', 'model_latest.pt') 71 | ) 72 | if is_best: 73 | torch.save( 74 | target.state_dict(), 75 | os.path.join(apath, 'model', 'model_best.pt') 76 | ) 77 | 78 | if self.save_models: 79 | torch.save( 80 | target.state_dict(), 81 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 82 | ) 83 | 84 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 85 | if cpu: 86 | kwargs = {'map_location': lambda storage, loc: storage} 87 | else: 88 | kwargs = {} 89 | 90 | if resume == -1: 91 | self.get_model().load_state_dict( 92 | torch.load( 93 | os.path.join(apath, 'model', 'model_latest.pt'), 94 | **kwargs 95 | ), 96 | strict=False 97 | ) 98 | elif resume == 0: 99 | if pre_train != '.': 100 | print('Loading model from {}'.format(pre_train)) 101 | self.get_model().load_state_dict( 102 | torch.load(pre_train, **kwargs), 103 | strict=False 104 | ) 105 | else: 106 | self.get_model().load_state_dict( 107 | torch.load( 108 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 109 | **kwargs 110 | ), 111 | strict=False 112 | ) 113 | 114 | def forward_chop(self, x, shave=10, min_size=160000): 115 | scale = self.scale[self.idx_scale] 116 | n_GPUs = min(self.n_GPUs, 4) 117 | b, c, h, w = x.size() 118 | h_half, w_half = h // 2, w // 2 119 | h_size, w_size = h_half + shave, w_half + shave 120 | lr_list = [ 121 | x[:, :, 0:h_size, 0:w_size], 122 | x[:, :, 0:h_size, (w - w_size):w], 123 | x[:, :, (h - h_size):h, 0:w_size], 124 | x[:, :, (h - h_size):h, (w - w_size):w]] 125 | 126 | if w_size * h_size < min_size: 127 | sr_list = [] 128 | for i in range(0, 4, n_GPUs): 129 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 130 | sr_batch = self.model(lr_batch) 131 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 132 | else: 133 | sr_list = [ 134 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 135 | for patch in lr_list 136 | ] 137 | 138 | h, w = scale * h, scale * w 139 | h_half, w_half = scale * h_half, scale * w_half 140 | h_size, w_size = scale * h_size, scale * w_size 141 | shave *= scale 142 | 143 | output = x.new(b, c, h, w) 144 | output[:, :, 0:h_half, 0:w_half] \ 145 | = sr_list[0][:, :, 0:h_half, 0:w_half] 146 | output[:, :, 0:h_half, w_half:w] \ 147 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 148 | output[:, :, h_half:h, 0:w_half] \ 149 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 150 | output[:, :, h_half:h, w_half:w] \ 151 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 152 | 153 | return output 154 | 155 | def forward_x8(self, x, forward_function): 156 | def _transform(v, op): 157 | if self.precision != 'single': v = v.float() 158 | 159 | v2np = v.data.cpu().numpy() 160 | if op == 'v': 161 | tfnp = v2np[:, :, :, ::-1].copy() 162 | elif op == 'h': 163 | tfnp = v2np[:, :, ::-1, :].copy() 164 | elif op == 't': 165 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 166 | 167 | ret = torch.Tensor(tfnp).to(self.device) 168 | if self.precision == 'half': ret = ret.half() 169 | 170 | return ret 171 | 172 | lr_list = [x] 173 | for tf in 'v', 'h', 't': 174 | lr_list.extend([_transform(t, tf) for t in lr_list]) 175 | 176 | sr_list = [forward_function(aug) for aug in lr_list] 177 | for i in range(len(sr_list)): 178 | if i > 3: 179 | sr_list[i] = _transform(sr_list[i], 't') 180 | if i % 4 > 1: 181 | sr_list[i] = _transform(sr_list[i], 'h') 182 | if (i % 4) % 2 == 1: 183 | sr_list[i] = _transform(sr_list[i], 'v') 184 | 185 | output_cat = torch.cat(sr_list, dim=0) 186 | output = output_cat.mean(dim=0, keepdim=True) 187 | 188 | return output 189 | 190 | -------------------------------------------------------------------------------- /code/model/frn_updown.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | def make_model(args, parent=False): 6 | return FRN_UPDOWN(args) 7 | 8 | 9 | 10 | 11 | ## Channel Attention (CA) Layer 12 | class CALayer(nn.Module): 13 | def __init__(self, channel, reduction=16): 14 | super(CALayer, self).__init__() 15 | # global average pooling: feature --> point 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | # feature channel downscale and upscale --> channel weight 18 | self.conv_du = nn.Sequential( 19 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 22 | ) 23 | 24 | self.sigmoid = nn.Sigmoid() 25 | self.ch = channel 26 | def forward(self, x): 27 | y = self.avg_pool(x) 28 | y = self.conv_du(y) 29 | y = self.sigmoid(y) 30 | 31 | return x * y 32 | ## Residual Channel Attention Block (RCAB) 33 | 34 | class RCAB(nn.Module): 35 | def __init__( 36 | self, conv, n_feat, kernel_size, reduction, 37 | bias=True, bn=False, act=nn.LeakyReLU(0.2, True), res_scale=1, px=1): 38 | 39 | super(RCAB, self).__init__() 40 | modules_body = [] 41 | # modules_body.append(common.invPixelShuffle(2)) 42 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 43 | modules_body.append(act) 44 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 45 | 46 | if px != 1: 47 | modules_body.append(common.invPixelShuffle(px)) 48 | modules_body.append(CALayer(n_feat*px**2, reduction)) 49 | if px != 1: 50 | modules_body.append(nn.PixelShuffle(px)) 51 | self.body = nn.Sequential(*modules_body) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x) 56 | #res = self.body(x).mul(self.res_scale) 57 | res += x 58 | return res 59 | ## Residual Group (RG) 60 | class ResidualGroup(nn.Module): 61 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks, n_resgroups, px): 62 | super(ResidualGroup, self).__init__() 63 | modules_body = [] 64 | if len(n_resgroups) ==0: 65 | modules_body = [ 66 | RCAB( 67 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, px=px) \ 68 | for _ in range(n_resblocks)] 69 | else: 70 | modules_body = [ 71 | ResidualGroup(conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks, n_resgroups[1:], px=px)\ 72 | for _ in range(n_resgroups[0])] 73 | 74 | 75 | 76 | 77 | 78 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 79 | self.body = nn.Sequential(*modules_body) 80 | 81 | def forward(self, x): 82 | res = self.body(x) 83 | res += x 84 | return res 85 | 86 | ## Residual Channel Attention Network (RCAN) 87 | class FRN_UPDOWN(nn.Module): 88 | def __init__(self, args, conv=common.default_conv): 89 | super(FRN_UPDOWN, self).__init__() 90 | 91 | n_resgroups = args.n_resgroups 92 | # n_resgroups_ = args.n_resgroups2 93 | # n_resgroups__ = args.n_resgroups3 94 | 95 | 96 | n_resblocks = args.n_resblocks 97 | n_feats = args.n_feats 98 | kernel_size = 3 99 | reduction = args.reduction 100 | scale = args.scale[0] 101 | act = nn.ReLU(True) 102 | 103 | # RGB mean for DIV2K 104 | rgb_mean = (0.4488, 0.4371, 0.4040) 105 | rgb_std = (1.0, 1.0, 1.0) 106 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 107 | 108 | # define head module 109 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 110 | 111 | # define body module 112 | modules_body = [ 113 | ResidualGroup( 114 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks, n_resgroups=n_resgroups[1:], px=args.px) \ 115 | for _ in range(n_resgroups[0])] 116 | 117 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 118 | # define tail module 119 | modules_tail = [ 120 | common.Upsampler(conv, scale, n_feats, act=False), 121 | conv(n_feats, args.n_colors, kernel_size)] 122 | 123 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 124 | m_down = [ 125 | common.invUpsampler(conv, scale, n_feats, act=False), 126 | ] 127 | 128 | 129 | self.down = nn.Sequential(*m_down) 130 | 131 | # self.head = nn.Sequential(*modules_head) 132 | self.new_head = nn.Sequential(*modules_head) 133 | self.body = nn.Sequential(*modules_body) 134 | # self.tail = nn.Sequential(*modules_tail) 135 | self.new_tail = nn.Sequential(*modules_tail) 136 | 137 | for m in self.body.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | m.weight.data.normal_(0,0.01) 140 | m.bias.data.fill_(0) 141 | 142 | def forward(self, x): 143 | x = self.sub_mean(x) 144 | x = self.new_head(x) 145 | x = self.down(x) 146 | 147 | res = self.body(x) 148 | res += x 149 | 150 | 151 | x = self.new_tail(res) 152 | x = self.add_mean(x) 153 | 154 | return x 155 | 156 | def forward_(self, x): 157 | 158 | x = self.sub_mean(x) 159 | x = self.new_head(x) 160 | x = self.down(x) 161 | 162 | 163 | x = self.new_tail(x) 164 | x = self.add_mean(x) 165 | 166 | return x 167 | def load_state_dict(self, state_dict, strict=False): 168 | own_state = self.state_dict() 169 | for name, param in state_dict.items(): 170 | if name in own_state: 171 | if isinstance(param, nn.Parameter): 172 | param = param.data 173 | try: 174 | own_state[name].copy_(param) 175 | except Exception: 176 | if name.find('tail') >= 0: 177 | print('Replace pre-trained upsampler to new one...') 178 | else: 179 | raise RuntimeError('While copying the parameter named {}, ' 180 | 'whose dimensions in the model are {} and ' 181 | 'whose dimensions in the checkpoint are {}.' 182 | .format(name, own_state[name].size(), param.size())) 183 | elif strict: 184 | if name.find('tail') == -1: 185 | raise KeyError('unexpected key "{}" in state_dict' 186 | .format(name)) 187 | 188 | if strict: 189 | missing = set(own_state.keys()) - set(state_dict.keys()) 190 | if len(missing) > 0: 191 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 192 | -------------------------------------------------------------------------------- /code/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=3, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', action='store_true', 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=1, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | 21 | # Data specifications 22 | parser.add_argument('--dir_data', type=str, default='/home/junhk/dataset', 23 | help='dataset directory') 24 | parser.add_argument('--dir_demo', type=str, default='../test', 25 | help='demo image directory') 26 | parser.add_argument('--data_train', type=str, default='NTIRE', 27 | help='train dataset name') 28 | parser.add_argument('--data_test', type=str, default='NTIRE', # or NTIRE_VAL 29 | help='test dataset name') 30 | parser.add_argument('--benchmark_noise', action='store_true', 31 | help='use noisy benchmark sets') 32 | parser.add_argument('--n_train', type=int, default=30, 33 | help='number of training set') 34 | parser.add_argument('--n_val', type=int, default=10, 35 | help='number of validation set') 36 | parser.add_argument('--offset_val', type=int, default=30, 37 | help='validation index offest') 38 | parser.add_argument('--ext', type=str, default='img', 39 | help='dataset file extension') 40 | parser.add_argument('--scale', default='4', 41 | help='super resolution scale') 42 | parser.add_argument('--px', type=int, default=1, 43 | help='pixshuff downup scale (in RCAN-PS)') 44 | parser.add_argument('--patch_size', type=int, default=192, 45 | help='output patch size') 46 | parser.add_argument('--rgb_range', type=int, default=255, 47 | help='maximum value of RGB') 48 | parser.add_argument('--n_colors', type=int, default=3, 49 | help='number of color channels to use') 50 | parser.add_argument('--noise', type=str, default='.', 51 | help='Gaussian noise std.') 52 | parser.add_argument('--chop', action='store_true', 53 | help='enable memory-efficient forward') 54 | 55 | # Model specifications 56 | parser.add_argument('--model', default='FRN_UPDOWN', 57 | help='model name') 58 | 59 | parser.add_argument('--act', type=str, default='relu', 60 | help='activation function') 61 | parser.add_argument('--pre_train', type=str, default='.', 62 | help='pre-trained model directory') 63 | parser.add_argument('--extend', type=str, default='.', 64 | help='pre-trained model directory') 65 | parser.add_argument('--n_resblocks', type=int, default=8, 66 | help='number of innermost blocks of FRN (i.e. RCAN-PS)') 67 | parser.add_argument('--n_feats', type=int, default=64, 68 | help='number of feature maps') 69 | parser.add_argument('--res_scale', type=float, default=1, 70 | help='residual scaling') 71 | parser.add_argument('--shift_mean', default=True, 72 | help='subtract pixel mean from the input') 73 | parser.add_argument('--precision', type=str, default='single', 74 | choices=('single', 'half'), 75 | help='FP precision for test (single | half)') 76 | 77 | # Training specifications 78 | parser.add_argument('--reset', action='store_true', 79 | help='reset the training') 80 | parser.add_argument('--test_every', type=int, default=200, 81 | help='do test per every N batches') 82 | parser.add_argument('--epochs', type=int, default=1000, 83 | help='number of epochs to train') 84 | parser.add_argument('--batch_size', type=int, default=8, 85 | help='input batch size for training') 86 | parser.add_argument('--split_batch', type=int, default=1, 87 | help='split the batch into smaller chunks') 88 | parser.add_argument('--self_ensemble', action='store_true', 89 | help='use self-ensemble method for test') 90 | parser.add_argument('--test_only', action='store_true', 91 | help='set this option to test the model') 92 | parser.add_argument('--gan_k', type=int, default=1, 93 | help='k value for adversarial loss') 94 | 95 | # Optimization specifications 96 | parser.add_argument('--lr', type=float, default=1e-4, 97 | help='learning rate') 98 | parser.add_argument('--lr_decay', type=int, default=400, 99 | help='learning rate decay per N epochs') 100 | parser.add_argument('--decay_type', type=str, default='step', 101 | help='learning rate decay type') 102 | parser.add_argument('--gamma', type=float, default=0.5, 103 | help='learning rate decay factor for step decay') 104 | parser.add_argument('--optimizer', default='ADAM', 105 | choices=('SGD', 'ADAM', 'RMSprop'), 106 | help='optimizer to use (SGD | ADAM | RMSprop)') 107 | parser.add_argument('--momentum', type=float, default=0.9, 108 | help='SGD momentum') 109 | parser.add_argument('--beta1', type=float, default=0.9, 110 | help='ADAM beta1') 111 | parser.add_argument('--beta2', type=float, default=0.999, 112 | help='ADAM beta2') 113 | parser.add_argument('--epsilon', type=float, default=1e-8, 114 | help='ADAM epsilon for numerical stability') 115 | parser.add_argument('--weight_decay', type=float, default=0, 116 | help='weight decay') 117 | 118 | # Loss specifications 119 | parser.add_argument('--loss', type=str, default='1*L1', 120 | help='loss function configuration') 121 | parser.add_argument('--skip_threshold', type=float, default='1e6', 122 | help='skipping batch that has large error') 123 | 124 | # Log specifications 125 | parser.add_argument('--save', type=str, default='test', 126 | help='file name to save') 127 | parser.add_argument('--load', type=str, default='.', 128 | help='file name to load') 129 | parser.add_argument('--resume', type=int, default=0, 130 | help='resume from specific checkpoint') 131 | parser.add_argument('--print_model', action='store_true', 132 | help='print model') 133 | parser.add_argument('--save_models', action='store_true', 134 | help='save all intermediate models') 135 | parser.add_argument('--print_every', type=int, default=100, 136 | help='how many batches to wait before logging training status') 137 | parser.add_argument('--save_results', action='store_true', 138 | help='save output results') 139 | 140 | # options for residual group and feature channel reduction 141 | parser.add_argument('--n_resgroups', type=list, nargs='+', default=[2, 4, 8], 142 | help='number of residual groups in fractal residual groups') 143 | parser.add_argument('--reduction', type=int, default=16, 144 | help='number of feature maps reduction') 145 | # options for test 146 | parser.add_argument('--testpath', type=str, default='../test/DIV2K_val_LR_our', 147 | help='dataset directory for testing') 148 | parser.add_argument('--testset', type=str, default='Set5', 149 | help='dataset name for testing') 150 | 151 | args = parser.parse_args() 152 | template.set_template(args) 153 | 154 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 155 | 156 | if args.epochs == 0: 157 | args.epochs = 1e8 158 | 159 | for arg in vars(args): 160 | if vars(args)[arg] == 'True': 161 | vars(args)[arg] = True 162 | elif vars(args)[arg] == 'False': 163 | vars(args)[arg] = False 164 | 165 | --------------------------------------------------------------------------------