├── .gitignore ├── README.md ├── images ├── attention.png └── flow.png ├── mkdata.py └── src ├── __init__.py ├── data ├── __init__.py ├── benchmark.py ├── common.py ├── demo.py ├── div2k.py ├── div2kjpeg.py ├── sr291.py ├── srdata.py └── video.py ├── dataloader.py ├── demo.sh ├── loss ├── __init__.py ├── adversarial.py ├── discriminator.py ├── pytorch_ssim │ └── __init__.py └── vgg.py ├── main.py ├── model ├── __init__.py ├── common.py ├── ddbpn.py ├── denseskip.py ├── drrn.py ├── edsr.py ├── lapsrn.py ├── mdsr.py ├── rcan.py ├── rdn.py ├── srcnn.py ├── srresnet.py └── vdsr.py ├── option.py ├── template.py ├── trainer.py └── utility.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pixel-level Self-Paced Learning for Super-Resolution 2 | 3 | This is an official implementaion of the paper **Pixel-level Self-Paced Learning for Super-Resolution**, which has been accepted by ICASSP 2020. 4 | 5 | [[arxiv](https://arxiv.org/abs/2003.03113)][[PDF](https://arxiv.org/pdf/2003.03113)] 6 | 7 | trained model files: [Baidu Pan](https://pan.baidu.com/s/1ZDqJbn0kxAqEmkvMSUMD9g)(code: v0be) 8 | 9 | ## Requirements 10 | 11 | This code is forked from [thstkdgus35/EDSR-PyTorch](https://github.com/thstkdgus35/EDSR-PyTorch). In the light of its README, following libraries are required: 12 | 13 | - Python 3.6+ (Python 3.7.0 in my experiments) 14 | - PyTorch >= 1.0.0 (1.1.0 in my experiments) 15 | - numpy 16 | - skimage 17 | - imageio 18 | - matplotlib 19 | - tqdm 20 | 21 | ## Core Parts 22 | 23 | ![pspl framework](images/flow.png) 24 | 25 | Detail code can be found in [Loss.forward](https://github.com/Elin24/PSPL/blob/2deb17d4bcf7db17463238e143ca94e438e51e2a/src/loss/__init__.py#L60), which can be simplified as: 26 | 27 | ```python 28 | # take L1 Loss as example 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | from . import pytorch_ssim 34 | 35 | class Loss(nn.modules.loss._Loss): 36 | def __init__(self, spl_alpha, spl_beta, spl_maxVal): 37 | super(Loss, self).__init__() 38 | self.loss = nn.L1Loss() 39 | self.alpha = spl_alpha 40 | self.beta = spl_beta 41 | self.maxVal = spl_maxVal 42 | 43 | def forward(self, sr, hr, step): 44 | # calc sigma value 45 | sigma = self.alpha * step + self.beta 46 | # define gauss function 47 | gauss = lambda x: torch.exp(-((x+1) / sigma) ** 2) * self.maxVal 48 | # ssim value 49 | ssim = pytorch_ssim.ssim(hr, sr, reduction='none').detach() 50 | # calc attention weight 51 | weight = gauss(ssim).detach() 52 | nsr, nhr = sr * weight, hr * weight 53 | # calc loss 54 | lossval = self.loss(nsr, nhr) 55 | return lossval 56 | ``` 57 | 58 | the library pytorch_ssim is focked from [Po-Hsun-Su/pytorch-ssim](https://github.com/Po-Hsun-Su/pytorch-ssim) and rewrite some details for adopting it to our requirements. 59 | 60 | Attention weight values change according to *SSIM Index* and *steps*: 61 | ![attention values](images/attention.png) 62 | 63 | ## Citation 64 | 65 | If you find this project useful for your research, please cite: 66 | 67 | ```bibtex 68 | @inproceedings{lin2020pixel, 69 | title={Pixel-Level Self-Paced Learning For Super-Resolution} 70 | author={Lin, Wei and Gao, Junyu and Wang, Qi and Li, Xuelong}, 71 | booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 72 | year={2020}, 73 | pages={2538-2542} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /images/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Elin24/PSPL/91a07ed53902a1f9033ae7babcbd62ca252ab47b/images/attention.png -------------------------------------------------------------------------------- /images/flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Elin24/PSPL/91a07ed53902a1f9033ae7babcbd62ca252ab47b/images/flow.png -------------------------------------------------------------------------------- /mkdata.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import glob 4 | import os 5 | import PIL 6 | from PIL import Image 7 | from shutil import copyfile 8 | 9 | # Image.resize(size, PIL.Image.BICUBIC) 10 | 11 | def mkdir(dir): 12 | if not os.path.exists(dir): 13 | os.makedirs(dir) 14 | return dir 15 | 16 | def mkdata(root, scale): 17 | hrdir = mkdir(os.path.join(f'datasetx{scale}', 'benchmark', root, 'HR')) 18 | lrdir = mkdir(os.path.join(f'datasetx{scale}', 'benchmark', root, 'LR_bicubic', f'X{scale}')) 19 | for imgp in os.listdir(root): 20 | hrimgpath = os.path.join(hrdir, imgp) 21 | spimgp = os.path.splitext(imgp) 22 | lrimgpath = os.path.join(lrdir, f'{spimgp[0]}x{scale}{spimgp[1]}') 23 | 24 | imgp = os.path.join(root, imgp) 25 | hrimg = Image.open(imgp) 26 | w, h = hrimg.size 27 | nw = w if w % scale == 0 else (w - w % scale) 28 | nh = h if h % scale == 0 else (h - h % scale) 29 | if nw == w and nh == h: 30 | copyfile(imgp, hrimgpath) 31 | else: 32 | hrimg = hrimg.resize((nw, nh), PIL.Image.BICUBIC) 33 | hrimg.save(hrimgpath) 34 | nw, nh = nw // scale, nh // scale 35 | lrimg = hrimg.resize((nw, nh), PIL.Image.BICUBIC) 36 | lrimg.save(lrimgpath) 37 | print(f'{root} - {scale} : Ok.') 38 | 39 | 40 | if __name__ == '__main__': 41 | for root in ['Set5', 'Set14', 'B100', 'Urban100']: 42 | for scale in [2,3,4]: 43 | mkdata(root, scale) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Elin24/PSPL/91a07ed53902a1f9033ae7babcbd62ca252ab47b/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /src/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /src/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = scale if multi else 1 13 | tp = p * patch_size 14 | ip = tp // scale 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = scale * ix, scale * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | #print(img.shape, type(img)) 42 | #img = img[] 43 | #img = img[:,:,2, np.newaxis] 44 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 45 | elif n_channels == 3 and c == 1: 46 | img = np.concatenate([img] * n_channels, 2) 47 | 48 | return img 49 | 50 | return [_set_channel(a) for a in args] 51 | 52 | def np2Tensor(*args, rgb_range=255): 53 | def _np2Tensor(img): 54 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 55 | tensor = torch.from_numpy(np_transpose).float() 56 | tensor.mul_(rgb_range / 255) 57 | 58 | return tensor 59 | 60 | return [_np2Tensor(a) for a in args] 61 | 62 | def augment(*args, hflip=True, rot=True): 63 | hflip = hflip and random.random() < 0.5 64 | vflip = rot and random.random() < 0.5 65 | rot90 = rot and random.random() < 0.5 66 | 67 | def _augment(img): 68 | if hflip: img = img[:, ::-1, :] 69 | if vflip: img = img[::-1, :, :] 70 | if rot90: img = img.transpose(1, 0, 2) 71 | 72 | return img 73 | 74 | return [_augment(a) for a in args] 75 | 76 | -------------------------------------------------------------------------------- /src/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /src/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /src/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /src/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /src/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model in ['VDSR', 'DRRN', 'SRCNN']) 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('img') >= 0 or benchmark: 32 | self.images_hr, self.images_lr = list_hr, list_lr 33 | elif args.ext.find('sep') >= 0: 34 | os.makedirs( 35 | self.dir_hr.replace(self.apath, path_bin), 36 | exist_ok=True 37 | ) 38 | for s in self.scale: 39 | os.makedirs( 40 | os.path.join( 41 | self.dir_lr.replace(self.apath, path_bin), 42 | 'X{}'.format(s) 43 | ), 44 | exist_ok=True 45 | ) 46 | 47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 48 | for h in list_hr: 49 | b = h.replace(self.apath, path_bin) 50 | b = b.replace(self.ext[0], '.pt') 51 | self.images_hr.append(b) 52 | self._check_and_load(args.ext, h, b, verbose=True) 53 | for i, ll in enumerate(list_lr): 54 | for l in ll: 55 | b = l.replace(self.apath, path_bin) 56 | b = b.replace(self.ext[1], '.pt') 57 | self.images_lr[i].append(b) 58 | self._check_and_load(args.ext, l, b, verbose=True) 59 | if train: 60 | n_patches = args.batch_size * args.test_every 61 | n_images = len(args.data_train) * len(self.images_hr) 62 | if n_images == 0: 63 | self.repeat = 0 64 | else: 65 | self.repeat = max(n_patches // n_images, 1) 66 | 67 | # Below functions as used to prepare images 68 | def _scan(self): 69 | names_hr = sorted( 70 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 71 | ) 72 | names_lr = [[] for _ in self.scale] 73 | for f in names_hr: 74 | filename, _ = os.path.splitext(os.path.basename(f)) 75 | for si, s in enumerate(self.scale): 76 | names_lr[si].append(os.path.join( 77 | self.dir_lr, 'X{}/{}x{}{}'.format( 78 | s, filename, s, self.ext[1] 79 | ) 80 | )) 81 | 82 | return names_hr, names_lr 83 | 84 | def _set_filesystem(self, dir_data): 85 | self.apath = os.path.join(dir_data, self.name) 86 | self.dir_hr = os.path.join(self.apath, 'HR') 87 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 88 | if self.input_large: self.dir_lr += 'L' 89 | self.ext = ('.png', '.png') 90 | 91 | def _check_and_load(self, ext, img, f, verbose=True): 92 | if not os.path.isfile(f) or ext.find('reset') >= 0: 93 | if verbose: 94 | print('Making a binary: {}'.format(f)) 95 | with open(f, 'wb') as _f: 96 | pickle.dump(imageio.imread(img), _f) 97 | 98 | def __getitem__(self, idx): 99 | lr, hr, filename = self._load_file(idx) 100 | pair = self.get_patch(lr, hr) 101 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 102 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 103 | 104 | return pair_t[0], pair_t[1], filename 105 | 106 | def __len__(self): 107 | if self.train: 108 | return len(self.images_hr) * self.repeat 109 | else: 110 | return len(self.images_hr) 111 | 112 | def _get_index(self, idx): 113 | if self.train: 114 | return idx % len(self.images_hr) 115 | else: 116 | return idx 117 | 118 | def _load_file(self, idx): 119 | idx = self._get_index(idx) 120 | f_hr = self.images_hr[idx] 121 | f_lr = self.images_lr[self.idx_scale][idx] 122 | 123 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 124 | if self.args.ext == 'img' or self.benchmark: 125 | hr = imageio.imread(f_hr) 126 | lr = imageio.imread(f_lr) 127 | elif self.args.ext.find('sep') >= 0: 128 | with open(f_hr, 'rb') as _f: 129 | hr = pickle.load(_f) 130 | with open(f_lr, 'rb') as _f: 131 | lr = pickle.load(_f) 132 | 133 | return lr, hr, filename 134 | 135 | def get_patch(self, lr, hr): 136 | scale = self.scale[self.idx_scale] 137 | if self.train: 138 | lr, hr = common.get_patch( 139 | lr, hr, 140 | patch_size=self.args.patch_size, 141 | scale=scale, 142 | multi=(len(self.scale) > 1), 143 | input_large=self.input_large 144 | ) 145 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 146 | else: 147 | ih, iw = lr.shape[:2] 148 | hr = hr[0:ih * scale, 0:iw * scale] 149 | 150 | return lr, hr 151 | 152 | def set_scale(self, idx_scale): 153 | if not self.input_large: 154 | self.idx_scale = idx_scale 155 | else: 156 | self.idx_scale = random.randint(0, len(self.scale) - 1) 157 | 158 | -------------------------------------------------------------------------------- /src/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import random 3 | 4 | import torch 5 | import torch.multiprocessing as multiprocessing 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SequentialSampler 8 | from torch.utils.data import RandomSampler 9 | from torch.utils.data import BatchSampler 10 | from torch.utils.data import _utils 11 | from torch.utils.data.dataloader import _DataLoaderIter 12 | 13 | from torch.utils.data._utils import collate 14 | from torch.utils.data._utils import signal_handling 15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 16 | from torch.utils.data._utils import ExceptionWrapper 17 | from torch.utils.data._utils import IS_WINDOWS 18 | from torch.utils.data._utils.worker import ManagerWatchdog 19 | 20 | from torch._six import queue 21 | 22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 23 | try: 24 | collate._use_shared_memory = True 25 | signal_handling._set_worker_signal_handlers() 26 | 27 | torch.set_num_threads(1) 28 | random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | data_queue.cancel_join_thread() 32 | 33 | if init_fn is not None: 34 | init_fn(worker_id) 35 | 36 | watchdog = ManagerWatchdog() 37 | 38 | while watchdog.is_alive(): 39 | try: 40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 41 | except queue.Empty: 42 | continue 43 | 44 | if r is None: 45 | assert done_event.is_set() 46 | return 47 | elif done_event.is_set(): 48 | continue 49 | 50 | idx, batch_indices = r 51 | try: 52 | idx_scale = 0 53 | if len(scale) > 1 and dataset.train: 54 | idx_scale = random.randrange(0, len(scale)) 55 | dataset.set_scale(idx_scale) 56 | 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | samples.append(idx_scale) 59 | except Exception: 60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 61 | else: 62 | data_queue.put((idx, samples)) 63 | del samples 64 | 65 | except KeyboardInterrupt: 66 | pass 67 | 68 | class _MSDataLoaderIter(_DataLoaderIter): 69 | 70 | def __init__(self, loader): 71 | self.dataset = loader.dataset 72 | self.scale = loader.scale 73 | self.collate_fn = loader.collate_fn 74 | self.batch_sampler = loader.batch_sampler 75 | self.num_workers = loader.num_workers 76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 77 | self.timeout = loader.timeout 78 | 79 | self.sample_iter = iter(self.batch_sampler) 80 | 81 | base_seed = torch.LongTensor(1).random_().item() 82 | 83 | if self.num_workers > 0: 84 | self.worker_init_fn = loader.worker_init_fn 85 | self.worker_queue_idx = 0 86 | self.worker_result_queue = multiprocessing.Queue() 87 | self.batches_outstanding = 0 88 | self.worker_pids_set = False 89 | self.shutdown = False 90 | self.send_idx = 0 91 | self.rcvd_idx = 0 92 | self.reorder_dict = {} 93 | self.done_event = multiprocessing.Event() 94 | 95 | base_seed = torch.LongTensor(1).random_()[0] 96 | 97 | self.index_queues = [] 98 | self.workers = [] 99 | for i in range(self.num_workers): 100 | index_queue = multiprocessing.Queue() 101 | index_queue.cancel_join_thread() 102 | w = multiprocessing.Process( 103 | target=_ms_loop, 104 | args=( 105 | self.dataset, 106 | index_queue, 107 | self.worker_result_queue, 108 | self.done_event, 109 | self.collate_fn, 110 | self.scale, 111 | base_seed + i, 112 | self.worker_init_fn, 113 | i 114 | ) 115 | ) 116 | w.daemon = True 117 | w.start() 118 | self.index_queues.append(index_queue) 119 | self.workers.append(w) 120 | 121 | if self.pin_memory: 122 | self.data_queue = queue.Queue() 123 | pin_memory_thread = threading.Thread( 124 | target=_utils.pin_memory._pin_memory_loop, 125 | args=( 126 | self.worker_result_queue, 127 | self.data_queue, 128 | torch.cuda.current_device(), 129 | self.done_event 130 | ) 131 | ) 132 | pin_memory_thread.daemon = True 133 | pin_memory_thread.start() 134 | self.pin_memory_thread = pin_memory_thread 135 | else: 136 | self.data_queue = self.worker_result_queue 137 | 138 | _utils.signal_handling._set_worker_pids( 139 | id(self), tuple(w.pid for w in self.workers) 140 | ) 141 | _utils.signal_handling._set_SIGCHLD_handler() 142 | self.worker_pids_set = True 143 | 144 | for _ in range(2 * self.num_workers): 145 | self._put_indices() 146 | 147 | 148 | class MSDataLoader(DataLoader): 149 | 150 | def __init__(self, cfg, *args, **kwargs): 151 | super(MSDataLoader, self).__init__( 152 | *args, **kwargs, num_workers=cfg.n_threads 153 | ) 154 | self.scale = cfg.scale 155 | 156 | def __iter__(self): 157 | return _MSDataLoaderIter(self) 158 | 159 | -------------------------------------------------------------------------------- /src/demo.sh: -------------------------------------------------------------------------------- 1 | # EDSR baseline model (x2) + JPEG augmentation 2 | #python3 main.py --model EDSR --scale 4 --save edsr_x4 --reset --data_test Set5+Set14+B100+Urban100+DIV2K --n_GPUs 1 --epochs 300 --dir_data ../../datasetx4 --reset 3 | #python3 main.py --model EDSR --scale 4 --save edsr_x4_spl --reset --data_test Set5+Set14+B100+Urban100+DIV2K --n_GPUs 1 --epochs 300 --dir_data ../../datasetx4 --reset 4 | # test 5 | #python main.py --model EDSR --scale 2 --test_only --dir_data ../../dataset/testx2 --n_GPUs 1 --data_test Set5+Set14+B100+Urban100 --pre_train ../experiment/edsr_baseline_x2_L1/model/model_best.pt 6 | 7 | #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 8 | 9 | # EDSR baseline model (x3) - from EDSR baseline model (x2) 10 | #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] 11 | 12 | # EDSR baseline model (x4) - from EDSR baseline model (x2) 13 | #python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] 14 | 15 | # EDSR in the paper (x2) 16 | #python3 main.py --template EDSR_paper --scale 2 --save edsr_x2_spl_1011 --n_GPUs 1 --patch_size 96 --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 --resume -1 17 | #python3 main.py --template EDSR_paper --save edsr_x4_spl_test --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --test_only --pre_train ../experiment-train/edsr_x4_spl_1013/model/model_best.pt --reset 18 | #python3 main.py --template EDSR_paper --save edsr_x4_spl_test --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --test_only --pre_train ../experiment-train/edsr_x4_spl_1013/model/model_latest.pt --reset 19 | # EDSR in the paper (x3) - from EDSR (x2) 20 | #python3 main.py --template EDSR_paper --scale 3 --save edsr_spl_x3_1012 --reset --n_GPUs 1 --patch_size 144 --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx3 --pre_train /media/E/linwei/SISR/EDSR-PyTorch-SPL/experiment/edsr_x2_spl_1011/model/model_best.pt 21 | #python main.py --model EDSR --scale 3 --save edsr_x3_spl_1012 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir] 22 | 23 | # EDSR in the paper (x4) - from EDSR (x2) 24 | #python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] 25 | #python3 main.py --template EDSR_paper --scale 4 --save edsr_x4_spl_1013 --n_GPUs 1 --patch_size 192 --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --resume -1 26 | #python3 main.py --template SRRESNET --data_test B100 --scale 4 --pre_train ../experiment-train/SRRESNETx4_SPL_1011/model/model_latest.pt --test_only --save_results --dir_data ../../datasetx4 --n_GPUs 1 --save test_SRResNet_spl_b100 27 | # MDSR baseline model 28 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models 29 | 30 | # MDSR in the paper 31 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models 32 | 33 | # Standard benchmarks (Ex. EDSR_baseline_x4) 34 | #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble 35 | 36 | #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble 37 | 38 | # Test your own images 39 | #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results 40 | 41 | # Advanced - Test with JPEG images 42 | #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results 43 | 44 | # Advanced - Training with adversarial loss 45 | #python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download 46 | 47 | # RDN BI model (x2) 48 | #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset 49 | # RDN BI model (x3) 50 | #python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset 51 | # RDN BI model (x4) 52 | #python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset 53 | 54 | # RCAN_BIX2_G10R20P48, input=48x48, output=96x96 55 | # pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0 56 | #python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96 57 | # RCAN_BIX3_G10R20P48, input=48x48, output=144x144 58 | #python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt 59 | # RCAN_BIX4_G10R20P48, input=48x48, output=192x192 60 | #python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt 61 | # RCAN_BIX8_G10R20P48, input=48x48, output=384x384 62 | #python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt 63 | 64 | #VDSR 65 | #python3 main.py --template VDSR --save VDSR_x2_spl_1009 --scale 2 --reset --patch_size 96 --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 --n_GPUs 1 66 | #python3 main.py --template VDSR --save VDSR_x3_spl_1009 --scale 3 --reset --patch_size 144 --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx3 --n_GPUs 1 67 | #python3 main.py --template VDSR --save VDSR_x4_spl_1009 --scale 4 --reset --patch_size 192 --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 68 | 69 | #python3 main.py --template VDSR --save vdsr_x2_spl_test --scale 2 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 --n_GPUs 1 --test_only --pre_train ../experiment-train/VDSR_x2_spl_1009/model/model_best.pt --reset 70 | #python3 main.py --template VDSR --save vdsr_x3_spl_test --scale 3 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx3 --n_GPUs 1 --test_only --pre_train ../experiment-train/VDSR_x3_spl_1009/model/model_best.pt --reset 71 | #python3 main.py --template VDSR --save vdsr_x4_spl_test --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --test_only --pre_train ../experiment-train/VDSR_x4_spl_1009/model/model_latest.pt --reset 72 | 73 | 74 | # LapSRN 75 | #python3 main.py --template LapSRN --save LapSRN_x2_spl_1010 --scale 2 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 --n_GPUs 1 --patch_size 128 76 | #python3 main.py --template LapSRN --save LapSRN_x3_spl_1010 --scale 3 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx3 --n_GPUs 1 --patch_size 127 77 | #python3 main.py --template LapSRN --save LapSRN_x4_spl_1010 --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --patch_size 128 78 | 79 | #python3 main.py --template LapSRN --save lapsrn_x2_spl_test --scale 2 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 --n_GPUs 1 --test_only --pre_train ../experiment-train/LapSRN_x2_spl_1010/model/model_best.pt --reset 80 | #python3 main.py --template LapSRN --save lapsrn_x4_spl_test --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --test_only --pre_train ../experiment-train/LapSRN_x4_spl_1010/model/model_best.pt --reset 81 | 82 | # DRRN 83 | #python3 main.py --template DRRN --save DRRN_x2_spl_1010 --scale 2 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 --n_GPUs 2 84 | #python3 main.py --template DRRN --save DRRN_x3_spl_1010 --scale 3 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx3 --n_GPUs 2 85 | #python3 main.py --template DRRN --save DRRN_x4_spl_1010 --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 2 86 | 87 | #python3 main.py --template DRRN --save drrn_x4_spl_test --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --test_only --pre_train ../experiment-train/DRRN_x4_spl_1010/model/model_best.pt --reset 88 | 89 | # SRCNN 90 | #python3 main.py --template SRCNN --save SRCNN_x2_spl_1010 --scale 2 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 --n_GPUs 1 91 | #python3 main.py --template SRCNN --save SRCNN_x3_spl_1010 --scale 3 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx3 --n_GPUs 1 92 | #python3 main.py --template SRCNN --save SRCNN_x4_spl_1019 --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 93 | 94 | #python3 main.py --template SRCNN --save SRCNN_x4_spl_test --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --test_only --pre_train ../experiment-train/SRCNN_x4_spl_1010/model/model_best.pt --reset 95 | 96 | # SRResNet 97 | #python3 main.py --template SRRESNET --scale 4 --save SRRESNETx4_1012 --reset --save_models --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 98 | 99 | #python3 main.py --template SRRESNET --save SRRESNET_x4_baseline_test --scale 4 --reset --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx4 --n_GPUs 1 --test_only --pre_train ../experiment-train/SRRESNETx4_1012/model/model_best.pt --reset 100 | 101 | # MDSR 102 | # python3 main.py --template MDSR --scale 2+4 --n_resblocks 80 --save MDSR_spl_1012 --reset --save_models --n_GPUs 1 --data_test Set5+Set14+B100+Urban100 --dir_data ../../datasetx2 103 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from . import pytorch_ssim 14 | 15 | class Loss(nn.modules.loss._Loss): 16 | def __init__(self, args, ckp): 17 | super(Loss, self).__init__() 18 | print('Preparing loss function:') 19 | 20 | self.n_GPUs = args.n_GPUs 21 | self.loss = [] 22 | self.loss_module = nn.ModuleList() 23 | for loss in args.loss.split('+'): 24 | weight, loss_type = loss.split('*') 25 | if loss_type == 'MSE': 26 | loss_function = nn.MSELoss()#(reduction='none') 27 | elif loss_type == 'L1': 28 | loss_function = nn.L1Loss()#(reduction='none') 29 | 30 | self.loss.append({ 31 | 'type': loss_type, 32 | 'weight': float(weight), 33 | 'function': loss_function} 34 | ) 35 | 36 | if len(self.loss) > 1: 37 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 38 | 39 | for l in self.loss: 40 | if l['function'] is not None: 41 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 42 | self.loss_module.append(l['function']) 43 | 44 | self.log = torch.Tensor() 45 | 46 | device = torch.device('cpu' if args.cpu else 'cuda') 47 | self.loss_module.to(device) 48 | if args.precision == 'half': self.loss_module.half() 49 | if not args.cpu and args.n_GPUs > 1: 50 | self.loss_module = nn.DataParallel( 51 | self.loss_module, range(args.n_GPUs) 52 | ) 53 | 54 | if args.load != '': self.load(ckp.dir, cpu=args.cpu) 55 | 56 | self.alpha = args.splalpha 57 | self.beta = args.splbeta 58 | self.maxVal = args.splval 59 | 60 | def forward(self, sr, hr, step=1000): 61 | #print(sr.shape, hr.shape) 62 | sigma = self.alpha * step + self.beta 63 | gauss = lambda x: torch.exp(-((x+1) / sigma) ** 2) * self.maxVal 64 | ssim = pytorch_ssim.ssim(hr, sr, reduction='none').detach() 65 | weight = gauss(ssim).detach() 66 | sr, hr = sr * weight, hr * weight 67 | 68 | losses = [] 69 | for i, l in enumerate(self.loss): 70 | if l['function'] is not None: 71 | loss = l['function'](sr, hr) 72 | effective_loss = l['weight'] * loss 73 | losses.append(effective_loss) 74 | self.log[-1, i] += effective_loss.item() 75 | elif l['type'] == 'DIS': 76 | self.log[-1, i] += self.loss[i - 1]['function'].loss 77 | 78 | loss_sum = sum(losses) 79 | if len(self.loss) > 1: 80 | self.log[-1, -1] += loss_sum.item() 81 | 82 | return loss_sum 83 | 84 | def step(self): 85 | for l in self.get_loss_module(): 86 | if hasattr(l, 'scheduler'): 87 | l.scheduler.step() 88 | 89 | def start_log(self): 90 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 91 | 92 | def end_log(self, n_batches): 93 | self.log[-1].div_(n_batches) 94 | 95 | def display_loss(self, batch): 96 | n_samples = batch + 1 97 | log = [] 98 | for l, c in zip(self.loss, self.log[-1]): 99 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 100 | 101 | return ''.join(log) 102 | 103 | def plot_loss(self, apath, epoch): 104 | axis = np.linspace(1, epoch, epoch) 105 | for i, l in enumerate(self.loss): 106 | label = '{} Loss'.format(l['type']) 107 | fig = plt.figure() 108 | plt.title(label) 109 | plt.plot(axis, self.log[:, i].numpy(), label=label) 110 | plt.legend() 111 | plt.xlabel('Epochs') 112 | plt.ylabel('Loss') 113 | plt.grid(True) 114 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 115 | plt.close(fig) 116 | 117 | def get_loss_module(self): 118 | if self.n_GPUs == 1: 119 | return self.loss_module 120 | else: 121 | return self.loss_module.module 122 | 123 | def save(self, apath): 124 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 125 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 126 | 127 | def load(self, apath, cpu=False): 128 | if cpu: 129 | kwargs = {'map_location': lambda storage, loc: storage} 130 | else: 131 | kwargs = {} 132 | 133 | self.load_state_dict(torch.load( 134 | os.path.join(apath, 'loss.pt'), 135 | **kwargs 136 | )) 137 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 138 | for l in self.get_loss_module(): 139 | if hasattr(l, 'scheduler'): 140 | for _ in range(len(self.log)): l.scheduler.step() 141 | 142 | -------------------------------------------------------------------------------- /src/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from types import SimpleNamespace 3 | 4 | from model import common 5 | from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | class Adversarial(nn.Module): 13 | def __init__(self, args, gan_type): 14 | super(Adversarial, self).__init__() 15 | self.gan_type = gan_type 16 | self.gan_k = args.gan_k 17 | self.dis = discriminator.Discriminator(args) 18 | if gan_type == 'WGAN_GP': 19 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4 20 | optim_dict = { 21 | 'optimizer': 'ADAM', 22 | 'betas': (0, 0.9), 23 | 'epsilon': 1e-8, 24 | 'lr': 1e-5, 25 | 'weight_decay': args.weight_decay, 26 | 'decay': args.decay, 27 | 'gamma': args.gamma 28 | } 29 | optim_args = SimpleNamespace(**optim_dict) 30 | else: 31 | optim_args = args 32 | 33 | self.optimizer = utility.make_optimizer(optim_args, self.dis) 34 | 35 | def forward(self, fake, real): 36 | # updating discriminator... 37 | self.loss = 0 38 | fake_detach = fake.detach() # do not backpropagate through G 39 | for _ in range(self.gan_k): 40 | self.optimizer.zero_grad() 41 | # d: B x 1 tensor 42 | d_fake = self.dis(fake_detach) 43 | d_real = self.dis(real) 44 | retain_graph = False 45 | if self.gan_type == 'GAN': 46 | loss_d = self.bce(d_real, d_fake) 47 | elif self.gan_type.find('WGAN') >= 0: 48 | loss_d = (d_fake - d_real).mean() 49 | if self.gan_type.find('GP') >= 0: 50 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 51 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 52 | hat.requires_grad = True 53 | d_hat = self.dis(hat) 54 | gradients = torch.autograd.grad( 55 | outputs=d_hat.sum(), inputs=hat, 56 | retain_graph=True, create_graph=True, only_inputs=True 57 | )[0] 58 | gradients = gradients.view(gradients.size(0), -1) 59 | gradient_norm = gradients.norm(2, dim=1) 60 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 61 | loss_d += gradient_penalty 62 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 63 | elif self.gan_type == 'RGAN': 64 | better_real = d_real - d_fake.mean(dim=0, keepdim=True) 65 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True) 66 | loss_d = self.bce(better_real, better_fake) 67 | retain_graph = True 68 | 69 | # Discriminator update 70 | self.loss += loss_d.item() 71 | loss_d.backward(retain_graph=retain_graph) 72 | self.optimizer.step() 73 | 74 | if self.gan_type == 'WGAN': 75 | for p in self.dis.parameters(): 76 | p.data.clamp_(-1, 1) 77 | 78 | self.loss /= self.gan_k 79 | 80 | # updating generator... 81 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is 82 | if self.gan_type == 'GAN': 83 | label_real = torch.ones_like(d_fake_bp) 84 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) 85 | elif self.gan_type.find('WGAN') >= 0: 86 | loss_g = -d_fake_bp.mean() 87 | elif self.gan_type == 'RGAN': 88 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) 89 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) 90 | loss_g = self.bce(better_fake, better_real) 91 | 92 | # Generator loss 93 | return loss_g 94 | 95 | def state_dict(self, *args, **kwargs): 96 | state_discriminator = self.dis.state_dict(*args, **kwargs) 97 | state_optimizer = self.optimizer.state_dict() 98 | 99 | return dict(**state_discriminator, **state_optimizer) 100 | 101 | def bce(self, real, fake): 102 | label_real = torch.ones_like(real) 103 | label_fake = torch.zeros_like(fake) 104 | bce_real = F.binary_cross_entropy_with_logits(real, label_real) 105 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) 106 | bce_loss = bce_real + bce_fake 107 | return bce_loss 108 | 109 | # Some references 110 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 111 | # OR 112 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 113 | -------------------------------------------------------------------------------- /src/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /src/loss/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, reduction = 'mean'): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | # print('ssim:', ssim_map.shape, ssim_map.max().item(), ssim_map.min().item()) 34 | 35 | if reduction == 'mean': 36 | return ssim_map.mean() 37 | elif reduction == 'none': 38 | return ssim_map 39 | elif reduction == 'navg': 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size = 11, size_average = True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | if self.size_average: 66 | return _ssim(img1, img2, window, self.window_size, channel) 67 | else: 68 | return _ssim(img1, img2, window, self.window_size, channel, 'navg') 69 | 70 | def ssim(img1, img2, window_size = 11, reduction = 'mean'): 71 | (_, channel, _, _) = img1.size() 72 | window = create_window(window_size, channel) 73 | 74 | if img1.is_cuda: 75 | window = window.cuda(img1.get_device()) 76 | window = window.type_as(img1) 77 | 78 | return _ssim(img1, img2, window, window_size, channel, reduction) 79 | -------------------------------------------------------------------------------- /src/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | def main(): 14 | global model 15 | if args.data_test == ['video']: 16 | from videotester import VideoTester 17 | model = model.Model(args, checkpoint) 18 | t = VideoTester(args, model, checkpoint) 19 | t.test() 20 | else: 21 | if checkpoint.ok: 22 | loader = data.Data(args) 23 | _model = model.Model(args, checkpoint) 24 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 25 | t = Trainer(args, loader, _model, _loss, checkpoint) 26 | while not t.terminate(): 27 | t.train() 28 | t.test() 29 | 30 | checkpoint.done() 31 | 32 | if __name__ == '__main__': 33 | main() 34 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = args.scale 15 | self.idx_scale = 0 16 | self.input_large = (args.model in ['VDSR', 'DRRN', 'SRCNN']) 17 | self.self_ensemble = args.self_ensemble 18 | self.chop = args.chop 19 | self.precision = args.precision 20 | self.cpu = args.cpu 21 | self.device = torch.device('cpu' if args.cpu else 'cuda') 22 | self.n_GPUs = args.n_GPUs 23 | self.save_models = args.save_models 24 | 25 | module = import_module('model.' + args.model.lower()) 26 | self.model = module.make_model(args).to(self.device) 27 | if args.precision == 'half': 28 | self.model.half() 29 | 30 | self.load( 31 | ckp.get_path('model'), 32 | pre_train=args.pre_train, 33 | resume=args.resume, 34 | cpu=args.cpu 35 | ) 36 | print(self.model, file=ckp.log_file) 37 | 38 | def forward(self, x, idx_scale): 39 | self.idx_scale = idx_scale 40 | if hasattr(self.model, 'set_scale'): 41 | self.model.set_scale(idx_scale) 42 | 43 | if self.training: 44 | if self.n_GPUs > 1: 45 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 46 | else: 47 | return self.model(x) 48 | else: 49 | if self.chop: 50 | forward_function = self.forward_chop 51 | else: 52 | forward_function = self.model.forward 53 | 54 | if self.self_ensemble: 55 | return self.forward_x8(x, forward_function=forward_function) 56 | else: 57 | return forward_function(x) 58 | 59 | def save(self, apath, epoch, is_best=False): 60 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 61 | 62 | if is_best: 63 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 64 | if self.save_models: 65 | save_dirs.append( 66 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 67 | ) 68 | 69 | for s in save_dirs: 70 | torch.save(self.model.state_dict(), s) 71 | 72 | def load(self, apath, pre_train='', resume=-1, cpu=False): 73 | load_from = None 74 | kwargs = {} 75 | if cpu: 76 | kwargs = {'map_location': lambda storage, loc: storage} 77 | 78 | if resume == -1: 79 | load_from = torch.load( 80 | os.path.join(apath, 'model_latest.pt'), 81 | **kwargs 82 | ) 83 | elif resume == 0: 84 | if pre_train == 'download': 85 | print('Download the model') 86 | dir_model = os.path.join('..', 'models') 87 | os.makedirs(dir_model, exist_ok=True) 88 | load_from = torch.utils.model_zoo.load_url( 89 | self.model.url, 90 | model_dir=dir_model, 91 | **kwargs 92 | ) 93 | elif pre_train: 94 | print('Load the model from {}'.format(pre_train)) 95 | load_from = torch.load(pre_train, **kwargs) 96 | else: 97 | load_from = torch.load( 98 | os.path.join(apath, 'model_{}.pt'.format(resume)), 99 | **kwargs 100 | ) 101 | 102 | if load_from: 103 | self.model.load_state_dict(load_from, strict=False) 104 | 105 | def forward_chop(self, *args, shave=10, min_size=160000): 106 | scale = 1 if self.input_large else self.scale[self.idx_scale] 107 | n_GPUs = min(self.n_GPUs, 4) 108 | # height, width 109 | h, w = args[0].size()[-2:] 110 | 111 | top = slice(0, h//2 + shave) 112 | bottom = slice(h - h//2 - shave, h) 113 | left = slice(0, w//2 + shave) 114 | right = slice(w - w//2 - shave, w) 115 | x_chops = [torch.cat([ 116 | a[..., top, left], 117 | a[..., top, right], 118 | a[..., bottom, left], 119 | a[..., bottom, right] 120 | ]) for a in args] 121 | 122 | y_chops = [] 123 | if h * w < 4 * min_size: 124 | for i in range(0, 4, n_GPUs): 125 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 126 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 127 | if not isinstance(y, list): y = [y] 128 | if not y_chops: 129 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 130 | else: 131 | for y_chop, _y in zip(y_chops, y): 132 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 133 | else: 134 | for p in zip(*x_chops): 135 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 136 | if not isinstance(y, list): y = [y] 137 | if not y_chops: 138 | y_chops = [[_y] for _y in y] 139 | else: 140 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 141 | 142 | h *= scale 143 | w *= scale 144 | top = slice(0, h//2) 145 | bottom = slice(h - h//2, h) 146 | bottom_r = slice(h//2 - h, None) 147 | left = slice(0, w//2) 148 | right = slice(w - w//2, w) 149 | right_r = slice(w//2 - w, None) 150 | 151 | # batch size, number of color channels 152 | b, c = y_chops[0][0].size()[:-2] 153 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 154 | for y_chop, _y in zip(y_chops, y): 155 | _y[..., top, left] = y_chop[0][..., top, left] 156 | _y[..., top, right] = y_chop[1][..., top, right_r] 157 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 158 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 159 | 160 | if len(y) == 1: y = y[0] 161 | 162 | return y 163 | 164 | def forward_x8(self, *args, forward_function=None): 165 | def _transform(v, op): 166 | if self.precision != 'single': v = v.float() 167 | 168 | v2np = v.data.cpu().numpy() 169 | if op == 'v': 170 | tfnp = v2np[:, :, :, ::-1].copy() 171 | elif op == 'h': 172 | tfnp = v2np[:, :, ::-1, :].copy() 173 | elif op == 't': 174 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 175 | 176 | ret = torch.Tensor(tfnp).to(self.device) 177 | if self.precision == 'half': ret = ret.half() 178 | 179 | return ret 180 | 181 | list_x = [] 182 | for a in args: 183 | x = [a] 184 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 185 | 186 | list_x.append(x) 187 | 188 | list_y = [] 189 | for x in zip(*list_x): 190 | y = forward_function(*x) 191 | if not isinstance(y, list): y = [y] 192 | if not list_y: 193 | list_y = [[_y] for _y in y] 194 | else: 195 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 196 | 197 | for _list_y in list_y: 198 | for i in range(len(_list_y)): 199 | if i > 3: 200 | _list_y[i] = _transform(_list_y[i], 't') 201 | if i % 4 > 1: 202 | _list_y[i] = _transform(_list_y[i], 'h') 203 | if (i % 4) % 2 == 1: 204 | _list_y[i] = _transform(_list_y[i], 'v') 205 | 206 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 207 | if len(y) == 1: y = y[0] 208 | 209 | return y 210 | -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2), bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 27 | bn=True, act=nn.ReLU(True)): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) 88 | 89 | -------------------------------------------------------------------------------- /src/model/ddbpn.py: -------------------------------------------------------------------------------- 1 | # Deep Back-Projection Networks For Super-Resolution 2 | # https://arxiv.org/abs/1803.02735 3 | 4 | from . import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return DDBPN(args) 12 | 13 | def projection_conv(in_channels, out_channels, scale, up=True): 14 | kernel_size, stride, padding = { 15 | 2: (6, 2, 2), 16 | 4: (8, 4, 2), 17 | 8: (12, 8, 2) 18 | }[scale] 19 | if up: 20 | conv_f = nn.ConvTranspose2d 21 | else: 22 | conv_f = nn.Conv2d 23 | 24 | return conv_f( 25 | in_channels, out_channels, kernel_size, 26 | stride=stride, padding=padding 27 | ) 28 | 29 | class DenseProjection(nn.Module): 30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 31 | super(DenseProjection, self).__init__() 32 | if bottleneck: 33 | self.bottleneck = nn.Sequential(*[ 34 | nn.Conv2d(in_channels, nr, 1), 35 | nn.PReLU(nr) 36 | ]) 37 | inter_channels = nr 38 | else: 39 | self.bottleneck = None 40 | inter_channels = in_channels 41 | 42 | self.conv_1 = nn.Sequential(*[ 43 | projection_conv(inter_channels, nr, scale, up), 44 | nn.PReLU(nr) 45 | ]) 46 | self.conv_2 = nn.Sequential(*[ 47 | projection_conv(nr, inter_channels, scale, not up), 48 | nn.PReLU(inter_channels) 49 | ]) 50 | self.conv_3 = nn.Sequential(*[ 51 | projection_conv(inter_channels, nr, scale, up), 52 | nn.PReLU(nr) 53 | ]) 54 | 55 | def forward(self, x): 56 | if self.bottleneck is not None: 57 | x = self.bottleneck(x) 58 | 59 | a_0 = self.conv_1(x) 60 | b_0 = self.conv_2(a_0) 61 | e = b_0.sub(x) 62 | a_1 = self.conv_3(e) 63 | 64 | out = a_0.add(a_1) 65 | 66 | return out 67 | 68 | class DDBPN(nn.Module): 69 | def __init__(self, args): 70 | super(DDBPN, self).__init__() 71 | scale = args.scale[0] 72 | 73 | n0 = 128 74 | nr = 32 75 | self.depth = 6 76 | 77 | rgb_mean = (0.4488, 0.4371, 0.4040) 78 | rgb_std = (1.0, 1.0, 1.0) 79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 80 | initial = [ 81 | nn.Conv2d(args.n_colors, n0, 3, padding=1), 82 | nn.PReLU(n0), 83 | nn.Conv2d(n0, nr, 1), 84 | nn.PReLU(nr) 85 | ] 86 | self.initial = nn.Sequential(*initial) 87 | 88 | self.upmodules = nn.ModuleList() 89 | self.downmodules = nn.ModuleList() 90 | channels = nr 91 | for i in range(self.depth): 92 | self.upmodules.append( 93 | DenseProjection(channels, nr, scale, True, i > 1) 94 | ) 95 | if i != 0: 96 | channels += nr 97 | 98 | channels = nr 99 | for i in range(self.depth - 1): 100 | self.downmodules.append( 101 | DenseProjection(channels, nr, scale, False, i != 0) 102 | ) 103 | channels += nr 104 | 105 | reconstruction = [ 106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 107 | ] 108 | self.reconstruction = nn.Sequential(*reconstruction) 109 | 110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 111 | 112 | def forward(self, x): 113 | x = self.sub_mean(x) 114 | x = self.initial(x) 115 | 116 | h_list = [] 117 | l_list = [] 118 | for i in range(self.depth - 1): 119 | if i == 0: 120 | l = x 121 | else: 122 | l = torch.cat(l_list, dim=1) 123 | h_list.append(self.upmodules[i](l)) 124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 125 | 126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 127 | out = self.reconstruction(torch.cat(h_list, dim=1)) 128 | out = self.add_mean(out) 129 | 130 | return out 131 | 132 | -------------------------------------------------------------------------------- /src/model/denseskip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from model import common 6 | 7 | def make_model(args, parent=False): 8 | return DenseSkip(args) 9 | 10 | class DenseBlock(nn.Module): 11 | def __init__(self, growth_rate, n_feat_in, n_layers, conv=common.default_conv): 12 | super(DenseBlock, self).__init__() 13 | 14 | kernel_size = 3 15 | feat = n_feat_in 16 | body = [] 17 | 18 | for i in xrange(n_layers): 19 | to_concat = False if (i==0) else True 20 | layer = common.DenseLayer(conv,feat,growth_rate,kernel_size,True, 21 | to_concat) 22 | #self.add_module('DenseLayer{}'.format(i+1),layer) 23 | body.append(layer) 24 | 25 | if i==0: feat = growth_rate 26 | else: feat += growth_rate 27 | 28 | self.body = nn.Sequential(*body) 29 | 30 | def forward(self, x): 31 | return self.body(x) 32 | 33 | class DenseSkip(nn.Module): 34 | def __init__(self, args, conv=common.default_conv): 35 | super(DenseSkip, self).__init__() 36 | 37 | self.act = nn.ReLU(True) 38 | kernel_size = 3 39 | growth_rate = args.growth_rate 40 | n_feats = args.n_feats 41 | n_denseblocks = args.n_denseblocks 42 | scale = args.scale[0] 43 | self.is_sub_mean = args.is_sub_mean 44 | 45 | rgb_mean = (0.4488, 0.4371, 0.4040) 46 | rgb_std = (1.0, 1.0, 1.0) 47 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 48 | 49 | self.head = nn.Sequential(*[nn.Conv2d(in_channels=args.n_channel_in, out_channels=n_feats, 50 | kernel_size=kernel_size,padding=1), 51 | nn.ReLU(True)]) 52 | 53 | self.dense_blocks = [] 54 | for i in xrange(n_denseblocks): 55 | db = DenseBlock(growth_rate, n_feats, args.n_layers) 56 | self.add_module('DenseBlock{}'.format(i+1), db) 57 | self.dense_blocks.append(db) 58 | 59 | self.bottleneck = nn.Conv2d(in_channels=n_feats*(n_denseblocks+1), 60 | out_channels=n_feats*2, kernel_size=1, 61 | stride=1, padding=0, bias=False) 62 | 63 | self.tail = [common.Upsampler(nn.ConvTranspose2d, scale=scale, n_feat=n_feats*2, 64 | act=self.act, bias=False, type='deconv')] 65 | self.tail = nn.Sequential(*self.tail) 66 | 67 | self.reconstruction = nn.Conv2d(in_channels=n_feats*2, out_channels=args.n_channel_out, 68 | kernel_size=kernel_size, stride=1, padding=1, bias=False) 69 | 70 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 71 | 72 | def forward(self, x): 73 | if self.is_sub_mean: 74 | x = self.sub_mean(x) 75 | 76 | x = self.head(x) 77 | outs = [x] 78 | 79 | for db in self.dense_blocks: 80 | x = db(x) 81 | 82 | outs.append(x) 83 | 84 | x = torch.cat(outs, 1) 85 | x = self.bottleneck(x) 86 | x = self.tail(x) 87 | x = self.reconstruction(x) 88 | 89 | if self.is_sub_mean: 90 | x = self.add_mean(x) 91 | 92 | return x -------------------------------------------------------------------------------- /src/model/drrn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from . import common 3 | 4 | def make_model(args): 5 | return DRRN(args) 6 | 7 | class DRRN(nn.Module): 8 | def __init__(self, args, conv=common.default_conv): 9 | super(DRRN, self).__init__() 10 | 11 | kernel_size = 3 12 | self.n_layers = args.n_layers 13 | self.is_residual = args.n_colors == args.n_colors 14 | 15 | self.head = conv(args.n_colors,args.n_feats,kernel_size,bias=False) 16 | self.conv1 = conv(args.n_feats,args.n_feats,kernel_size,bias=False) 17 | self.conv2 = conv(args.n_feats,args.n_feats,kernel_size,bias=False) 18 | self.tail = conv(args.n_feats,args.n_colors,kernel_size,bias=False) 19 | self.act = nn.ReLU(inplace=True) 20 | 21 | def forward(self, x): 22 | residual = x 23 | inputs = self.head(self.act(x)) 24 | out = inputs 25 | 26 | for _ in range(self.n_layers): 27 | out = self.conv2(self.act(self.conv1(self.act(out)))) 28 | out += inputs 29 | 30 | self.features = [self.act(out)] 31 | out = self.tail(self.features[0]) 32 | 33 | return out + residual if (self.is_residual) else out 34 | -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 12 | } 13 | 14 | def make_model(args, parent=False): 15 | return EDSR(args) 16 | 17 | class EDSR(nn.Module): 18 | def __init__(self, args, conv=common.default_conv): 19 | super(EDSR, self).__init__() 20 | 21 | n_resblocks = args.n_resblocks 22 | n_feats = args.n_feats 23 | kernel_size = 3 24 | scale = args.scale[0] 25 | act = nn.ReLU(True) 26 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 27 | if url_name in url: 28 | self.url = url[url_name] 29 | else: 30 | self.url = None 31 | self.sub_mean = common.MeanShift(args.rgb_range) 32 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 33 | 34 | # define head module 35 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 36 | 37 | # define body module 38 | m_body = [ 39 | common.ResBlock( 40 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 41 | ) for _ in range(n_resblocks) 42 | ] 43 | m_body.append(conv(n_feats, n_feats, kernel_size)) 44 | 45 | # define tail module 46 | m_tail = [ 47 | common.Upsampler(conv, scale, n_feats, act=False), 48 | conv(n_feats, args.n_colors, kernel_size) 49 | ] 50 | 51 | self.head = nn.Sequential(*m_head) 52 | self.body = nn.Sequential(*m_body) 53 | self.tail = nn.Sequential(*m_tail) 54 | 55 | def forward(self, x): 56 | x = self.sub_mean(x) 57 | x = self.head(x) 58 | 59 | res = self.body(x) 60 | res += x 61 | 62 | x = self.tail(res) 63 | x = self.add_mean(x) 64 | 65 | return x 66 | 67 | def load_state_dict(self, state_dict, strict=True): 68 | own_state = self.state_dict() 69 | for name, param in state_dict.items(): 70 | if name in own_state: 71 | if isinstance(param, nn.Parameter): 72 | param = param.data 73 | try: 74 | own_state[name].copy_(param) 75 | except Exception: 76 | if name.find('tail') == -1: 77 | raise RuntimeError('While copying the parameter named {}, ' 78 | 'whose dimensions in the model are {} and ' 79 | 'whose dimensions in the checkpoint are {}.' 80 | .format(name, own_state[name].size(), param.size())) 81 | elif strict: 82 | if name.find('tail') == -1: 83 | raise KeyError('unexpected key "{}" in state_dict' 84 | .format(name)) 85 | 86 | -------------------------------------------------------------------------------- /src/model/lapsrn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | try: 5 | # Python 2 6 | from itertools import izip 7 | except ImportError: 8 | # Python 3 9 | izip = zip 10 | xrange = range 11 | 12 | def make_model(args): 13 | return LapSRN(args) 14 | 15 | class Conv_Block(nn.Module): 16 | def __init__(self, n_layers, n_feats, negative_slope, kernel_size, upsample=True): 17 | super(Conv_Block, self).__init__() 18 | body = [] 19 | tail = [] 20 | 21 | for _ in xrange(n_layers): 22 | layer = nn.Conv2d(in_channels=n_feats, out_channels=n_feats, 23 | kernel_size=kernel_size,stride=1,padding=1,bias=False) 24 | act = nn.LeakyReLU(negative_slope=negative_slope, inplace=True) 25 | body.extend([layer,act]) 26 | 27 | if upsample: 28 | upsampled = nn.ConvTranspose2d(in_channels=n_feats, out_channels=n_feats, 29 | kernel_size=4, stride=2, padding=1,bias=False) 30 | act = nn.LeakyReLU(negative_slope, inplace=True) 31 | tail.extend([upsampled, act]) 32 | 33 | self.body = nn.Sequential(*body) 34 | self.tail = nn.Sequential(*tail) 35 | 36 | def forward(self, x): 37 | self.down_feats = self.body(x) 38 | return self.tail(self.down_feats) 39 | 40 | def load_state_dict(self, state_dict, strict=True): 41 | own_state = self.state_dict() 42 | 43 | for (k1, p1), (k2, p2) in izip(state_dict.items(), own_state.items()): 44 | if isinstance(p1, nn.Parameter): 45 | p1 = p1.data 46 | try: 47 | own_state[k2].copy_(p1) 48 | except Exception: 49 | raise RuntimeError('p1 dims = {}, p2 dims = {}'.format(p1.size(), p2.size())) 50 | 51 | class LapSRN(nn.Module): 52 | def __init__(self, args): 53 | super(LapSRN, self).__init__() 54 | 55 | kernel_size = 3 56 | self.scale = args.scale[0] 57 | 58 | head = [nn.Conv2d(in_channels=args.n_colors, out_channels=args.n_feats, 59 | kernel_size=kernel_size, stride=1, padding=1, bias=False), 60 | nn.LeakyReLU(args.negative_slope, True)] 61 | self.head = nn.Sequential(*head) 62 | 63 | self.feats_branch, self.images_branch, self.residuals_branch = [], [], [] 64 | 65 | n_iters = 1 if (self.scale==1) else int(math.log(self.scale, 2)) 66 | for i in xrange(n_iters): 67 | feat_branch = Conv_Block(args.n_layers, args.n_feats, args.negative_slope, 68 | kernel_size, upsample=not(self.scale==1)) 69 | if not (self.scale==1): 70 | img_branch = nn.ConvTranspose2d(in_channels=args.n_colors if (i==0) \ 71 | else args.n_colors, 72 | out_channels=args.n_colors, 73 | kernel_size=4, stride=2, padding=1, 74 | bias=False) 75 | else: 76 | img_branch = nn.Conv2d(in_channels=args.n_colors if (i==0) else args.n_colors, 77 | kernel_size=kernel_size,out_channels=args.n_colors, 78 | stride=1,padding=1,bias=True) 79 | 80 | res_branch = nn.Conv2d(in_channels=args.n_feats, out_channels=args.n_colors, 81 | kernel_size=kernel_size,stride=1,padding=1,bias=False) 82 | 83 | self.add_module('img_branch_{}'.format(i+1), img_branch) 84 | self.add_module('residual_branch_{}'.format(i+1), res_branch) 85 | self.add_module('feat_branch_{}'.format(i+1),feat_branch) 86 | 87 | self.feats_branch.append(feat_branch) 88 | self.images_branch.append(img_branch) 89 | self.residuals_branch.append(res_branch) 90 | 91 | def forward(self, x): 92 | fx = self.head(x) 93 | 94 | self.features = [] 95 | self.down_feats = None 96 | 97 | for feat,img,res in izip(self.feats_branch,self.images_branch, 98 | self.residuals_branch): 99 | fx = feat(fx) 100 | ix = img(x) 101 | rx = res(fx) 102 | 103 | x = rx + ix 104 | 105 | self.features.append(fx) 106 | 107 | if self.down_feats is None: 108 | self.down_feats = feat.down_feats 109 | 110 | return x 111 | 112 | def load_state_dict(self, state_dict, strict=True): 113 | own_state = self.state_dict() 114 | 115 | for (k1, p1), (k2, p2) in izip(state_dict.items(), own_state.items()): 116 | if isinstance(p1, nn.Parameter): 117 | p1 = p1.data 118 | try: 119 | own_state[k2].copy_(p1) 120 | except Exception: 121 | raise RuntimeError('error; {}, {}, {}, {}'.format(k1, k2, p1.size(), p2.size())) -------------------------------------------------------------------------------- /src/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt', 7 | 'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return MDSR(args) 12 | 13 | class MDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(MDSR, self).__init__() 16 | n_resblocks = args.n_resblocks 17 | n_feats = args.n_feats 18 | kernel_size = 3 19 | act = nn.ReLU(True) 20 | self.scale_idx = 0 21 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 22 | self.sub_mean = common.MeanShift(args.rgb_range) 23 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 24 | 25 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 26 | 27 | self.pre_process = nn.ModuleList([ 28 | nn.Sequential( 29 | common.ResBlock(conv, n_feats, 5, act=act), 30 | common.ResBlock(conv, n_feats, 5, act=act) 31 | ) for _ in args.scale 32 | ]) 33 | 34 | m_body = [ 35 | common.ResBlock( 36 | conv, n_feats, kernel_size, act=act 37 | ) for _ in range(n_resblocks) 38 | ] 39 | m_body.append(conv(n_feats, n_feats, kernel_size)) 40 | 41 | self.upsample = nn.ModuleList([ 42 | common.Upsampler(conv, s, n_feats, act=False) for s in args.scale 43 | ]) 44 | 45 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 46 | 47 | self.head = nn.Sequential(*m_head) 48 | self.body = nn.Sequential(*m_body) 49 | self.tail = nn.Sequential(*m_tail) 50 | 51 | def forward(self, x): 52 | x = self.sub_mean(x) 53 | x = self.head(x) 54 | x = self.pre_process[self.scale_idx](x) 55 | 56 | res = self.body(x) 57 | res += x 58 | 59 | x = self.upsample[self.scale_idx](res) 60 | x = self.tail(x) 61 | x = self.add_mean(x) 62 | 63 | return x 64 | 65 | def set_scale(self, scale_idx): 66 | self.scale_idx = scale_idx 67 | 68 | -------------------------------------------------------------------------------- /src/model/rcan.py: -------------------------------------------------------------------------------- 1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks 2 | ## https://arxiv.org/abs/1807.02758 3 | from model import common 4 | 5 | import torch.nn as nn 6 | 7 | def make_model(args, parent=False): 8 | return RCAN(args) 9 | 10 | ## Channel Attention (CA) Layer 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | # global average pooling: feature --> point 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | # feature channel downscale and upscale --> channel weight 17 | self.conv_du = nn.Sequential( 18 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 21 | nn.Sigmoid() 22 | ) 23 | 24 | def forward(self, x): 25 | y = self.avg_pool(x) 26 | y = self.conv_du(y) 27 | return x * y 28 | 29 | ## Residual Channel Attention Block (RCAB) 30 | class RCAB(nn.Module): 31 | def __init__( 32 | self, conv, n_feat, kernel_size, reduction, 33 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 34 | 35 | super(RCAB, self).__init__() 36 | modules_body = [] 37 | for i in range(2): 38 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 39 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 40 | if i == 0: modules_body.append(act) 41 | modules_body.append(CALayer(n_feat, reduction)) 42 | self.body = nn.Sequential(*modules_body) 43 | self.res_scale = res_scale 44 | 45 | def forward(self, x): 46 | res = self.body(x) 47 | #res = self.body(x).mul(self.res_scale) 48 | res += x 49 | return res 50 | 51 | ## Residual Group (RG) 52 | class ResidualGroup(nn.Module): 53 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 54 | super(ResidualGroup, self).__init__() 55 | modules_body = [] 56 | modules_body = [ 57 | RCAB( 58 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 59 | for _ in range(n_resblocks)] 60 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 61 | self.body = nn.Sequential(*modules_body) 62 | 63 | def forward(self, x): 64 | res = self.body(x) 65 | res += x 66 | return res 67 | 68 | ## Residual Channel Attention Network (RCAN) 69 | class RCAN(nn.Module): 70 | def __init__(self, args, conv=common.default_conv): 71 | super(RCAN, self).__init__() 72 | 73 | n_resgroups = args.n_resgroups 74 | n_resblocks = args.n_resblocks 75 | n_feats = args.n_feats 76 | kernel_size = 3 77 | reduction = args.reduction 78 | scale = args.scale[0] 79 | act = nn.ReLU(True) 80 | 81 | # RGB mean for DIV2K 82 | self.sub_mean = common.MeanShift(args.rgb_range) 83 | 84 | # define head module 85 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 86 | 87 | # define body module 88 | modules_body = [ 89 | ResidualGroup( 90 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 91 | for _ in range(n_resgroups)] 92 | 93 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 94 | 95 | # define tail module 96 | modules_tail = [ 97 | common.Upsampler(conv, scale, n_feats, act=False), 98 | conv(n_feats, args.n_colors, kernel_size)] 99 | 100 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 101 | 102 | self.head = nn.Sequential(*modules_head) 103 | self.body = nn.Sequential(*modules_body) 104 | self.tail = nn.Sequential(*modules_tail) 105 | 106 | def forward(self, x): 107 | x = self.sub_mean(x) 108 | x = self.head(x) 109 | 110 | res = self.body(x) 111 | res += x 112 | 113 | x = self.tail(res) 114 | x = self.add_mean(x) 115 | 116 | return x 117 | 118 | def load_state_dict(self, state_dict, strict=False): 119 | own_state = self.state_dict() 120 | for name, param in state_dict.items(): 121 | if name in own_state: 122 | if isinstance(param, nn.Parameter): 123 | param = param.data 124 | try: 125 | own_state[name].copy_(param) 126 | except Exception: 127 | if name.find('tail') >= 0: 128 | print('Replace pre-trained upsampler to new one...') 129 | else: 130 | raise RuntimeError('While copying the parameter named {}, ' 131 | 'whose dimensions in the model are {} and ' 132 | 'whose dimensions in the checkpoint are {}.' 133 | .format(name, own_state[name].size(), param.size())) 134 | elif strict: 135 | if name.find('tail') == -1: 136 | raise KeyError('unexpected key "{}" in state_dict' 137 | .format(name)) 138 | 139 | if strict: 140 | missing = set(own_state.keys()) - set(state_dict.keys()) 141 | if len(missing) > 0: 142 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 143 | -------------------------------------------------------------------------------- /src/model/rdn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import common 4 | from itertools import izip 5 | 6 | def make_model(args, parent=False): 7 | return RDN(args) 8 | 9 | class RDN(nn.Module): 10 | def __init__(self, args, conv=common.default_conv): 11 | super(RDN, self).__init__() 12 | 13 | rgb_mean = (0.4488, 0.4371, 0.4040) 14 | rgb_std = (1.0, 1.0, 1.0) 15 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 16 | 17 | kernel_size = 3 18 | self.is_sub_mean = args.is_sub_mean 19 | 20 | self.conv1 = conv(args.n_channel_in, args.n_feats, kernel_size, bias=True) 21 | self.conv2 = conv(args.n_feats, args.n_feats, kernel_size, bias=True) 22 | 23 | self.RDBs = [] 24 | for i in xrange(args.n_denseblocks): 25 | RDB = common.RDB(args.n_feats,args.n_layers,args.growth_rate,conv,kernel_size,True) 26 | self.add_module('RDB{}'.format(i+1),RDB) 27 | self.RDBs.append(RDB) 28 | 29 | self.gff_1 = nn.Conv2d(args.n_feats*args.n_denseblocks, args.n_feats, 30 | kernel_size=1, padding=0, bias=True) 31 | self.gff_3 = conv(args.n_feats, args.n_feats, kernel_size, bias=True) 32 | 33 | m_tail = [common.Upsampler(conv, args.scale[0], args.n_feats, act=False), 34 | conv(args.n_feats, args.n_channel_out, kernel_size)] 35 | 36 | self.tail = nn.Sequential(*m_tail) 37 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 38 | 39 | def forward(self, x): 40 | if self.is_sub_mean: 41 | x = self.sub_mean(x) 42 | 43 | F_minus = self.conv1(x) 44 | x = self.conv2(F_minus) 45 | to_concat = [] 46 | 47 | for db in self.RDBs: 48 | x = db(x) 49 | to_concat.append(x) 50 | 51 | x = torch.cat(to_concat, 1) 52 | x = self.gff_1(x) 53 | x = self.gff_3(x) 54 | x = x + F_minus 55 | 56 | self.down_feats = x 57 | 58 | out = self.tail(x) 59 | 60 | if self.is_sub_mean: 61 | out = self.add_mean(out) 62 | 63 | return out 64 | 65 | def load_state_dict(self, state_dict, strict=True): 66 | own_state = self.state_dict() 67 | 68 | for (k1, p1), (k2, p2) in izip(state_dict.items(), own_state.items()): 69 | if (k1.split('.')[0] == '0') or (k1.split('.')[0] == '5'): #do not copy shift mean layer 70 | continue 71 | 72 | if isinstance(p1, nn.Parameter): 73 | p1 = p1.data 74 | 75 | try: 76 | own_state[k2].copy_(p1) 77 | except Exception: 78 | raise RuntimeError('error') 79 | -------------------------------------------------------------------------------- /src/model/srcnn.py: -------------------------------------------------------------------------------- 1 | from . import common 2 | import torch.nn as nn 3 | 4 | def make_model(args): 5 | return SRResNet(args) 6 | 7 | class SRResNet(nn.Module): 8 | def __init__(self, args, conv=common.default_conv): 9 | super(SRResNet, self).__init__() 10 | 11 | #kernel_size = 3 12 | #scale = args.scale[0] 13 | #act = nn.LeakyReLU(negative_slope=0.2) 14 | 15 | rgb_mean = (0.4488, 0.4371, 0.4040) 16 | rgb_std = (1.0, 1.0, 1.0) 17 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 18 | 19 | self.body = nn.Sequential( 20 | conv(args.n_colors, args.n_feats * 2, 9), 21 | nn.ReLU(inplace=True), 22 | conv(args.n_feats * 2, args.n_feats, 5), 23 | nn.ReLU(inplace=True), 24 | conv(args.n_feats, args.n_colors, 5) 25 | ) 26 | 27 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 28 | 29 | def forward(self, x): 30 | x = self.sub_mean(x) 31 | 32 | x = self.body(x) 33 | 34 | x = self.add_mean(x) 35 | 36 | return x -------------------------------------------------------------------------------- /src/model/srresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from model import common 5 | 6 | def make_model(args, parent=False): 7 | return SRResNet(args) 8 | 9 | class _Residual_Block(nn.Module): 10 | def __init__(self): 11 | super(_Residual_Block, self).__init__() 12 | 13 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.in1 = nn.InstanceNorm2d(64, affine=True) 15 | self.relu = nn.LeakyReLU(0.2, inplace=True) 16 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 17 | self.in2 = nn.InstanceNorm2d(64, affine=True) 18 | 19 | def forward(self, x): 20 | identity_data = x 21 | output = self.relu(self.in1(self.conv1(x))) 22 | output = self.in2(self.conv2(output)) 23 | output = torch.add(output,identity_data) 24 | return output 25 | 26 | class SRResNet(nn.Module): 27 | def __init__(self, args): 28 | super(SRResNet, self).__init__() 29 | 30 | self.sub_mean = common.MeanShift(args.rgb_range) 31 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 32 | 33 | self.conv_input = nn.Conv2d(in_channels=args.n_colors, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False) 34 | self.relu = nn.LeakyReLU(0.2, inplace=True) 35 | 36 | self.residual = self.make_layer(_Residual_Block, args.n_feats) 37 | 38 | self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 39 | self.bn_mid = nn.InstanceNorm2d(64, affine=True) 40 | 41 | self.upscale4x = nn.Sequential( 42 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 43 | nn.PixelShuffle(2), 44 | nn.LeakyReLU(0.2, inplace=True), 45 | nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), 46 | nn.PixelShuffle(2), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | ) 49 | 50 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=args.n_colors, kernel_size=9, stride=1, padding=4, bias=False) 51 | 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 55 | m.weight.data.normal_(0, math.sqrt(2. / n)) 56 | if m.bias is not None: 57 | m.bias.data.zero_() 58 | 59 | def make_layer(self, block, num_of_layer): 60 | layers = [] 61 | for _ in range(num_of_layer): 62 | layers.append(block()) 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | x = self.sub_mean(x) 67 | out = self.relu(self.conv_input(x)) 68 | residual = out 69 | out = self.residual(out) 70 | out = self.bn_mid(self.conv_mid(out)) 71 | out = torch.add(out,residual) 72 | out = self.upscale4x(out) 73 | out = self.conv_output(out) 74 | out = self.add_mean(out) 75 | return out -------------------------------------------------------------------------------- /src/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=6, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', action='store_true', 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=1, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | 21 | # Data specifications 22 | parser.add_argument('--dir_data', type=str, default='../../dataset', 23 | help='dataset directory') 24 | parser.add_argument('--dir_demo', type=str, default='../test', 25 | help='demo image directory') 26 | parser.add_argument('--data_train', type=str, default='DIV2K', 27 | help='train dataset name') 28 | parser.add_argument('--data_test', type=str, default='DIV2K', 29 | help='test dataset name') 30 | parser.add_argument('--data_range', type=str, default='1-800/801-900', 31 | help='train/test data range') 32 | parser.add_argument('--ext', type=str, default='sep', 33 | help='dataset file extension') 34 | parser.add_argument('--scale', type=str, default='4', 35 | help='super resolution scale') 36 | parser.add_argument('--patch_size', type=int, default=192, 37 | help='output patch size') 38 | parser.add_argument('--rgb_range', type=int, default=255, 39 | help='maximum value of RGB') 40 | parser.add_argument('--n_colors', type=int, default=3, 41 | help='number of color channels to use') 42 | parser.add_argument('--chop', action='store_true', 43 | help='enable memory-efficient forward') 44 | parser.add_argument('--no_augment', action='store_true', 45 | help='do not use data augmentation') 46 | 47 | # Model specifications 48 | parser.add_argument('--model', default='EDSR', 49 | help='model name') 50 | 51 | parser.add_argument('--act', type=str, default='relu', 52 | help='activation function') 53 | parser.add_argument('--negative_slope', type=float, default=0.2, 54 | help='negative slope parameter for PRelu') 55 | parser.add_argument('--pre_train', type=str, default='', 56 | help='pre-trained model directory') 57 | parser.add_argument('--extend', type=str, default='.', 58 | help='pre-trained model directory') 59 | parser.add_argument('--n_resblocks', type=int, default=16, 60 | help='number of residual blocks') 61 | parser.add_argument('--n_layers', type=int, default=8, 62 | help='number of layers inside a dense block [DenseSkip]') 63 | parser.add_argument('--n_feats', type=int, default=64, 64 | help='number of feature maps') 65 | parser.add_argument('--res_scale', type=float, default=1, 66 | help='residual scaling') 67 | parser.add_argument('--shift_mean', default=True, 68 | help='subtract pixel mean from the input') 69 | parser.add_argument('--dilation', action='store_true', 70 | help='use dilated convolution') 71 | parser.add_argument('--precision', type=str, default='single', 72 | choices=('single', 'half'), 73 | help='FP precision for test (single | half)') 74 | 75 | # Option for Residual dense network (RDN) 76 | parser.add_argument('--G0', type=int, default=64, 77 | help='default number of filters. (Use in RDN)') 78 | parser.add_argument('--RDNkSize', type=int, default=3, 79 | help='default kernel size. (Use in RDN)') 80 | parser.add_argument('--RDNconfig', type=str, default='B', 81 | help='parameters config of RDN. (Use in RDN)') 82 | 83 | # Option for Residual channel attention network (RCAN) 84 | parser.add_argument('--n_resgroups', type=int, default=10, 85 | help='number of residual groups') 86 | parser.add_argument('--reduction', type=int, default=16, 87 | help='number of feature maps reduction') 88 | 89 | # Training specifications 90 | parser.add_argument('--reset', action='store_true', 91 | help='reset the training') 92 | parser.add_argument('--test_every', type=int, default=1000, 93 | help='do test per every N batches') 94 | parser.add_argument('--epochs', type=int, default=300, 95 | help='number of epochs to train') 96 | parser.add_argument('--batch_size', type=int, default=16, 97 | help='input batch size for training') 98 | parser.add_argument('--split_batch', type=int, default=1, 99 | help='split the batch into smaller chunks') 100 | parser.add_argument('--self_ensemble', action='store_true', 101 | help='use self-ensemble method for test') 102 | parser.add_argument('--test_only', action='store_true', 103 | help='set this option to test the model') 104 | parser.add_argument('--gan_k', type=int, default=1, 105 | help='k value for adversarial loss') 106 | 107 | # Optimization specifications 108 | parser.add_argument('--lr', type=float, default=1e-4, 109 | help='learning rate') 110 | parser.add_argument('--decay', type=str, default='200', 111 | help='learning rate decay type') 112 | parser.add_argument('--gamma', type=float, default=0.5, 113 | help='learning rate decay factor for step decay') 114 | parser.add_argument('--optimizer', default='ADAM', 115 | choices=('SGD', 'ADAM', 'RMSprop'), 116 | help='optimizer to use (SGD | ADAM | RMSprop)') 117 | parser.add_argument('--momentum', type=float, default=0.9, 118 | help='SGD momentum') 119 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 120 | help='ADAM beta') 121 | parser.add_argument('--epsilon', type=float, default=1e-8, 122 | help='ADAM epsilon for numerical stability') 123 | parser.add_argument('--weight_decay', type=float, default=0, 124 | help='weight decay') 125 | parser.add_argument('--gclip', type=float, default=0, 126 | help='gradient clipping threshold (0 = no clipping)') 127 | 128 | # Loss specifications 129 | parser.add_argument('--loss', type=str, default='1*L1', 130 | help='loss function configuration') 131 | parser.add_argument('--skip_threshold', type=float, default='1e8', 132 | help='skipping batch that has large error') 133 | 134 | # Log specifications 135 | parser.add_argument('--save', type=str, default='test', 136 | help='file name to save') 137 | parser.add_argument('--load', type=str, default='', 138 | help='file name to load') 139 | parser.add_argument('--resume', type=int, default=0, 140 | help='resume from specific checkpoint') 141 | parser.add_argument('--save_models', action='store_true', 142 | help='save all intermediate models') 143 | parser.add_argument('--print_every', type=int, default=100, 144 | help='how many batches to wait before logging training status') 145 | parser.add_argument('--save_results', action='store_true', 146 | help='save output results') 147 | parser.add_argument('--save_gt', action='store_true', 148 | help='save low-resolution and high-resolution images together') 149 | 150 | # SPL parameter 151 | parser.add_argument('--splalpha', default=1, 152 | help='sigma init value') 153 | parser.add_argument('--splbeta', default=0, 154 | help='sigma init value') 155 | parser.add_argument('--splval', default=2, 156 | help='sigma init value') 157 | 158 | args = parser.parse_args() 159 | template.set_template(args) 160 | 161 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 162 | args.data_train = args.data_train.split('+') 163 | args.data_test = args.data_test.split('+') 164 | 165 | if args.epochs == 0: 166 | args.epochs = 1e8 167 | 168 | for arg in vars(args): 169 | if vars(args)[arg] == 'True': 170 | vars(args)[arg] = True 171 | elif vars(args)[arg] == 'False': 172 | vars(args)[arg] = False 173 | 174 | -------------------------------------------------------------------------------- /src/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-4 53 | if args.template.find('DRRN') >= 0: 54 | args.model = 'DRRN' 55 | args.n_colors = 1 56 | args.n_layers = 25 57 | args.n_feats = 128 58 | if args.template.find('LapSRN') >= 0: 59 | args.model = 'LapSRN' 60 | args.n_layers=10 61 | args.loss="1*MSE" 62 | args.n_colors = 1 63 | 64 | if args.template.find('SRCNN') >= 0: 65 | args.model = 'SRCNN' 66 | args.loss="1*MSE" 67 | #args.n_colors = 1 68 | args.is_sub_mean = True 69 | args.patch_size = 128 70 | args.n_feats = 32 71 | if args.template.find("SRRESNET") >= 0: 72 | args.model = 'SRRESNET' 73 | args.n_feats = 16 74 | args.epochs = 500 75 | args.loss = '1*MSE' 76 | args.patch_size = 128 -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | 5 | import utility 6 | 7 | import torch 8 | import torch.nn.utils as utils 9 | from tqdm import tqdm 10 | 11 | class Trainer(): 12 | def __init__(self, args, loader, my_model, my_loss, ckp): 13 | self.args = args 14 | self.scale = args.scale 15 | 16 | self.ckp = ckp 17 | self.loader_train = loader.loader_train 18 | self.loader_test = loader.loader_test 19 | self.model = my_model 20 | self.loss = my_loss 21 | self.optimizer = utility.make_optimizer(args, self.model) 22 | 23 | if self.args.load != '': 24 | self.optimizer.load(ckp.dir, epoch=len(ckp.log)) 25 | 26 | self.error_last = 1e8 27 | self.age = 0. 28 | 29 | def train(self): 30 | self.loss.step() 31 | epoch = self.optimizer.get_last_epoch() + 1 32 | lr = self.optimizer.get_lr() 33 | 34 | self.ckp.write_log( 35 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 36 | ) 37 | self.loss.start_log() 38 | self.model.train() 39 | 40 | timer_data, timer_model = utility.timer(), utility.timer() 41 | # TEMP 42 | self.loader_train.dataset.set_scale(0) 43 | bum = len(self.loader_train) 44 | for batch, (lr, hr, _,) in enumerate(self.loader_train): 45 | lr, hr = self.prepare(lr, hr) 46 | timer_data.hold() 47 | timer_model.tic() 48 | self.age += 1.0 / bum 49 | self.optimizer.zero_grad() 50 | sr = self.model(lr, 0) 51 | #print(lr.shape, sr.shape, hr.shape) 52 | loss = self.loss(sr, hr, self.age) 53 | loss.backward() 54 | if self.args.gclip > 0: 55 | utils.clip_grad_value_( 56 | self.model.parameters(), 57 | self.args.gclip 58 | ) 59 | self.optimizer.step() 60 | 61 | timer_model.hold() 62 | 63 | if (batch + 1) % self.args.print_every == 0: 64 | self.ckp.write_log('[{}/{}]\t{}\t{}\t{:.1f}+{:.1f}s'.format( 65 | (batch + 1) * self.args.batch_size, 66 | len(self.loader_train.dataset), 67 | f'[age: {self.age:.2f}]', 68 | self.loss.display_loss(batch), 69 | timer_model.release(), 70 | timer_data.release())) 71 | 72 | timer_data.tic() 73 | 74 | self.loss.end_log(len(self.loader_train)) 75 | self.error_last = self.loss.log[-1, -1] 76 | self.optimizer.schedule() 77 | 78 | def test(self): 79 | torch.set_grad_enabled(False) 80 | 81 | epoch = self.optimizer.get_last_epoch() 82 | self.ckp.write_log('\nEvaluation:') 83 | self.ckp.add_log( 84 | torch.zeros(1, len(self.loader_test), len(self.scale)) 85 | ) 86 | self.model.eval() 87 | 88 | timer_test = utility.timer() 89 | if self.args.save_results: self.ckp.begin_background() 90 | for idx_data, d in enumerate(self.loader_test): 91 | for idx_scale, scale in enumerate(self.scale): 92 | d.dataset.set_scale(idx_scale) 93 | ssim = 0.0 94 | for lr, hr, filename in tqdm(d, ncols=80): 95 | lr, hr = self.prepare(lr, hr) 96 | sr = self.model(lr, idx_scale) 97 | sr = utility.quantize(sr, self.args.rgb_range) 98 | save_list = [sr] 99 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 100 | sr, hr, scale, self.args.rgb_range, dataset=d 101 | ) 102 | ssim += utility.calc_ssim(sr, hr, self.args.rgb_range) 103 | if self.args.save_gt: 104 | save_list.extend([lr, hr]) 105 | 106 | if self.args.save_results: 107 | self.ckp.save_results(d, filename[0], save_list, scale) 108 | 109 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 110 | best = self.ckp.log.max(0) 111 | self.ckp.write_log( 112 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 113 | d.dataset.name, 114 | scale, 115 | self.ckp.log[-1, idx_data, idx_scale], 116 | best[0][idx_data, idx_scale], 117 | best[1][idx_data, idx_scale] + 1 118 | ) 119 | ) 120 | print(f"ssim = {ssim / len(d)}") 121 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) 122 | self.ckp.write_log('Saving...') 123 | 124 | if self.args.save_results: 125 | self.ckp.end_background() 126 | 127 | if not self.args.test_only: 128 | self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 129 | 130 | self.ckp.write_log( 131 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 132 | ) 133 | 134 | torch.set_grad_enabled(True) 135 | 136 | def prepare(self, *args): 137 | device = torch.device('cpu' if self.args.cpu else 'cuda') 138 | def _prepare(tensor): 139 | if self.args.precision == 'half': tensor = tensor.half() 140 | return tensor.to(device) 141 | 142 | return [_prepare(a) for a in args] 143 | 144 | def terminate(self): 145 | if self.args.test_only: 146 | self.test() 147 | return True 148 | else: 149 | epoch = self.optimizer.get_last_epoch() + 1 150 | return epoch >= self.args.epochs 151 | 152 | -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from multiprocessing import Process 6 | from multiprocessing import Queue 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import numpy as np 13 | import imageio 14 | 15 | import torch 16 | import torch.optim as optim 17 | import torch.optim.lr_scheduler as lrs 18 | 19 | from loss import pytorch_ssim 20 | 21 | class timer(): 22 | def __init__(self): 23 | self.acc = 0 24 | self.tic() 25 | 26 | def tic(self): 27 | self.t0 = time.time() 28 | 29 | def toc(self, restart=False): 30 | diff = time.time() - self.t0 31 | if restart: self.t0 = time.time() 32 | return diff 33 | 34 | def hold(self): 35 | self.acc += self.toc() 36 | 37 | def release(self): 38 | ret = self.acc 39 | self.acc = 0 40 | 41 | return ret 42 | 43 | def reset(self): 44 | self.acc = 0 45 | 46 | class checkpoint(): 47 | def __init__(self, args): 48 | self.args = args 49 | self.ok = True 50 | self.log = torch.Tensor() 51 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 52 | 53 | if not args.load: 54 | if not args.save: 55 | args.save = now 56 | self.dir = os.path.join('..', 'experiment', args.save) 57 | else: 58 | self.dir = os.path.join('..', 'experiment', args.load) 59 | if os.path.exists(self.dir): 60 | self.log = torch.load(self.get_path('psnr_log.pt')) 61 | print('Continue from epoch {}...'.format(len(self.log))) 62 | else: 63 | args.load = '' 64 | 65 | if args.reset: 66 | os.system('rm -rf ' + self.dir) 67 | args.load = '' 68 | 69 | os.makedirs(self.dir, exist_ok=True) 70 | os.makedirs(self.get_path('model'), exist_ok=True) 71 | for d in args.data_test: 72 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 73 | 74 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' 75 | self.log_file = open(self.get_path('log.txt'), open_type) 76 | with open(self.get_path('config.txt'), open_type) as f: 77 | f.write(now + '\n\n') 78 | for arg in vars(args): 79 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 80 | f.write('\n') 81 | 82 | self.n_processes = 8 83 | 84 | def get_path(self, *subdir): 85 | return os.path.join(self.dir, *subdir) 86 | 87 | def save(self, trainer, epoch, is_best=False): 88 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 89 | trainer.loss.save(self.dir) 90 | trainer.loss.plot_loss(self.dir, epoch) 91 | 92 | self.plot_psnr(epoch) 93 | trainer.optimizer.save(self.dir) 94 | torch.save(self.log, self.get_path('psnr_log.pt')) 95 | 96 | def add_log(self, log): 97 | self.log = torch.cat([self.log, log]) 98 | 99 | def write_log(self, log, refresh=False): 100 | print(log) 101 | self.log_file.write(log + '\n') 102 | if refresh: 103 | self.log_file.close() 104 | self.log_file = open(self.get_path('log.txt'), 'a') 105 | 106 | def done(self): 107 | self.log_file.close() 108 | 109 | def plot_psnr(self, epoch): 110 | axis = np.linspace(1, epoch, epoch) 111 | for idx_data, d in enumerate(self.args.data_test): 112 | label = 'SR on {}'.format(d) 113 | fig = plt.figure() 114 | plt.title(label) 115 | for idx_scale, scale in enumerate(self.args.scale): 116 | plt.plot( 117 | axis, 118 | self.log[:, idx_data, idx_scale].numpy(), 119 | label='Scale {}'.format(scale) 120 | ) 121 | plt.legend() 122 | plt.xlabel('Epochs') 123 | plt.ylabel('PSNR') 124 | plt.grid(True) 125 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 126 | plt.close(fig) 127 | 128 | def begin_background(self): 129 | self.queue = Queue() 130 | 131 | def bg_target(queue): 132 | while True: 133 | if not queue.empty(): 134 | filename, tensor = queue.get() 135 | if filename is None: break 136 | imageio.imwrite(filename, tensor.numpy()) 137 | 138 | self.process = [ 139 | Process(target=bg_target, args=(self.queue,)) \ 140 | for _ in range(self.n_processes) 141 | ] 142 | 143 | for p in self.process: p.start() 144 | 145 | def end_background(self): 146 | for _ in range(self.n_processes): self.queue.put((None, None)) 147 | while not self.queue.empty(): time.sleep(1) 148 | for p in self.process: p.join() 149 | 150 | def save_results(self, dataset, filename, save_list, scale): 151 | if self.args.save_results: 152 | filename = self.get_path( 153 | 'results-{}'.format(dataset.dataset.name), 154 | '{}_x{}_'.format(filename, scale) 155 | ) 156 | 157 | postfix = ('SR', 'LR', 'HR') 158 | for v, p in zip(save_list, postfix): 159 | normalized = v[0].mul(255 / self.args.rgb_range) 160 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 161 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 162 | 163 | def quantize(img, rgb_range): 164 | pixel_range = 255 / rgb_range 165 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 166 | 167 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 168 | if hr.nelement() == 1: return 0 169 | 170 | diff = (sr - hr) / rgb_range 171 | if dataset and dataset.dataset.benchmark: 172 | shave = scale 173 | if diff.size(1) > 1: 174 | gray_coeffs = [65.738, 129.057, 25.064] 175 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 176 | diff = diff.mul(convert).sum(dim=1) 177 | else: 178 | shave = scale + 6 179 | 180 | valid = diff[..., shave:-shave, shave:-shave] 181 | mse = valid.pow(2).mean() 182 | 183 | return -10 * math.log10(mse) 184 | 185 | def make_optimizer(args, target): 186 | ''' 187 | make optimizer and scheduler together 188 | ''' 189 | # optimizer 190 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 191 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 192 | 193 | if args.optimizer == 'SGD': 194 | optimizer_class = optim.SGD 195 | kwargs_optimizer['momentum'] = args.momentum 196 | elif args.optimizer == 'ADAM': 197 | optimizer_class = optim.Adam 198 | kwargs_optimizer['betas'] = args.betas 199 | kwargs_optimizer['eps'] = args.epsilon 200 | elif args.optimizer == 'RMSprop': 201 | optimizer_class = optim.RMSprop 202 | kwargs_optimizer['eps'] = args.epsilon 203 | 204 | # scheduler 205 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 206 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 207 | scheduler_class = lrs.MultiStepLR 208 | 209 | class CustomOptimizer(optimizer_class): 210 | def __init__(self, *args, **kwargs): 211 | super(CustomOptimizer, self).__init__(*args, **kwargs) 212 | 213 | def _register_scheduler(self, scheduler_class, **kwargs): 214 | self.scheduler = scheduler_class(self, **kwargs) 215 | 216 | def save(self, save_dir): 217 | torch.save(self.state_dict(), self.get_dir(save_dir)) 218 | 219 | def load(self, load_dir, epoch=1): 220 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 221 | if epoch > 1: 222 | for _ in range(epoch): self.scheduler.step() 223 | 224 | def get_dir(self, dir_path): 225 | return os.path.join(dir_path, 'optimizer.pt') 226 | 227 | def schedule(self): 228 | self.scheduler.step() 229 | 230 | def get_lr(self): 231 | return self.scheduler.get_lr()[0] 232 | 233 | def get_last_epoch(self): 234 | return self.scheduler.last_epoch 235 | 236 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 237 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 238 | return optimizer 239 | 240 | def calc_ssim(img1, img2, rgb_range): 241 | gray_coeffs = [65.738, 129.057, 25.064] 242 | convert = img1.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 243 | nimg1 = (img1 / rgb_range).mul(convert).sum(dim=1).unsqueeze(1) 244 | nimg2 = (img2 / rgb_range).mul(convert).sum(dim=1).unsqueeze(1) 245 | #print(nimg1.shape) 246 | #nimg1, nimg2 = img1, img2 247 | ssim = pytorch_ssim.ssim(nimg1, nimg2).item() 248 | #print(ssim) 249 | return ssim --------------------------------------------------------------------------------