├── .gitignore ├── LICENSE ├── README.md └── src ├── __init__.py ├── data ├── __init__.py ├── benchmark.py ├── common.py ├── demo.py ├── df2k.py ├── div2k.py ├── div2kjpeg.py ├── sr291.py ├── srdata.py └── video.py ├── dataloader.py ├── loss ├── __init__.py ├── __loss__.py ├── adversarial.py ├── demo.sh ├── discriminator.py ├── hash.py └── vgg.py ├── main.py ├── model ├── LICENSE ├── README.md ├── __init__.py ├── attention.py ├── common.py ├── ddbpn.py ├── edsr.py ├── mdsr.py ├── mssr.py ├── nlsn.py ├── rcan.py ├── rdn.py ├── utils │ ├── __init__.py │ └── tools.py └── vdsr.py ├── option.py ├── template.py ├── test.sh ├── train.sh ├── trainer.py ├── utility.py ├── utils ├── __init__.py └── tools.py └── videotester.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore Mac system files 2 | .DS_store 3 | 4 | dataset/ 5 | 6 | experiment/ 7 | 8 | __pycache__ 9 | 10 | *.swp 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 njulj 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaDM 2 | ## [AdaDM: Enabling Normalization for Image Super-Resolution](https://arxiv.org/abs/2111.13905). 3 | You can apply BN, LN or GN in SR networks with our AdaDM. Pretrained models (EDSR\*/RDN\*/NLSN\*) can be downloaded from 4 | [Google Drive](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) or 5 | [BaiduYun](https://pan.baidu.com/s/18I3j4DJFvbNvTFHzDwsssA). The password for BaiduYun is `kymj`. 6 | 7 | :loudspeaker: If you use [BasicSR](https://github.com/xinntao/BasicSR) framework, you need to turn off the Exponential Moving Average (EMA) option when 8 | applying BN in the generator network (e.g., RRDBNet). You can disable EMA by setting `ema_decay=0` in corresponding `.yml` configuration file. 9 | 10 | | Model | Scale | File name (.pt) | Urban100 | Manga109 | 11 | | --- | --- | --- | --- | --- | 12 | |**EDSR** | 2 | | 32.93 | 39.10 | 13 | || 3 || 28.80 | 34.17 | 14 | || 4 || 26.64 | 31.02 | 15 | |**EDSR***| 2 | [EDSR_AdaDM_DIV2K_X2](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 33.12 | 39.31 | 16 | || 3 | [EDSR_AdaDM_DIV2K_X3](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 29.02 | 34.48 | 17 | || 4 | [EDSR_AdaDM_DIV2K_X4](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 26.83 | 31.24 | 18 | |**RDN** | 2 | | 32.89 | 39.18 | 19 | || 3 | | 28.80 | 34.13 | 20 | || 4 | | 26.61 | 31.00 | 21 | |**RDN***| 2 | [RDN_AdaDM_DIV2K_X2](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 33.03 | 39.18 | 22 | || 3 | [RDN_AdaDM_DIV2K_X3](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 28.95 | 34.29 | 23 | || 4 | [RDN_AdaDM_DIV2K_X4](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 26.72 | 31.18 | 24 | |**NLSN** | 2 | | 33.42 | 39.59 | 25 | || 3 | | 29.25 | 34.57 | 26 | || 4 | | 26.96 | 31.27 | 27 | |**NLSN*** | 2 | [NLSN_AdaDM_DIV2K_X2](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 33.59 | 39.67 | 28 | || 3 | [NLSN_AdaDM_DIV2K_X3](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 29.53 | 34.95 | 29 | || 4 | [NLSN_AdaDM_DIV2K_X4](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 27.24 | 31.73 | 30 | 31 | ## Preparation 32 | Please refer to [EDSR](https://github.com/thstkdgus35/EDSR-PyTorch) for instructions on dataset download and software installation, then clone our repository as follows: 33 | ```bash 34 | git clone https://github.com/njulj/AdaDM.git 35 | ``` 36 | 37 | ## Training 38 | ```bash 39 | cd AdaDM/src 40 | bash train.sh 41 | ``` 42 | Example training command in train.sh looks like: 43 | ```bash 44 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 2\ 45 | --n_GPUs 1 --batch_size 16 --patch_size 96 --rgb_range 255 --res_scale 0.1\ 46 | --save EDSR_AdaDM_Test_DIV2K_X2 --dir_data ../dataset --data_test Urban100\ 47 | --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results 48 | ``` 49 | Here, `$GPU_ID` specifies the GPU id used for training. `EDSR_AdaDM_Test_DIV2K_X2` is the directory where all files are saved during training. 50 | `--dir_data` specifies the root directory for all datasets, you should place the DIV2K and benchmark (e.g., Urban100) datasets under this directory. 51 | 52 | ## Testing 53 | ```bash 54 | cd AdaDM/src 55 | bash test.sh 56 | ``` 57 | Example testing command in test.sh looks like: 58 | ```bash 59 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale $SCALE\ 60 | --pre_train ../experiment/test/model/EDSR_AdaDM_DIV2K_X$SCALE.pt\ 61 | --dir_data ../dataset --n_GPUs 1 --test_only --data_test $TEST_DATASET 62 | ``` 63 | Here, `$GPU_ID` specifies the GPU id used for testing. `$SCALE` indicates the upscaling factor (e.g., 2, 3, 4). `--pre_train` specifies the path of 64 | saved checkpoints. `$TEST_DATASET` indicates the dataset to be tested. 65 | 66 | ## Acknowledgement 67 | This repository is built on [EDSR](https://github.com/thstkdgus35/EDSR-PyTorch) and [NLSN](https://github.com/HarukiYqM/Non-Local-Sparse-Attention). We thank the authors for sharing their codes. 68 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /src/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /src/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = scale if multi else 1 13 | tp = p * patch_size 14 | ip = tp // scale 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = scale * ix, scale * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 42 | elif n_channels == 3 and c == 1: 43 | img = np.concatenate([img] * n_channels, 2) 44 | 45 | return img 46 | 47 | return [_set_channel(a) for a in args] 48 | 49 | def np2Tensor(*args, rgb_range=255): 50 | def _np2Tensor(img): 51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 52 | tensor = torch.from_numpy(np_transpose).float() 53 | tensor.mul_(rgb_range / 255) 54 | 55 | return tensor 56 | 57 | return [_np2Tensor(a) for a in args] 58 | 59 | def augment(*args, hflip=True, rot=True): 60 | hflip = hflip and random.random() < 0.5 61 | vflip = rot and random.random() < 0.5 62 | rot90 = rot and random.random() < 0.5 63 | 64 | def _augment(img): 65 | if hflip: img = img[:, ::-1, :] 66 | if vflip: img = img[::-1, :, :] 67 | if rot90: img = img.transpose(1, 0, 2) 68 | 69 | return img 70 | 71 | return [_augment(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /src/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /src/data/df2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DF2K(srdata.SRData): 5 | def __init__(self, args, name='DF2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DF2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DF2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DF2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DF2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DF2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /src/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /src/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /src/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /src/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('img') >= 0 or benchmark: 32 | self.images_hr, self.images_lr = list_hr, list_lr 33 | elif args.ext.find('sep') >= 0: 34 | os.makedirs( 35 | self.dir_hr.replace(self.apath, path_bin), 36 | exist_ok=True 37 | ) 38 | for s in self.scale: 39 | os.makedirs( 40 | os.path.join( 41 | self.dir_lr.replace(self.apath, path_bin), 42 | 'X{}'.format(s) 43 | ), 44 | exist_ok=True 45 | ) 46 | 47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 48 | for h in list_hr: 49 | b = h.replace(self.apath, path_bin) 50 | b = b.replace(self.ext[0], '.pt') 51 | self.images_hr.append(b) 52 | self._check_and_load(args.ext, h, b, verbose=True) 53 | for i, ll in enumerate(list_lr): 54 | for l in ll: 55 | b = l.replace(self.apath, path_bin) 56 | b = b.replace(self.ext[1], '.pt') 57 | self.images_lr[i].append(b) 58 | self._check_and_load(args.ext, l, b, verbose=True) 59 | if train: 60 | n_patches = args.batch_size * args.test_every 61 | n_images = len(args.data_train) * len(self.images_hr) 62 | if n_images == 0: 63 | self.repeat = 0 64 | else: 65 | self.repeat = max(n_patches // n_images, 1) 66 | 67 | # Below functions as used to prepare images 68 | def _scan(self): 69 | names_hr = sorted( 70 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 71 | ) 72 | names_lr = [[] for _ in self.scale] 73 | for f in names_hr: 74 | filename, _ = os.path.splitext(os.path.basename(f)) 75 | if self.name == 'Manga109': 76 | filename = filename[:-6] 77 | for si, s in enumerate(self.scale): 78 | names_lr[si].append(os.path.join( 79 | self.dir_lr, 'X{}/{}x{}{}'.format( 80 | s, filename, s, self.ext[1] 81 | ) 82 | )) 83 | 84 | return names_hr, names_lr 85 | 86 | def _set_filesystem(self, dir_data): 87 | self.apath = os.path.join(dir_data, self.name) 88 | self.dir_hr = os.path.join(self.apath, 'HR') 89 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 90 | if self.input_large: self.dir_lr += 'L' 91 | self.ext = ('.png', '.png') 92 | 93 | def _check_and_load(self, ext, img, f, verbose=True): 94 | if not os.path.isfile(f) or ext.find('reset') >= 0: 95 | if verbose: 96 | print('Making a binary: {}'.format(f)) 97 | with open(f, 'wb') as _f: 98 | pickle.dump(imageio.imread(img), _f) 99 | 100 | def __getitem__(self, idx): 101 | lr, hr, filename = self._load_file(idx) 102 | pair = self.get_patch(lr, hr) 103 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 104 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 105 | 106 | return pair_t[0], pair_t[1], filename 107 | 108 | def __len__(self): 109 | if self.train: 110 | return len(self.images_hr) * self.repeat 111 | else: 112 | return len(self.images_hr) 113 | 114 | def _get_index(self, idx): 115 | if self.train: 116 | return idx % len(self.images_hr) 117 | else: 118 | return idx 119 | 120 | def _load_file(self, idx): 121 | idx = self._get_index(idx) 122 | f_hr = self.images_hr[idx] 123 | f_lr = self.images_lr[self.idx_scale][idx] 124 | 125 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 126 | if self.args.ext == 'img' or self.benchmark: 127 | hr = imageio.imread(f_hr) 128 | lr = imageio.imread(f_lr) 129 | elif self.args.ext.find('sep') >= 0: 130 | with open(f_hr, 'rb') as _f: 131 | hr = pickle.load(_f) 132 | with open(f_lr, 'rb') as _f: 133 | lr = pickle.load(_f) 134 | 135 | return lr, hr, filename 136 | 137 | def get_patch(self, lr, hr): 138 | scale = self.scale[self.idx_scale] 139 | if self.train: 140 | lr, hr = common.get_patch( 141 | lr, hr, 142 | patch_size=self.args.patch_size, 143 | scale=scale, 144 | multi=(len(self.scale) > 1), 145 | input_large=self.input_large 146 | ) 147 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 148 | else: 149 | ih, iw = lr.shape[:2] 150 | hr = hr[0:ih * scale, 0:iw * scale] 151 | 152 | return lr, hr 153 | 154 | def set_scale(self, idx_scale): 155 | if not self.input_large: 156 | self.idx_scale = idx_scale 157 | else: 158 | self.idx_scale = random.randint(0, len(self.scale) - 1) 159 | 160 | -------------------------------------------------------------------------------- /src/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import random 3 | 4 | import torch 5 | import torch.multiprocessing as multiprocessing 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SequentialSampler 8 | from torch.utils.data import RandomSampler 9 | from torch.utils.data import BatchSampler 10 | from torch.utils.data import _utils 11 | from torch.utils.data.dataloader import _DataLoaderIter 12 | 13 | from torch.utils.data._utils import collate 14 | from torch.utils.data._utils import signal_handling 15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 16 | from torch.utils.data._utils import ExceptionWrapper 17 | from torch.utils.data._utils import IS_WINDOWS 18 | from torch.utils.data._utils.worker import ManagerWatchdog 19 | 20 | from torch._six import queue 21 | 22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 23 | try: 24 | collate._use_shared_memory = True 25 | signal_handling._set_worker_signal_handlers() 26 | 27 | torch.set_num_threads(1) 28 | random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | data_queue.cancel_join_thread() 32 | 33 | if init_fn is not None: 34 | init_fn(worker_id) 35 | 36 | watchdog = ManagerWatchdog() 37 | 38 | while watchdog.is_alive(): 39 | try: 40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 41 | except queue.Empty: 42 | continue 43 | 44 | if r is None: 45 | assert done_event.is_set() 46 | return 47 | elif done_event.is_set(): 48 | continue 49 | 50 | idx, batch_indices = r 51 | try: 52 | idx_scale = 0 53 | if len(scale) > 1 and dataset.train: 54 | idx_scale = random.randrange(0, len(scale)) 55 | dataset.set_scale(idx_scale) 56 | 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | samples.append(idx_scale) 59 | except Exception: 60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 61 | else: 62 | data_queue.put((idx, samples)) 63 | del samples 64 | 65 | except KeyboardInterrupt: 66 | pass 67 | 68 | class _MSDataLoaderIter(_DataLoaderIter): 69 | 70 | def __init__(self, loader): 71 | self.dataset = loader.dataset 72 | self.scale = loader.scale 73 | self.collate_fn = loader.collate_fn 74 | self.batch_sampler = loader.batch_sampler 75 | self.num_workers = loader.num_workers 76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 77 | self.timeout = loader.timeout 78 | 79 | self.sample_iter = iter(self.batch_sampler) 80 | 81 | base_seed = torch.LongTensor(1).random_().item() 82 | 83 | if self.num_workers > 0: 84 | self.worker_init_fn = loader.worker_init_fn 85 | self.worker_queue_idx = 0 86 | self.worker_result_queue = multiprocessing.Queue() 87 | self.batches_outstanding = 0 88 | self.worker_pids_set = False 89 | self.shutdown = False 90 | self.send_idx = 0 91 | self.rcvd_idx = 0 92 | self.reorder_dict = {} 93 | self.done_event = multiprocessing.Event() 94 | 95 | base_seed = torch.LongTensor(1).random_()[0] 96 | 97 | self.index_queues = [] 98 | self.workers = [] 99 | for i in range(self.num_workers): 100 | index_queue = multiprocessing.Queue() 101 | index_queue.cancel_join_thread() 102 | w = multiprocessing.Process( 103 | target=_ms_loop, 104 | args=( 105 | self.dataset, 106 | index_queue, 107 | self.worker_result_queue, 108 | self.done_event, 109 | self.collate_fn, 110 | self.scale, 111 | base_seed + i, 112 | self.worker_init_fn, 113 | i 114 | ) 115 | ) 116 | w.daemon = True 117 | w.start() 118 | self.index_queues.append(index_queue) 119 | self.workers.append(w) 120 | 121 | if self.pin_memory: 122 | self.data_queue = queue.Queue() 123 | pin_memory_thread = threading.Thread( 124 | target=_utils.pin_memory._pin_memory_loop, 125 | args=( 126 | self.worker_result_queue, 127 | self.data_queue, 128 | torch.cuda.current_device(), 129 | self.done_event 130 | ) 131 | ) 132 | pin_memory_thread.daemon = True 133 | pin_memory_thread.start() 134 | self.pin_memory_thread = pin_memory_thread 135 | else: 136 | self.data_queue = self.worker_result_queue 137 | 138 | _utils.signal_handling._set_worker_pids( 139 | id(self), tuple(w.pid for w in self.workers) 140 | ) 141 | _utils.signal_handling._set_SIGCHLD_handler() 142 | self.worker_pids_set = True 143 | 144 | for _ in range(2 * self.num_workers): 145 | self._put_indices() 146 | 147 | 148 | class MSDataLoader(DataLoader): 149 | 150 | def __init__(self, cfg, *args, **kwargs): 151 | super(MSDataLoader, self).__init__( 152 | *args, **kwargs, num_workers=cfg.n_threads 153 | ) 154 | self.scale = cfg.scale 155 | 156 | def __iter__(self): 157 | return _MSDataLoaderIter(self) 158 | 159 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class Loss(nn.modules.loss._Loss): 15 | def __init__(self, args, ckp): 16 | super(Loss, self).__init__() 17 | print('Preparing loss function:') 18 | 19 | self.n_GPUs = args.n_GPUs 20 | self.loss = [] 21 | self.loss_module = nn.ModuleList() 22 | for loss in args.loss.split('+'): 23 | weight, loss_type = loss.split('*') 24 | if loss_type == 'MSE': 25 | loss_function = nn.MSELoss() 26 | elif loss_type == 'L1': 27 | loss_function = nn.L1Loss() 28 | elif loss_type.find('VGG') >= 0: 29 | module = import_module('loss.vgg') 30 | loss_function = getattr(module, 'VGG')( 31 | loss_type[3:], 32 | rgb_range=args.rgb_range 33 | ) 34 | elif loss_type.find('GAN') >= 0: 35 | module = import_module('loss.adversarial') 36 | loss_function = getattr(module, 'Adversarial')( 37 | args, 38 | loss_type 39 | ) 40 | 41 | self.loss.append({ 42 | 'type': loss_type, 43 | 'weight': float(weight), 44 | 'function': loss_function} 45 | ) 46 | if loss_type.find('GAN') >= 0: 47 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 48 | 49 | if len(self.loss) > 1: 50 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 51 | 52 | for l in self.loss: 53 | if l['function'] is not None: 54 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 55 | self.loss_module.append(l['function']) 56 | 57 | self.log = torch.Tensor() 58 | 59 | device = torch.device('cpu' if args.cpu else 'cuda') 60 | self.loss_module.to(device) 61 | if args.precision == 'half': self.loss_module.half() 62 | if not args.cpu and args.n_GPUs > 1: 63 | self.loss_module = nn.DataParallel(self.loss_module,range(args.n_GPUs)) 64 | 65 | if args.load != '': self.load(ckp.dir, cpu=args.cpu) 66 | 67 | def forward(self, sr, hr): 68 | losses = [] 69 | for i, l in enumerate(self.loss): 70 | if l['function'] is not None: 71 | loss = l['function'](sr, hr) 72 | effective_loss = l['weight'] * loss 73 | losses.append(effective_loss) 74 | self.log[-1, i] += effective_loss.item() 75 | elif l['type'] == 'DIS': 76 | self.log[-1, i] += self.loss[i - 1]['function'].loss 77 | 78 | loss_sum = sum(losses) 79 | if len(self.loss) > 1: 80 | self.log[-1, -1] += loss_sum.item() 81 | 82 | return loss_sum 83 | 84 | def step(self): 85 | for l in self.get_loss_module(): 86 | if hasattr(l, 'scheduler'): 87 | l.scheduler.step() 88 | 89 | def start_log(self): 90 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 91 | 92 | def end_log(self, n_batches): 93 | self.log[-1].div_(n_batches) 94 | 95 | def display_loss(self, batch): 96 | n_samples = batch + 1 97 | log = [] 98 | for l, c in zip(self.loss, self.log[-1]): 99 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 100 | 101 | return ''.join(log) 102 | 103 | def plot_loss(self, apath, epoch): 104 | axis = np.linspace(1, epoch, epoch) 105 | for i, l in enumerate(self.loss): 106 | label = '{} Loss'.format(l['type']) 107 | fig = plt.figure() 108 | plt.title(label) 109 | plt.plot(axis, self.log[:, i].numpy(), label=label) 110 | plt.legend() 111 | plt.xlabel('Epochs') 112 | plt.ylabel('Loss') 113 | plt.grid(True) 114 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 115 | plt.close(fig) 116 | 117 | def get_loss_module(self): 118 | if self.n_GPUs == 1: 119 | return self.loss_module 120 | else: 121 | return self.loss_module.module 122 | 123 | def save(self, apath): 124 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 125 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 126 | 127 | def load(self, apath, cpu=False): 128 | if cpu: 129 | kwargs = {'map_location': lambda storage, loc: storage} 130 | else: 131 | kwargs = {} 132 | 133 | self.load_state_dict(torch.load( 134 | os.path.join(apath, 'loss.pt'), 135 | **kwargs 136 | )) 137 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 138 | for l in self.get_loss_module(): 139 | if hasattr(l, 'scheduler'): 140 | for _ in range(len(self.log)): l.scheduler.step() 141 | 142 | -------------------------------------------------------------------------------- /src/loss/__loss__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/loss/__loss__.py -------------------------------------------------------------------------------- /src/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from types import SimpleNamespace 3 | 4 | from model import common 5 | from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | class Adversarial(nn.Module): 13 | def __init__(self, args, gan_type): 14 | super(Adversarial, self).__init__() 15 | self.gan_type = gan_type 16 | self.gan_k = args.gan_k 17 | self.dis = discriminator.Discriminator(args) 18 | if gan_type == 'WGAN_GP': 19 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4 20 | optim_dict = { 21 | 'optimizer': 'ADAM', 22 | 'betas': (0, 0.9), 23 | 'epsilon': 1e-8, 24 | 'lr': 1e-5, 25 | 'weight_decay': args.weight_decay, 26 | 'decay': args.decay, 27 | 'gamma': args.gamma 28 | } 29 | optim_args = SimpleNamespace(**optim_dict) 30 | else: 31 | optim_args = args 32 | 33 | self.optimizer = utility.make_optimizer(optim_args, self.dis) 34 | 35 | def forward(self, fake, real): 36 | # updating discriminator... 37 | self.loss = 0 38 | fake_detach = fake.detach() # do not backpropagate through G 39 | for _ in range(self.gan_k): 40 | self.optimizer.zero_grad() 41 | # d: B x 1 tensor 42 | d_fake = self.dis(fake_detach) 43 | d_real = self.dis(real) 44 | retain_graph = False 45 | if self.gan_type == 'GAN': 46 | loss_d = self.bce(d_real, d_fake) 47 | elif self.gan_type.find('WGAN') >= 0: 48 | loss_d = (d_fake - d_real).mean() 49 | if self.gan_type.find('GP') >= 0: 50 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 51 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 52 | hat.requires_grad = True 53 | d_hat = self.dis(hat) 54 | gradients = torch.autograd.grad( 55 | outputs=d_hat.sum(), inputs=hat, 56 | retain_graph=True, create_graph=True, only_inputs=True 57 | )[0] 58 | gradients = gradients.view(gradients.size(0), -1) 59 | gradient_norm = gradients.norm(2, dim=1) 60 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 61 | loss_d += gradient_penalty 62 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 63 | elif self.gan_type == 'RGAN': 64 | better_real = d_real - d_fake.mean(dim=0, keepdim=True) 65 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True) 66 | loss_d = self.bce(better_real, better_fake) 67 | retain_graph = True 68 | 69 | # Discriminator update 70 | self.loss += loss_d.item() 71 | loss_d.backward(retain_graph=retain_graph) 72 | self.optimizer.step() 73 | 74 | if self.gan_type == 'WGAN': 75 | for p in self.dis.parameters(): 76 | p.data.clamp_(-1, 1) 77 | 78 | self.loss /= self.gan_k 79 | 80 | # updating generator... 81 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is 82 | if self.gan_type == 'GAN': 83 | label_real = torch.ones_like(d_fake_bp) 84 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) 85 | elif self.gan_type.find('WGAN') >= 0: 86 | loss_g = -d_fake_bp.mean() 87 | elif self.gan_type == 'RGAN': 88 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) 89 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) 90 | loss_g = self.bce(better_fake, better_real) 91 | 92 | # Generator loss 93 | return loss_g 94 | 95 | def state_dict(self, *args, **kwargs): 96 | state_discriminator = self.dis.state_dict(*args, **kwargs) 97 | state_optimizer = self.optimizer.state_dict() 98 | 99 | return dict(**state_discriminator, **state_optimizer) 100 | 101 | def bce(self, real, fake): 102 | label_real = torch.ones_like(real) 103 | label_fake = torch.zeros_like(fake) 104 | bce_real = F.binary_cross_entropy_with_logits(real, label_real) 105 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) 106 | bce_loss = bce_real + bce_fake 107 | return bce_loss 108 | 109 | # Some references 110 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 111 | # OR 112 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 113 | -------------------------------------------------------------------------------- /src/loss/demo.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/loss/demo.sh -------------------------------------------------------------------------------- /src/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /src/loss/hash.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class HASH(nn.Module): 9 | def __init__(self): 10 | super(HASH, self).__init__() 11 | self.l1 = nn.L1Loss() 12 | def forward(self, sr, qk, orders, hr, m=3): 13 | #hash loss 14 | qk = F.normalize(qk, p=2, dim=1, eps=5e-5) 15 | N,C,H,W = qk.shape 16 | qk = qk.view(N,C,H*W) 17 | qk_t = qk.permute(0,2,1).contiguous() 18 | similarity_map = F.relu(torch.matmul(qk_t, qk),inplace=True) #[N,H*W,H*W] 19 | 20 | orders = orders.unsqueeze(2).expand_as(similarity_map)#[N,H*W,H*W] 21 | orders_t = torch.transpose(orders,1,2) 22 | dist = torch.pow(orders-orders_t,2) 23 | 24 | ls = torch.mean(similarity_map*torch.log(torch.exp(dist+m)+1)) 25 | ld = torch.mean((1-similarity_map)*torch.log(torch.exp(-dist+m)+1)) 26 | loss = 0.005*(ls+ld)+self.l1(sr,hr) 27 | 28 | return loss 29 | -------------------------------------------------------------------------------- /src/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | def main(): 14 | global model 15 | if args.data_test == ['video']: 16 | from videotester import VideoTester 17 | model = model.Model(args, checkpoint) 18 | print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 19 | t = VideoTester(args, model, checkpoint) 20 | t.test() 21 | else: 22 | if checkpoint.ok: 23 | loader = data.Data(args) 24 | _model = model.Model(args, checkpoint) 25 | print('Total params: %.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0)) 26 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 27 | t = Trainer(args, loader, _model, _loss, checkpoint) 28 | while not t.terminate(): 29 | t.train() 30 | t.test() 31 | 32 | checkpoint.done() 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /src/model/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/model/README.md: -------------------------------------------------------------------------------- 1 | # EDSR-PyTorch 2 | ![](/figs/main.png) 3 | 4 | This repository is an official PyTorch implementation of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017, 2nd NTIRE**. 5 | You can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017). 6 | 7 | If you find our work useful in your research or publication, please cite our work: 8 | 9 | [1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** 2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)] 10 | ``` 11 | @InProceedings{Lim_2017_CVPR_Workshops, 12 | author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu}, 13 | title = {Enhanced Deep Residual Networks for Single Image Super-Resolution}, 14 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 15 | month = {July}, 16 | year = {2017} 17 | } 18 | ``` 19 | We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images. 20 | 21 | **Differences between Torch version** 22 | * Codes are much more compact. (Removed all unnecessary parts.) 23 | * Models are smaller. (About half.) 24 | * Slightly better performances. 25 | * Training and evaluation requires less memory. 26 | * Python-based. 27 | 28 | ## Dependencies 29 | * Python 3.6 30 | * PyTorch >= 0.4.0 31 | * numpy 32 | * skimage 33 | * **imageio** 34 | * matplotlib 35 | * tqdm 36 | 37 | **Recent updates** 38 | 39 | * July 22, 2018 40 | * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models. 41 | * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). 42 | 43 | 44 | ## Code 45 | Clone this repository into any place you want. 46 | ```bash 47 | git clone https://github.com/thstkdgus35/EDSR-PyTorch 48 | cd EDSR-PyTorch 49 | ``` 50 | 51 | ## Quick start (Demo) 52 | You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/``) We support **png** and **jpeg** files. 53 | 54 | Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute. 55 | ```bash 56 | cd src # You are now in */EDSR-PyTorch/src 57 | sh demo.sh 58 | ``` 59 | 60 | You can find the result images from ```experiment/test/results``` folder. 61 | 62 | | Model | Scale | File name (.pt) | Parameters | ****PSNR** | 63 | | --- | --- | --- | --- | --- | 64 | | **EDSR** | 2 | EDSR_baseline_x2 | 1.37 M | 34.61 dB | 65 | | | | *EDSR_x2 | 40.7 M | 35.03 dB | 66 | | | 3 | EDSR_baseline_x3 | 1.55 M | 30.92 dB | 67 | | | | *EDSR_x3 | 43.7 M | 31.26 dB | 68 | | | 4 | EDSR_baseline_x4 | 1.52 M | 28.95 dB | 69 | | | | *EDSR_x4 | 43.1 M | 29.25 dB | 70 | | **MDSR** | 2 | MDSR_baseline | 3.23 M | 34.63 dB | 71 | | | | *MDSR | 7.95 M| 34.92 dB | 72 | | | 3 | MDSR_baseline | | 30.94 dB | 73 | | | | *MDSR | | 31.22 dB | 74 | | | 4 | MDSR_baseline | | 28.97 dB | 75 | | | | *MDSR | | 29.24 dB | 76 | 77 | *Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (542MB) 78 | **We measured PSNR using DIV2K 0801 ~ 0900, RGB channels, without self-ensemble. (scale + 2) pixels from the image boundary are ignored. 79 | 80 | You can evaluate your models with widely-used benchmark datasets: 81 | 82 | [Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html), 83 | 84 | [Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests), 85 | 86 | [B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), 87 | 88 | [Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr). 89 | 90 | For these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. You can download [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB). Set ``--dir_data `` to evaluate the EDSR and MDSR with the benchmarks. 91 | 92 | ## How to train EDSR and MDSR 93 | We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB). 94 | 95 | Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```src/option.py``` to the place where DIV2K images are located. 96 | 97 | We recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument. 98 | 99 | If you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file. 100 | 101 | You can train EDSR and MDSR by yourself. All scripts are provided in the ``src/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train ``` argument. 102 | 103 | ```bash 104 | cd src # You are now in */EDSR-PyTorch/src 105 | sh demo.sh 106 | ``` 107 | 108 | **Update log** 109 | * Jan 04, 2018 110 | * Many parts are re-written. You cannot use previous scripts and models directly. 111 | * Pre-trained MDSR is temporarily disabled. 112 | * Training details are included. 113 | 114 | * Jan 09, 2018 115 | * Missing files are included (```src/data/MyImage.py```). 116 | * Some links are fixed. 117 | 118 | * Jan 16, 2018 119 | * Memory efficient forward function is implemented. 120 | * Add --chop_forward argument to your script to enable it. 121 | * Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.) 122 | 123 | * Feb 21, 2018 124 | * Fixed the problem when loading pre-trained multi-gpu model. 125 | * Added pre-trained scale 2 baseline model. 126 | * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models. 127 | * PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch. 128 | 129 | * Feb 23, 2018 130 | * Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version. 131 | 132 | * With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution. 133 | * New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.) 134 | * With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.) 135 | * If you cannot make the binary pack, just use the default setting (``--ext img``). 136 | 137 | * Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match. 138 | * Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.) 139 | * Added performance comparison between Torch7 model and PyTorch models. 140 | 141 | * Mar 5, 2018 142 | * All baseline models are uploaded. 143 | * Now supports half-precision at test time. Use ``--precision half`` to enable it. This does not degrade the output images. 144 | 145 | * Mar 11, 2018 146 | * Fixed some typos in the code and script. 147 | * Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only. 148 | * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected. 149 | 150 | * Mar 20, 2018 151 | * Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time. 152 | * Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images. 153 | * Changed the behavior of skip_batch. 154 | 155 | * Mar 29, 2018 156 | * We now provide all models from our paper. 157 | * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble. 158 | * ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before. 159 | * Some codes and script are re-written. 160 | 161 | * Apr 9, 2018 162 | * VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet. 163 | * Many codes are refactored. If there exists a bug, please report it. 164 | * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L. 165 | 166 | * Apr 26, 2018 167 | * Compatible with PyTorch 0.4.0 168 | * Please use the legacy/0.3.1 branch if you are using the old version of PyTorch. 169 | * Minor bug fixes 170 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | 23 | module = import_module('model.' + args.model.lower()) 24 | self.model = module.make_model(args).to(self.device) 25 | if args.precision == 'half': self.model.half() 26 | 27 | if not args.cpu and args.n_GPUs > 1: 28 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 29 | 30 | self.load( 31 | ckp.dir, 32 | pre_train=args.pre_train, 33 | resume=args.resume, 34 | cpu=args.cpu 35 | ) 36 | print(self.model, file=ckp.log_file) 37 | 38 | def forward(self, x, idx_scale): 39 | self.idx_scale = idx_scale 40 | target = self.get_model() 41 | if hasattr(target, 'set_scale'): 42 | target.set_scale(idx_scale) 43 | 44 | if self.self_ensemble and not self.training: 45 | if self.chop: 46 | forward_function = self.forward_chop 47 | else: 48 | forward_function = self.model.forward 49 | 50 | return self.forward_x8(x, forward_function) 51 | elif self.chop and not self.training: 52 | return self.forward_chop(x) 53 | else: 54 | return self.model(x) 55 | 56 | def get_model(self): 57 | if self.n_GPUs == 1: 58 | return self.model 59 | else: 60 | return self.model.module 61 | 62 | def state_dict(self, **kwargs): 63 | target = self.get_model() 64 | return target.state_dict(**kwargs) 65 | 66 | def save(self, apath, epoch, is_best=False): 67 | target = self.get_model() 68 | torch.save( 69 | target.state_dict(), 70 | os.path.join(apath, 'model_latest.pt') 71 | ) 72 | if is_best: 73 | torch.save( 74 | target.state_dict(), 75 | os.path.join(apath, 'model_best.pt') 76 | ) 77 | 78 | #if self.save_models: 79 | if self.save_models and epoch % 50 == 0: 80 | torch.save( 81 | target.state_dict(), 82 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 83 | ) 84 | 85 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 86 | if cpu: 87 | kwargs = {'map_location': lambda storage, loc: storage} 88 | else: 89 | kwargs = {} 90 | 91 | if resume == -1: 92 | self.get_model().load_state_dict( 93 | torch.load( 94 | os.path.join(apath, 'model_latest.pt'), 95 | **kwargs 96 | ), 97 | strict=False 98 | ) 99 | elif resume == 0: 100 | if pre_train != '.': 101 | print('Loading model from {}'.format(pre_train)) 102 | self.get_model().load_state_dict( 103 | torch.load(pre_train, **kwargs), 104 | strict=False 105 | ) 106 | else: 107 | self.get_model().load_state_dict( 108 | torch.load( 109 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 110 | **kwargs 111 | ), 112 | strict=False 113 | ) 114 | 115 | def forward_chop(self, x, shave=10, min_size=120000): 116 | scale = self.scale[self.idx_scale] 117 | n_GPUs = min(self.n_GPUs, 4) 118 | b, c, h, w = x.size() 119 | h_half, w_half = h // 2, w // 2 120 | h_size, w_size = h_half + shave, w_half + shave 121 | h_size +=4-h_size%4 122 | w_size +=8-w_size%8 123 | 124 | lr_list = [ 125 | x[:, :, 0:h_size, 0:w_size], 126 | x[:, :, 0:h_size, (w - w_size):w], 127 | x[:, :, (h - h_size):h, 0:w_size], 128 | x[:, :, (h - h_size):h, (w - w_size):w]] 129 | 130 | if w_size * h_size < min_size: 131 | sr_list = [] 132 | for i in range(0, 4, n_GPUs): 133 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 134 | sr_batch = self.model(lr_batch) 135 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 136 | else: 137 | sr_list = [ 138 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 139 | for patch in lr_list 140 | ] 141 | 142 | h, w = scale * h, scale * w 143 | h_half, w_half = scale * h_half, scale * w_half 144 | h_size, w_size = scale * h_size, scale * w_size 145 | shave *= scale 146 | 147 | output = x.new(b, c, h, w) 148 | output[:, :, 0:h_half, 0:w_half] \ 149 | = sr_list[0][:, :, 0:h_half, 0:w_half] 150 | output[:, :, 0:h_half, w_half:w] \ 151 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 152 | output[:, :, h_half:h, 0:w_half] \ 153 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 154 | output[:, :, h_half:h, w_half:w] \ 155 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 156 | 157 | return output 158 | 159 | def forward_x8(self, x, forward_function): 160 | def _transform(v, op): 161 | if self.precision != 'single': v = v.float() 162 | 163 | v2np = v.data.cpu().numpy() 164 | if op == 'v': 165 | tfnp = v2np[:, :, :, ::-1].copy() 166 | elif op == 'h': 167 | tfnp = v2np[:, :, ::-1, :].copy() 168 | elif op == 't': 169 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 170 | 171 | ret = torch.Tensor(tfnp).to(self.device) 172 | if self.precision == 'half': ret = ret.half() 173 | 174 | return ret 175 | 176 | lr_list = [x] 177 | for tf in 'v', 'h', 't': 178 | lr_list.extend([_transform(t, tf) for t in lr_list]) 179 | 180 | sr_list = [forward_function(aug) for aug in lr_list] 181 | for i in range(len(sr_list)): 182 | if i > 3: 183 | sr_list[i] = _transform(sr_list[i], 't') 184 | if i % 4 > 1: 185 | sr_list[i] = _transform(sr_list[i], 'h') 186 | if (i % 4) % 2 == 1: 187 | sr_list[i] = _transform(sr_list[i], 'v') 188 | 189 | output_cat = torch.cat(sr_list, dim=0) 190 | output = output_cat.mean(dim=0, keepdim=True) 191 | 192 | return output 193 | 194 | -------------------------------------------------------------------------------- /src/model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model import common 5 | 6 | class NonLocalSparseAttention(nn.Module): 7 | def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1): 8 | super(NonLocalSparseAttention,self).__init__() 9 | self.chunk_size = chunk_size 10 | self.n_hashes = n_hashes 11 | self.reduction = reduction 12 | self.res_scale = res_scale 13 | self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None) 14 | self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None) 15 | 16 | def LSH(self, hash_buckets, x): 17 | #x: [N,H*W,C] 18 | N = x.shape[0] 19 | device = x.device 20 | 21 | #generate random rotation matrix 22 | rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2] 23 | random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2] 24 | 25 | #locality sensitive hashing 26 | rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2] 27 | rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets] 28 | 29 | #get hash codes 30 | hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N,n_hashes,H*W] 31 | 32 | #add offsets to avoid hash codes overlapping between hash rounds 33 | offsets = torch.arange(self.n_hashes, device=device) 34 | offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1)) 35 | hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes*H*W] 36 | 37 | return hash_codes 38 | 39 | def add_adjacent_buckets(self, x): 40 | x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2) 41 | x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2) 42 | return torch.cat([x, x_extra_back,x_extra_forward], dim=3) 43 | 44 | def forward(self, input): 45 | 46 | N,_,H,W = input.shape 47 | x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1) 48 | y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1) 49 | L,C = x_embed.shape[-2:] 50 | 51 | #number of hash buckets/hash bits 52 | hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128) 53 | 54 | #get assigned hash codes/bucket number 55 | hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W] 56 | hash_codes = hash_codes.detach() 57 | 58 | #group elements with same hash code by sorting 59 | _, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W] 60 | _, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order 61 | mod_indices = (indices % L) #now range from (0->H*W) 62 | x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C] 63 | y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C] 64 | 65 | #pad the embedding if it cannot be divided by chunk_size 66 | padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0 67 | x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C] 68 | y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction)) 69 | if padding: 70 | pad_x = x_att_buckets[:,:,-padding:,:].clone() 71 | pad_y = y_att_buckets[:,:,-padding:,:].clone() 72 | x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2) 73 | y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2) 74 | 75 | x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C] 76 | y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction)) 77 | 78 | x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5) 79 | 80 | #allow attend to adjacent buckets 81 | x_match = self.add_adjacent_buckets(x_match) 82 | y_att_buckets = self.add_adjacent_buckets(y_att_buckets) 83 | 84 | #unormalized attention score 85 | raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3] 86 | 87 | #softmax 88 | bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True) 89 | score = torch.exp(raw_score - bucket_score) #(after softmax) 90 | bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1]) 91 | 92 | #attention 93 | ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C] 94 | ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction)) 95 | 96 | #if padded, then remove extra elements 97 | if padding: 98 | ret = ret[:,:,:-padding,:].clone() 99 | bucket_score = bucket_score[:,:,:-padding].clone() 100 | 101 | #recover the original order 102 | ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C] 103 | bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W] 104 | ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C] 105 | bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W] 106 | 107 | #weighted sum multi-round attention 108 | ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C] 109 | bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1)) 110 | probs = nn.functional.softmax(bucket_score,dim=1) 111 | ret = torch.sum(ret * probs, dim=1) 112 | 113 | ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input 114 | return ret 115 | 116 | 117 | class NonLocalAttention(nn.Module): 118 | def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=1, softmax_scale=10, average=True, res_scale=1,conv=common.default_conv): 119 | super(NonLocalAttention, self).__init__() 120 | self.res_scale = res_scale 121 | self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU()) 122 | self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU()) 123 | self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU()) 124 | 125 | def forward(self, input): 126 | x_embed_1 = self.conv_match1(input) 127 | x_embed_2 = self.conv_match2(input) 128 | x_assembly = self.conv_assembly(input) 129 | 130 | N,C,H,W = x_embed_1.shape 131 | x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C)) 132 | x_embed_2 = x_embed_2.view(N,C,H*W) 133 | score = torch.matmul(x_embed_1, x_embed_2) 134 | score = F.softmax(score, dim=2) 135 | x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1) 136 | x_final = torch.matmul(score, x_assembly) 137 | return x_final.permute(0,2,1).view(N,-1,H,W)+self.res_scale*input 138 | -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def batched_index_select(values, indices): 9 | last_dim = values.shape[-1] 10 | return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim)) 11 | 12 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): 13 | return nn.Conv2d( 14 | in_channels, out_channels, kernel_size, 15 | padding=(kernel_size//2),stride=stride, bias=bias) 16 | 17 | class MeanShift(nn.Conv2d): 18 | def __init__( 19 | self, rgb_range, 20 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 21 | 22 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 23 | std = torch.Tensor(rgb_std) 24 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 25 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 26 | for p in self.parameters(): 27 | p.requires_grad = False 28 | 29 | class BasicBlock(nn.Sequential): 30 | def __init__( 31 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, 32 | bn=False, act=nn.PReLU()): 33 | 34 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 35 | if bn: 36 | m.append(nn.BatchNorm2d(out_channels)) 37 | if act is not None: 38 | m.append(act) 39 | 40 | super(BasicBlock, self).__init__(*m) 41 | 42 | class ResBlock(nn.Module): 43 | def __init__( 44 | self, conv, n_feats, kernel_size, 45 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 46 | 47 | super(ResBlock, self).__init__() 48 | m = [] 49 | for i in range(2): 50 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 51 | if bn: 52 | m.append(nn.BatchNorm2d(n_feats)) 53 | if i == 0: 54 | m.append(act) 55 | 56 | self.body = nn.Sequential(*m) 57 | self.res_scale = res_scale 58 | 59 | def forward(self, x): 60 | res = self.body(x).mul(self.res_scale) 61 | res += x 62 | 63 | return res 64 | 65 | class ResBlock_AdaDM(nn.Module): 66 | def __init__( 67 | self, conv, n_feats, kernel_size, 68 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 69 | 70 | super(ResBlock_AdaDM, self).__init__() 71 | self.conv1 = conv(n_feats, n_feats, kernel_size, bias=bias) 72 | self.conv2 = conv(n_feats, n_feats, kernel_size, bias=bias) 73 | self.act = act 74 | self.res_scale = res_scale 75 | self.phi = nn.Conv2d(1, 1, 1, 1, 0, bias=True) 76 | self.phi.weight.data.fill_(1) 77 | self.phi.bias.data.fill_(0) 78 | self.norm1 = nn.BatchNorm2d(n_feats) 79 | self.norm2 = nn.BatchNorm2d(n_feats) 80 | 81 | def forward(self, x): 82 | s = torch.std(x, dim=[1,2,3], keepdim=True) 83 | x_n = self.norm1(x) 84 | res = self.conv1(x_n) 85 | res = self.act(res) 86 | res = self.norm2(res) 87 | res = self.conv2(res).mul(self.res_scale) 88 | res = res * torch.exp(self.phi(torch.log(s))) 89 | 90 | res += x 91 | 92 | return res 93 | 94 | class Upsampler(nn.Sequential): 95 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 96 | 97 | m = [] 98 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 99 | for _ in range(int(math.log(scale, 2))): 100 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 101 | m.append(nn.PixelShuffle(2)) 102 | if bn: 103 | m.append(nn.BatchNorm2d(n_feats)) 104 | if act == 'relu': 105 | m.append(nn.ReLU(True)) 106 | elif act == 'prelu': 107 | m.append(nn.PReLU(n_feats)) 108 | 109 | elif scale == 3: 110 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 111 | m.append(nn.PixelShuffle(3)) 112 | if bn: 113 | m.append(nn.BatchNorm2d(n_feats)) 114 | if act == 'relu': 115 | m.append(nn.ReLU(True)) 116 | elif act == 'prelu': 117 | m.append(nn.PReLU(n_feats)) 118 | else: 119 | raise NotImplementedError 120 | 121 | super(Upsampler, self).__init__(*m) 122 | 123 | -------------------------------------------------------------------------------- /src/model/ddbpn.py: -------------------------------------------------------------------------------- 1 | # Deep Back-Projection Networks For Super-Resolution 2 | # https://arxiv.org/abs/1803.02735 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return DDBPN(args) 12 | 13 | def projection_conv(in_channels, out_channels, scale, up=True): 14 | kernel_size, stride, padding = { 15 | 2: (6, 2, 2), 16 | 4: (8, 4, 2), 17 | 8: (12, 8, 2) 18 | }[scale] 19 | if up: 20 | conv_f = nn.ConvTranspose2d 21 | else: 22 | conv_f = nn.Conv2d 23 | 24 | return conv_f( 25 | in_channels, out_channels, kernel_size, 26 | stride=stride, padding=padding 27 | ) 28 | 29 | class DenseProjection(nn.Module): 30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 31 | super(DenseProjection, self).__init__() 32 | if bottleneck: 33 | self.bottleneck = nn.Sequential(*[ 34 | nn.Conv2d(in_channels, nr, 1), 35 | nn.PReLU(nr) 36 | ]) 37 | inter_channels = nr 38 | else: 39 | self.bottleneck = None 40 | inter_channels = in_channels 41 | 42 | self.conv_1 = nn.Sequential(*[ 43 | projection_conv(inter_channels, nr, scale, up), 44 | nn.PReLU(nr) 45 | ]) 46 | self.conv_2 = nn.Sequential(*[ 47 | projection_conv(nr, inter_channels, scale, not up), 48 | nn.PReLU(inter_channels) 49 | ]) 50 | self.conv_3 = nn.Sequential(*[ 51 | projection_conv(inter_channels, nr, scale, up), 52 | nn.PReLU(nr) 53 | ]) 54 | 55 | def forward(self, x): 56 | if self.bottleneck is not None: 57 | x = self.bottleneck(x) 58 | 59 | a_0 = self.conv_1(x) 60 | b_0 = self.conv_2(a_0) 61 | e = b_0.sub(x) 62 | a_1 = self.conv_3(e) 63 | 64 | out = a_0.add(a_1) 65 | 66 | return out 67 | 68 | class DDBPN(nn.Module): 69 | def __init__(self, args): 70 | super(DDBPN, self).__init__() 71 | scale = args.scale[0] 72 | 73 | n0 = 128 74 | nr = 32 75 | self.depth = 6 76 | 77 | rgb_mean = (0.4488, 0.4371, 0.4040) 78 | rgb_std = (1.0, 1.0, 1.0) 79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 80 | initial = [ 81 | nn.Conv2d(args.n_colors, n0, 3, padding=1), 82 | nn.PReLU(n0), 83 | nn.Conv2d(n0, nr, 1), 84 | nn.PReLU(nr) 85 | ] 86 | self.initial = nn.Sequential(*initial) 87 | 88 | self.upmodules = nn.ModuleList() 89 | self.downmodules = nn.ModuleList() 90 | channels = nr 91 | for i in range(self.depth): 92 | self.upmodules.append( 93 | DenseProjection(channels, nr, scale, True, i > 1) 94 | ) 95 | if i != 0: 96 | channels += nr 97 | 98 | channels = nr 99 | for i in range(self.depth - 1): 100 | self.downmodules.append( 101 | DenseProjection(channels, nr, scale, False, i != 0) 102 | ) 103 | channels += nr 104 | 105 | reconstruction = [ 106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 107 | ] 108 | self.reconstruction = nn.Sequential(*reconstruction) 109 | 110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 111 | 112 | def forward(self, x): 113 | x = self.sub_mean(x) 114 | x = self.initial(x) 115 | 116 | h_list = [] 117 | l_list = [] 118 | for i in range(self.depth - 1): 119 | if i == 0: 120 | l = x 121 | else: 122 | l = torch.cat(l_list, dim=1) 123 | h_list.append(self.upmodules[i](l)) 124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 125 | 126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 127 | out = self.reconstruction(torch.cat(h_list, dim=1)) 128 | out = self.add_mean(out) 129 | 130 | return out 131 | 132 | -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return EDSR(args) 7 | 8 | class EDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(EDSR, self).__init__() 11 | 12 | n_resblocks = args.n_resblocks 13 | n_feats = args.n_feats 14 | kernel_size = 3 15 | scale = args.scale[0] 16 | act = nn.ReLU(True) 17 | 18 | # define head module 19 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 20 | 21 | # define body module 22 | m_body = [ 23 | common.ResBlock_AdaDM( 24 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 25 | ) for _ in range(n_resblocks) 26 | ] 27 | m_body.append(conv(n_feats, n_feats, kernel_size)) 28 | 29 | # define tail module 30 | m_tail = [ 31 | common.Upsampler(conv, scale, n_feats, act=False), 32 | conv(n_feats, args.n_colors, kernel_size) 33 | ] 34 | 35 | self.head = nn.Sequential(*m_head) 36 | self.body = nn.Sequential(*m_body) 37 | self.tail = nn.Sequential(*m_tail) 38 | 39 | def forward(self, x): 40 | x = self.head(x) 41 | 42 | res = self.body(x) 43 | res += x 44 | 45 | x = self.tail(res) 46 | 47 | return x 48 | 49 | def load_state_dict(self, state_dict, strict=True): 50 | own_state = self.state_dict() 51 | for name, param in state_dict.items(): 52 | if name in own_state: 53 | if isinstance(param, nn.Parameter): 54 | param = param.data 55 | try: 56 | own_state[name].copy_(param) 57 | except Exception: 58 | if name.find('tail') == -1: 59 | raise RuntimeError('While copying the parameter named {}, ' 60 | 'whose dimensions in the model are {} and ' 61 | 'whose dimensions in the checkpoint are {}.' 62 | .format(name, own_state[name].size(), param.size())) 63 | elif strict: 64 | if name.find('tail') == -1: 65 | raise KeyError('unexpected key "{}" in state_dict' 66 | .format(name)) 67 | -------------------------------------------------------------------------------- /src/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return MDSR(args) 7 | 8 | class MDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(MDSR, self).__init__() 11 | n_resblocks = args.n_resblocks 12 | n_feats = args.n_feats 13 | kernel_size = 3 14 | self.scale_idx = 0 15 | 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | self.pre_process = nn.ModuleList([ 25 | nn.Sequential( 26 | common.ResBlock(conv, n_feats, 5, act=act), 27 | common.ResBlock(conv, n_feats, 5, act=act) 28 | ) for _ in args.scale 29 | ]) 30 | 31 | m_body = [ 32 | common.ResBlock( 33 | conv, n_feats, kernel_size, act=act 34 | ) for _ in range(n_resblocks) 35 | ] 36 | m_body.append(conv(n_feats, n_feats, kernel_size)) 37 | 38 | self.upsample = nn.ModuleList([ 39 | common.Upsampler( 40 | conv, s, n_feats, act=False 41 | ) for s in args.scale 42 | ]) 43 | 44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 45 | 46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 47 | 48 | self.head = nn.Sequential(*m_head) 49 | self.body = nn.Sequential(*m_body) 50 | self.tail = nn.Sequential(*m_tail) 51 | 52 | def forward(self, x): 53 | x = self.sub_mean(x) 54 | x = self.head(x) 55 | x = self.pre_process[self.scale_idx](x) 56 | 57 | res = self.body(x) 58 | res += x 59 | 60 | x = self.upsample[self.scale_idx](res) 61 | x = self.tail(x) 62 | x = self.add_mean(x) 63 | 64 | return x 65 | 66 | def set_scale(self, scale_idx): 67 | self.scale_idx = scale_idx 68 | 69 | -------------------------------------------------------------------------------- /src/model/mssr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | import torch.nn as nn 3 | import torch 4 | from model.attention import ContextualAttention,NonLocalAttention 5 | def make_model(args, parent=False): 6 | return MSSR(args) 7 | 8 | class MultisourceProjection(nn.Module): 9 | def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv): 10 | super(MultisourceProjection, self).__init__() 11 | self.up_attention = ContextualAttention(scale=2) 12 | self.down_attention = NonLocalAttention() 13 | self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) 14 | self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1) 15 | 16 | def forward(self,x): 17 | down_map = self.upsample(self.down_attention(x)) 18 | up_map = self.up_attention(x) 19 | 20 | err = self.encoder(up_map-down_map) 21 | final_map = down_map + err 22 | 23 | return final_map 24 | 25 | class RecurrentProjection(nn.Module): 26 | def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv): 27 | super(RecurrentProjection, self).__init__() 28 | self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv) 29 | self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv) 30 | self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) 31 | #self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) 32 | self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()]) 33 | self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()]) 34 | self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) 35 | self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()]) 36 | self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU()) 37 | 38 | 39 | def forward(self, x): 40 | x_up = self.multi_source_projection_1(x) 41 | 42 | x_down = self.down_sample_1(x_up) 43 | error_up = self.error_encode_1(x-x_down) 44 | h_estimate_1 = x_up + error_up 45 | 46 | x_up_2 = self.multi_source_projection_2(h_estimate_1) 47 | x_down_2 = self.down_sample_3(x_up_2) 48 | error_up_2 = self.error_encode_2(x-x_down_2) 49 | h_estimate_2 = x_up_2 + error_up_2 50 | x_final = self.post_conv(self.down_sample_4(h_estimate_2)) 51 | 52 | return x_final, h_estimate_2 53 | 54 | 55 | 56 | 57 | 58 | class MSSR(nn.Module): 59 | def __init__(self, args, conv=common.default_conv): 60 | super(MSSR, self).__init__() 61 | 62 | #n_convblock = args.n_convblocks 63 | n_feats = args.n_feats 64 | self.depth = args.depth 65 | kernel_size = 3 66 | scale = args.scale[0] 67 | 68 | 69 | rgb_mean = (0.4488, 0.4371, 0.4040) 70 | rgb_std = (1.0, 1.0, 1.0) 71 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 72 | 73 | # define head module 74 | m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()), 75 | common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())] 76 | 77 | # define multiple reconstruction module 78 | 79 | self.body = RecurrentProjection(n_feats) 80 | 81 | 82 | # define tail module 83 | m_tail = [ 84 | nn.Conv2d( 85 | n_feats*self.depth, args.n_colors, kernel_size, 86 | padding=(kernel_size//2) 87 | ) 88 | ] 89 | 90 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 91 | 92 | self.head = nn.Sequential(*m_head) 93 | self.tail = nn.Sequential(*m_tail) 94 | def forward(self,input): 95 | x = self.sub_mean(input) 96 | x = self.head(x) 97 | bag = [] 98 | for i in range(self.depth): 99 | x, h_estimate = self.body(x) 100 | bag.append(h_estimate) 101 | h_feature = torch.cat(bag,dim=1) 102 | h_final = self.tail(h_feature) 103 | 104 | return self.add_mean(h_final) 105 | -------------------------------------------------------------------------------- /src/model/nlsn.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import attention 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | if args.dilation: 7 | from model import dilated 8 | return NLSN(args, dilated.dilated_conv) 9 | else: 10 | return NLSN(args) 11 | 12 | 13 | class NLSN(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(NLSN, self).__init__() 16 | 17 | n_resblock = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | scale = args.scale[0] 21 | act = nn.ReLU(True) 22 | 23 | rgb_mean = (0.4488, 0.4371, 0.4040) 24 | rgb_std = (1.0, 1.0, 1.0) 25 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 26 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 27 | 28 | # define body module 29 | m_body = [attention.NonLocalSparseAttention( 30 | channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)] 31 | 32 | for i in range(n_resblock): 33 | m_body.append( common.ResBlock_AdaDM( 34 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 35 | )) 36 | if (i+1)%8==0: 37 | m_body.append(attention.NonLocalSparseAttention( 38 | channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)) 39 | m_body.append(conv(n_feats, n_feats, kernel_size)) 40 | 41 | # define tail module 42 | m_tail = [ 43 | common.Upsampler(conv, scale, n_feats, act=False), 44 | nn.Conv2d( 45 | n_feats, args.n_colors, kernel_size, 46 | padding=(kernel_size//2) 47 | ) 48 | ] 49 | 50 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 51 | 52 | self.head = nn.Sequential(*m_head) 53 | self.body = nn.Sequential(*m_body) 54 | self.tail = nn.Sequential(*m_tail) 55 | 56 | def forward(self, x): 57 | x = self.sub_mean(x) 58 | x = self.head(x) 59 | 60 | res = self.body(x) 61 | res += x 62 | 63 | x = self.tail(res) 64 | x = self.add_mean(x) 65 | 66 | return x 67 | 68 | def load_state_dict(self, state_dict, strict=True): 69 | own_state = self.state_dict() 70 | for name, param in state_dict.items(): 71 | if name in own_state: 72 | if isinstance(param, nn.Parameter): 73 | param = param.data 74 | try: 75 | own_state[name].copy_(param) 76 | except Exception: 77 | if name.find('tail') == -1: 78 | raise RuntimeError('While copying the parameter named {}, ' 79 | 'whose dimensions in the model are {} and ' 80 | 'whose dimensions in the checkpoint are {}.' 81 | .format(name, own_state[name].size(), param.size())) 82 | elif strict: 83 | if name.find('tail') == -1: 84 | raise KeyError('unexpected key "{}" in state_dict' 85 | .format(name)) 86 | 87 | -------------------------------------------------------------------------------- /src/model/rcan.py: -------------------------------------------------------------------------------- 1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks 2 | ## https://arxiv.org/abs/1807.02758 3 | from model import common 4 | 5 | import torch.nn as nn 6 | import torch 7 | def make_model(args, parent=False): 8 | return RCAN(args) 9 | 10 | ## Channel Attention (CA) Layer 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | # global average pooling: feature --> point 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | # feature channel downscale and upscale --> channel weight 17 | #self.a = torch.nn.Parameter(torch.Tensor([0])) 18 | #self.a.requires_grad=True 19 | 20 | self.conv_du = nn.Sequential( 21 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 24 | nn.Sigmoid() 25 | ) 26 | 27 | def forward(self, x): 28 | y = self.avg_pool(x) 29 | y = self.conv_du(y) 30 | return x * y 31 | 32 | ## Residual Channel Attention Block (RCAB) 33 | class RCAB(nn.Module): 34 | def __init__( 35 | self, conv, n_feat, kernel_size, reduction, 36 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 37 | 38 | super(RCAB, self).__init__() 39 | modules_body = [] 40 | for i in range(2): 41 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 42 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 43 | if i == 0: modules_body.append(act) 44 | modules_body.append(CALayer(n_feat, reduction)) 45 | self.body = nn.Sequential(*modules_body) 46 | self.res_scale = res_scale 47 | 48 | def forward(self, x): 49 | res = self.body(x) 50 | #res = self.body(x).mul(self.res_scale) 51 | res += x 52 | return res 53 | 54 | ## Residual Group (RG) 55 | class ResidualGroup(nn.Module): 56 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 57 | super(ResidualGroup, self).__init__() 58 | modules_body = [] 59 | modules_body = [ 60 | RCAB( 61 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 62 | for _ in range(n_resblocks)] 63 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 64 | self.body = nn.Sequential(*modules_body) 65 | 66 | def forward(self, x): 67 | res = self.body(x) 68 | res += x 69 | return res 70 | 71 | ## Residual Channel Attention Network (RCAN) 72 | class RCAN(nn.Module): 73 | def __init__(self, args, conv=common.default_conv): 74 | super(RCAN, self).__init__() 75 | self.a = nn.Parameter(torch.Tensor([0])) 76 | self.a.requires_grad=True 77 | n_resgroups = args.n_resgroups 78 | n_resblocks = args.n_resblocks 79 | n_feats = args.n_feats 80 | kernel_size = 3 81 | reduction = args.reduction 82 | scale = args.scale[0] 83 | act = nn.ReLU(True) 84 | 85 | # RGB mean for DIV2K 86 | rgb_mean = (0.4488, 0.4371, 0.4040) 87 | rgb_std = (1.0, 1.0, 1.0) 88 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 89 | 90 | # define head module 91 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 92 | 93 | # define body module 94 | modules_body = [ 95 | ResidualGroup( 96 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 97 | for _ in range(n_resgroups)] 98 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 99 | 100 | # define tail module 101 | modules_tail = [ 102 | common.Upsampler(conv, scale, n_feats, act=False), 103 | conv(n_feats, args.n_colors, kernel_size)] 104 | 105 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 106 | 107 | self.head = nn.Sequential(*modules_head) 108 | self.body = nn.Sequential(*modules_body) 109 | self.tail = nn.Sequential(*modules_tail) 110 | 111 | def forward(self, x): 112 | x = self.sub_mean(x) 113 | x = self.head(x) 114 | res = self.body(x) 115 | res += x 116 | 117 | x = self.tail(res) 118 | x = self.add_mean(x) 119 | 120 | return x 121 | 122 | def load_state_dict(self, state_dict, strict=False): 123 | own_state = self.state_dict() 124 | for name, param in state_dict.items(): 125 | if name in own_state: 126 | if isinstance(param, nn.Parameter): 127 | param = param.data 128 | try: 129 | own_state[name].copy_(param) 130 | except Exception: 131 | if name.find('msa') or name.find('a') >= 0: 132 | print('Replace pre-trained upsampler to new one...') 133 | else: 134 | raise RuntimeError('While copying the parameter named {}, ' 135 | 'whose dimensions in the model are {} and ' 136 | 'whose dimensions in the checkpoint are {}.' 137 | .format(name, own_state[name].size(), param.size())) 138 | elif strict: 139 | if name.find('msa') == -1: 140 | raise KeyError('unexpected key "{}" in state_dict' 141 | .format(name)) 142 | 143 | if strict: 144 | missing = set(own_state.keys()) - set(state_dict.keys()) 145 | if len(missing) > 0: 146 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 147 | -------------------------------------------------------------------------------- /src/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDB_AdaDM(nn.Module): 46 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 47 | super(RDB_AdaDM, self).__init__() 48 | G0 = growRate0 49 | G = growRate 50 | C = nConvLayers 51 | 52 | convs = [] 53 | for c in range(C): 54 | convs.append(RDB_Conv(G0 + c*G, G)) 55 | self.convs = nn.Sequential(*convs) 56 | 57 | # Local Feature Fusion 58 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 59 | self.phi = nn.Conv2d(1, 1, 1, 1, 0, bias=True) 60 | self.phi.weight.data.fill_(1) 61 | self.phi.bias.data.fill_(0) 62 | self.norm = nn.BatchNorm2d(G0) 63 | 64 | def forward(self, x): 65 | s = torch.std(x, dim=[1,2,3], keepdim=True) 66 | x_n = self.norm(x) 67 | F = self.LFF(self.convs(x_n)) 68 | F = F * (torch.exp(self.phi(torch.log(s)))) 69 | 70 | return F + x 71 | 72 | class RDN(nn.Module): 73 | def __init__(self, args): 74 | super(RDN, self).__init__() 75 | r = args.scale[0] 76 | G0 = args.G0 77 | kSize = args.RDNkSize 78 | 79 | # number of RDB blocks, conv layers, out channels 80 | self.D, C, G = { 81 | 'A': (20, 6, 32), 82 | 'B': (16, 8, 64), 83 | }[args.RDNconfig] 84 | 85 | # Shallow feature extraction net 86 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 87 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 88 | 89 | # Redidual dense blocks and dense feature fusion 90 | self.RDBs = nn.ModuleList() 91 | for i in range(self.D): 92 | self.RDBs.append( 93 | RDB_AdaDM(growRate0 = G0, growRate = G, nConvLayers = C) 94 | ) 95 | 96 | # Global Feature Fusion 97 | self.GFF = nn.Sequential(*[ 98 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 99 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 100 | ]) 101 | 102 | # Up-sampling net 103 | if r == 2 or r == 3: 104 | self.UPNet = nn.Sequential(*[ 105 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 106 | nn.PixelShuffle(r), 107 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 108 | ]) 109 | elif r == 4: 110 | self.UPNet = nn.Sequential(*[ 111 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 112 | nn.PixelShuffle(2), 113 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 114 | nn.PixelShuffle(2), 115 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 116 | ]) 117 | else: 118 | raise ValueError("scale must be 2 or 3 or 4.") 119 | 120 | def forward(self, x): 121 | f__1 = self.SFENet1(x) 122 | x = self.SFENet2(f__1) 123 | 124 | RDBs_out = [] 125 | for i in range(self.D): 126 | x = self.RDBs[i](x) 127 | RDBs_out.append(x) 128 | 129 | x = self.GFF(torch.cat(RDBs_out,1)) 130 | x += f__1 131 | 132 | return self.UPNet(x) 133 | -------------------------------------------------------------------------------- /src/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/model/utils/__init__.py -------------------------------------------------------------------------------- /src/model/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /src/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=18, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', action='store_true', 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=1, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | parser.add_argument('--local_rank',type=int, default=0) 21 | # Data specifications 22 | parser.add_argument('--dir_data', type=str, default='../../../', 23 | help='dataset directory') 24 | parser.add_argument('--dir_demo', type=str, default='../Demo', 25 | help='demo image directory') 26 | parser.add_argument('--data_train', type=str, default='DIV2K', 27 | help='train dataset name') 28 | parser.add_argument('--data_test', type=str, default='DIV2K', 29 | help='test dataset name') 30 | parser.add_argument('--data_range', type=str, default='1-800/801-810', 31 | help='train/test data range') 32 | parser.add_argument('--ext', type=str, default='sep', 33 | help='dataset file extension') 34 | parser.add_argument('--scale', type=str, default='4', 35 | help='super resolution scale') 36 | parser.add_argument('--patch_size', type=int, default=192, 37 | help='output patch size') 38 | parser.add_argument('--rgb_range', type=int, default=255, 39 | help='maximum value of RGB') 40 | parser.add_argument('--n_colors', type=int, default=3, 41 | help='number of color channels to use') 42 | parser.add_argument('--chunk_size',type=int,default=144, 43 | help='attention bucket size') 44 | parser.add_argument('--n_hashes',type=int,default=4, 45 | help='number of hash rounds') 46 | parser.add_argument('--chop', action='store_true', 47 | help='enable memory-efficient forward') 48 | parser.add_argument('--no_augment', action='store_true', 49 | help='do not use data augmentation') 50 | 51 | # Model specifications 52 | parser.add_argument('--model', default='EDSR', 53 | help='model name') 54 | 55 | parser.add_argument('--act', type=str, default='relu', 56 | help='activation function') 57 | parser.add_argument('--pre_train', type=str, default='.', 58 | help='pre-trained model directory') 59 | parser.add_argument('--extend', type=str, default='.', 60 | help='pre-trained model directory') 61 | parser.add_argument('--n_resblocks', type=int, default=20, 62 | help='number of residual blocks') 63 | parser.add_argument('--n_feats', type=int, default=64, 64 | help='number of feature maps') 65 | parser.add_argument('--res_scale', type=float, default=1, 66 | help='residual scaling') 67 | parser.add_argument('--shift_mean', default=True, 68 | help='subtract pixel mean from the input') 69 | parser.add_argument('--dilation', action='store_true', 70 | help='use dilated convolution') 71 | parser.add_argument('--precision', type=str, default='single', 72 | choices=('single', 'half'), 73 | help='FP precision for test (single | half)') 74 | 75 | # Option for Residual dense network (RDN) 76 | parser.add_argument('--G0', type=int, default=64, 77 | help='default number of filters. (Use in RDN)') 78 | parser.add_argument('--RDNkSize', type=int, default=3, 79 | help='default kernel size. (Use in RDN)') 80 | parser.add_argument('--RDNconfig', type=str, default='B', 81 | help='parameters config of RDN. (Use in RDN)') 82 | 83 | parser.add_argument('--depth', type=int, default=12, 84 | help='number of residual groups') 85 | # Option for Residual channel attention network (RCAN) 86 | parser.add_argument('--n_resgroups', type=int, default=10, 87 | help='number of residual groups') 88 | parser.add_argument('--reduction', type=int, default=16, 89 | help='number of feature maps reduction') 90 | 91 | # Training specifications 92 | parser.add_argument('--reset', action='store_true', 93 | help='reset the training') 94 | parser.add_argument('--test_every', type=int, default=1000, 95 | help='do test per every N batches') 96 | parser.add_argument('--epochs', type=int, default=1000, 97 | help='number of epochs to train') 98 | parser.add_argument('--batch_size', type=int, default=16, 99 | help='input batch size for training') 100 | parser.add_argument('--split_batch', type=int, default=1, 101 | help='split the batch into smaller chunks') 102 | parser.add_argument('--self_ensemble', action='store_true', 103 | help='use self-ensemble method for test') 104 | parser.add_argument('--test_only', action='store_true', 105 | help='set this option to test the model') 106 | parser.add_argument('--gan_k', type=int, default=1, 107 | help='k value for adversarial loss') 108 | 109 | # Optimization specifications 110 | parser.add_argument('--lr', type=float, default=1e-4, 111 | help='learning rate') 112 | parser.add_argument('--decay', type=str, default='200', 113 | help='learning rate decay type') 114 | parser.add_argument('--gamma', type=float, default=0.5, 115 | help='learning rate decay factor for step decay') 116 | parser.add_argument('--optimizer', default='ADAM', 117 | choices=('SGD', 'ADAM', 'RMSprop'), 118 | help='optimizer to use (SGD | ADAM | RMSprop)') 119 | parser.add_argument('--momentum', type=float, default=0.9, 120 | help='SGD momentum') 121 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 122 | help='ADAM beta') 123 | parser.add_argument('--epsilon', type=float, default=1e-8, 124 | help='ADAM epsilon for numerical stability') 125 | parser.add_argument('--weight_decay', type=float, default=0, 126 | help='weight decay') 127 | parser.add_argument('--gclip', type=float, default=0, 128 | help='gradient clipping threshold (0 = no clipping)') 129 | 130 | # Loss specifications 131 | parser.add_argument('--loss', type=str, default='1*L1', 132 | help='loss function configuration') 133 | parser.add_argument('--skip_threshold', type=float, default='1e8', 134 | help='skipping batch that has large error') 135 | 136 | # Log specifications 137 | parser.add_argument('--save', type=str, default='test', 138 | help='file name to save') 139 | parser.add_argument('--load', type=str, default='', 140 | help='file name to load') 141 | parser.add_argument('--resume', type=int, default=0, 142 | help='resume from specific checkpoint') 143 | parser.add_argument('--save_models', action='store_true', 144 | help='save all intermediate models') 145 | parser.add_argument('--print_every', type=int, default=100, 146 | help='how many batches to wait before logging training status') 147 | parser.add_argument('--save_results', action='store_true', 148 | help='save output results') 149 | parser.add_argument('--save_gt', action='store_true', 150 | help='save low-resolution and high-resolution images together') 151 | 152 | args = parser.parse_args() 153 | template.set_template(args) 154 | 155 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 156 | args.data_train = args.data_train.split('+') 157 | args.data_test = args.data_test.split('+') 158 | 159 | if args.epochs == 0: 160 | args.epochs = 1e8 161 | 162 | for arg in vars(args): 163 | if vars(args)[arg] == 'True': 164 | vars(args)[arg] = True 165 | elif vars(args)[arg] == 'False': 166 | vars(args)[arg] = False 167 | 168 | -------------------------------------------------------------------------------- /src/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-1 53 | 54 | -------------------------------------------------------------------------------- /src/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_ID=3 4 | SCALE=2 5 | TEST_DATASET="Urban100" 6 | TEST_MODEL="EDSR" 7 | 8 | ###################################################################################################### 9 | # EDSR Test 10 | if [[ $TEST_MODEL == "EDSR" ]]; then 11 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale $SCALE\ 12 | --res_scale 0.1 --pre_train ../experiment/test/model/EDSR_AdaDM_DIV2K_X$SCALE.pt\ 13 | --dir_data ../dataset --n_GPUs 1 --test_only --data_test $TEST_DATASET 14 | fi 15 | 16 | ###################################################################################################### 17 | # RDN Test 18 | if [[ $TEST_MODEL == "RDN" ]]; then 19 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale $SCALE\ 20 | --pre_train ../experiment/test/model/RDN_AdaDM_DIV2K_X$SCALE.pt\ 21 | --dir_data ../dataset --n_GPUs 1 --test_only --data_test $TEST_DATASET 22 | fi 23 | 24 | ###################################################################################################### 25 | # NLSN Test 26 | if [[ $TEST_MODEL == "NLSN" ]]; then 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --dir_data ../dataset --model NLSN --chunk_size 144\ 28 | --n_hashes 4 --chop --rgb_range 1 --scale $SCALE --n_feats 256 --n_resblocks 32 --res_scale 0.1\ 29 | --pre_train ../experiment/test/model/NLSN_AdaDM_DIV2K_X$SCALE.pt --test_only --data_test $TEST_DATASET 30 | fi 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /src/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_ID=0 4 | 5 | ###################################################################################################### 6 | # EDSR Train X2 7 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 2\ 8 | --n_GPUs 1 --batch_size 16 --patch_size 96 --rgb_range 255 --res_scale 0.1\ 9 | --save EDSR_AdaDM_Test_DIV2K_X2 --dir_data ../dataset --data_test Urban100\ 10 | --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results 11 | 12 | # EDSR Train X3 13 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 3\ 14 | # ---n_GPUs 1 -batch_size 16 --patch_size 144 --rgb_range 255 --res_scale 0.1\ 15 | # --save EDSR_AdaDM_Test_DIV2K_X3 --dir_data ../dataset --data_test Urban100\ 16 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\ 17 | # --pre_train ../experiment/EDSR_AdaDM_Test_DIV2K_X2/model/model_best.pt 18 | 19 | # EDSR Train X4 20 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 4\ 21 | # ---n_GPUs 1 -batch_size 16 --patch_size 192 --rgb_range 255 --res_scale 0.1\ 22 | # --save EDSR_AdaDM_Test_DIV2K_X4 --dir_data ../dataset --data_test Urban100\ 23 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\ 24 | # --pre_train ../experiment/EDSR_AdaDM_Test_DIV2K_X2/model/model_best.pt 25 | 26 | ###################################################################################################### 27 | # RDN Train X2 28 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale 2\ 29 | # --batch_size 16 --patch_size 96 --rgb_range 255 --n_GPUs 1\ 30 | # --save RDN_AdaDM_Test_DIV2K_X2 --dir_data ../dataset --data_test Urban100\ 31 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results 32 | 33 | # RDN Train X3 34 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale 3\ 35 | # --batch_size 16 --patch_size 144 --rgb_range 255 --n_GPUs 1\ 36 | # --save RDN_AdaDM_Test_DIV2K_X3 --dir_data ../dataset --data_test Urban100\ 37 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\ 38 | # --pre_train ../experiment/RDN_AdaDM_Test_DIV2K_X2/model/model_best.pt 39 | 40 | # RDN Train X4 41 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale 4\ 42 | # --batch_size 16 --patch_size 192 --rgb_range 255 --n_GPUs 1\ 43 | # --save RDN_AdaDM_Test_DIV2K_X4 --dir_data ../dataset --data_test Urban100\ 44 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\ 45 | # --pre_train ../experiment/RDN_AdaDM_Test_DIV2K_X2/model/model_best.pt 46 | 47 | 48 | 49 | ###################################################################################################### 50 | # NLSN Train X2 51 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model NLSN --dir_data ../dataset --n_GPUs 1\ 52 | # --chunk_size 144 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop\ 53 | # --n_resblocks 32 --n_feats 256 --rgb_range 1 --res_scale 0.1 --batch_size 16 --scale 2\ 54 | # --patch_size 96 --save NLSN_AdaDM_Test_DIV2K_X2 --data_test Urban100 --save_models --save_results 55 | 56 | # NLSN Train X3 57 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model NLSN --dir_data ../dataset --n_GPUs 1\ 58 | # --chunk_size 144 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop\ 59 | # --n_resblocks 32 --n_feats 256 --rgb_range 1 --res_scale 0.1 --batch_size 16 --scale 3\ 60 | # --patch_size 144 --save NLSN_AdaDM_Test_DIV2K_X3 --data_test Urban100 --save_models --save_results\ 61 | # --pre_train ../experiment/NLSN_AdaDM_Test_DIV2K_X2/model/model_best.pt 62 | 63 | # NLSN Train X4 64 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model NLSN --dir_data ../dataset --n_GPUs 1\ 65 | # --chunk_size 144 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop\ 66 | # --n_resblocks 32 --n_feats 256 --rgb_range 1 --res_scale 0.1 --batch_size 16 --scale 4\ 67 | # --patch_size 192 --save NLSN_AdaDM_Test_DIV2K_X4 --data_test Urban100 --save_models --save_results\ 68 | # --pre_train ../experiment/NLSN_AdaDM_Test_DIV2K_X2/model/model_best.pt 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | 5 | import utility 6 | 7 | import torch 8 | import torch.nn.utils as utils 9 | from tqdm import tqdm 10 | 11 | class Trainer(): 12 | def __init__(self, args, loader, my_model, my_loss, ckp): 13 | self.args = args 14 | self.scale = args.scale 15 | 16 | self.ckp = ckp 17 | self.loader_train = loader.loader_train 18 | self.loader_test = loader.loader_test 19 | self.model = my_model 20 | self.loss = my_loss 21 | self.optimizer = utility.make_optimizer(args, self.model) 22 | 23 | if self.args.load != '': 24 | self.optimizer.load(ckp.dir, epoch=len(ckp.log)) 25 | 26 | self.error_last = 1e8 27 | 28 | def train(self): 29 | self.loss.step() 30 | epoch = self.optimizer.get_last_epoch() + 1 31 | lr = self.optimizer.get_lr() 32 | 33 | self.ckp.write_log( 34 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 35 | ) 36 | self.loss.start_log() 37 | self.model.train() 38 | 39 | timer_data, timer_model = utility.timer(), utility.timer() 40 | # TEMP 41 | self.loader_train.dataset.set_scale(0) 42 | for batch, (lr, hr, _,) in enumerate(self.loader_train): 43 | lr, hr = self.prepare(lr, hr) 44 | timer_data.hold() 45 | timer_model.tic() 46 | 47 | self.optimizer.zero_grad() 48 | sr = self.model(lr, 0) 49 | loss = self.loss(sr, hr) 50 | loss.backward() 51 | if self.args.gclip > 0: 52 | utils.clip_grad_value_( 53 | self.model.parameters(), 54 | self.args.gclip 55 | ) 56 | self.optimizer.step() 57 | 58 | timer_model.hold() 59 | 60 | if (batch + 1) % self.args.print_every == 0: 61 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 62 | (batch + 1) * self.args.batch_size, 63 | len(self.loader_train.dataset), 64 | self.loss.display_loss(batch), 65 | timer_model.release(), 66 | timer_data.release())) 67 | 68 | timer_data.tic() 69 | 70 | self.loss.end_log(len(self.loader_train)) 71 | self.error_last = self.loss.log[-1, -1] 72 | self.optimizer.schedule() 73 | 74 | def test(self): 75 | torch.set_grad_enabled(False) 76 | 77 | epoch = self.optimizer.get_last_epoch() 78 | self.ckp.write_log('\nEvaluation:') 79 | self.ckp.add_log( 80 | torch.zeros(1, len(self.loader_test), len(self.scale)) 81 | ) 82 | self.model.eval() 83 | 84 | timer_test = utility.timer() 85 | if self.args.save_results: self.ckp.begin_background() 86 | for idx_data, d in enumerate(self.loader_test): 87 | for idx_scale, scale in enumerate(self.scale): 88 | d.dataset.set_scale(idx_scale) 89 | for lr, hr, filename in tqdm(d, ncols=80): 90 | lr, hr = self.prepare(lr, hr) 91 | sr = self.model(lr, idx_scale) 92 | sr = utility.quantize(sr, self.args.rgb_range) 93 | 94 | save_list = [sr] 95 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 96 | sr, hr, scale, self.args.rgb_range, dataset=d 97 | ) 98 | if self.args.save_gt: 99 | save_list.extend([lr, hr]) 100 | 101 | if self.args.save_results: 102 | self.ckp.save_results(d, filename[0], save_list, scale) 103 | 104 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 105 | best = self.ckp.log.max(0) 106 | self.ckp.write_log( 107 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 108 | d.dataset.name, 109 | scale, 110 | self.ckp.log[-1, idx_data, idx_scale], 111 | best[0][idx_data, idx_scale], 112 | best[1][idx_data, idx_scale] + 1 113 | ) 114 | ) 115 | 116 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) 117 | self.ckp.write_log('Saving...') 118 | 119 | if self.args.save_results: 120 | self.ckp.end_background() 121 | 122 | if not self.args.test_only: 123 | self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 124 | 125 | self.ckp.write_log( 126 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 127 | ) 128 | 129 | torch.set_grad_enabled(True) 130 | 131 | def prepare(self, *args): 132 | device = torch.device('cpu' if self.args.cpu else 'cuda') 133 | def _prepare(tensor): 134 | if self.args.precision == 'half': tensor = tensor.half() 135 | return tensor.to(device) 136 | 137 | return [_prepare(a) for a in args] 138 | 139 | def terminate(self): 140 | if self.args.test_only: 141 | self.test() 142 | return True 143 | else: 144 | epoch = self.optimizer.get_last_epoch() + 1 145 | return epoch >= self.args.epochs 146 | 147 | -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from multiprocessing import Process 6 | from multiprocessing import Queue 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import numpy as np 13 | import imageio 14 | 15 | import torch 16 | import torch.optim as optim 17 | import torch.optim.lr_scheduler as lrs 18 | 19 | class timer(): 20 | def __init__(self): 21 | self.acc = 0 22 | self.tic() 23 | 24 | def tic(self): 25 | self.t0 = time.time() 26 | 27 | def toc(self, restart=False): 28 | diff = time.time() - self.t0 29 | if restart: self.t0 = time.time() 30 | return diff 31 | 32 | def hold(self): 33 | self.acc += self.toc() 34 | 35 | def release(self): 36 | ret = self.acc 37 | self.acc = 0 38 | 39 | return ret 40 | 41 | def reset(self): 42 | self.acc = 0 43 | 44 | class checkpoint(): 45 | def __init__(self, args): 46 | self.args = args 47 | self.ok = True 48 | self.log = torch.Tensor() 49 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 50 | 51 | if not args.load: 52 | if not args.save: 53 | args.save = now 54 | self.dir = os.path.join('..', 'experiment', args.save) 55 | else: 56 | self.dir = os.path.join('..', 'experiment', args.load) 57 | if os.path.exists(self.dir): 58 | self.log = torch.load(self.get_path('psnr_log.pt')) 59 | print('Continue from epoch {}...'.format(len(self.log))) 60 | else: 61 | args.load = '' 62 | 63 | if args.reset: 64 | os.system('rm -rf ' + self.dir) 65 | args.load = '' 66 | 67 | os.makedirs(self.dir, exist_ok=True) 68 | os.makedirs(self.get_path('model'), exist_ok=True) 69 | for d in args.data_test: 70 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 71 | 72 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' 73 | self.log_file = open(self.get_path('log.txt'), open_type) 74 | with open(self.get_path('config.txt'), open_type) as f: 75 | f.write(now + '\n\n') 76 | for arg in vars(args): 77 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 78 | f.write('\n') 79 | 80 | self.n_processes = 8 81 | 82 | def get_path(self, *subdir): 83 | return os.path.join(self.dir, *subdir) 84 | 85 | def save(self, trainer, epoch, is_best=False): 86 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 87 | trainer.loss.save(self.dir) 88 | trainer.loss.plot_loss(self.dir, epoch) 89 | 90 | self.plot_psnr(epoch) 91 | trainer.optimizer.save(self.dir) 92 | torch.save(self.log, self.get_path('psnr_log.pt')) 93 | 94 | def add_log(self, log): 95 | self.log = torch.cat([self.log, log]) 96 | 97 | def write_log(self, log, refresh=False): 98 | print(log) 99 | self.log_file.write(log + '\n') 100 | if refresh: 101 | self.log_file.close() 102 | self.log_file = open(self.get_path('log.txt'), 'a') 103 | 104 | def done(self): 105 | self.log_file.close() 106 | 107 | def plot_psnr(self, epoch): 108 | axis = np.linspace(1, epoch, epoch) 109 | for idx_data, d in enumerate(self.args.data_test): 110 | label = 'SR on {}'.format(d) 111 | fig = plt.figure() 112 | plt.title(label) 113 | for idx_scale, scale in enumerate(self.args.scale): 114 | plt.plot( 115 | axis, 116 | self.log[:, idx_data, idx_scale].numpy(), 117 | label='Scale {}'.format(scale) 118 | ) 119 | plt.legend() 120 | plt.xlabel('Epochs') 121 | plt.ylabel('PSNR') 122 | plt.grid(True) 123 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 124 | plt.close(fig) 125 | 126 | def begin_background(self): 127 | self.queue = Queue() 128 | 129 | def bg_target(queue): 130 | while True: 131 | if not queue.empty(): 132 | filename, tensor = queue.get() 133 | if filename is None: break 134 | imageio.imwrite(filename, tensor.numpy()) 135 | 136 | self.process = [ 137 | Process(target=bg_target, args=(self.queue,)) \ 138 | for _ in range(self.n_processes) 139 | ] 140 | 141 | for p in self.process: p.start() 142 | 143 | def end_background(self): 144 | for _ in range(self.n_processes): self.queue.put((None, None)) 145 | while not self.queue.empty(): time.sleep(1) 146 | for p in self.process: p.join() 147 | 148 | def save_results(self, dataset, filename, save_list, scale): 149 | if self.args.save_results: 150 | filename = self.get_path( 151 | 'results-{}'.format(dataset.dataset.name), 152 | '{}_x{}_'.format(filename, scale) 153 | ) 154 | 155 | postfix = ('SR', 'LR', 'HR') 156 | for v, p in zip(save_list, postfix): 157 | normalized = v[0].mul(255 / self.args.rgb_range) 158 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 159 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 160 | 161 | def quantize(img, rgb_range): 162 | pixel_range = 255 / rgb_range 163 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 164 | 165 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 166 | if hr.nelement() == 1: return 0 167 | 168 | diff = (sr - hr) / rgb_range 169 | if dataset and dataset.dataset.benchmark: 170 | shave = scale 171 | if diff.size(1) > 1: 172 | gray_coeffs = [65.738, 129.057, 25.064] 173 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 174 | diff = diff.mul(convert).sum(dim=1) 175 | else: 176 | shave = scale + 6 177 | 178 | valid = diff[..., shave:-shave, shave:-shave] 179 | mse = valid.pow(2).mean() 180 | 181 | return -10 * math.log10(mse) 182 | 183 | def make_optimizer(args, target): 184 | ''' 185 | make optimizer and scheduler together 186 | ''' 187 | # optimizer 188 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 189 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 190 | 191 | if args.optimizer == 'SGD': 192 | optimizer_class = optim.SGD 193 | kwargs_optimizer['momentum'] = args.momentum 194 | elif args.optimizer == 'ADAM': 195 | optimizer_class = optim.Adam 196 | kwargs_optimizer['betas'] = args.betas 197 | kwargs_optimizer['eps'] = args.epsilon 198 | elif args.optimizer == 'RMSprop': 199 | optimizer_class = optim.RMSprop 200 | kwargs_optimizer['eps'] = args.epsilon 201 | 202 | # scheduler 203 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 204 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 205 | scheduler_class = lrs.MultiStepLR 206 | 207 | class CustomOptimizer(optimizer_class): 208 | def __init__(self, *args, **kwargs): 209 | super(CustomOptimizer, self).__init__(*args, **kwargs) 210 | 211 | def _register_scheduler(self, scheduler_class, **kwargs): 212 | self.scheduler = scheduler_class(self, **kwargs) 213 | 214 | def save(self, save_dir): 215 | torch.save(self.state_dict(), self.get_dir(save_dir)) 216 | 217 | def load(self, load_dir, epoch=1): 218 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 219 | if epoch > 1: 220 | for _ in range(epoch): self.scheduler.step() 221 | 222 | def get_dir(self, dir_path): 223 | return os.path.join(dir_path, 'optimizer.pt') 224 | 225 | def schedule(self): 226 | self.scheduler.step() 227 | 228 | def get_lr(self): 229 | return self.scheduler.get_lr()[0] 230 | 231 | def get_last_epoch(self): 232 | return self.scheduler.last_epoch 233 | 234 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 235 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 236 | return optimizer 237 | 238 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /src/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | --------------------------------------------------------------------------------