├── src ├── __init__.py ├── .gitignore ├── data │ ├── sr291.py │ ├── div2kjpeg.py │ ├── benchmark.py │ ├── demo.py │ ├── df2k.py │ ├── div2k.py │ ├── video.py │ ├── common.py │ ├── __init__.py │ ├── srdata_no_bin.py │ ├── srdata.py │ └── srdata_bin.py ├── requirements.txt ├── pruner │ ├── __init__.py │ ├── meta_pruner.py │ ├── l1_pruner.py │ ├── utils.py │ └── assl_pruner.py ├── cal_modelsize.py ├── loss │ ├── vgg.py │ ├── discriminator.py │ ├── adversarial.py │ └── __init__.py ├── model │ ├── vdsr.py │ ├── mdsr.py │ ├── edsr.py │ ├── ledsr.py │ ├── rdn.py │ ├── ddbpn.py │ ├── rcan.py │ ├── common.py │ ├── rirsr.py │ └── __init__.py ├── template.py ├── videotester.py ├── layer.py ├── trainer.py ├── dataloader.py ├── main.py ├── demo.sh ├── utility.py └── option.py ├── figs ├── neu.png ├── smile.png ├── psnr_ssim.png ├── NIPS21_ASSL.png └── visual_urban100_x4.png └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/neu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/neu.png -------------------------------------------------------------------------------- /figs/smile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/smile.png -------------------------------------------------------------------------------- /figs/psnr_ssim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/psnr_ssim.png -------------------------------------------------------------------------------- /figs/NIPS21_ASSL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/NIPS21_ASSL.png -------------------------------------------------------------------------------- /figs/visual_urban100_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/visual_urban100_x4.png -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .script_history 3 | Debug_Dir 4 | Experiments 5 | data/cifar10* 6 | data/mnist 7 | data/imagenet* 8 | base_models 9 | model/*.th 10 | train_params/ 11 | -------------------------------------------------------------------------------- /src/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.2.0 2 | scikit_image==0.16.2 3 | imageio==2.9.0 4 | pandas==1.0.5 5 | matplotlib==3.2.2 6 | numpy==1.18.5 7 | tqdm==4.47.0 8 | scipy==1.5.0 9 | torchvision==0.4.0a0+6b959ee 10 | Pillow==8.1.1 11 | PyYAML==5.4.1 12 | skimage==0.0 13 | torchsummaryX==1.3.0 14 | -------------------------------------------------------------------------------- /src/pruner/__init__.py: -------------------------------------------------------------------------------- 1 | from . import l1_pruner, assl_pruner 2 | 3 | # when new pruner implementation is added in the 'pruner' dir, update this dict to maintain minimal code change. 4 | # key: pruning method name, value: the corresponding pruner 5 | pruner_dict = { 6 | 'L1': l1_pruner, 7 | 'ASSL': assl_pruner, 8 | } -------------------------------------------------------------------------------- /src/cal_modelsize.py: -------------------------------------------------------------------------------- 1 | from torchsummaryX import summary 2 | from importlib import import_module 3 | import torch 4 | import model 5 | from model import edsr 6 | from option import args 7 | import sys 8 | # checkpoint = utility.checkpoint(args) 9 | # my_model = model.Model(args) 10 | 11 | Net = import_module('model.' + args.model.lower()) 12 | net = eval("Net.%s" % args.model.upper())(args).cuda() 13 | 14 | # net = edsr.EDSR(args) 15 | height_lr = 1280 // args.scale[0] 16 | width_lr = 720 // args.scale[0] 17 | input = torch.zeros((1, 3, height_lr, width_lr)).cuda() 18 | 19 | summary(net, input) 20 | -------------------------------------------------------------------------------- /src/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /src/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /src/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /src/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /src/data/df2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DF2K(srdata.SRData): 5 | def __init__(self, args, name='DF2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DF2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | # def _scan(self): 21 | # names_hr, names_lr = super(DF2K, self)._scan() 22 | # names_hr = names_hr[self.begin - 1:self.end] 23 | # names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | # return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DF2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DF2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DF2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /src/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | # def _scan(self): 21 | # names_hr, names_lr = super(DIV2K, self)._scan() 22 | # names_hr = names_hr[self.begin - 1:self.end] 23 | # names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | # return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /src/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /src/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /src/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-1 53 | 54 | -------------------------------------------------------------------------------- /src/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /src/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = scale if multi else 1 13 | tp = p * patch_size 14 | ip = tp // scale 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = scale * ix, scale * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 42 | elif n_channels == 3 and c == 1: 43 | img = np.concatenate([img] * n_channels, 2) 44 | 45 | return img 46 | 47 | return [_set_channel(a) for a in args] 48 | 49 | def np2Tensor(*args, rgb_range=255): 50 | def _np2Tensor(img): 51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 52 | tensor = torch.from_numpy(np_transpose).float() 53 | tensor.mul_(rgb_range / 255) 54 | 55 | return tensor 56 | 57 | return [_np2Tensor(a) for a in args] 58 | 59 | def augment(*args, hflip=True, rot=True): 60 | hflip = hflip and random.random() < 0.5 61 | vflip = rot and random.random() < 0.5 62 | rot90 = rot and random.random() < 0.5 63 | 64 | def _augment(img): 65 | if hflip: img = img[:, ::-1, :] 66 | if vflip: img = img[::-1, :, :] 67 | if rot90: img = img.transpose(1, 0, 2) 68 | 69 | return img 70 | 71 | return [_augment(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /src/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt', 7 | 'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return MDSR(args) 12 | 13 | class MDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(MDSR, self).__init__() 16 | n_resblocks = args.n_resblocks 17 | n_feats = args.n_feats 18 | kernel_size = 3 19 | act = nn.ReLU(True) 20 | self.scale_idx = 0 21 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 22 | self.sub_mean = common.MeanShift(args.rgb_range) 23 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 24 | 25 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 26 | 27 | self.pre_process = nn.ModuleList([ 28 | nn.Sequential( 29 | common.ResBlock(conv, n_feats, 5, act=act), 30 | common.ResBlock(conv, n_feats, 5, act=act) 31 | ) for _ in args.scale 32 | ]) 33 | 34 | m_body = [ 35 | common.ResBlock( 36 | conv, n_feats, kernel_size, act=act 37 | ) for _ in range(n_resblocks) 38 | ] 39 | m_body.append(conv(n_feats, n_feats, kernel_size)) 40 | 41 | self.upsample = nn.ModuleList([ 42 | common.Upsampler(conv, s, n_feats, act=False) for s in args.scale 43 | ]) 44 | 45 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 46 | 47 | self.head = nn.Sequential(*m_head) 48 | self.body = nn.Sequential(*m_body) 49 | self.tail = nn.Sequential(*m_tail) 50 | 51 | def forward(self, x): 52 | x = self.sub_mean(x) 53 | x = self.head(x) 54 | x = self.pre_process[self.scale_idx](x) 55 | 56 | res = self.body(x) 57 | res += x 58 | 59 | x = self.upsample[self.scale_idx](res) 60 | x = self.tail(x) 61 | x = self.add_mean(x) 62 | 63 | return x 64 | 65 | def set_scale(self, scale_idx): 66 | self.scale_idx = scale_idx 67 | 68 | -------------------------------------------------------------------------------- /src/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 12 | } 13 | 14 | def make_model(args, parent=False): 15 | return EDSR(args) 16 | 17 | class EDSR(nn.Module): 18 | def __init__(self, args, conv=common.default_conv): 19 | super(EDSR, self).__init__() 20 | 21 | n_resblocks = args.n_resblocks 22 | n_feats = args.n_feats 23 | kernel_size = 3 24 | scale = args.scale[0] 25 | act = nn.ReLU(True) 26 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 27 | if url_name in url: 28 | self.url = url[url_name] 29 | else: 30 | self.url = None 31 | self.sub_mean = common.MeanShift(args.rgb_range) 32 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 33 | 34 | # use conv with weight normalization 35 | if args.wn: 36 | conv = common.wn_conv 37 | 38 | # define head module 39 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 40 | 41 | # define body module 42 | m_body = [ 43 | common.ResBlock( 44 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale, 45 | ) for _ in range(n_resblocks) 46 | ] 47 | m_body.append(conv(n_feats, n_feats, kernel_size)) 48 | 49 | # define tail module 50 | m_tail = [ 51 | common.Upsampler(conv, scale, n_feats, act=False), 52 | conv(n_feats, args.n_colors, kernel_size) 53 | ] 54 | 55 | self.head = nn.Sequential(*m_head) 56 | self.body = nn.Sequential(*m_body) 57 | self.tail = nn.Sequential(*m_tail) 58 | 59 | def forward(self, x): 60 | x = self.sub_mean(x) 61 | x = self.head(x) 62 | 63 | res = self.body(x) 64 | res += x 65 | 66 | x = self.tail(res) 67 | x = self.add_mean(x) 68 | 69 | return x 70 | 71 | def load_state_dict(self, state_dict, strict=True): 72 | own_state = self.state_dict() 73 | for name, param in state_dict.items(): 74 | if name in own_state: 75 | if isinstance(param, nn.Parameter): 76 | param = param.data 77 | try: 78 | own_state[name].copy_(param) 79 | except Exception: 80 | if name.find('tail') == -1: 81 | raise RuntimeError('While copying the parameter named {}, ' 82 | 'whose dimensions in the model are {} and ' 83 | 'whose dimensions in the checkpoint are {}.' 84 | .format(name, own_state[name].size(), param.size())) 85 | elif strict: 86 | if name.find('tail') == -1: 87 | raise KeyError('unexpected key "{}" in state_dict' 88 | .format(name)) 89 | 90 | -------------------------------------------------------------------------------- /src/model/ledsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 12 | } 13 | 14 | def make_model(args, parent=False): 15 | return LEDSR(args) 16 | 17 | class LEDSR(nn.Module): 18 | def __init__(self, args, conv=common.default_conv): 19 | super(LEDSR, self).__init__() 20 | 21 | n_resblocks = args.n_resblocks 22 | n_feats = args.n_feats 23 | kernel_size = 3 24 | scale = args.scale[0] 25 | act = nn.ReLU(True) 26 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 27 | if url_name in url: 28 | self.url = url[url_name] 29 | else: 30 | self.url = None 31 | self.sub_mean = common.MeanShift(args.rgb_range) 32 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 33 | 34 | # use conv with weight normalization 35 | if args.wn: 36 | conv = common.wn_conv 37 | 38 | # define head module 39 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 40 | 41 | # define body module 42 | m_body = [ 43 | common.ResBlock( 44 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 45 | ) for _ in range(n_resblocks) 46 | ] 47 | m_body.append(conv(n_feats, n_feats, kernel_size)) 48 | 49 | # define tail module 50 | # m_tail = [ 51 | # common.Upsampler(conv, scale, n_feats, act=False), 52 | # conv(n_feats, args.n_colors, kernel_size) 53 | # ] 54 | m_tail = [ 55 | common.LiteUpsampler(conv, scale, n_feats, args.n_colors, act=False) 56 | ] 57 | self.head = nn.Sequential(*m_head) 58 | self.body = nn.Sequential(*m_body) 59 | self.tail = nn.Sequential(*m_tail) 60 | 61 | def forward(self, x): 62 | x = self.sub_mean(x) 63 | x = self.head(x) 64 | 65 | res = self.body(x) 66 | res += x 67 | 68 | x = self.tail(res) 69 | x = self.add_mean(x) 70 | 71 | return x 72 | 73 | def load_state_dict(self, state_dict, strict=True): 74 | own_state = self.state_dict() 75 | for name, param in state_dict.items(): 76 | if name in own_state: 77 | if isinstance(param, nn.Parameter): 78 | param = param.data 79 | try: 80 | own_state[name].copy_(param) 81 | except Exception: 82 | if name.find('tail') == -1: 83 | raise RuntimeError('While copying the parameter named {}, ' 84 | 'whose dimensions in the model are {} and ' 85 | 'whose dimensions in the checkpoint are {}.' 86 | .format(name, own_state[name].size(), param.size())) 87 | elif strict: 88 | if name.find('tail') == -1: 89 | raise KeyError('unexpected key "{}" in state_dict' 90 | .format(name)) 91 | 92 | -------------------------------------------------------------------------------- /src/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | r = args.scale[0] 49 | G0 = args.G0 50 | kSize = args.RDNkSize 51 | 52 | # number of RDB blocks, conv layers, out channels 53 | self.D, C, G = { 54 | 'A': (20, 6, 32), 55 | 'B': (16, 8, 64), 56 | }[args.RDNconfig] 57 | 58 | # Shallow feature extraction net 59 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 60 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self.RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self.RDBs.append( 66 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 73 | ]) 74 | 75 | # Up-sampling net 76 | if r == 2 or r == 3: 77 | self.UPNet = nn.Sequential(*[ 78 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 79 | nn.PixelShuffle(r), 80 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 81 | ]) 82 | elif r == 4: 83 | self.UPNet = nn.Sequential(*[ 84 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 85 | nn.PixelShuffle(2), 86 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 87 | nn.PixelShuffle(2), 88 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 89 | ]) 90 | else: 91 | raise ValueError("scale must be 2 or 3 or 4.") 92 | 93 | def forward(self, x): 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = self.RDBs[i](x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out,1)) 103 | x += f__1 104 | 105 | return self.UPNet(x) 106 | -------------------------------------------------------------------------------- /src/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import numpy as np 5 | from collections import OrderedDict 6 | tensor2list = lambda x: x.data.cpu().numpy().tolist() 7 | 8 | class Layer: 9 | def __init__(self, name, size, layer_index, module, res=False, layer_type=None): 10 | self.name = name 11 | self.module = module 12 | self.size = [] # deprecated in support of 'shape' 13 | for x in size: 14 | self.size.append(x) 15 | self.shape = self.size 16 | self.layer_index = layer_index # deprecated in support of 'index' 17 | self.index = layer_index 18 | self.layer_type = layer_type # deprecated in support of 'type' 19 | self.type = layer_type 20 | self.is_shortcut = True if "downsample" in name else False 21 | # if res: 22 | # self.stage, self.seq_index, self.block_index = self._get_various_index_by_name(name) 23 | 24 | def _get_various_index_by_name(self, name): 25 | '''Get the indeces including stage, seq_ix, blk_ix. 26 | Same stage means the same feature map size. 27 | ''' 28 | global lastest_stage # an awkward impel, just for now 29 | if name.startswith('module.'): 30 | name = name[7:] # remove the prefix caused by pytorch data parallel 31 | 32 | if "conv1" == name: # TODO: this might not be so safe 33 | lastest_stage = 0 34 | return 0, None, None 35 | if "linear" in name or 'fc' in name: # Note: this can be risky. Check it fully. TODO: @mingsun-tse 36 | return lastest_stage + 1, None, None # fc layer should always be the last layer 37 | else: 38 | try: 39 | stage = int(name.split(".")[0][-1]) # ONLY work for standard resnets. name example: layer2.2.conv1, layer4.0.downsample.0 40 | seq_ix = int(name.split(".")[1]) 41 | if 'conv' in name.split(".")[-1]: 42 | blk_ix = int(name[-1]) - 1 43 | else: 44 | blk_ix = -1 # shortcut layer 45 | lastest_stage = stage 46 | return stage, seq_ix, blk_ix 47 | except: 48 | print('! Parsing the layer name failed: %s. Please check.' % name) 49 | 50 | class LayerStruct: 51 | def __init__(self, model, LEARNABLES): 52 | self.model = model 53 | self.LEARNABLES = LEARNABLES 54 | self.register_layers() 55 | self.get_print_prefix() 56 | self.print_layer_stats() 57 | 58 | def register_layers(self): 59 | """This will maintain a data structure that can return some useful information by the name of a layer. 60 | TODO-@mst: Update this: https://github.com/MingSun-Tse/Pruning/blob/2bb7012d81e3c8326a2f756782b41c3d7ca9da21/pruner/meta_pruner.py#L65 61 | """ 62 | self.layers = OrderedDict() 63 | self._max_len_name = 0 64 | self._max_len_shape = 0 65 | 66 | ix = -1 # layer index, starts from 0 67 | for name, m in self.model.named_modules(): 68 | if isinstance(m, self.LEARNABLES): 69 | if "downsample" not in name: 70 | ix += 1 71 | self._max_len_name = max(self._max_len_name, len(name)) 72 | self.layers[name] = Layer(name, size=m.weight.size(), layer_index=ix, module=m, layer_type=m.__class__.__name__) 73 | self._max_len_shape = max(self._max_len_shape, len(str(self.layers[name].shape))) 74 | 75 | self._max_len_ix = len(str(ix)) 76 | self.num_layers = ix + 1 77 | 78 | def get_print_prefix(self): 79 | self.print_prefix = OrderedDict() 80 | for name, layer in self.layers.items(): 81 | format_str = f"[%{self._max_len_ix}d] %{self._max_len_name}s %{self._max_len_shape}s" 82 | self.print_prefix[name] = format_str % (layer.index, name, layer.shape) 83 | 84 | def print_layer_stats(self): 85 | print('************************ Layer Statistics ************************') 86 | for name, layer in self.layers.items(): 87 | print(f'{self.print_prefix[name]}') 88 | print('******************************************************************') -------------------------------------------------------------------------------- /src/model/ddbpn.py: -------------------------------------------------------------------------------- 1 | # Deep Back-Projection Networks For Super-Resolution 2 | # https://arxiv.org/abs/1803.02735 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return DDBPN(args) 12 | 13 | def projection_conv(in_channels, out_channels, scale, up=True): 14 | kernel_size, stride, padding = { 15 | 2: (6, 2, 2), 16 | 4: (8, 4, 2), 17 | 8: (12, 8, 2) 18 | }[scale] 19 | if up: 20 | conv_f = nn.ConvTranspose2d 21 | else: 22 | conv_f = nn.Conv2d 23 | 24 | return conv_f( 25 | in_channels, out_channels, kernel_size, 26 | stride=stride, padding=padding 27 | ) 28 | 29 | class DenseProjection(nn.Module): 30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 31 | super(DenseProjection, self).__init__() 32 | if bottleneck: 33 | self.bottleneck = nn.Sequential(*[ 34 | nn.Conv2d(in_channels, nr, 1), 35 | nn.PReLU(nr) 36 | ]) 37 | inter_channels = nr 38 | else: 39 | self.bottleneck = None 40 | inter_channels = in_channels 41 | 42 | self.conv_1 = nn.Sequential(*[ 43 | projection_conv(inter_channels, nr, scale, up), 44 | nn.PReLU(nr) 45 | ]) 46 | self.conv_2 = nn.Sequential(*[ 47 | projection_conv(nr, inter_channels, scale, not up), 48 | nn.PReLU(inter_channels) 49 | ]) 50 | self.conv_3 = nn.Sequential(*[ 51 | projection_conv(inter_channels, nr, scale, up), 52 | nn.PReLU(nr) 53 | ]) 54 | 55 | def forward(self, x): 56 | if self.bottleneck is not None: 57 | x = self.bottleneck(x) 58 | 59 | a_0 = self.conv_1(x) 60 | b_0 = self.conv_2(a_0) 61 | e = b_0.sub(x) 62 | a_1 = self.conv_3(e) 63 | 64 | out = a_0.add(a_1) 65 | 66 | return out 67 | 68 | class DDBPN(nn.Module): 69 | def __init__(self, args): 70 | super(DDBPN, self).__init__() 71 | scale = args.scale[0] 72 | 73 | n0 = 128 74 | nr = 32 75 | self.depth = 6 76 | 77 | rgb_mean = (0.4488, 0.4371, 0.4040) 78 | rgb_std = (1.0, 1.0, 1.0) 79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 80 | initial = [ 81 | nn.Conv2d(args.n_colors, n0, 3, padding=1), 82 | nn.PReLU(n0), 83 | nn.Conv2d(n0, nr, 1), 84 | nn.PReLU(nr) 85 | ] 86 | self.initial = nn.Sequential(*initial) 87 | 88 | self.upmodules = nn.ModuleList() 89 | self.downmodules = nn.ModuleList() 90 | channels = nr 91 | for i in range(self.depth): 92 | self.upmodules.append( 93 | DenseProjection(channels, nr, scale, True, i > 1) 94 | ) 95 | if i != 0: 96 | channels += nr 97 | 98 | channels = nr 99 | for i in range(self.depth - 1): 100 | self.downmodules.append( 101 | DenseProjection(channels, nr, scale, False, i != 0) 102 | ) 103 | channels += nr 104 | 105 | reconstruction = [ 106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 107 | ] 108 | self.reconstruction = nn.Sequential(*reconstruction) 109 | 110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 111 | 112 | def forward(self, x): 113 | x = self.sub_mean(x) 114 | x = self.initial(x) 115 | 116 | h_list = [] 117 | l_list = [] 118 | for i in range(self.depth - 1): 119 | if i == 0: 120 | l = x 121 | else: 122 | l = torch.cat(l_list, dim=1) 123 | h_list.append(self.upmodules[i](l)) 124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 125 | 126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 127 | out = self.reconstruction(torch.cat(h_list, dim=1)) 128 | out = self.add_mean(out) 129 | 130 | return out 131 | 132 | -------------------------------------------------------------------------------- /src/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from types import SimpleNamespace 3 | 4 | from model import common 5 | from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | class Adversarial(nn.Module): 13 | def __init__(self, args, gan_type): 14 | super(Adversarial, self).__init__() 15 | self.gan_type = gan_type 16 | self.gan_k = args.gan_k 17 | self.dis = discriminator.Discriminator(args) 18 | if gan_type == 'WGAN_GP': 19 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4 20 | optim_dict = { 21 | 'optimizer': 'ADAM', 22 | 'betas': (0, 0.9), 23 | 'epsilon': 1e-8, 24 | 'lr': 1e-5, 25 | 'weight_decay': args.weight_decay, 26 | 'decay': args.decay, 27 | 'gamma': args.gamma 28 | } 29 | optim_args = SimpleNamespace(**optim_dict) 30 | else: 31 | optim_args = args 32 | 33 | self.optimizer = utility.make_optimizer(optim_args, self.dis) 34 | 35 | def forward(self, fake, real): 36 | # updating discriminator... 37 | self.loss = 0 38 | fake_detach = fake.detach() # do not backpropagate through G 39 | for _ in range(self.gan_k): 40 | self.optimizer.zero_grad() 41 | # d: B x 1 tensor 42 | d_fake = self.dis(fake_detach) 43 | d_real = self.dis(real) 44 | retain_graph = False 45 | if self.gan_type == 'GAN': 46 | loss_d = self.bce(d_real, d_fake) 47 | elif self.gan_type.find('WGAN') >= 0: 48 | loss_d = (d_fake - d_real).mean() 49 | if self.gan_type.find('GP') >= 0: 50 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 51 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 52 | hat.requires_grad = True 53 | d_hat = self.dis(hat) 54 | gradients = torch.autograd.grad( 55 | outputs=d_hat.sum(), inputs=hat, 56 | retain_graph=True, create_graph=True, only_inputs=True 57 | )[0] 58 | gradients = gradients.view(gradients.size(0), -1) 59 | gradient_norm = gradients.norm(2, dim=1) 60 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 61 | loss_d += gradient_penalty 62 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 63 | elif self.gan_type == 'RGAN': 64 | better_real = d_real - d_fake.mean(dim=0, keepdim=True) 65 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True) 66 | loss_d = self.bce(better_real, better_fake) 67 | retain_graph = True 68 | 69 | # Discriminator update 70 | self.loss += loss_d.item() 71 | loss_d.backward(retain_graph=retain_graph) 72 | self.optimizer.step() 73 | 74 | if self.gan_type == 'WGAN': 75 | for p in self.dis.parameters(): 76 | p.data.clamp_(-1, 1) 77 | 78 | self.loss /= self.gan_k 79 | 80 | # updating generator... 81 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is 82 | if self.gan_type == 'GAN': 83 | label_real = torch.ones_like(d_fake_bp) 84 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) 85 | elif self.gan_type.find('WGAN') >= 0: 86 | loss_g = -d_fake_bp.mean() 87 | elif self.gan_type == 'RGAN': 88 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) 89 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) 90 | loss_g = self.bce(better_fake, better_real) 91 | 92 | # Generator loss 93 | return loss_g 94 | 95 | def state_dict(self, *args, **kwargs): 96 | state_discriminator = self.dis.state_dict(*args, **kwargs) 97 | state_optimizer = self.optimizer.state_dict() 98 | 99 | return dict(**state_discriminator, **state_optimizer) 100 | 101 | def bce(self, real, fake): 102 | label_real = torch.ones_like(real) 103 | label_fake = torch.zeros_like(fake) 104 | bce_real = F.binary_cross_entropy_with_logits(real, label_real) 105 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) 106 | bce_loss = bce_real + bce_fake 107 | return bce_loss 108 | 109 | # Some references 110 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 111 | # OR 112 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 113 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | 5 | import utility 6 | 7 | import torch 8 | import torch.nn.utils as utils 9 | from tqdm import tqdm 10 | 11 | class Trainer(): 12 | def __init__(self, args, loader, my_model, my_loss, ckp): 13 | self.args = args 14 | self.scale = args.scale 15 | 16 | self.ckp = ckp 17 | self.loader_train = loader.loader_train 18 | self.loader_test = loader.loader_test 19 | self.model = my_model 20 | self.loss = my_loss 21 | self.optimizer = utility.make_optimizer(args, self.model) 22 | 23 | if self.args.load != '': 24 | self.optimizer.load(ckp.dir, epoch=len(ckp.log)) 25 | 26 | self.error_last = 1e8 27 | 28 | def train(self): 29 | self.loss.step() 30 | epoch = self.optimizer.get_last_epoch() + 1 31 | lr = self.optimizer.get_lr() 32 | 33 | self.ckp.write_log( 34 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 35 | ) 36 | self.loss.start_log() 37 | self.model.train() 38 | 39 | timer_data, timer_model = utility.timer(), utility.timer() 40 | # TEMP 41 | self.loader_train.dataset.set_scale(0) 42 | for batch, (lr, hr, _,) in enumerate(self.loader_train): 43 | lr, hr = self.prepare(lr, hr) 44 | timer_data.hold() 45 | timer_model.tic() 46 | 47 | self.optimizer.zero_grad() 48 | sr = self.model(lr, 0) 49 | loss = self.loss(sr, hr) 50 | loss.backward() 51 | if self.args.gclip > 0: 52 | utils.clip_grad_value_( 53 | self.model.parameters(), 54 | self.args.gclip 55 | ) 56 | self.optimizer.step() 57 | 58 | timer_model.hold() 59 | 60 | if (batch + 1) % self.args.print_every == 0: 61 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 62 | (batch + 1) * self.args.batch_size, 63 | len(self.loader_train.dataset), 64 | self.loss.display_loss(batch), 65 | timer_model.release(), 66 | timer_data.release())) 67 | 68 | timer_data.tic() 69 | 70 | self.loss.end_log(len(self.loader_train)) 71 | self.error_last = self.loss.log[-1, -1] 72 | self.optimizer.schedule() 73 | 74 | def test(self): 75 | torch.set_grad_enabled(False) 76 | 77 | epoch = self.optimizer.get_last_epoch() 78 | self.ckp.write_log('\nEvaluation:') 79 | self.ckp.add_log( 80 | torch.zeros(1, len(self.loader_test), len(self.scale)) 81 | ) 82 | self.model.eval() 83 | 84 | timer_test = utility.timer() 85 | if self.args.save_results: self.ckp.begin_background() 86 | for idx_data, d in enumerate(self.loader_test): 87 | for idx_scale, scale in enumerate(self.scale): 88 | d.dataset.set_scale(idx_scale) 89 | for lr, hr, filename in tqdm(d, ncols=80): 90 | lr, hr = self.prepare(lr, hr) 91 | sr = self.model(lr, idx_scale) 92 | sr = utility.quantize(sr, self.args.rgb_range) 93 | 94 | save_list = [sr] 95 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 96 | sr, hr, scale, self.args.rgb_range, dataset=d 97 | ) 98 | if self.args.save_gt: 99 | save_list.extend([lr, hr]) 100 | 101 | if self.args.save_results: 102 | self.ckp.save_results(d, filename[0], save_list, scale) 103 | 104 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 105 | best = self.ckp.log.max(0) 106 | self.ckp.write_log( 107 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 108 | d.dataset.name, 109 | scale, 110 | self.ckp.log[-1, idx_data, idx_scale], 111 | best[0][idx_data, idx_scale], 112 | best[1][idx_data, idx_scale] + 1 113 | ) 114 | ) 115 | 116 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) 117 | self.ckp.write_log('Saving...') 118 | 119 | if self.args.save_results: 120 | self.ckp.end_background() 121 | 122 | if not self.args.test_only: 123 | self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 124 | 125 | self.ckp.write_log( 126 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 127 | ) 128 | 129 | torch.set_grad_enabled(True) 130 | 131 | def prepare(self, *args): 132 | device = torch.device('cpu' if self.args.cpu else 'cuda') 133 | def _prepare(tensor): 134 | if self.args.precision == 'half': tensor = tensor.half() 135 | return tensor.to(device) 136 | 137 | return [_prepare(a) for a in args] 138 | 139 | def terminate(self): 140 | if self.args.test_only: 141 | self.test() 142 | return True 143 | else: 144 | epoch = self.optimizer.get_last_epoch() + 1 145 | return epoch >= self.args.epochs 146 | 147 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class Loss(nn.modules.loss._Loss): 15 | def __init__(self, args, ckp): 16 | super(Loss, self).__init__() 17 | print('Preparing loss function:') 18 | 19 | self.n_GPUs = args.n_GPUs 20 | self.loss = [] 21 | self.loss_module = nn.ModuleList() 22 | for loss in args.loss.split('+'): 23 | weight, loss_type = loss.split('*') 24 | if loss_type == 'MSE': 25 | loss_function = nn.MSELoss() 26 | elif loss_type == 'L1': 27 | loss_function = nn.L1Loss() 28 | elif loss_type.find('VGG') >= 0: 29 | module = import_module('loss.vgg') 30 | loss_function = getattr(module, 'VGG')( 31 | loss_type[3:], 32 | rgb_range=args.rgb_range 33 | ) 34 | elif loss_type.find('GAN') >= 0: 35 | module = import_module('loss.adversarial') 36 | loss_function = getattr(module, 'Adversarial')( 37 | args, 38 | loss_type 39 | ) 40 | 41 | self.loss.append({ 42 | 'type': loss_type, 43 | 'weight': float(weight), 44 | 'function': loss_function} 45 | ) 46 | if loss_type.find('GAN') >= 0: 47 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 48 | 49 | if len(self.loss) > 1: 50 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 51 | 52 | for l in self.loss: 53 | if l['function'] is not None: 54 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 55 | self.loss_module.append(l['function']) 56 | 57 | self.log = torch.Tensor() 58 | 59 | device = torch.device('cpu' if args.cpu else 'cuda') 60 | self.loss_module.to(device) 61 | if args.precision == 'half': self.loss_module.half() 62 | if not args.cpu and args.n_GPUs > 1: 63 | self.loss_module = nn.DataParallel( 64 | self.loss_module, range(args.n_GPUs) 65 | ) 66 | 67 | if args.load != '': self.load(ckp.dir, cpu=args.cpu) 68 | 69 | def forward(self, sr, hr): 70 | losses = [] 71 | for i, l in enumerate(self.loss): 72 | if l['function'] is not None: 73 | loss = l['function'](sr, hr) 74 | effective_loss = l['weight'] * loss 75 | losses.append(effective_loss) 76 | self.log[-1, i] += effective_loss.item() 77 | elif l['type'] == 'DIS': 78 | self.log[-1, i] += self.loss[i - 1]['function'].loss 79 | 80 | loss_sum = sum(losses) 81 | if len(self.loss) > 1: 82 | self.log[-1, -1] += loss_sum.item() 83 | 84 | return loss_sum 85 | 86 | def step(self): 87 | for l in self.get_loss_module(): 88 | if hasattr(l, 'scheduler'): 89 | l.scheduler.step() 90 | 91 | def start_log(self): 92 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 93 | 94 | def end_log(self, n_batches): 95 | self.log[-1].div_(n_batches) 96 | 97 | def display_loss(self, batch): 98 | n_samples = batch + 1 99 | log = [] 100 | for l, c in zip(self.loss, self.log[-1]): 101 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 102 | 103 | return ''.join(log) 104 | 105 | def plot_loss(self, apath, epoch): 106 | if epoch > 0: 107 | axis = np.linspace(1, epoch, epoch) 108 | for i, l in enumerate(self.loss): 109 | label = '{} Loss'.format(l['type']) 110 | fig = plt.figure() 111 | plt.title(label) 112 | plt.plot(axis, self.log[:, i].numpy(), label=label) 113 | plt.legend() 114 | plt.xlabel('Epochs') 115 | plt.ylabel('Loss') 116 | plt.grid(True) 117 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 118 | plt.close(fig) 119 | 120 | def get_loss_module(self): 121 | if self.n_GPUs == 1: 122 | return self.loss_module 123 | else: 124 | return self.loss_module.module 125 | 126 | def save(self, apath): 127 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 128 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 129 | 130 | def load(self, apath, cpu=False): 131 | if cpu: 132 | kwargs = {'map_location': lambda storage, loc: storage} 133 | else: 134 | kwargs = {} 135 | 136 | self.load_state_dict(torch.load( 137 | os.path.join(apath, 'loss.pt'), 138 | **kwargs 139 | )) 140 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 141 | for l in self.get_loss_module(): 142 | if hasattr(l, 'scheduler'): 143 | for _ in range(len(self.log)): l.scheduler.step() 144 | 145 | -------------------------------------------------------------------------------- /src/model/rcan.py: -------------------------------------------------------------------------------- 1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks 2 | ## https://arxiv.org/abs/1807.02758 3 | from model import common 4 | 5 | import torch.nn as nn 6 | 7 | def make_model(args, parent=False): 8 | return RCAN(args) 9 | 10 | ## Channel Attention (CA) Layer 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | # global average pooling: feature --> point 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | # feature channel downscale and upscale --> channel weight 17 | self.conv_du = nn.Sequential( 18 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 21 | nn.Sigmoid() 22 | ) 23 | 24 | def forward(self, x): 25 | y = self.avg_pool(x) 26 | y = self.conv_du(y) 27 | return x * y 28 | 29 | ## Residual Channel Attention Block (RCAB) 30 | class RCAB(nn.Module): 31 | def __init__( 32 | self, conv, n_feat, kernel_size, reduction, 33 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 34 | 35 | super(RCAB, self).__init__() 36 | modules_body = [] 37 | for i in range(2): 38 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 39 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 40 | if i == 0: modules_body.append(act) 41 | modules_body.append(CALayer(n_feat, reduction)) 42 | self.body = nn.Sequential(*modules_body) 43 | self.res_scale = res_scale 44 | 45 | def forward(self, x): 46 | res = self.body(x) 47 | #res = self.body(x).mul(self.res_scale) 48 | res += x 49 | return res 50 | 51 | ## Residual Group (RG) 52 | class ResidualGroup(nn.Module): 53 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 54 | super(ResidualGroup, self).__init__() 55 | modules_body = [] 56 | modules_body = [ 57 | RCAB( 58 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 59 | for _ in range(n_resblocks)] 60 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 61 | self.body = nn.Sequential(*modules_body) 62 | 63 | def forward(self, x): 64 | res = self.body(x) 65 | res += x 66 | return res 67 | 68 | ## Residual Channel Attention Network (RCAN) 69 | class RCAN(nn.Module): 70 | def __init__(self, args, conv=common.default_conv): 71 | super(RCAN, self).__init__() 72 | 73 | n_resgroups = args.n_resgroups 74 | n_resblocks = args.n_resblocks 75 | n_feats = args.n_feats 76 | kernel_size = 3 77 | reduction = args.reduction 78 | scale = args.scale[0] 79 | act = nn.ReLU(True) 80 | 81 | # RGB mean for DIV2K 82 | self.sub_mean = common.MeanShift(args.rgb_range) 83 | 84 | # define head module 85 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 86 | 87 | # define body module 88 | modules_body = [ 89 | ResidualGroup( 90 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 91 | for _ in range(n_resgroups)] 92 | 93 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 94 | 95 | # define tail module 96 | modules_tail = [ 97 | common.Upsampler(conv, scale, n_feats, act=False), 98 | conv(n_feats, args.n_colors, kernel_size)] 99 | 100 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 101 | 102 | self.head = nn.Sequential(*modules_head) 103 | self.body = nn.Sequential(*modules_body) 104 | self.tail = nn.Sequential(*modules_tail) 105 | 106 | def forward(self, x): 107 | x = self.sub_mean(x) 108 | x = self.head(x) 109 | 110 | res = self.body(x) 111 | res += x 112 | 113 | x = self.tail(res) 114 | x = self.add_mean(x) 115 | 116 | return x 117 | 118 | def load_state_dict(self, state_dict, strict=False): 119 | own_state = self.state_dict() 120 | for name, param in state_dict.items(): 121 | if name in own_state: 122 | if isinstance(param, nn.Parameter): 123 | param = param.data 124 | try: 125 | own_state[name].copy_(param) 126 | except Exception: 127 | if name.find('tail') >= 0: 128 | print('Replace pre-trained upsampler to new one...') 129 | else: 130 | raise RuntimeError('While copying the parameter named {}, ' 131 | 'whose dimensions in the model are {} and ' 132 | 'whose dimensions in the checkpoint are {}.' 133 | .format(name, own_state[name].size(), param.size())) 134 | elif strict: 135 | if name.find('tail') == -1: 136 | raise KeyError('unexpected key "{}" in state_dict' 137 | .format(name)) 138 | 139 | if strict: 140 | missing = set(own_state.keys()) - set(state_dict.keys()) 141 | if len(missing) > 0: 142 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 143 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import random 3 | 4 | import torch 5 | import torch.multiprocessing as multiprocessing 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SequentialSampler 8 | from torch.utils.data import RandomSampler 9 | from torch.utils.data import BatchSampler 10 | from torch.utils.data import _utils 11 | from torch.utils.data.dataloader import _DataLoaderIter 12 | 13 | from torch.utils.data._utils import collate 14 | from torch.utils.data._utils import signal_handling 15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 16 | from torch.utils.data._utils import ExceptionWrapper 17 | from torch.utils.data._utils import IS_WINDOWS 18 | from torch.utils.data._utils.worker import ManagerWatchdog 19 | 20 | from torch._six import queue 21 | 22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 23 | try: 24 | collate._use_shared_memory = True 25 | signal_handling._set_worker_signal_handlers() 26 | 27 | torch.set_num_threads(1) 28 | random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | data_queue.cancel_join_thread() 32 | 33 | if init_fn is not None: 34 | init_fn(worker_id) 35 | 36 | watchdog = ManagerWatchdog() 37 | 38 | while watchdog.is_alive(): 39 | try: 40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 41 | except queue.Empty: 42 | continue 43 | 44 | if r is None: 45 | assert done_event.is_set() 46 | return 47 | elif done_event.is_set(): 48 | continue 49 | 50 | idx, batch_indices = r 51 | try: 52 | idx_scale = 0 53 | if len(scale) > 1 and dataset.train: 54 | idx_scale = random.randrange(0, len(scale)) 55 | dataset.set_scale(idx_scale) 56 | 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | samples.append(idx_scale) 59 | except Exception: 60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 61 | else: 62 | data_queue.put((idx, samples)) 63 | del samples 64 | 65 | except KeyboardInterrupt: 66 | pass 67 | 68 | class _MSDataLoaderIter(_DataLoaderIter): 69 | 70 | def __init__(self, loader): 71 | self.dataset = loader.dataset 72 | self.scale = loader.scale 73 | self.collate_fn = loader.collate_fn 74 | self.batch_sampler = loader.batch_sampler 75 | self.num_workers = loader.num_workers 76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 77 | self.timeout = loader.timeout 78 | 79 | self.sample_iter = iter(self.batch_sampler) 80 | 81 | base_seed = torch.LongTensor(1).random_().item() 82 | 83 | if self.num_workers > 0: 84 | self.worker_init_fn = loader.worker_init_fn 85 | self.worker_queue_idx = 0 86 | self.worker_result_queue = multiprocessing.Queue() 87 | self.batches_outstanding = 0 88 | self.worker_pids_set = False 89 | self.shutdown = False 90 | self.send_idx = 0 91 | self.rcvd_idx = 0 92 | self.reorder_dict = {} 93 | self.done_event = multiprocessing.Event() 94 | 95 | base_seed = torch.LongTensor(1).random_()[0] 96 | 97 | self.index_queues = [] 98 | self.workers = [] 99 | for i in range(self.num_workers): 100 | index_queue = multiprocessing.Queue() 101 | index_queue.cancel_join_thread() 102 | w = multiprocessing.Process( 103 | target=_ms_loop, 104 | args=( 105 | self.dataset, 106 | index_queue, 107 | self.worker_result_queue, 108 | self.done_event, 109 | self.collate_fn, 110 | self.scale, 111 | base_seed + i, 112 | self.worker_init_fn, 113 | i 114 | ) 115 | ) 116 | w.daemon = True 117 | w.start() 118 | self.index_queues.append(index_queue) 119 | self.workers.append(w) 120 | 121 | if self.pin_memory: 122 | self.data_queue = queue.Queue() 123 | pin_memory_thread = threading.Thread( 124 | target=_utils.pin_memory._pin_memory_loop, 125 | args=( 126 | self.worker_result_queue, 127 | self.data_queue, 128 | torch.cuda.current_device(), 129 | self.done_event 130 | ) 131 | ) 132 | pin_memory_thread.daemon = True 133 | pin_memory_thread.start() 134 | self.pin_memory_thread = pin_memory_thread 135 | else: 136 | self.data_queue = self.worker_result_queue 137 | 138 | _utils.signal_handling._set_worker_pids( 139 | id(self), tuple(w.pid for w in self.workers) 140 | ) 141 | _utils.signal_handling._set_SIGCHLD_handler() 142 | self.worker_pids_set = True 143 | 144 | for _ in range(2 * self.num_workers): 145 | self._put_indices() 146 | 147 | 148 | class MSDataLoader(DataLoader): 149 | 150 | def __init__(self, cfg, *args, **kwargs): 151 | super(MSDataLoader, self).__init__( 152 | *args, **kwargs, num_workers=cfg.n_threads 153 | ) 154 | self.scale = cfg.scale 155 | 156 | def __iter__(self): 157 | return _MSDataLoaderIter(self) 158 | 159 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | import utility 5 | import data 6 | import model 7 | import loss 8 | from option import args 9 | from utils import get_n_flops_, get_n_params_ 10 | from torchsummaryX import summary 11 | 12 | torch.manual_seed(args.seed) 13 | checkpoint = utility.checkpoint(args) 14 | 15 | # @mst: select different trainers corresponding to different methods 16 | if args.method in ['']: 17 | from trainer import Trainer 18 | elif args.method in ['KD']: 19 | from trainer_kd import TrainerKD as Trainer 20 | elif args.method in ['L1', 'GReg-1', 'ASSL']: 21 | from trainer import Trainer 22 | from pruner import pruner_dict 23 | 24 | # @mst: KD 25 | def set_up_teacher(args, checkpoint, T_model, T_weights, T_n_resblocks, T_n_feats): 26 | # update args 27 | args = copy.deepcopy(args) # avoid modifying the original args 28 | args.model = T_model 29 | args.n_resblocks = T_n_resblocks 30 | args.n_feats = T_n_feats 31 | 32 | # set up model 33 | global model 34 | model = model.Model(args, ckp=None) 35 | 36 | # load pretraiend weights 37 | ckpt = torch.load(T_weights) 38 | model.model.load_state_dict(ckpt) 39 | checkpoint.write_log('==> Set up teacher successfully, pretrained weights: "%s"' % T_weights) 40 | return model 41 | 42 | def main(): 43 | global model, checkpoint 44 | if args.data_test == ['video']: 45 | from videotester import VideoTester 46 | model = model.Model(args, checkpoint) 47 | t = VideoTester(args, model, checkpoint) 48 | t.test() 49 | else: 50 | if checkpoint.ok: 51 | loader = data.Data(args) 52 | 53 | # @mst: different methods require different model settings 54 | if args.method in ['']: # original setting 55 | _model = model.Model(args, checkpoint) 56 | elif args.method in ['KD']: 57 | _model_S = model.Model(args, checkpoint) 58 | _model_T = set_up_teacher(args, checkpoint, args.T_model, args.T_weights, args.T_n_resblocks, args.T_n_feats) 59 | _model = [_model_T, _model_S] 60 | elif args.method in ['L1', 'GReg-1', 'ASSL']: 61 | _model = model.Model(args, checkpoint) 62 | class passer: pass 63 | passer.ckp = checkpoint 64 | passer.loss = loss.Loss(args, checkpoint) if not args.test_only else None 65 | passer.loader = loader 66 | pruner = pruner_dict[args.method].Pruner(_model, args, logger=None, passer=passer) 67 | 68 | # get the statistics of unpruned model 69 | # height_lr = 1280 // args.scale[0] 70 | # width_lr = 720 // args.scale[0] 71 | # n_params_original_v2 = get_n_params_(_model) 72 | # n_flops_original_v2 = get_n_flops_(_model, img_size=(height_lr, width_lr), n_channel=3, count_adds=False, idx_scale=args.scale[0]) 73 | 74 | _model = pruner.prune() # get the pruned model as initialization for later finetuning 75 | 76 | # get the statistics of pruned model and print 77 | height_lr = 1280 // args.scale[0] 78 | width_lr = 720 // args.scale[0] 79 | dummy_input = torch.zeros((1, 3, height_lr, width_lr)).cuda() 80 | 81 | # temporarily change the print fn 82 | import builtins, functools 83 | flops_f = open(checkpoint.get_path('model_complexity.txt'), 'w+') 84 | original_print = builtins.print 85 | builtins.print = functools.partial(print, file=flops_f, flush=True) 86 | summary(_model, dummy_input, {'idx_scale': args.scale[0]}) 87 | builtins.print = original_print 88 | 89 | # @mst: old code for printing FLOPs etc. Deprecated now. Will be removed. 90 | # n_params_now_v2 = get_n_params_(_model) 91 | # n_flops_now_v2 = get_n_flops_(_model, img_size=(height_lr, width_lr), n_channel=3, count_adds=False, idx_scale=args.scale[0]) 92 | # checkpoint.write_log_prune("==> n_params_original_v2: {:>7.4f}M, n_flops_original_v2: {:>7.4f}G".format(n_params_original_v2/1e6, n_flops_original_v2/1e9)) 93 | # checkpoint.write_log_prune("==> n_params_now_v2: {:>7.4f}M, n_flops_now_v2: {:>7.4f}G".format(n_params_now_v2/1e6, n_flops_now_v2/1e9)) 94 | # ratio_param = (n_params_original_v2 - n_params_now_v2) / n_params_original_v2 95 | # ratio_flops = (n_flops_original_v2 - n_flops_now_v2) / n_flops_original_v2 96 | # compression_ratio = 1.0 / (1 - ratio_param) 97 | # speedup_ratio = 1.0 / (1 - ratio_flops) 98 | # checkpoint.write_log_prune("==> reduction ratio -- params: {:>5.2f}% (compression {:>.2f}x), flops: {:>5.2f}% (speedup {:>.2f}x)".format(ratio_param*100, compression_ratio, ratio_flops*100, speedup_ratio)) 99 | 100 | # reset checkpoint and loss 101 | args.save = args.save + "_FT" 102 | checkpoint = utility.checkpoint(args) 103 | 104 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 105 | t = Trainer(args, loader, _model, _loss, checkpoint) 106 | 107 | if args.greg_mode in ['all']: 108 | checkpoint.save(t, epoch=0, is_best=False) 109 | print(f'==> Regularizing all wn_scale parameters done. Checkpoint saved. Exit!') 110 | import shutil 111 | shutil.rmtree(checkpoint.dir) # this folder is not really used, so removed here 112 | exit(0) 113 | 114 | while not t.terminate(): 115 | t.train() 116 | t.test() 117 | 118 | checkpoint.done() 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Conv2D_WN(nn.Conv2d): 8 | '''Conv2D with weight normalization. 9 | ''' 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | kernel_size, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | bias=True, 19 | padding_mode='zeros', # TODO: refine this type 20 | device=None, 21 | dtype=None 22 | ): 23 | super(Conv2D_WN, self).__init__(in_channels, out_channels, kernel_size, 24 | stride=stride, padding=padding, dilation=dilation, groups=groups, 25 | bias=bias, padding_mode=padding_mode) 26 | 27 | # set up the scale variable in weight normalization 28 | self.wn_scale = nn.Parameter(torch.ones(out_channels), requires_grad=True) 29 | self.init_wn() 30 | 31 | def init_wn(self): 32 | """initialize the wn parameters""" 33 | for i in range(self.weight.size(0)): 34 | self.wn_scale.data[i] = torch.norm(self.weight.data[i]) 35 | 36 | def forward(self, input): 37 | w = F.normalize(self.weight, dim=(1,2,3)) 38 | w = w * self.wn_scale.view(-1,1,1,1) 39 | return F.conv2d(input, w, self.bias, self.stride, 40 | self.padding, self.dilation, self.groups) 41 | 42 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 43 | return nn.Conv2d( 44 | in_channels, out_channels, kernel_size, 45 | padding=(kernel_size//2), bias=bias) 46 | 47 | def wn_conv(in_channels, out_channels, kernel_size, bias=True): 48 | return Conv2D_WN( 49 | in_channels, out_channels, kernel_size, 50 | padding=(kernel_size//2), bias=bias) 51 | 52 | class MeanShift(nn.Conv2d): 53 | def __init__( 54 | self, rgb_range, 55 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 56 | 57 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 58 | std = torch.Tensor(rgb_std) 59 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 60 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 61 | for p in self.parameters(): 62 | p.requires_grad = False 63 | 64 | class BasicBlock(nn.Sequential): 65 | def __init__( 66 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 67 | bn=True, act=nn.ReLU(True)): 68 | 69 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 70 | if bn: 71 | m.append(nn.BatchNorm2d(out_channels)) 72 | if act is not None: 73 | m.append(act) 74 | 75 | super(BasicBlock, self).__init__(*m) 76 | 77 | class ResBlock(nn.Module): 78 | def __init__( 79 | self, conv, n_feats, kernel_size, 80 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 81 | 82 | super(ResBlock, self).__init__() 83 | m = [] 84 | for i in range(2): 85 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 86 | if bn: 87 | m.append(nn.BatchNorm2d(n_feats)) 88 | if i == 0: 89 | m.append(act) 90 | 91 | self.body = nn.Sequential(*m) 92 | self.res_scale = res_scale 93 | 94 | def forward(self, x): 95 | res = self.body(x).mul(self.res_scale) 96 | res += x 97 | 98 | return res 99 | 100 | class Upsampler(nn.Sequential): 101 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 102 | 103 | m = [] 104 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 105 | for _ in range(int(math.log(scale, 2))): 106 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 107 | m.append(nn.PixelShuffle(2)) 108 | if bn: 109 | m.append(nn.BatchNorm2d(n_feats)) 110 | if act == 'relu': 111 | m.append(nn.ReLU(True)) 112 | elif act == 'prelu': 113 | m.append(nn.PReLU(n_feats)) 114 | 115 | elif scale == 3: 116 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 117 | m.append(nn.PixelShuffle(3)) 118 | if bn: 119 | m.append(nn.BatchNorm2d(n_feats)) 120 | if act == 'relu': 121 | m.append(nn.ReLU(True)) 122 | elif act == 'prelu': 123 | m.append(nn.PReLU(n_feats)) 124 | else: 125 | raise NotImplementedError 126 | 127 | super(Upsampler, self).__init__(*m) 128 | 129 | 130 | 131 | 132 | class LiteUpsampler(nn.Sequential): 133 | def __init__(self, conv, scale, n_feats, n_out=3, bn=False, act=False, bias=True): 134 | 135 | m = [] 136 | m.append(conv(n_feats, n_out*(scale ** 2), 3, bias)) 137 | m.append(nn.PixelShuffle(scale)) 138 | # if (scale & (scale - 1)) == 0: # Is scale = 2^n? 139 | # for _ in range(int(math.log(scale, 2))): 140 | # m.append(conv(n_feats, 4 * n_out, 3, bias)) 141 | # m.append(nn.PixelShuffle(2)) 142 | # if bn: 143 | # m.append(nn.BatchNorm2d(n_out)) 144 | # if act == 'relu': 145 | # m.append(nn.ReLU(True)) 146 | # elif act == 'prelu': 147 | # m.append(nn.PReLU(n_out)) 148 | 149 | # elif scale == 3: 150 | # m.append(conv(n_feats, 9 * n_out, 3, bias)) 151 | # m.append(nn.PixelShuffle(3)) 152 | # if bn: 153 | # m.append(nn.BatchNorm2d(n_out)) 154 | # if act == 'relu': 155 | # m.append(nn.ReLU(True)) 156 | # elif act == 'prelu': 157 | # m.append(nn.PReLU(n_out)) 158 | # else: 159 | # raise NotImplementedError 160 | 161 | super(LiteUpsampler, self).__init__(*m) 162 | 163 | -------------------------------------------------------------------------------- /src/pruner/meta_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import time 5 | import numpy as np 6 | from math import ceil, sqrt 7 | from collections import OrderedDict 8 | from utils import strdict_to_dict 9 | from fnmatch import fnmatch, fnmatchcase 10 | from layer import LayerStruct 11 | from .utils import get_pr_model, get_constrained_layers, pick_pruned_model, get_kept_filter_channel, replace_module, get_next_bn 12 | from .utils import get_masks 13 | 14 | class MetaPruner: 15 | def __init__(self, model, args, logger, passer): 16 | self.model = model 17 | self.args = args 18 | self.logger = logger 19 | self.logprint = logger.log_printer.logprint if logger else print 20 | self.netprint = logger.log_printer.netprint if logger else print 21 | 22 | # set up layers 23 | self.LEARNABLES = (nn.Conv2d, nn.Linear) # the layers we focus on for pruning 24 | layer_struct = LayerStruct(model, self.LEARNABLES) 25 | self.layers = layer_struct.layers 26 | self._max_len_ix = layer_struct._max_len_ix 27 | self._max_len_name = layer_struct._max_len_name 28 | self.layer_print_prefix = layer_struct.print_prefix 29 | 30 | # set up pr for each layer 31 | self.raw_pr = get_pr_model(self.layers, args.stage_pr, skip=args.skip_layers, compare_mode=args.compare_mode) 32 | self.pr = copy.deepcopy(self.raw_pr) 33 | 34 | # pick pruned and kept weight groups 35 | self.constrained_layers = get_constrained_layers(self.layers, self.args.same_pruned_wg_layers) 36 | print(f'Constrained layers: {self.constrained_layers}') 37 | 38 | def _get_kept_wg_L1(self, align_constrained=False): 39 | # ************************* core pruning function ************************** 40 | self.pr, self.pruned_wg, self.kept_wg = pick_pruned_model(self.model, self.layers, self.raw_pr, 41 | wg=self.args.wg, 42 | criterion=self.args.prune_criterion, 43 | compare_mode=self.args.compare_mode, 44 | sort_mode=self.args.pick_pruned, 45 | constrained=self.constrained_layers, 46 | align_constrained=align_constrained) 47 | # *************************************************************************** 48 | 49 | # print 50 | print(f'*********** Get pruned wg ***********') 51 | for name, layer in self.layers.items(): 52 | logtmp = f'{self.layer_print_prefix[name]} -- Got pruned wg by L1 sorting ({self.args.pick_pruned}), pr {self.pr[name]}' 53 | ext = f' -- This is a constrained layer. Its pruned/kept indices have been adjusted.' if name in self.constrained_layers else '' 54 | self.netprint(logtmp + ext) 55 | print(f'*************************************') 56 | 57 | def _prune_and_build_new_model(self): 58 | if self.args.wg == 'weight': 59 | self.masks = get_masks(self.layers, self.pruned_wg) 60 | return 61 | 62 | new_model = copy.deepcopy(self.model) 63 | for name, m in self.model.named_modules(): 64 | if isinstance(m, self.LEARNABLES): 65 | kept_filter, kept_chl = get_kept_filter_channel(self.layers, name, pr=self.pr, kept_wg=self.kept_wg, wg=self.args.wg) 66 | 67 | # decide if renit the current layer 68 | reinit = False 69 | for rl in self.args.reinit_layers: 70 | if fnmatch(name, rl): 71 | reinit = True 72 | break 73 | 74 | # get number of channels (can be manually assigned) 75 | num_chl = self.args.layer_chl[name] if name in self.args.layer_chl else len(kept_chl) 76 | 77 | # copy weight and bias 78 | bias = False if isinstance(m.bias, type(None)) else True 79 | if isinstance(m, nn.Conv2d): 80 | new_layer = nn.Conv2d(num_chl, len(kept_filter), m.kernel_size, 81 | m.stride, m.padding, m.dilation, m.groups, bias).cuda() 82 | if not reinit: 83 | kept_weights = m.weight.data[kept_filter][:, kept_chl, :, :] 84 | 85 | elif isinstance(m, nn.Linear): 86 | kept_weights = m.weight.data[kept_filter][:, kept_chl] 87 | new_layer = nn.Linear(in_features=len(kept_chl), out_features=len(kept_filter), bias=bias).cuda() 88 | 89 | if not reinit: 90 | new_layer.weight.data.copy_(kept_weights) # load weights into the new module 91 | if bias: 92 | kept_bias = m.bias.data[kept_filter] 93 | new_layer.bias.data.copy_(kept_bias) 94 | 95 | # load the new conv 96 | replace_module(new_model, name, new_layer) 97 | 98 | # get the corresponding bn (if any) for later use 99 | next_bn = get_next_bn(self.model, m) 100 | 101 | elif isinstance(m, nn.BatchNorm2d) and m == next_bn: 102 | new_bn = nn.BatchNorm2d(len(kept_filter), eps=m.eps, momentum=m.momentum, 103 | affine=m.affine, track_running_stats=m.track_running_stats).cuda() 104 | 105 | # copy bn weight and bias 106 | if self.args.copy_bn_w: 107 | weight = m.weight.data[kept_filter] 108 | new_bn.weight.data.copy_(weight) 109 | if self.args.copy_bn_b: 110 | bias = m.bias.data[kept_filter] 111 | new_bn.bias.data.copy_(bias) 112 | 113 | # copy bn running stats 114 | new_bn.running_mean.data.copy_(m.running_mean[kept_filter]) 115 | new_bn.running_var.data.copy_(m.running_var[kept_filter]) 116 | new_bn.num_batches_tracked.data.copy_(m.num_batches_tracked) 117 | 118 | # load the new bn 119 | replace_module(new_model, name, new_bn) 120 | self.model = new_model 121 | 122 | # print the layer shape of pruned model 123 | LayerStruct(new_model, self.LEARNABLES) 124 | return new_model -------------------------------------------------------------------------------- /src/data/srdata_no_bin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('img') >= 0 or benchmark: 32 | self.images_hr, self.images_lr = list_hr, list_lr 33 | elif args.ext.find('sep') >= 0: 34 | os.makedirs( 35 | self.dir_hr.replace(self.apath, path_bin), 36 | exist_ok=True 37 | ) 38 | for s in self.scale: 39 | os.makedirs( 40 | os.path.join( 41 | self.dir_lr.replace(self.apath, path_bin), 42 | 'X{}'.format(s) 43 | ), 44 | exist_ok=True 45 | ) 46 | 47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 48 | for h in list_hr: 49 | b = h.replace(self.apath, path_bin) 50 | b = b.replace(self.ext[0], '.pt') 51 | self.images_hr.append(b) 52 | self._check_and_load(args.ext, h, b, verbose=True) 53 | for i, ll in enumerate(list_lr): 54 | for l in ll: 55 | b = l.replace(self.apath, path_bin) 56 | b = b.replace(self.ext[1], '.pt') 57 | self.images_lr[i].append(b) 58 | self._check_and_load(args.ext, l, b, verbose=True) 59 | if train: 60 | n_patches = args.batch_size * args.test_every 61 | n_images = len(args.data_train) * len(self.images_hr) 62 | if n_images == 0: 63 | self.repeat = 0 64 | else: 65 | self.repeat = max(n_patches // n_images, 1) 66 | 67 | # # Below functions as used to prepare images 68 | # def _scan(self): 69 | # names_hr = sorted( 70 | # glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 71 | # ) 72 | # names_lr = [[] for _ in self.scale] 73 | # for f in names_hr: 74 | # filename, _ = os.path.splitext(os.path.basename(f)) 75 | # for si, s in enumerate(self.scale): 76 | # names_lr[si].append(os.path.join( 77 | # self.dir_lr, 'X{}/{}x{}{}'.format( 78 | # s, filename, s, self.ext[1] 79 | # ) 80 | # )) 81 | 82 | # return names_hr, names_lr 83 | 84 | def _scan(self): 85 | list_hr = [] 86 | list_lr = [[] for _ in self.scale] 87 | 88 | for i in range(self.begin, self.end + 1): 89 | filename = '{:0>4}'.format(i) 90 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext[0])) 91 | for si, s in enumerate(self.scale): 92 | list_lr[si].append(os.path.join( 93 | self.dir_lr, 94 | 'X{}/{}x{}{}'.format(s, filename, s, self.ext[1]) 95 | )) 96 | 97 | return list_hr, list_lr 98 | 99 | def _set_filesystem(self, dir_data): 100 | self.apath = os.path.join(dir_data, self.name) 101 | self.dir_hr = os.path.join(self.apath, 'HR') 102 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 103 | if self.input_large: self.dir_lr += 'L' 104 | self.ext = ('.png', '.png') 105 | 106 | def _check_and_load(self, ext, img, f, verbose=True): 107 | if not os.path.isfile(f) or ext.find('reset') >= 0: 108 | if verbose: 109 | print('Making a binary: {}'.format(f)) 110 | with open(f, 'wb') as _f: 111 | pickle.dump(imageio.imread(img), _f) 112 | 113 | def __getitem__(self, idx): 114 | lr, hr, filename = self._load_file(idx) 115 | pair = self.get_patch(lr, hr) 116 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 117 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 118 | 119 | return pair_t[0], pair_t[1], filename 120 | 121 | def __len__(self): 122 | if self.train: 123 | return len(self.images_hr) * self.repeat 124 | else: 125 | return len(self.images_hr) 126 | 127 | def _get_index(self, idx): 128 | if self.train: 129 | return idx % len(self.images_hr) 130 | else: 131 | return idx 132 | 133 | def _load_file(self, idx): 134 | idx = self._get_index(idx) 135 | f_hr = self.images_hr[idx] 136 | f_lr = self.images_lr[self.idx_scale][idx] 137 | 138 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 139 | if self.args.ext == 'img' or self.benchmark: 140 | hr = imageio.imread(f_hr) 141 | lr = imageio.imread(f_lr) 142 | elif self.args.ext.find('sep') >= 0: 143 | with open(f_hr, 'rb') as _f: 144 | hr = pickle.load(_f) 145 | with open(f_lr, 'rb') as _f: 146 | lr = pickle.load(_f) 147 | 148 | return lr, hr, filename 149 | 150 | def get_patch(self, lr, hr): 151 | scale = self.scale[self.idx_scale] 152 | if self.train: 153 | lr, hr = common.get_patch( 154 | lr, hr, 155 | patch_size=self.args.patch_size, 156 | scale=scale, 157 | multi=(len(self.scale) > 1), 158 | input_large=self.input_large 159 | ) 160 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 161 | else: 162 | ih, iw = lr.shape[:2] 163 | hr = hr[0:ih * scale, 0:iw * scale] 164 | 165 | return lr, hr 166 | 167 | def set_scale(self, idx_scale): 168 | if not self.input_large: 169 | self.idx_scale = idx_scale 170 | else: 171 | self.idx_scale = random.randint(0, len(self.scale) - 1) 172 | 173 | -------------------------------------------------------------------------------- /src/model/rirsr.py: -------------------------------------------------------------------------------- 1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks 2 | ## https://arxiv.org/abs/1807.02758 3 | from model import common 4 | 5 | import torch.nn as nn 6 | # residual in residual (RIR) from RCAN. 7 | # we remove channel attention for easy implementation, when we conduct online pruning. 8 | def make_model(args, parent=False): 9 | return RIRSR(args) 10 | 11 | ## Channel Attention (CA) Layer 12 | class CALayer(nn.Module): 13 | def __init__(self, channel, reduction=16): 14 | super(CALayer, self).__init__() 15 | # global average pooling: feature --> point 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | # feature channel downscale and upscale --> channel weight 18 | self.conv_du = nn.Sequential( 19 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 22 | nn.Sigmoid() 23 | ) 24 | 25 | def forward(self, x): 26 | y = self.avg_pool(x) 27 | y = self.conv_du(y) 28 | return x * y 29 | 30 | ## Residual Channel Attention Block (RCAB) 31 | class RCAB(nn.Module): 32 | def __init__( 33 | self, conv, n_feat, kernel_size, reduction, 34 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 35 | 36 | super(RCAB, self).__init__() 37 | modules_body = [] 38 | for i in range(2): 39 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 40 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 41 | if i == 0: modules_body.append(act) 42 | modules_body.append(CALayer(n_feat, reduction)) 43 | self.body = nn.Sequential(*modules_body) 44 | self.res_scale = res_scale 45 | 46 | def forward(self, x): 47 | res = self.body(x) 48 | #res = self.body(x).mul(self.res_scale) 49 | res += x 50 | return res 51 | 52 | ## Residual Group (RG) 53 | # class ResidualGroup(nn.Module): 54 | # def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 55 | # super(ResidualGroup, self).__init__() 56 | # modules_body = [] 57 | # modules_body = [ 58 | # RCAB( 59 | # conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 60 | # for _ in range(n_resblocks)] 61 | # modules_body.append(conv(n_feat, n_feat, kernel_size)) 62 | # self.body = nn.Sequential(*modules_body) 63 | 64 | # def forward(self, x): 65 | # res = self.body(x) 66 | # res += x 67 | # return res 68 | 69 | class ResidualGroup(nn.Module): 70 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 71 | super(ResidualGroup, self).__init__() 72 | modules_body = [] 73 | # for RCAN 74 | # modules_body = [ 75 | # RCAB( 76 | # conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 77 | # for _ in range(n_resblocks)] 78 | modules_body = [ 79 | common.ResBlock( 80 | conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 81 | for _ in range(n_resblocks)] 82 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 83 | self.body = nn.Sequential(*modules_body) 84 | 85 | def forward(self, x): 86 | res = self.body(x) 87 | res += x 88 | return res 89 | 90 | ## Residual Channel Attention Network (RCAN) 91 | ## Residual in Residual Super-Resolution (RIRSR) 92 | class RIRSR(nn.Module): 93 | def __init__(self, args, conv=common.default_conv): 94 | super(RIRSR, self).__init__() 95 | 96 | n_resgroups = args.n_resgroups 97 | n_resblocks = args.n_resblocks 98 | n_feats = args.n_feats 99 | kernel_size = 3 100 | reduction = args.reduction 101 | scale = args.scale[0] 102 | act = nn.ReLU(True) 103 | 104 | # RGB mean for DIV2K 105 | self.sub_mean = common.MeanShift(args.rgb_range) 106 | 107 | # define head module 108 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 109 | 110 | # define body module 111 | modules_body = [ 112 | ResidualGroup( 113 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 114 | for _ in range(n_resgroups)] 115 | 116 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 117 | 118 | # define tail module 119 | modules_tail = [ 120 | common.Upsampler(conv, scale, n_feats, act=False), 121 | conv(n_feats, args.n_colors, kernel_size)] 122 | 123 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 124 | 125 | self.head = nn.Sequential(*modules_head) 126 | self.body = nn.Sequential(*modules_body) 127 | self.tail = nn.Sequential(*modules_tail) 128 | 129 | def forward(self, x): 130 | x = self.sub_mean(x) 131 | x = self.head(x) 132 | 133 | res = self.body(x) 134 | res += x 135 | 136 | x = self.tail(res) 137 | x = self.add_mean(x) 138 | 139 | return x 140 | 141 | def load_state_dict(self, state_dict, strict=False): 142 | own_state = self.state_dict() 143 | for name, param in state_dict.items(): 144 | if name in own_state: 145 | if isinstance(param, nn.Parameter): 146 | param = param.data 147 | try: 148 | own_state[name].copy_(param) 149 | except Exception: 150 | if name.find('tail') >= 0: 151 | print('Replace pre-trained upsampler to new one...') 152 | else: 153 | raise RuntimeError('While copying the parameter named {}, ' 154 | 'whose dimensions in the model are {} and ' 155 | 'whose dimensions in the checkpoint are {}.' 156 | .format(name, own_state[name].size(), param.size())) 157 | elif strict: 158 | if name.find('tail') == -1: 159 | raise KeyError('unexpected key "{}" in state_dict' 160 | .format(name)) 161 | 162 | if strict: 163 | missing = set(own_state.keys()) - set(state_dict.keys()) 164 | if len(missing) > 0: 165 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 166 | 167 | -------------------------------------------------------------------------------- /src/pruner/l1_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import time 5 | import numpy as np 6 | import utility 7 | from tqdm import tqdm 8 | from utils import _weights_init, _weights_init_orthogonal, orthogonalize_weights 9 | from .meta_pruner import MetaPruner 10 | 11 | 12 | # refer to: A Signal Propagation Perspective for Pruning Neural Networks at Initialization (ICLR 2020). 13 | # https://github.com/namhoonlee/spp-public 14 | def approximate_isometry_optimize(model, mask, lr, n_iter, wg='weight', print=print): 15 | def optimize(w): 16 | '''Approximate Isometry for sparse weights by iterative optimization 17 | ''' 18 | flattened = w.view(w.size(0), -1) # [n_filter, -1] 19 | identity = torch.eye(w.size(0)).cuda() # identity matrix 20 | w_ = torch.autograd.Variable(flattened, requires_grad=True) 21 | optim = torch.optim.Adam([w_], lr) 22 | for i in range(n_iter): 23 | loss = nn.MSELoss()(torch.matmul(w_, w_.t()), identity) 24 | optim.zero_grad() 25 | loss.backward() 26 | optim.step() 27 | if not isinstance(mask, type(None)): 28 | w_ = torch.mul(w_, mask[name]) # not update the pruned params 29 | w_ = torch.autograd.Variable(w_, requires_grad=True) 30 | optim = torch.optim.Adam([w_], lr) 31 | if i % 10 == 0: 32 | print('[%d/%d] approximate_isometry_optimize for layer "%s", loss %.6f' % (i, n_iter, name, loss.item())) 33 | return w_.view(m.weight.shape) 34 | 35 | for name, m in model.named_modules(): 36 | if isinstance(m, (nn.Conv2d, nn.Linear)): 37 | w_ = optimize(m.weight) 38 | m.weight.data.copy_(w_) 39 | print('Finished approximate_isometry_optimize for layer "%s"' % name) 40 | 41 | def exact_isometry_based_on_existing_weights(model, print=print): 42 | for name, m in model.named_modules(): 43 | if isinstance(m, (nn.Conv2d, nn.Linear)): 44 | w_ = orthogonalize_weights(m.weight) 45 | m.weight.data.copy_(w_) 46 | print('Finished exact_isometry for layer "%s"' % name) 47 | 48 | class Pruner(MetaPruner): 49 | def __init__(self, model, args, logger, passer): 50 | super(Pruner, self).__init__(model, args, logger, passer) 51 | ckp = passer.ckp 52 | self.logprint = ckp.write_log_prune # use another log file specifically for pruning logs 53 | self.netprint = ckp.write_log_prune 54 | 55 | # ************************** variables from RCAN ************************** 56 | loader = passer.loader 57 | self.scale = args.scale 58 | 59 | self.ckp = ckp 60 | self.loader_train = loader.loader_train 61 | self.loader_test = loader.loader_test 62 | self.model = model 63 | 64 | self.error_last = 1e8 65 | # ************************************************************************** 66 | 67 | def prune(self): 68 | self._get_kept_wg_L1() 69 | self.logprint(f"==> Before _prune_and_build_new_model. Testing...") 70 | self.test() 71 | self._prune_and_build_new_model() 72 | self.logprint(f"==> Pruned and built a new model. Testing...") 73 | self.test() 74 | mask = self.mask if self.args.wg == 'weight' else None 75 | 76 | if self.args.reinit: 77 | if self.args.reinit in ['default', 'kaiming_normal']: 78 | self.model.apply(_weights_init) # completely reinit weights via 'kaiming_normal' 79 | self.logprint("==> Reinit model: default ('kaiming_normal' for Conv/FC; 0 mean, 1 std for BN)") 80 | 81 | elif self.args.reinit in ['orth', 'exact_isometry_from_scratch']: 82 | self.model.apply(lambda m: _weights_init_orthogonal(m, act=self.args.activation)) # reinit weights via 'orthogonal_' from scratch 83 | self.logprint("==> Reinit model: exact_isometry ('orthogonal_' for Conv/FC; 0 mean, 1 std for BN)") 84 | 85 | elif self.args.reinit == 'exact_isometry_based_on_existing': 86 | exact_isometry_based_on_existing_weights(self.model, print=self.logprint) # orthogonalize weights based on existing weights 87 | self.logprint("==> Reinit model: exact_isometry (orthogonalize Conv/FC weights based on existing weights)") 88 | 89 | elif self.args.reinit == 'approximate_isometry': # A Signal Propagation Perspective for Pruning Neural Networks at Initialization (ICLR 2020) 90 | approximate_isometry_optimize(self.model, mask=mask, lr=self.args.lr_AI, n_iter=10000, print=self.logprint) # 10000 refers to the paper above; lr in the paper is 0.1, but not converged here 91 | self.logprint("==> Reinit model: approximate_isometry") 92 | 93 | else: 94 | raise NotImplementedError 95 | 96 | return copy.deepcopy(self.model) 97 | 98 | def test(self): 99 | is_train = self.model.training 100 | torch.set_grad_enabled(False) 101 | 102 | self.ckp.write_log('Evaluation:') 103 | self.ckp.add_log( 104 | torch.zeros(1, len(self.loader_test), len(self.scale)) 105 | ) 106 | self.model.eval() 107 | 108 | timer_test = utility.timer() 109 | if self.args.save_results: self.ckp.begin_background() 110 | for idx_data, d in enumerate(self.loader_test): 111 | for idx_scale, scale in enumerate(self.scale): 112 | d.dataset.set_scale(idx_scale) 113 | for lr, hr, filename in tqdm(d, ncols=80): 114 | lr, hr = self.prepare(lr, hr) 115 | sr = self.model(lr, idx_scale) 116 | sr = utility.quantize(sr, self.args.rgb_range) 117 | 118 | save_list = [sr] 119 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 120 | sr, hr, scale, self.args.rgb_range, dataset=d 121 | ) 122 | if self.args.save_gt: 123 | save_list.extend([lr, hr]) 124 | 125 | if self.args.save_results: 126 | self.ckp.save_results(d, filename[0], save_list, scale) 127 | 128 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 129 | best = self.ckp.log.max(0) 130 | logstr = '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {}) [method: {} compare_mode: {}]'.format( 131 | d.dataset.name, 132 | scale, 133 | self.ckp.log[-1, idx_data, idx_scale], 134 | best[0][idx_data, idx_scale], 135 | best[1][idx_data, idx_scale] + 1, 136 | self.args.method, 137 | self.args.compare_mode, 138 | ) 139 | self.ckp.write_log(logstr) 140 | self.logprint(logstr) 141 | 142 | self.ckp.write_log('Forward: {:.2f}s'.format(timer_test.toc())) 143 | self.ckp.write_log('Saving...') 144 | 145 | if self.args.save_results: 146 | self.ckp.end_background() 147 | 148 | # if not self.args.test_only: 149 | # self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 150 | 151 | self.ckp.write_log( 152 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 153 | ) 154 | 155 | torch.set_grad_enabled(True) 156 | 157 | if is_train: 158 | self.model.train() 159 | 160 | def prepare(self, *args): 161 | device = torch.device('cpu' if self.args.cpu else 'cuda') 162 | def _prepare(tensor): 163 | if self.args.precision == 'half': tensor = tensor.half() 164 | return tensor.to(device) 165 | 166 | return [_prepare(a) for a in args] -------------------------------------------------------------------------------- /src/demo.sh: -------------------------------------------------------------------------------- 1 | # EDSR baseline model (x2) + JPEG augmentation 2 | python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset 3 | #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 4 | 5 | # EDSR baseline model (x3) - from EDSR baseline model (x2) 6 | #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] 7 | 8 | # EDSR baseline model (x4) - from EDSR baseline model (x2) 9 | #python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] 10 | 11 | # EDSR in the paper (x2) 12 | #python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset 13 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.2.0 --n_resblocks 80 --n_feats 64 14 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.2.0_2 --n_resblocks 80 --n_feats 64 15 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0 --n_resblocks 80 --n_feats 64 16 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0_2 --n_resblocks 80 --n_feats 64 17 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0_3 --n_resblocks 80 --n_feats 64 18 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0_4 --n_resblocks 80 --n_feats 64 19 | # EDSR in the paper (x3) - from EDSR (x2) 20 | #python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir] 21 | 22 | # EDSR in the paper (x4) - from EDSR (x2) 23 | #python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] 24 | 25 | # MDSR baseline model 26 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models 27 | 28 | # MDSR in the paper 29 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models 30 | 31 | # Standard benchmarks (Ex. EDSR_baseline_x4) 32 | #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble 33 | 34 | #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble 35 | 36 | # Test your own images 37 | #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results 38 | 39 | # Advanced - Test with JPEG images 40 | #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results 41 | 42 | # Advanced - Training with adversarial loss 43 | #python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download 44 | 45 | # RDN BI model (x2) 46 | #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset 47 | # RDN BI model (x3) 48 | #python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset 49 | # RDN BI model (x4) 50 | #python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset 51 | 52 | # RCAN_BIX2_G10R20P48, input=48x48, output=96x96 53 | # pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0 54 | #python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96 55 | # RCAN_BIX3_G10R20P48, input=48x48, output=144x144 56 | #python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt 57 | # RCAN_BIX4_G10R20P48, input=48x48, output=192x192 58 | #python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt 59 | # RCAN_BIX8_G10R20P48, input=48x48, output=384x384 60 | #python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt 61 | 62 | 63 | 64 | 65 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R16F64P48B16_DF2KBIX2 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 66 | 67 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R16F64P48B16_DF2KBIX3 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 68 | 69 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 4 --patch_size 192 --save EDSR_R16F64P48B16_DF2KBIX4 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 70 | 71 | # debug 72 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R16F64P48B16_DF2KBIX2_dtep_decay --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3551-3555 --test_every 100 --lr_decay 5 73 | 74 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R16F64P48B16_DF2KBDX3_dtep_decay --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BDX3 --data_train DF2K --data_test DF2K --data_range 1-3450/3551-3555 --test_every 100 --lr_decay 5 --chop 75 | 76 | 77 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R16F64P48B16_DF2KDNX3_dtep_decay --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/DNX3 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --test_every 100 --lr_decay 5 --chop --save_results --save_gt 78 | 79 | 80 | 81 | #### 82 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R16F64P48B16_DF2K_BIX2 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 83 | 84 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R80F64P48B16_DIV2K_BIX2 --n_resblocks 80 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DIV2K --data_test DIV2K --data_range 1-800/801-810 --chop 85 | 86 | 87 | CUDA_VISIBLE_DEVICES=3 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R80F64P48B16_DF2K_BIX2 --n_resblocks 80 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 88 | 89 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R32F256P48B16_DF2K_BIX2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 90 | 91 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R80F64P48B16_DF2K_BIX3 --n_resblocks 80 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 92 | 93 | # RCAN 94 | CUDA_VISIBLE_DEVICES=0 python main.py --model RCAN --scale 2 --patch_size 96 --save RCAN_G10R20P48B16_DF2K_BIX2 --n_resgroups 10 --n_resblocks 20 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = args.scale 15 | self.idx_scale = 0 16 | self.input_large = (args.model == 'VDSR') 17 | self.self_ensemble = args.self_ensemble 18 | self.chop = args.chop 19 | self.precision = args.precision 20 | self.cpu = args.cpu 21 | self.device = torch.device('cpu' if args.cpu else 'cuda') 22 | self.n_GPUs = args.n_GPUs 23 | self.save_models = args.save_models 24 | 25 | module = import_module('model.' + args.model.lower()) 26 | self.model = module.make_model(args).to(self.device) 27 | if args.precision == 'half': 28 | self.model.half() 29 | 30 | if ckp: 31 | load_from = self.load( 32 | ckp.get_path('model'), 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | if load_from and args.wn: # @mst 38 | for _, module in self.model.named_modules(): 39 | if hasattr(module, 'wn_scale'): 40 | module.init_wn() 41 | print(f'==> Weights are reloaded, reinit wn_scale.') 42 | 43 | 44 | def forward(self, x, idx_scale): 45 | self.idx_scale = idx_scale 46 | if hasattr(self.model, 'set_scale'): 47 | self.model.set_scale(idx_scale) 48 | 49 | if self.training: 50 | if self.n_GPUs > 1: 51 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 52 | else: 53 | return self.model(x) 54 | else: 55 | if self.chop: 56 | forward_function = self.forward_chop 57 | else: 58 | forward_function = self.model.forward 59 | 60 | if self.self_ensemble: 61 | return self.forward_x8(x, forward_function=forward_function) 62 | else: 63 | return forward_function(x) 64 | 65 | def save(self, apath, epoch, is_best=False): 66 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 67 | 68 | if is_best: 69 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 70 | if self.save_models: 71 | save_dirs.append( 72 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 73 | ) 74 | 75 | for s in save_dirs: 76 | to_save = { 77 | 'state_dict': self.model.state_dict(), 78 | 'arch': self.model, 79 | } 80 | torch.save(to_save, s) 81 | 82 | def load(self, apath, pre_train='', resume=-1, cpu=False): 83 | load_from = None 84 | kwargs = {} 85 | if cpu: 86 | kwargs = {'map_location': lambda storage, loc: storage} 87 | 88 | if resume == -1: 89 | load_from = torch.load( 90 | os.path.join(apath, 'model_latest.pt'), 91 | **kwargs 92 | ) 93 | elif resume == 0: 94 | if pre_train == 'download': 95 | print('Download the model') 96 | dir_model = os.path.join('..', 'models') 97 | os.makedirs(dir_model, exist_ok=True) 98 | load_from = torch.utils.model_zoo.load_url( 99 | self.model.url, 100 | model_dir=dir_model, 101 | **kwargs 102 | ) 103 | elif pre_train: 104 | print('Load the model from {}'.format(pre_train)) 105 | load_from = torch.load(pre_train, **kwargs) 106 | else: 107 | load_from = torch.load( 108 | os.path.join(apath, 'model_{}.pt'.format(resume)), 109 | **kwargs 110 | ) 111 | 112 | if load_from: 113 | if 'state_dict' in load_from: # @mst: for pruned models, load the pruned model arch as 'self.model' 114 | self.model = load_from['arch'] 115 | self.model.load_state_dict(load_from['state_dict'], strict=False) 116 | else: 117 | self.model.load_state_dict(load_from, strict=False) 118 | return load_from 119 | 120 | # shave = 10, min_size=160000 121 | def forward_chop(self, x, shave=10, min_size=160000): 122 | scale = self.scale[self.idx_scale] 123 | n_GPUs = min(self.n_GPUs, 4) 124 | b, c, h, w = x.size() 125 | h_half, w_half = h // 2, w // 2 126 | h_size, w_size = h_half + shave, w_half + shave 127 | lr_list = [ 128 | x[:, :, 0:h_size, 0:w_size], 129 | x[:, :, 0:h_size, (w - w_size):w], 130 | x[:, :, (h - h_size):h, 0:w_size], 131 | x[:, :, (h - h_size):h, (w - w_size):w]] 132 | 133 | if w_size * h_size < min_size: 134 | sr_list = [] 135 | for i in range(0, 4, n_GPUs): 136 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 137 | sr_batch = self.model(lr_batch) 138 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 139 | else: 140 | sr_list = [ 141 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 142 | for patch in lr_list 143 | ] 144 | 145 | h, w = scale * h, scale * w 146 | h_half, w_half = scale * h_half, scale * w_half 147 | h_size, w_size = scale * h_size, scale * w_size 148 | shave *= scale 149 | 150 | output = x.new(b, c, h, w) 151 | output[:, :, 0:h_half, 0:w_half] \ 152 | = sr_list[0][:, :, 0:h_half, 0:w_half] 153 | output[:, :, 0:h_half, w_half:w] \ 154 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 155 | output[:, :, h_half:h, 0:w_half] \ 156 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 157 | output[:, :, h_half:h, w_half:w] \ 158 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 159 | 160 | return output 161 | 162 | def forward_x8(self, *args, forward_function=None): 163 | def _transform(v, op): 164 | if self.precision != 'single': v = v.float() 165 | 166 | v2np = v.data.cpu().numpy() 167 | if op == 'v': 168 | tfnp = v2np[:, :, :, ::-1].copy() 169 | elif op == 'h': 170 | tfnp = v2np[:, :, ::-1, :].copy() 171 | elif op == 't': 172 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 173 | 174 | ret = torch.Tensor(tfnp).to(self.device) 175 | if self.precision == 'half': ret = ret.half() 176 | 177 | return ret 178 | 179 | list_x = [] 180 | for a in args: 181 | x = [a] 182 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 183 | 184 | list_x.append(x) 185 | 186 | list_y = [] 187 | for x in zip(*list_x): 188 | y = forward_function(*x) 189 | if not isinstance(y, list): y = [y] 190 | if not list_y: 191 | list_y = [[_y] for _y in y] 192 | else: 193 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 194 | 195 | for _list_y in list_y: 196 | for i in range(len(_list_y)): 197 | if i > 3: 198 | _list_y[i] = _transform(_list_y[i], 't') 199 | if i % 4 > 1: 200 | _list_y[i] = _transform(_list_y[i], 'h') 201 | if (i % 4) % 2 == 1: 202 | _list_y[i] = _transform(_list_y[i], 'v') 203 | 204 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 205 | if len(y) == 1: y = y[0] 206 | 207 | return y 208 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASSL (NeurIPS'21 Spotlight) 2 | 3 |
8 | 9 | This repository is for a new network pruning method (`Aligned Structured Sparsity Learning, ASSL`) for efficient single image super-resolution (SR), introduced in our NeurIPS 2021 **Spotlight** paper: 10 | > **Aligned Structured Sparsity Learning for Efficient Image Super-Resolution [[Camera Ready](https://papers.nips.cc/paper/2021/file/15de21c670ae7c3f6f3f1f37029303c9-Paper.pdf)] [[Visual Results](https://github.com/MingSun-Tse/ASSL/releases)]** \ 11 | > [Yulun Zhang*](http://yulunzhang.com/), [Huan Wang*](http://huanwang.tech/), [Can Qin](http://canqin.tech/), and [Yun Fu](http://www1.ece.neu.edu/~yunfu/) (*equal contribution) \ 12 | > Northeastern University, Boston, MA, USA 13 | 14 | 15 | ## Introduction 16 |
18 |
78 |
84 |