├── CAR ├── README.md ├── code │ ├── LICENSE │ ├── __init__.py │ ├── __pycache__ │ │ ├── option.cpython-36.pyc │ │ ├── template.cpython-36.pyc │ │ ├── trainer.cpython-36.pyc │ │ └── utility.cpython-36.pyc │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── benchmark.cpython-36.pyc │ │ │ ├── common.cpython-36.pyc │ │ │ ├── div2k.cpython-36.pyc │ │ │ └── srdata.cpython-36.pyc │ │ ├── benchmark.py │ │ ├── common.py │ │ ├── demo.py │ │ ├── div2k.py │ │ ├── div2kjpeg.py │ │ ├── sr291.py │ │ ├── srdata.py │ │ └── video.py │ ├── dataloader.py │ ├── demo.sb │ ├── demo.sh │ ├── loss │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ ├── adversarial.py │ │ ├── discriminator.py │ │ └── vgg.py │ ├── main.py │ ├── model │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── attention.cpython-36.pyc │ │ │ ├── common.cpython-36.pyc │ │ │ ├── edsr.cpython-36.pyc │ │ │ ├── edsr2.cpython-36.pyc │ │ │ ├── edsrl1.cpython-36.pyc │ │ │ ├── edsrl1t.cpython-36.pyc │ │ │ └── mssr.cpython-36.pyc │ │ ├── attention.py │ │ ├── common.py │ │ ├── ddbpn.py │ │ ├── mdsr.py │ │ ├── panet.py │ │ ├── rcan.py │ │ ├── rdn.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-35.pyc │ │ │ │ └── tools.cpython-35.pyc │ │ │ └── tools.py │ │ └── vdsr.py │ ├── option.py │ ├── template.py │ ├── trainer.py │ ├── utility.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── tools.cpython-35.pyc │ │ │ └── tools.cpython-36.pyc │ │ └── tools.py │ └── videotester.py └── experiment │ └── PP_Evaluate_CAR_Y_PSNR_SSIM.m ├── DN_RGB ├── README.md ├── code │ ├── LICENSE │ ├── __init__.py │ ├── __pycache__ │ │ ├── option.cpython-36.pyc │ │ ├── template.cpython-36.pyc │ │ ├── trainer.cpython-36.pyc │ │ └── utility.cpython-36.pyc │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── benchmark.cpython-36.pyc │ │ │ ├── common.cpython-36.pyc │ │ │ ├── div2k.cpython-36.pyc │ │ │ └── srdata.cpython-36.pyc │ │ ├── benchmark.py │ │ ├── common.py │ │ ├── demo.py │ │ ├── div2k.py │ │ ├── div2kjpeg.py │ │ ├── sr291.py │ │ ├── srdata.py │ │ └── video.py │ ├── dataloader.py │ ├── demo.sb │ ├── loss │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ ├── adversarial.py │ │ ├── discriminator.py │ │ └── vgg.py │ ├── main.py │ ├── model │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── attention.cpython-36.pyc │ │ │ ├── common.cpython-36.pyc │ │ │ ├── edsr.cpython-36.pyc │ │ │ ├── edsr2.cpython-36.pyc │ │ │ ├── edsrl1.cpython-36.pyc │ │ │ ├── edsrl1t.cpython-36.pyc │ │ │ ├── mssr.cpython-36.pyc │ │ │ └── panet.cpython-36.pyc │ │ ├── attention.py │ │ ├── common.py │ │ ├── ddbpn.py │ │ ├── mdsr.py │ │ ├── panet.py │ │ ├── rcan.py │ │ ├── rdn.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-35.pyc │ │ │ │ └── tools.cpython-35.pyc │ │ │ └── tools.py │ │ └── vdsr.py │ ├── option.py │ ├── template.py │ ├── trainer.py │ ├── utility.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── tools.cpython-35.pyc │ │ │ └── tools.cpython-36.pyc │ │ └── tools.py │ └── videotester.py └── experiment │ └── PP_Evaluate_DN_RGB_PSNR_SSIM.m ├── Demosaic ├── README.md ├── code │ ├── LICENSE │ ├── __init__.py │ ├── __pycache__ │ │ ├── option.cpython-35.pyc │ │ ├── option.cpython-36.pyc │ │ ├── template.cpython-35.pyc │ │ ├── template.cpython-36.pyc │ │ ├── trainer.cpython-35.pyc │ │ ├── trainer.cpython-36.pyc │ │ ├── utility.cpython-35.pyc │ │ └── utility.cpython-36.pyc │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── benchmark.cpython-36.pyc │ │ │ ├── common.cpython-35.pyc │ │ │ ├── common.cpython-36.pyc │ │ │ ├── div2k.cpython-35.pyc │ │ │ ├── div2k.cpython-36.pyc │ │ │ ├── srdata.cpython-35.pyc │ │ │ └── srdata.cpython-36.pyc │ │ ├── benchmark.py │ │ ├── common.py │ │ ├── demo.py │ │ ├── div2k.py │ │ ├── div2kjpeg.py │ │ ├── sr291.py │ │ ├── srdata.py │ │ └── video.py │ ├── dataloader.py │ ├── demo.sb │ ├── loss │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ └── __init__.cpython-36.pyc │ │ ├── adversarial.py │ │ ├── discriminator.py │ │ └── vgg.py │ ├── main.py │ ├── model │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── attention.cpython-35.pyc │ │ │ ├── attention.cpython-36.pyc │ │ │ ├── common.cpython-35.pyc │ │ │ ├── common.cpython-36.pyc │ │ │ ├── edsr.cpython-35.pyc │ │ │ ├── edsr.cpython-36.pyc │ │ │ ├── edsr2.cpython-36.pyc │ │ │ ├── edsrl1.cpython-36.pyc │ │ │ ├── edsrl1t.cpython-36.pyc │ │ │ └── mssr.cpython-36.pyc │ │ ├── attention.py │ │ ├── common.py │ │ ├── ddbpn.py │ │ ├── mdsr.py │ │ ├── panet.py │ │ ├── rcan.py │ │ ├── rdn.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-35.pyc │ │ │ │ └── tools.cpython-35.pyc │ │ │ └── tools.py │ │ └── vdsr.py │ ├── option.py │ ├── template.py │ ├── trainer.py │ ├── utility.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── tools.cpython-35.pyc │ │ │ └── tools.cpython-36.pyc │ │ └── tools.py │ └── videotester.py └── experiment │ └── PP_Evaluate_Demosaic_RGB_PSNR_SSIM.m ├── Figs ├── PSNR_CAR.png ├── PSNR_DN_RGB.png ├── PSNR_Demosaic.png ├── PSNR_SR.png ├── Screenshot from 2020-04-23 18-56-03.png ├── Visual_CAR.png ├── Visual_DN_RGB.png ├── Visual_Demosaic.png ├── Visual_SR.png └── block.png ├── LICENSE ├── README.md └── SR ├── README.md ├── code ├── __init__.py ├── __pycache__ │ ├── option.cpython-35.pyc │ ├── option.cpython-36.pyc │ ├── template.cpython-35.pyc │ ├── template.cpython-36.pyc │ ├── trainer.cpython-35.pyc │ ├── trainer.cpython-36.pyc │ ├── utility.cpython-35.pyc │ └── utility.cpython-36.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── benchmark.cpython-35.pyc │ │ ├── benchmark.cpython-36.pyc │ │ ├── common.cpython-35.pyc │ │ ├── common.cpython-36.pyc │ │ ├── demo.cpython-36.pyc │ │ ├── div2k.cpython-35.pyc │ │ ├── div2k.cpython-36.pyc │ │ ├── srdata.cpython-35.pyc │ │ └── srdata.cpython-36.pyc │ ├── benchmark.py │ ├── common.py │ ├── demo.py │ ├── div2k.py │ ├── div2kjpeg.py │ ├── sr291.py │ ├── srdata.py │ └── video.py ├── dataloader.py ├── demo.sb ├── loss │ ├── __init__.py │ ├── __loss__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ └── __init__.cpython-36.pyc │ ├── adversarial.py │ ├── demo.sh │ ├── discriminator.py │ └── vgg.py ├── main.py ├── model │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── attention.cpython-35.pyc │ │ ├── attention.cpython-36.pyc │ │ ├── common.cpython-35.pyc │ │ ├── common.cpython-36.pyc │ │ ├── edsr.cpython-35.pyc │ │ ├── edsr.cpython-36.pyc │ │ ├── mssr.cpython-35.pyc │ │ ├── mssr.cpython-36.pyc │ │ └── rcan.cpython-36.pyc │ ├── attention.py │ ├── common.py │ ├── ddbpn.py │ ├── edsr.py │ ├── mdsr.py │ ├── mssr.py │ ├── paedsr.py │ ├── rcan.py │ ├── rdn.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-35.pyc │ │ │ └── tools.cpython-35.pyc │ │ └── tools.py │ └── vdsr.py ├── option.py ├── template.py ├── trainer.py ├── utility.py ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── tools.cpython-35.pyc │ │ └── tools.cpython-36.pyc │ └── tools.py └── videotester.py └── experiment └── Evaluate_PSNR_SSIM.m /CAR/code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CAR/code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/__init__.py -------------------------------------------------------------------------------- /CAR/code/__pycache__/option.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/__pycache__/option.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/__pycache__/template.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/__pycache__/template.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/__pycache__/utility.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/__pycache__/utility.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['CBSD68','classic5','LIVE1','Kodak24','Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /CAR/code/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/data/__pycache__/benchmark.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/data/__pycache__/benchmark.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/data/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/data/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/data/__pycache__/div2k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/data/__pycache__/div2k.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/data/__pycache__/srdata.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/data/__pycache__/srdata.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('','.jpg') 25 | 26 | -------------------------------------------------------------------------------- /CAR/code/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=1, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | print('heelo') 11 | print(args[0].shape) 12 | 13 | if not input_large: 14 | p = 1 if multi else 1 15 | tp = p * patch_size 16 | ip = tp // 1 17 | else: 18 | tp = patch_size 19 | ip = patch_size 20 | 21 | ix = random.randrange(0, iw - ip + 1) 22 | iy = random.randrange(0, ih - ip + 1) 23 | 24 | if not input_large: 25 | tx, ty = 1 * ix, 1 * iy 26 | else: 27 | tx, ty = ix, iy 28 | 29 | ret = [ 30 | args[0][iy:iy + ip, ix:ix + ip], 31 | *[a[ty:ty + tp, tx:tx + tp] for a in args[1:]] 32 | ] 33 | 34 | return ret 35 | 36 | def set_channel(*args, n_channels=3): 37 | def _set_channel(img): 38 | if img.ndim == 2: 39 | img = np.expand_dims(img, axis=2) 40 | 41 | c = img.shape[2] 42 | if n_channels == 1 and c == 3: 43 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 44 | elif n_channels == 3 and c == 1: 45 | img = np.concatenate([img] * n_channels, 2) 46 | 47 | return img 48 | 49 | return [_set_channel(a) for a in args] 50 | 51 | def np2Tensor(*args, rgb_range=255): 52 | def _np2Tensor(img): 53 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 54 | tensor = torch.from_numpy(np_transpose).float() 55 | tensor.mul_(rgb_range / 255) 56 | 57 | return tensor 58 | 59 | return [_np2Tensor(a) for a in args] 60 | 61 | def augment(*args, hflip=True, rot=True): 62 | hflip = hflip and random.random() < 0.5 63 | vflip = rot and random.random() < 0.5 64 | rot90 = rot and random.random() < 0.5 65 | 66 | def _augment(img): 67 | if hflip: img = img[:, ::-1] 68 | if vflip: img = img[::-1, :] 69 | if rot90: img = img.transpose(1, 0) 70 | 71 | return img 72 | 73 | return [_augment(a) for a in args] 74 | 75 | -------------------------------------------------------------------------------- /CAR/code/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /CAR/code/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /CAR/code/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /CAR/code/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /CAR/code/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /CAR/code/demo.sb: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # PANet Train 3 | #python main.py --n_GPUs 2 --batch_size 16 --lr 1e-4 --decay 200-400-600-800 ---save_models --model PANET --scale 10 --patch_size 48 --save PANET_Q10 --n_feats 64 --data_train DIV2K --chop 4 | 5 | # Test 6 | python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test classic5+LIVE1 --scale 40 --n_resblocks 80 --n_feats 64 --pre_train ../Q40.pt --test_only 7 | -------------------------------------------------------------------------------- /CAR/code/demo.sh: -------------------------------------------------------------------------------- 1 | # EDSR baseline model (x2) + JPEG augmentation 2 | python main.py --n_GPUs 2 --batch_size 16 --reset --save_models --model EDSRL1T --scale 1 --patch_size 48 --save EDSR_R16F64P48_CSAL1TK20 --n_feats 64 --data_train DIV2K --chop 3 | #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 4 | 5 | # EDSR baseline model (x3) - from EDSR baseline model (x2) 6 | #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] 7 | 8 | # EDSR baseline model (x4) - from EDSR baseline model (x2) 9 | #python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train pre-trained EDSR_baseline_x2 model dir 10 | 11 | # EDSR in the paper (x2) 12 | #python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset 13 | 14 | # EDSR in the paper (x3) - from EDSR (x2) 15 | #python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir] 16 | 17 | # EDSR in the paper (x4) - from EDSR (x2) 18 | #python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] 19 | 20 | # MDSR baseline model 21 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models 22 | 23 | # MDSR in the paper 24 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models 25 | 26 | # Standard benchmarks (Ex. EDSR_baseline_x4) 27 | #python main.py --model EDSRL1 --save_results --n_GPUs 2 --chop --data_test Kodak24+CBSD68+Urban100 --scale 1 --pre_train ../experiment/EDSR_R16F64P48_CSALK20/model/model_best.pt --test_only 28 | 29 | #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble 30 | 31 | # Test your own images 32 | #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results 33 | 34 | # Advanced - Test with JPEG images 35 | #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results 36 | 37 | # Advanced - Training with adversarial loss 38 | #python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download 39 | 40 | # RDN BI model (x2) 41 | #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset 42 | # RDN BI model (x3) 43 | #python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset 44 | # RDN BI model (x4) 45 | #python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset 46 | 47 | # RCAN_BIX2_G10R20P48, input=48x48, output=96x96 48 | # pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0 49 | #python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96 50 | # RCAN_BIX3_G10R20P48, input=48x48, output=144x144 51 | #python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt 52 | # RCAN_BIX4_G10R20P48, input=48x48, output=192x192 53 | #python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt 54 | # RCAN_BIX8_G10R20P48, input=48x48, output=384x384 55 | #python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt 56 | 57 | -------------------------------------------------------------------------------- /CAR/code/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /CAR/code/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /CAR/code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | def main(): 14 | global model 15 | if args.data_test == ['video']: 16 | from videotester import VideoTester 17 | model = model.Model(args,checkpoint) 18 | print('total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 19 | t = VideoTester(args, model, checkpoint) 20 | t.test() 21 | else: 22 | if checkpoint.ok: 23 | loader = data.Data(args) 24 | _model = model.Model(args, checkpoint) 25 | print('total params:%.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0)) 26 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 27 | t = Trainer(args, loader, _model, _loss, checkpoint) 28 | while not t.terminate(): 29 | t.train() 30 | t.test() 31 | 32 | checkpoint.done() 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /CAR/code/model/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/edsr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/edsr.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/edsr2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/edsr2.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/edsrl1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/edsrl1.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/edsrl1t.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/edsrl1t.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/__pycache__/mssr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/__pycache__/mssr.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2),stride=stride, bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, 27 | bn=False, act=nn.PReLU()): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) 88 | 89 | -------------------------------------------------------------------------------- /CAR/code/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return MDSR(args) 7 | 8 | class MDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(MDSR, self).__init__() 11 | n_resblocks = args.n_resblocks 12 | n_feats = args.n_feats 13 | kernel_size = 3 14 | self.scale_idx = 0 15 | 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | self.pre_process = nn.ModuleList([ 25 | nn.Sequential( 26 | common.ResBlock(conv, n_feats, 5, act=act), 27 | common.ResBlock(conv, n_feats, 5, act=act) 28 | ) for _ in args.scale 29 | ]) 30 | 31 | m_body = [ 32 | common.ResBlock( 33 | conv, n_feats, kernel_size, act=act 34 | ) for _ in range(n_resblocks) 35 | ] 36 | m_body.append(conv(n_feats, n_feats, kernel_size)) 37 | 38 | self.upsample = nn.ModuleList([ 39 | common.Upsampler( 40 | conv, s, n_feats, act=False 41 | ) for s in args.scale 42 | ]) 43 | 44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 45 | 46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 47 | 48 | self.head = nn.Sequential(*m_head) 49 | self.body = nn.Sequential(*m_body) 50 | self.tail = nn.Sequential(*m_tail) 51 | 52 | def forward(self, x): 53 | x = self.sub_mean(x) 54 | x = self.head(x) 55 | x = self.pre_process[self.scale_idx](x) 56 | 57 | res = self.body(x) 58 | res += x 59 | 60 | x = self.upsample[self.scale_idx](res) 61 | x = self.tail(x) 62 | x = self.add_mean(x) 63 | 64 | return x 65 | 66 | def set_scale(self, scale_idx): 67 | self.scale_idx = scale_idx 68 | 69 | -------------------------------------------------------------------------------- /CAR/code/model/panet.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import attention 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return PANET(args) 7 | 8 | class PANET(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(PANET, self).__init__() 11 | 12 | n_resblocks = args.n_resblocks 13 | n_feats = args.n_feats 14 | kernel_size = 3 15 | scale = args.scale[0] 16 | 17 | rgb_mean = (0.4488, 0.4371, 0.4040) 18 | rgb_std = (1.0, 1.0, 1.0) 19 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 20 | msa = attention.PyramidAttention() 21 | # define head module 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | # define body module 25 | m_body = [ 26 | common.ResBlock( 27 | conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale 28 | ) for _ in range(n_resblocks//2) 29 | ] 30 | m_body.append(msa) 31 | for i in range(n_resblocks//2): 32 | m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale)) 33 | 34 | m_body.append(conv(n_feats, n_feats, kernel_size)) 35 | 36 | # define tail module 37 | #m_tail = [ 38 | # common.Upsampler(conv, scale, n_feats, act=False), 39 | # conv(n_feats, args.n_colors, kernel_size) 40 | #] 41 | m_tail = [ 42 | conv(n_feats, args.n_colors, kernel_size) 43 | ] 44 | 45 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 46 | 47 | self.head = nn.Sequential(*m_head) 48 | self.body = nn.Sequential(*m_body) 49 | self.tail = nn.Sequential(*m_tail) 50 | 51 | def forward(self, x): 52 | #x = self.sub_mean(x) 53 | x = self.head(x) 54 | 55 | res = self.body(x) 56 | 57 | res += x 58 | 59 | x = self.tail(res) 60 | #x = self.add_mean(x) 61 | 62 | return x 63 | 64 | def load_state_dict(self, state_dict, strict=True): 65 | own_state = self.state_dict() 66 | for name, param in state_dict.items(): 67 | if name in own_state: 68 | if isinstance(param, nn.Parameter): 69 | param = param.data 70 | try: 71 | own_state[name].copy_(param) 72 | except Exception: 73 | if name.find('tail') == -1: 74 | raise RuntimeError('While copying the parameter named {}, ' 75 | 'whose dimensions in the model are {} and ' 76 | 'whose dimensions in the checkpoint are {}.' 77 | .format(name, own_state[name].size(), param.size())) 78 | elif strict: 79 | if name.find('tail') == -1: 80 | raise KeyError('unexpected key "{}" in state_dict' 81 | .format(name)) 82 | 83 | -------------------------------------------------------------------------------- /CAR/code/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | r = args.scale[0] 49 | G0 = args.G0 50 | kSize = args.RDNkSize 51 | 52 | # number of RDB blocks, conv layers, out channels 53 | self.D, C, G = { 54 | 'A': (20, 6, 32), 55 | 'B': (16, 8, 64), 56 | }[args.RDNconfig] 57 | 58 | # Shallow feature extraction net 59 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 60 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self.RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self.RDBs.append( 66 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 73 | ]) 74 | 75 | # Up-sampling net 76 | if r == 2 or r == 3: 77 | self.UPNet = nn.Sequential(*[ 78 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 79 | nn.PixelShuffle(r), 80 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 81 | ]) 82 | elif r == 4: 83 | self.UPNet = nn.Sequential(*[ 84 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 85 | nn.PixelShuffle(2), 86 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 87 | nn.PixelShuffle(2), 88 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 89 | ]) 90 | else: 91 | raise ValueError("scale must be 2 or 3 or 4.") 92 | 93 | def forward(self, x): 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = self.RDBs[i](x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out,1)) 103 | x += f__1 104 | 105 | return self.UPNet(x) 106 | -------------------------------------------------------------------------------- /CAR/code/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/utils/__init__.py -------------------------------------------------------------------------------- /CAR/code/model/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /CAR/code/model/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/model/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /CAR/code/model/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /CAR/code/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /CAR/code/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-1 53 | 54 | -------------------------------------------------------------------------------- /CAR/code/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/utils/__init__.py -------------------------------------------------------------------------------- /CAR/code/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /CAR/code/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /CAR/code/utils/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/CAR/code/utils/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /CAR/code/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /CAR/code/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /DN_RGB/code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /DN_RGB/code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/__init__.py -------------------------------------------------------------------------------- /DN_RGB/code/__pycache__/option.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/__pycache__/option.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/__pycache__/template.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/__pycache__/template.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/__pycache__/utility.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/__pycache__/utility.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['CBSD68','Kodak24','Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /DN_RGB/code/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/data/__pycache__/benchmark.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/data/__pycache__/benchmark.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/data/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/data/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/data/__pycache__/div2k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/data/__pycache__/div2k.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/data/__pycache__/srdata.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/data/__pycache__/srdata.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /DN_RGB/code/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=1, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = 1 if multi else 1 13 | tp = p * patch_size 14 | ip = tp // 1 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = 1 * ix, 1 * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 42 | elif n_channels == 3 and c == 1: 43 | img = np.concatenate([img] * n_channels, 2) 44 | 45 | return img 46 | 47 | return [_set_channel(a) for a in args] 48 | 49 | def np2Tensor(*args, rgb_range=255): 50 | def _np2Tensor(img): 51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 52 | tensor = torch.from_numpy(np_transpose).float() 53 | tensor.mul_(rgb_range / 255) 54 | 55 | return tensor 56 | 57 | return [_np2Tensor(a) for a in args] 58 | 59 | def augment(*args, hflip=True, rot=True): 60 | hflip = hflip and random.random() < 0.5 61 | vflip = rot and random.random() < 0.5 62 | rot90 = rot and random.random() < 0.5 63 | 64 | def _augment(img): 65 | if hflip: img = img[:, ::-1, :] 66 | if vflip: img = img[::-1, :, :] 67 | if rot90: img = img.transpose(1, 0, 2) 68 | 69 | return img 70 | 71 | return [_augment(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /DN_RGB/code/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /DN_RGB/code/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /DN_RGB/code/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /DN_RGB/code/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /DN_RGB/code/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /DN_RGB/code/demo.sb: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # PANET Train 3 | #python main.py --n_GPUs 1 --lr 1e-5 --batch_size 16 --save_models --epoch 100 --model PANET --scale 50 --patch_size 48 --reset --save MDSR_att_N50 --n_feats 64 --data_train DIV2K --chop 4 | 5 | #PANET Test 6 | python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test Kodak24+CBSD68+Urban100 --scale 10 --pre_train ../model_N10.pt --test_only 7 | -------------------------------------------------------------------------------- /DN_RGB/code/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /DN_RGB/code/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /DN_RGB/code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | def main(): 14 | global model 15 | if args.data_test == ['video']: 16 | from videotester import VideoTester 17 | model = model.Model(args,checkpoint) 18 | print('total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 19 | t = VideoTester(args, model, checkpoint) 20 | t.test() 21 | else: 22 | if checkpoint.ok: 23 | loader = data.Data(args) 24 | _model = model.Model(args, checkpoint) 25 | #print('total params:%.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 26 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 27 | t = Trainer(args, loader, _model, _loss, checkpoint) 28 | while not t.terminate(): 29 | t.train() 30 | t.test() 31 | 32 | checkpoint.done() 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /DN_RGB/code/model/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/edsr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/edsr.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/edsr2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/edsr2.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/edsrl1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/edsrl1.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/edsrl1t.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/edsrl1t.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/mssr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/mssr.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/__pycache__/panet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/__pycache__/panet.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2),stride=stride, bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, 27 | bn=False, act=nn.PReLU()): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) 88 | 89 | -------------------------------------------------------------------------------- /DN_RGB/code/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return MDSR(args) 7 | 8 | class MDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(MDSR, self).__init__() 11 | n_resblocks = args.n_resblocks 12 | n_feats = args.n_feats 13 | kernel_size = 3 14 | self.scale_idx = 0 15 | 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | self.pre_process = nn.ModuleList([ 25 | nn.Sequential( 26 | common.ResBlock(conv, n_feats, 5, act=act), 27 | common.ResBlock(conv, n_feats, 5, act=act) 28 | ) for _ in args.scale 29 | ]) 30 | 31 | m_body = [ 32 | common.ResBlock( 33 | conv, n_feats, kernel_size, act=act 34 | ) for _ in range(n_resblocks) 35 | ] 36 | m_body.append(conv(n_feats, n_feats, kernel_size)) 37 | 38 | self.upsample = nn.ModuleList([ 39 | common.Upsampler( 40 | conv, s, n_feats, act=False 41 | ) for s in args.scale 42 | ]) 43 | 44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 45 | 46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 47 | 48 | self.head = nn.Sequential(*m_head) 49 | self.body = nn.Sequential(*m_body) 50 | self.tail = nn.Sequential(*m_tail) 51 | 52 | def forward(self, x): 53 | x = self.sub_mean(x) 54 | x = self.head(x) 55 | x = self.pre_process[self.scale_idx](x) 56 | 57 | res = self.body(x) 58 | res += x 59 | 60 | x = self.upsample[self.scale_idx](res) 61 | x = self.tail(x) 62 | x = self.add_mean(x) 63 | 64 | return x 65 | 66 | def set_scale(self, scale_idx): 67 | self.scale_idx = scale_idx 68 | 69 | -------------------------------------------------------------------------------- /DN_RGB/code/model/panet.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import attention 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return PANET(args) 7 | 8 | class PANET(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(PANET, self).__init__() 11 | 12 | n_resblocks = args.n_resblocks 13 | n_feats = args.n_feats 14 | kernel_size = 3 15 | scale = args.scale[0] 16 | 17 | rgb_mean = (0.4488, 0.4371, 0.4040) 18 | rgb_std = (1.0, 1.0, 1.0) 19 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 20 | msa = attention.PyramidAttention() 21 | # define head module 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | # define body module 25 | m_body = [ 26 | common.ResBlock( 27 | conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale 28 | ) for _ in range(n_resblocks//2) 29 | ] 30 | m_body.append(msa) 31 | for i in range(n_resblocks//2): 32 | m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale)) 33 | 34 | m_body.append(conv(n_feats, n_feats, kernel_size)) 35 | 36 | # define tail module 37 | #m_tail = [ 38 | # common.Upsampler(conv, scale, n_feats, act=False), 39 | # conv(n_feats, args.n_colors, kernel_size) 40 | #] 41 | m_tail = [ 42 | conv(n_feats, args.n_colors, kernel_size) 43 | ] 44 | 45 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 46 | 47 | self.head = nn.Sequential(*m_head) 48 | self.body = nn.Sequential(*m_body) 49 | self.tail = nn.Sequential(*m_tail) 50 | 51 | def forward(self, x): 52 | #x = self.sub_mean(x) 53 | x = self.head(x) 54 | 55 | res = self.body(x) 56 | 57 | res += x 58 | 59 | x = self.tail(res) 60 | #x = self.add_mean(x) 61 | 62 | return x 63 | 64 | def load_state_dict(self, state_dict, strict=True): 65 | own_state = self.state_dict() 66 | for name, param in state_dict.items(): 67 | if name in own_state: 68 | if isinstance(param, nn.Parameter): 69 | param = param.data 70 | try: 71 | own_state[name].copy_(param) 72 | except Exception: 73 | if name.find('tail') == -1: 74 | raise RuntimeError('While copying the parameter named {}, ' 75 | 'whose dimensions in the model are {} and ' 76 | 'whose dimensions in the checkpoint are {}.' 77 | .format(name, own_state[name].size(), param.size())) 78 | elif strict: 79 | if name.find('tail') == -1: 80 | raise KeyError('unexpected key "{}" in state_dict' 81 | .format(name)) 82 | 83 | -------------------------------------------------------------------------------- /DN_RGB/code/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | r = args.scale[0] 49 | G0 = args.G0 50 | kSize = args.RDNkSize 51 | 52 | # number of RDB blocks, conv layers, out channels 53 | self.D, C, G = { 54 | 'A': (20, 6, 32), 55 | 'B': (16, 8, 64), 56 | }[args.RDNconfig] 57 | 58 | # Shallow feature extraction net 59 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 60 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self.RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self.RDBs.append( 66 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 73 | ]) 74 | 75 | # Up-sampling net 76 | if r == 2 or r == 3: 77 | self.UPNet = nn.Sequential(*[ 78 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 79 | nn.PixelShuffle(r), 80 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 81 | ]) 82 | elif r == 4: 83 | self.UPNet = nn.Sequential(*[ 84 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 85 | nn.PixelShuffle(2), 86 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 87 | nn.PixelShuffle(2), 88 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 89 | ]) 90 | else: 91 | raise ValueError("scale must be 2 or 3 or 4.") 92 | 93 | def forward(self, x): 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = self.RDBs[i](x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out,1)) 103 | x += f__1 104 | 105 | return self.UPNet(x) 106 | -------------------------------------------------------------------------------- /DN_RGB/code/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/utils/__init__.py -------------------------------------------------------------------------------- /DN_RGB/code/model/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/model/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /DN_RGB/code/model/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /DN_RGB/code/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /DN_RGB/code/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-1 53 | 54 | -------------------------------------------------------------------------------- /DN_RGB/code/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/utils/__init__.py -------------------------------------------------------------------------------- /DN_RGB/code/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /DN_RGB/code/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /DN_RGB/code/utils/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/DN_RGB/code/utils/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /DN_RGB/code/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /DN_RGB/code/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /Demosaic/code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Demosaic/code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__init__.py -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/option.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/option.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/option.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/option.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/template.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/template.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/template.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/template.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/trainer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/trainer.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/utility.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/utility.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/__pycache__/utility.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/__pycache__/utility.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['CBSD68','Kodak24','McM','Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/benchmark.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/benchmark.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/common.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/common.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/div2k.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/div2k.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/div2k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/div2k.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/srdata.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/srdata.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/__pycache__/srdata.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/data/__pycache__/srdata.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /Demosaic/code/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=1, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = 1 if multi else 1 13 | tp = p * patch_size 14 | ip = tp // 1 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = 1 * ix, 1 * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 42 | elif n_channels == 3 and c == 1: 43 | img = np.concatenate([img] * n_channels, 2) 44 | 45 | return img 46 | 47 | return [_set_channel(a) for a in args] 48 | 49 | def np2Tensor(*args, rgb_range=255): 50 | def _np2Tensor(img): 51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 52 | tensor = torch.from_numpy(np_transpose).float() 53 | tensor.mul_(rgb_range / 255) 54 | 55 | return tensor 56 | 57 | return [_np2Tensor(a) for a in args] 58 | 59 | def augment(*args, hflip=True, rot=True): 60 | hflip = hflip and random.random() < 0.5 61 | vflip = rot and random.random() < 0.5 62 | rot90 = rot and random.random() < 0.5 63 | 64 | def _augment(img): 65 | if hflip: img = img[:, ::-1, :] 66 | if vflip: img = img[::-1, :, :] 67 | if rot90: img = img.transpose(1, 0, 2) 68 | 69 | return img 70 | 71 | return [_augment(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /Demosaic/code/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /Demosaic/code/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /Demosaic/code/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /Demosaic/code/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /Demosaic/code/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /Demosaic/code/demo.sb: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # PANET Train 4 | #python main.py --n_GPUs 4 --lr 1e-4 --decay 200-400-600-800 --epoch 1000 --batch_size 16 --n_resblocks 80 --save_models --model PANET --scale 1 --patch_size 48 --save PANET_DEMOSAIC --n_feats 64 --data_train DIV2K --chop 5 | # Test 6 | python main.py --model PANET --save_results --n_GPUs 1 --chop --data_test McM+Kodak24+CBSD68+Urban100 --scale 1 --pre_train ../model_best.pt --test_only -------------------------------------------------------------------------------- /Demosaic/code/loss/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/loss/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /Demosaic/code/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /Demosaic/code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | def main(): 14 | global model 15 | if args.data_test == ['video']: 16 | from videotester import VideoTester 17 | model = model.Model(args,checkpoint) 18 | print('total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 19 | t = VideoTester(args, model, checkpoint) 20 | t.test() 21 | else: 22 | if checkpoint.ok: 23 | loader = data.Data(args) 24 | _model = model.Model(args, checkpoint) 25 | print('total params:%.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0)) 26 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 27 | t = Trainer(args, loader, _model, _loss, checkpoint) 28 | while not t.terminate(): 29 | t.train() 30 | t.test() 31 | 32 | checkpoint.done() 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /Demosaic/code/model/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/attention.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/common.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/common.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/edsr.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/edsr.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/edsr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/edsr.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/edsr2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/edsr2.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/edsrl1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/edsrl1.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/edsrl1t.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/edsrl1t.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/__pycache__/mssr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/__pycache__/mssr.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2),stride=stride, bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, 27 | bn=False, act=nn.PReLU()): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) 88 | 89 | -------------------------------------------------------------------------------- /Demosaic/code/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return MDSR(args) 7 | 8 | class MDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(MDSR, self).__init__() 11 | n_resblocks = args.n_resblocks 12 | n_feats = args.n_feats 13 | kernel_size = 3 14 | self.scale_idx = 0 15 | 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | self.pre_process = nn.ModuleList([ 25 | nn.Sequential( 26 | common.ResBlock(conv, n_feats, 5, act=act), 27 | common.ResBlock(conv, n_feats, 5, act=act) 28 | ) for _ in args.scale 29 | ]) 30 | 31 | m_body = [ 32 | common.ResBlock( 33 | conv, n_feats, kernel_size, act=act 34 | ) for _ in range(n_resblocks) 35 | ] 36 | m_body.append(conv(n_feats, n_feats, kernel_size)) 37 | 38 | self.upsample = nn.ModuleList([ 39 | common.Upsampler( 40 | conv, s, n_feats, act=False 41 | ) for s in args.scale 42 | ]) 43 | 44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 45 | 46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 47 | 48 | self.head = nn.Sequential(*m_head) 49 | self.body = nn.Sequential(*m_body) 50 | self.tail = nn.Sequential(*m_tail) 51 | 52 | def forward(self, x): 53 | x = self.sub_mean(x) 54 | x = self.head(x) 55 | x = self.pre_process[self.scale_idx](x) 56 | 57 | res = self.body(x) 58 | res += x 59 | 60 | x = self.upsample[self.scale_idx](res) 61 | x = self.tail(x) 62 | x = self.add_mean(x) 63 | 64 | return x 65 | 66 | def set_scale(self, scale_idx): 67 | self.scale_idx = scale_idx 68 | 69 | -------------------------------------------------------------------------------- /Demosaic/code/model/panet.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import attention 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return PANET(args) 7 | 8 | class PANET(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(PANET, self).__init__() 11 | 12 | n_resblocks = args.n_resblocks 13 | n_feats = args.n_feats 14 | kernel_size = 3 15 | scale = args.scale[0] 16 | 17 | rgb_mean = (0.4488, 0.4371, 0.4040) 18 | rgb_std = (1.0, 1.0, 1.0) 19 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 20 | msa = attention.PyramidAttention() 21 | # define head module 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | # define body module 25 | m_body = [ 26 | common.ResBlock( 27 | conv, n_feats, kernel_size, nn.PReLU(), res_scale=args.res_scale 28 | ) for _ in range(n_resblocks//2) 29 | ] 30 | m_body.append(msa) 31 | for i in range(n_resblocks//2): 32 | m_body.append(common.ResBlock(conv,n_feats,kernel_size,nn.PReLU(),res_scale=args.res_scale)) 33 | 34 | m_body.append(conv(n_feats, n_feats, kernel_size)) 35 | 36 | # define tail module 37 | #m_tail = [ 38 | # common.Upsampler(conv, scale, n_feats, act=False), 39 | # conv(n_feats, args.n_colors, kernel_size) 40 | #] 41 | m_tail = [ 42 | conv(n_feats, args.n_colors, kernel_size) 43 | ] 44 | 45 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 46 | 47 | self.head = nn.Sequential(*m_head) 48 | self.body = nn.Sequential(*m_body) 49 | self.tail = nn.Sequential(*m_tail) 50 | 51 | def forward(self, x): 52 | #x = self.sub_mean(x) 53 | x = self.head(x) 54 | 55 | res = self.body(x) 56 | 57 | res += x 58 | 59 | x = self.tail(res) 60 | #x = self.add_mean(x) 61 | 62 | return x 63 | 64 | def load_state_dict(self, state_dict, strict=True): 65 | own_state = self.state_dict() 66 | for name, param in state_dict.items(): 67 | if name in own_state: 68 | if isinstance(param, nn.Parameter): 69 | param = param.data 70 | try: 71 | own_state[name].copy_(param) 72 | except Exception: 73 | if name.find('tail') == -1: 74 | raise RuntimeError('While copying the parameter named {}, ' 75 | 'whose dimensions in the model are {} and ' 76 | 'whose dimensions in the checkpoint are {}.' 77 | .format(name, own_state[name].size(), param.size())) 78 | elif strict: 79 | if name.find('tail') == -1: 80 | raise KeyError('unexpected key "{}" in state_dict' 81 | .format(name)) 82 | 83 | -------------------------------------------------------------------------------- /Demosaic/code/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | r = args.scale[0] 49 | G0 = args.G0 50 | kSize = args.RDNkSize 51 | 52 | # number of RDB blocks, conv layers, out channels 53 | self.D, C, G = { 54 | 'A': (20, 6, 32), 55 | 'B': (16, 8, 64), 56 | }[args.RDNconfig] 57 | 58 | # Shallow feature extraction net 59 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 60 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self.RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self.RDBs.append( 66 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 73 | ]) 74 | 75 | # Up-sampling net 76 | if r == 2 or r == 3: 77 | self.UPNet = nn.Sequential(*[ 78 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 79 | nn.PixelShuffle(r), 80 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 81 | ]) 82 | elif r == 4: 83 | self.UPNet = nn.Sequential(*[ 84 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 85 | nn.PixelShuffle(2), 86 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 87 | nn.PixelShuffle(2), 88 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 89 | ]) 90 | else: 91 | raise ValueError("scale must be 2 or 3 or 4.") 92 | 93 | def forward(self, x): 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = self.RDBs[i](x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out,1)) 103 | x += f__1 104 | 105 | return self.UPNet(x) 106 | -------------------------------------------------------------------------------- /Demosaic/code/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/utils/__init__.py -------------------------------------------------------------------------------- /Demosaic/code/model/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/model/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/model/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /Demosaic/code/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /Demosaic/code/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-1 53 | 54 | -------------------------------------------------------------------------------- /Demosaic/code/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/utils/__init__.py -------------------------------------------------------------------------------- /Demosaic/code/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /Demosaic/code/utils/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Demosaic/code/utils/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /Demosaic/code/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /Demosaic/code/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /Figs/PSNR_CAR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/PSNR_CAR.png -------------------------------------------------------------------------------- /Figs/PSNR_DN_RGB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/PSNR_DN_RGB.png -------------------------------------------------------------------------------- /Figs/PSNR_Demosaic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/PSNR_Demosaic.png -------------------------------------------------------------------------------- /Figs/PSNR_SR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/PSNR_SR.png -------------------------------------------------------------------------------- /Figs/Screenshot from 2020-04-23 18-56-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/Screenshot from 2020-04-23 18-56-03.png -------------------------------------------------------------------------------- /Figs/Visual_CAR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/Visual_CAR.png -------------------------------------------------------------------------------- /Figs/Visual_DN_RGB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/Visual_DN_RGB.png -------------------------------------------------------------------------------- /Figs/Visual_Demosaic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/Visual_Demosaic.png -------------------------------------------------------------------------------- /Figs/Visual_SR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/Visual_SR.png -------------------------------------------------------------------------------- /Figs/block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/Figs/block.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yiqun Mei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SR/code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__init__.py -------------------------------------------------------------------------------- /SR/code/__pycache__/option.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/option.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/__pycache__/option.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/option.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/__pycache__/template.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/template.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/__pycache__/template.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/template.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/__pycache__/trainer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/trainer.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/__pycache__/utility.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/utility.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/__pycache__/utility.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/__pycache__/utility.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /SR/code/data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/benchmark.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/benchmark.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/benchmark.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/benchmark.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/common.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/common.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/demo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/demo.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/div2k.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/div2k.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/div2k.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/div2k.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/srdata.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/srdata.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/data/__pycache__/srdata.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/data/__pycache__/srdata.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /SR/code/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = scale if multi else 1 13 | tp = p * patch_size 14 | ip = tp // scale 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = scale * ix, scale * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 42 | elif n_channels == 3 and c == 1: 43 | img = np.concatenate([img] * n_channels, 2) 44 | 45 | return img 46 | 47 | return [_set_channel(a) for a in args] 48 | 49 | def np2Tensor(*args, rgb_range=255): 50 | def _np2Tensor(img): 51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 52 | tensor = torch.from_numpy(np_transpose).float() 53 | tensor.mul_(rgb_range / 255) 54 | 55 | return tensor 56 | 57 | return [_np2Tensor(a) for a in args] 58 | 59 | def augment(*args, hflip=True, rot=True): 60 | hflip = hflip and random.random() < 0.5 61 | vflip = rot and random.random() < 0.5 62 | rot90 = rot and random.random() < 0.5 63 | 64 | def _augment(img): 65 | if hflip: img = img[:, ::-1, :] 66 | if vflip: img = img[::-1, :, :] 67 | if rot90: img = img.transpose(1, 0, 2) 68 | 69 | return img 70 | 71 | return [_augment(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /SR/code/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /SR/code/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /SR/code/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /SR/code/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /SR/code/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /SR/code/demo.sb: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #Train x2 3 | python main.py --n_GPUs 4 --rgb_range 1 --reset --save_models --lr 1e-4 --decay 200-400-600-800 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model PAEDSR --scale 2 --patch_size 96 --save EDSR_PA_x2 --data_train DIV2K 4 | #Test 5 | python main.py --model PAEDSR --data_test Set5+Set14+B100+Urban100 --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train ../model_x2.pt --test_only --chop 6 | -------------------------------------------------------------------------------- /SR/code/loss/__loss__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/loss/__loss__.py -------------------------------------------------------------------------------- /SR/code/loss/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/loss/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/loss/demo.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/loss/demo.sh -------------------------------------------------------------------------------- /SR/code/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /SR/code/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /SR/code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | def main(): 14 | global model 15 | if args.data_test == ['video']: 16 | from videotester import VideoTester 17 | model = model.Model(args, checkpoint) 18 | print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 19 | t = VideoTester(args, model, checkpoint) 20 | t.test() 21 | else: 22 | if checkpoint.ok: 23 | loader = data.Data(args) 24 | _model = model.Model(args, checkpoint) 25 | print('Total params: %.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0)) 26 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 27 | t = Trainer(args, loader, _model, _loss, checkpoint) 28 | while not t.terminate(): 29 | t.train() 30 | t.test() 31 | 32 | checkpoint.done() 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /SR/code/model/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Sanghyun Son 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SR/code/model/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/attention.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/common.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/common.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/edsr.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/edsr.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/edsr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/edsr.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/mssr.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/mssr.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/mssr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/mssr.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/model/__pycache__/rcan.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/__pycache__/rcan.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2),stride=stride, bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, 27 | bn=False, act=nn.PReLU()): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) 88 | 89 | -------------------------------------------------------------------------------- /SR/code/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import attention 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | if args.dilation: 7 | from model import dilated 8 | return PAEDSR(args, dilated.dilated_conv) 9 | else: 10 | return PAEDSR(args) 11 | 12 | class PAEDSR(nn.Module): 13 | def __init__(self, args, conv=common.default_conv): 14 | super(PAEDSR, self).__init__() 15 | 16 | n_resblock = args.n_resblocks 17 | n_feats = args.n_feats 18 | kernel_size = 3 19 | scale = args.scale[0] 20 | act = nn.ReLU(True) 21 | 22 | rgb_mean = (0.4488, 0.4371, 0.4040) 23 | rgb_std = (1.0, 1.0, 1.0) 24 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 25 | self.msa = attention.PyramidAttention(channel=256, reduction=8,res_scale=args.res_scale); 26 | # define head module 27 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 28 | 29 | # define body module 30 | m_body = [ 31 | common.ResBlock( 32 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 33 | ) for _ in range(n_resblock//2) 34 | ] 35 | m_body.append(self.msa) 36 | for _ in range(n_resblock//2): 37 | m_body.append( common.ResBlock( 38 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 39 | )) 40 | m_body.append(conv(n_feats, n_feats, kernel_size)) 41 | 42 | # define tail module 43 | m_tail = [ 44 | common.Upsampler(conv, scale, n_feats, act=False), 45 | nn.Conv2d( 46 | n_feats, args.n_colors, kernel_size, 47 | padding=(kernel_size//2) 48 | ) 49 | ] 50 | 51 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 52 | 53 | self.head = nn.Sequential(*m_head) 54 | self.body = nn.Sequential(*m_body) 55 | self.tail = nn.Sequential(*m_tail) 56 | 57 | def forward(self, x): 58 | x = self.sub_mean(x) 59 | x = self.head(x) 60 | 61 | res = self.body(x) 62 | res += x 63 | 64 | x = self.tail(res) 65 | x = self.add_mean(x) 66 | 67 | return x 68 | 69 | def load_state_dict(self, state_dict, strict=True): 70 | own_state = self.state_dict() 71 | for name, param in state_dict.items(): 72 | if name in own_state: 73 | if isinstance(param, nn.Parameter): 74 | param = param.data 75 | try: 76 | own_state[name].copy_(param) 77 | except Exception: 78 | if name.find('tail') == -1: 79 | raise RuntimeError('While copying the parameter named {}, ' 80 | 'whose dimensions in the model are {} and ' 81 | 'whose dimensions in the checkpoint are {}.' 82 | .format(name, own_state[name].size(), param.size())) 83 | elif strict: 84 | if name.find('tail') == -1: 85 | raise KeyError('unexpected key "{}" in state_dict' 86 | .format(name)) 87 | 88 | -------------------------------------------------------------------------------- /SR/code/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return MDSR(args) 7 | 8 | class MDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(MDSR, self).__init__() 11 | n_resblocks = args.n_resblocks 12 | n_feats = args.n_feats 13 | kernel_size = 3 14 | self.scale_idx = 0 15 | 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | self.pre_process = nn.ModuleList([ 25 | nn.Sequential( 26 | common.ResBlock(conv, n_feats, 5, act=act), 27 | common.ResBlock(conv, n_feats, 5, act=act) 28 | ) for _ in args.scale 29 | ]) 30 | 31 | m_body = [ 32 | common.ResBlock( 33 | conv, n_feats, kernel_size, act=act 34 | ) for _ in range(n_resblocks) 35 | ] 36 | m_body.append(conv(n_feats, n_feats, kernel_size)) 37 | 38 | self.upsample = nn.ModuleList([ 39 | common.Upsampler( 40 | conv, s, n_feats, act=False 41 | ) for s in args.scale 42 | ]) 43 | 44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 45 | 46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 47 | 48 | self.head = nn.Sequential(*m_head) 49 | self.body = nn.Sequential(*m_body) 50 | self.tail = nn.Sequential(*m_tail) 51 | 52 | def forward(self, x): 53 | x = self.sub_mean(x) 54 | x = self.head(x) 55 | x = self.pre_process[self.scale_idx](x) 56 | 57 | res = self.body(x) 58 | res += x 59 | 60 | x = self.upsample[self.scale_idx](res) 61 | x = self.tail(x) 62 | x = self.add_mean(x) 63 | 64 | return x 65 | 66 | def set_scale(self, scale_idx): 67 | self.scale_idx = scale_idx 68 | 69 | -------------------------------------------------------------------------------- /SR/code/model/paedsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | from model import attention 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | if args.dilation: 7 | from model import dilated 8 | return PAEDSR(args, dilated.dilated_conv) 9 | else: 10 | return PAEDSR(args) 11 | 12 | class PAEDSR(nn.Module): 13 | def __init__(self, args, conv=common.default_conv): 14 | super(PAEDSR, self).__init__() 15 | 16 | n_resblock = args.n_resblocks 17 | n_feats = args.n_feats 18 | kernel_size = 3 19 | scale = args.scale[0] 20 | act = nn.ReLU(True) 21 | 22 | rgb_mean = (0.4488, 0.4371, 0.4040) 23 | rgb_std = (1.0, 1.0, 1.0) 24 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 25 | self.msa = attention.PyramidAttention(channel=256, reduction=8,res_scale=args.res_scale); 26 | # define head module 27 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 28 | 29 | # define body module 30 | m_body = [ 31 | common.ResBlock( 32 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 33 | ) for _ in range(n_resblock//2) 34 | ] 35 | m_body.append(self.msa) 36 | for _ in range(n_resblock//2): 37 | m_body.append( common.ResBlock( 38 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 39 | )) 40 | m_body.append(conv(n_feats, n_feats, kernel_size)) 41 | 42 | # define tail module 43 | m_tail = [ 44 | common.Upsampler(conv, scale, n_feats, act=False), 45 | nn.Conv2d( 46 | n_feats, args.n_colors, kernel_size, 47 | padding=(kernel_size//2) 48 | ) 49 | ] 50 | 51 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 52 | 53 | self.head = nn.Sequential(*m_head) 54 | self.body = nn.Sequential(*m_body) 55 | self.tail = nn.Sequential(*m_tail) 56 | 57 | def forward(self, x): 58 | x = self.sub_mean(x) 59 | x = self.head(x) 60 | 61 | res = self.body(x) 62 | res += x 63 | 64 | x = self.tail(res) 65 | x = self.add_mean(x) 66 | 67 | return x 68 | 69 | def load_state_dict(self, state_dict, strict=True): 70 | own_state = self.state_dict() 71 | for name, param in state_dict.items(): 72 | if name in own_state: 73 | if isinstance(param, nn.Parameter): 74 | param = param.data 75 | try: 76 | own_state[name].copy_(param) 77 | except Exception: 78 | if name.find('tail') == -1: 79 | raise RuntimeError('While copying the parameter named {}, ' 80 | 'whose dimensions in the model are {} and ' 81 | 'whose dimensions in the checkpoint are {}.' 82 | .format(name, own_state[name].size(), param.size())) 83 | elif strict: 84 | if name.find('tail') == -1: 85 | raise KeyError('unexpected key "{}" in state_dict' 86 | .format(name)) 87 | 88 | -------------------------------------------------------------------------------- /SR/code/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | r = args.scale[0] 49 | G0 = args.G0 50 | kSize = args.RDNkSize 51 | 52 | # number of RDB blocks, conv layers, out channels 53 | self.D, C, G = { 54 | 'A': (20, 6, 32), 55 | 'B': (16, 8, 64), 56 | }[args.RDNconfig] 57 | 58 | # Shallow feature extraction net 59 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 60 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self.RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self.RDBs.append( 66 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 73 | ]) 74 | 75 | # Up-sampling net 76 | if r == 2 or r == 3: 77 | self.UPNet = nn.Sequential(*[ 78 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 79 | nn.PixelShuffle(r), 80 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 81 | ]) 82 | elif r == 4: 83 | self.UPNet = nn.Sequential(*[ 84 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 85 | nn.PixelShuffle(2), 86 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 87 | nn.PixelShuffle(2), 88 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 89 | ]) 90 | else: 91 | raise ValueError("scale must be 2 or 3 or 4.") 92 | 93 | def forward(self, x): 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = self.RDBs[i](x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out,1)) 103 | x += f__1 104 | 105 | return self.UPNet(x) 106 | -------------------------------------------------------------------------------- /SR/code/model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/utils/__init__.py -------------------------------------------------------------------------------- /SR/code/model/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/model/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/model/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/model/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /SR/code/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /SR/code/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-1 53 | 54 | -------------------------------------------------------------------------------- /SR/code/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/utils/__init__.py -------------------------------------------------------------------------------- /SR/code/utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/utils/__pycache__/tools.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/utils/__pycache__/tools.cpython-35.pyc -------------------------------------------------------------------------------- /SR/code/utils/__pycache__/tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Pyramid-Attention-Networks/a267cf8ef663212c1e1f7238b6313ff1d780ae94/SR/code/utils/__pycache__/tools.cpython-36.pyc -------------------------------------------------------------------------------- /SR/code/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | def reduce_mean(x, axis=None, keepdim=False): 60 | if not axis: 61 | axis = range(len(x.shape)) 62 | for i in sorted(axis, reverse=True): 63 | x = torch.mean(x, dim=i, keepdim=keepdim) 64 | return x 65 | 66 | 67 | def reduce_std(x, axis=None, keepdim=False): 68 | if not axis: 69 | axis = range(len(x.shape)) 70 | for i in sorted(axis, reverse=True): 71 | x = torch.std(x, dim=i, keepdim=keepdim) 72 | return x 73 | 74 | 75 | def reduce_sum(x, axis=None, keepdim=False): 76 | if not axis: 77 | axis = range(len(x.shape)) 78 | for i in sorted(axis, reverse=True): 79 | x = torch.sum(x, dim=i, keepdim=keepdim) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /SR/code/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | --------------------------------------------------------------------------------