├── src ├── __init__.py ├── .gitignore ├── data │ ├── sr291.py │ ├── div2kjpeg.py │ ├── benchmark.py │ ├── demo.py │ ├── df2k.py │ ├── div2k.py │ ├── video.py │ ├── common.py │ ├── __init__.py │ ├── srdata_no_bin.py │ ├── srdata.py │ └── srdata_bin.py ├── requirements.txt ├── pruner │ ├── __init__.py │ ├── meta_pruner.py │ ├── l1_pruner.py │ ├── utils.py │ └── assl_pruner.py ├── cal_modelsize.py ├── loss │ ├── vgg.py │ ├── discriminator.py │ ├── adversarial.py │ └── __init__.py ├── model │ ├── vdsr.py │ ├── mdsr.py │ ├── edsr.py │ ├── ledsr.py │ ├── rdn.py │ ├── ddbpn.py │ ├── rcan.py │ ├── common.py │ ├── rirsr.py │ └── __init__.py ├── template.py ├── videotester.py ├── layer.py ├── trainer.py ├── dataloader.py ├── main.py ├── demo.sh ├── utility.py └── option.py ├── figs ├── neu.png ├── smile.png ├── psnr_ssim.png ├── NIPS21_ASSL.png └── visual_urban100_x4.png └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/neu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/neu.png -------------------------------------------------------------------------------- /figs/smile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/smile.png -------------------------------------------------------------------------------- /figs/psnr_ssim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/psnr_ssim.png -------------------------------------------------------------------------------- /figs/NIPS21_ASSL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/NIPS21_ASSL.png -------------------------------------------------------------------------------- /figs/visual_urban100_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MingSun-Tse/ASSL/HEAD/figs/visual_urban100_x4.png -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .script_history 3 | Debug_Dir 4 | Experiments 5 | data/cifar10* 6 | data/mnist 7 | data/imagenet* 8 | base_models 9 | model/*.th 10 | train_params/ 11 | -------------------------------------------------------------------------------- /src/data/sr291.py: -------------------------------------------------------------------------------- 1 | from data import srdata 2 | 3 | class SR291(srdata.SRData): 4 | def __init__(self, args, name='SR291', train=True, benchmark=False): 5 | super(SR291, self).__init__(args, name=name) 6 | 7 | -------------------------------------------------------------------------------- /src/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.2.0 2 | scikit_image==0.16.2 3 | imageio==2.9.0 4 | pandas==1.0.5 5 | matplotlib==3.2.2 6 | numpy==1.18.5 7 | tqdm==4.47.0 8 | scipy==1.5.0 9 | torchvision==0.4.0a0+6b959ee 10 | Pillow==8.1.1 11 | PyYAML==5.4.1 12 | skimage==0.0 13 | torchsummaryX==1.3.0 14 | -------------------------------------------------------------------------------- /src/pruner/__init__.py: -------------------------------------------------------------------------------- 1 | from . import l1_pruner, assl_pruner 2 | 3 | # when new pruner implementation is added in the 'pruner' dir, update this dict to maintain minimal code change. 4 | # key: pruning method name, value: the corresponding pruner 5 | pruner_dict = { 6 | 'L1': l1_pruner, 7 | 'ASSL': assl_pruner, 8 | } -------------------------------------------------------------------------------- /src/cal_modelsize.py: -------------------------------------------------------------------------------- 1 | from torchsummaryX import summary 2 | from importlib import import_module 3 | import torch 4 | import model 5 | from model import edsr 6 | from option import args 7 | import sys 8 | # checkpoint = utility.checkpoint(args) 9 | # my_model = model.Model(args) 10 | 11 | Net = import_module('model.' + args.model.lower()) 12 | net = eval("Net.%s" % args.model.upper())(args).cuda() 13 | 14 | # net = edsr.EDSR(args) 15 | height_lr = 1280 // args.scale[0] 16 | width_lr = 720 // args.scale[0] 17 | input = torch.zeros((1, 3, height_lr, width_lr)).cuda() 18 | 19 | summary(net, input) 20 | -------------------------------------------------------------------------------- /src/data/div2kjpeg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | from data import div2k 4 | 5 | class DIV2KJPEG(div2k.DIV2K): 6 | def __init__(self, args, name='', train=True, benchmark=False): 7 | self.q_factor = int(name.replace('DIV2K-Q', '')) 8 | super(DIV2KJPEG, self).__init__( 9 | args, name=name, train=train, benchmark=benchmark 10 | ) 11 | 12 | def _set_filesystem(self, dir_data): 13 | self.apath = os.path.join(dir_data, 'DIV2K') 14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 15 | self.dir_lr = os.path.join( 16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor) 17 | ) 18 | if self.input_large: self.dir_lr += 'L' 19 | self.ext = ('.png', '.jpg') 20 | 21 | -------------------------------------------------------------------------------- /src/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /src/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = benchmark 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] 28 | lr = imageio.imread(self.filelist[idx]) 29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 31 | 32 | return lr_t, -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /src/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | class VGG(nn.Module): 9 | def __init__(self, conv_index, rgb_range=1): 10 | super(VGG, self).__init__() 11 | vgg_features = models.vgg19(pretrained=True).features 12 | modules = [m for m in vgg_features] 13 | if conv_index.find('22') >= 0: 14 | self.vgg = nn.Sequential(*modules[:8]) 15 | elif conv_index.find('54') >= 0: 16 | self.vgg = nn.Sequential(*modules[:35]) 17 | 18 | vgg_mean = (0.485, 0.456, 0.406) 19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /src/data/df2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DF2K(srdata.SRData): 5 | def __init__(self, args, name='DF2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DF2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | # def _scan(self): 21 | # names_hr, names_lr = super(DF2K, self)._scan() 22 | # names_hr = names_hr[self.begin - 1:self.end] 23 | # names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | # return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DF2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DF2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DF2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /src/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | # def _scan(self): 21 | # names_hr, names_lr = super(DIV2K, self)._scan() 22 | # names_hr = names_hr[self.begin - 1:self.end] 23 | # names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | # return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 31 | if self.input_large: self.dir_lr += 'L' 32 | 33 | -------------------------------------------------------------------------------- /src/data/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import cv2 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class Video(data.Dataset): 13 | def __init__(self, args, name='Video', train=False, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.scale = args.scale 17 | self.idx_scale = 0 18 | self.train = False 19 | self.do_eval = False 20 | self.benchmark = benchmark 21 | 22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 23 | self.vidcap = cv2.VideoCapture(args.dir_demo) 24 | self.n_frames = 0 25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 26 | 27 | def __getitem__(self, idx): 28 | success, lr = self.vidcap.read() 29 | if success: 30 | self.n_frames += 1 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) 35 | else: 36 | vidcap.release() 37 | return None 38 | 39 | def __len__(self): 40 | return self.total_frames 41 | 42 | def set_scale(self, idx_scale): 43 | self.idx_scale = idx_scale 44 | 45 | -------------------------------------------------------------------------------- /src/model/vdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | 6 | url = { 7 | 'r20f64': '' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return VDSR(args) 12 | 13 | class VDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(VDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 21 | self.sub_mean = common.MeanShift(args.rgb_range) 22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 23 | 24 | def basic_block(in_channels, out_channels, act): 25 | return common.BasicBlock( 26 | conv, in_channels, out_channels, kernel_size, 27 | bias=True, bn=False, act=act 28 | ) 29 | 30 | # define body module 31 | m_body = [] 32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) 33 | for _ in range(n_resblocks - 2): 34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) 35 | m_body.append(basic_block(n_feats, args.n_colors, None)) 36 | 37 | self.body = nn.Sequential(*m_body) 38 | 39 | def forward(self, x): 40 | x = self.sub_mean(x) 41 | res = self.body(x) 42 | res += x 43 | x = self.add_mean(res) 44 | 45 | return x 46 | 47 | -------------------------------------------------------------------------------- /src/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.decay = '100' 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.decay = '500' 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.decay = '150' 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | if args.template.find('VDSR') >= 0: 48 | args.model = 'VDSR' 49 | args.n_resblocks = 20 50 | args.n_feats = 64 51 | args.patch_size = 41 52 | args.lr = 1e-1 53 | 54 | -------------------------------------------------------------------------------- /src/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | ''' 7 | output is not normalized 8 | ''' 9 | def __init__(self, args): 10 | super(Discriminator, self).__init__() 11 | 12 | in_channels = args.n_colors 13 | out_channels = 64 14 | depth = 7 15 | 16 | def _block(_in_channels, _out_channels, stride=1): 17 | return nn.Sequential( 18 | nn.Conv2d( 19 | _in_channels, 20 | _out_channels, 21 | 3, 22 | padding=1, 23 | stride=stride, 24 | bias=False 25 | ), 26 | nn.BatchNorm2d(_out_channels), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 28 | ) 29 | 30 | m_features = [_block(in_channels, out_channels)] 31 | for i in range(depth): 32 | in_channels = out_channels 33 | if i % 2 == 1: 34 | stride = 1 35 | out_channels *= 2 36 | else: 37 | stride = 2 38 | m_features.append(_block(in_channels, out_channels, stride=stride)) 39 | 40 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 41 | m_classifier = [ 42 | nn.Linear(out_channels * patch_size**2, 1024), 43 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 44 | nn.Linear(1024, 1) 45 | ] 46 | 47 | self.features = nn.Sequential(*m_features) 48 | self.classifier = nn.Sequential(*m_classifier) 49 | 50 | def forward(self, x): 51 | features = self.features(x) 52 | output = self.classifier(features.view(features.size(0), -1)) 53 | 54 | return output 55 | 56 | -------------------------------------------------------------------------------- /src/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = scale if multi else 1 13 | tp = p * patch_size 14 | ip = tp // scale 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = scale * ix, scale * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 42 | elif n_channels == 3 and c == 1: 43 | img = np.concatenate([img] * n_channels, 2) 44 | 45 | return img 46 | 47 | return [_set_channel(a) for a in args] 48 | 49 | def np2Tensor(*args, rgb_range=255): 50 | def _np2Tensor(img): 51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 52 | tensor = torch.from_numpy(np_transpose).float() 53 | tensor.mul_(rgb_range / 255) 54 | 55 | return tensor 56 | 57 | return [_np2Tensor(a) for a in args] 58 | 59 | def augment(*args, hflip=True, rot=True): 60 | hflip = hflip and random.random() < 0.5 61 | vflip = rot and random.random() < 0.5 62 | rot90 = rot and random.random() < 0.5 63 | 64 | def _augment(img): 65 | if hflip: img = img[:, ::-1, :] 66 | if vflip: img = img[::-1, :, :] 67 | if rot90: img = img.transpose(1, 0, 2) 68 | 69 | return img 70 | 71 | return [_augment(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /src/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt', 7 | 'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt' 8 | } 9 | 10 | def make_model(args, parent=False): 11 | return MDSR(args) 12 | 13 | class MDSR(nn.Module): 14 | def __init__(self, args, conv=common.default_conv): 15 | super(MDSR, self).__init__() 16 | n_resblocks = args.n_resblocks 17 | n_feats = args.n_feats 18 | kernel_size = 3 19 | act = nn.ReLU(True) 20 | self.scale_idx = 0 21 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)] 22 | self.sub_mean = common.MeanShift(args.rgb_range) 23 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 24 | 25 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 26 | 27 | self.pre_process = nn.ModuleList([ 28 | nn.Sequential( 29 | common.ResBlock(conv, n_feats, 5, act=act), 30 | common.ResBlock(conv, n_feats, 5, act=act) 31 | ) for _ in args.scale 32 | ]) 33 | 34 | m_body = [ 35 | common.ResBlock( 36 | conv, n_feats, kernel_size, act=act 37 | ) for _ in range(n_resblocks) 38 | ] 39 | m_body.append(conv(n_feats, n_feats, kernel_size)) 40 | 41 | self.upsample = nn.ModuleList([ 42 | common.Upsampler(conv, s, n_feats, act=False) for s in args.scale 43 | ]) 44 | 45 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 46 | 47 | self.head = nn.Sequential(*m_head) 48 | self.body = nn.Sequential(*m_body) 49 | self.tail = nn.Sequential(*m_tail) 50 | 51 | def forward(self, x): 52 | x = self.sub_mean(x) 53 | x = self.head(x) 54 | x = self.pre_process[self.scale_idx](x) 55 | 56 | res = self.body(x) 57 | res += x 58 | 59 | x = self.upsample[self.scale_idx](res) 60 | x = self.tail(x) 61 | x = self.add_mean(x) 62 | 63 | return x 64 | 65 | def set_scale(self, scale_idx): 66 | self.scale_idx = scale_idx 67 | 68 | -------------------------------------------------------------------------------- /src/videotester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import utility 5 | from data import common 6 | 7 | import torch 8 | import cv2 9 | 10 | from tqdm import tqdm 11 | 12 | class VideoTester(): 13 | def __init__(self, args, my_model, ckp): 14 | self.args = args 15 | self.scale = args.scale 16 | 17 | self.ckp = ckp 18 | self.model = my_model 19 | 20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) 21 | 22 | def test(self): 23 | torch.set_grad_enabled(False) 24 | 25 | self.ckp.write_log('\nEvaluation on video:') 26 | self.model.eval() 27 | 28 | timer_test = utility.timer() 29 | for idx_scale, scale in enumerate(self.scale): 30 | vidcap = cv2.VideoCapture(self.args.dir_demo) 31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 32 | vidwri = cv2.VideoWriter( 33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), 34 | cv2.VideoWriter_fourcc(*'XVID'), 35 | vidcap.get(cv2.CAP_PROP_FPS), 36 | ( 37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), 38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | ) 40 | ) 41 | 42 | tqdm_test = tqdm(range(total_frames), ncols=80) 43 | for _ in tqdm_test: 44 | success, lr = vidcap.read() 45 | if not success: break 46 | 47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 49 | lr, = self.prepare(lr.unsqueeze(0)) 50 | sr = self.model(lr, idx_scale) 51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) 52 | 53 | normalized = sr * 255 / self.args.rgb_range 54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 55 | vidwri.write(ndarr) 56 | 57 | vidcap.release() 58 | vidwri.release() 59 | 60 | self.ckp.write_log( 61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 62 | ) 63 | torch.set_grad_enabled(True) 64 | 65 | def prepare(self, *args): 66 | device = torch.device('cpu' if self.args.cpu else 'cuda') 67 | def _prepare(tensor): 68 | if self.args.precision == 'half': tensor = tensor.half() 69 | return tensor.to(device) 70 | 71 | return [_prepare(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 12 | } 13 | 14 | def make_model(args, parent=False): 15 | return EDSR(args) 16 | 17 | class EDSR(nn.Module): 18 | def __init__(self, args, conv=common.default_conv): 19 | super(EDSR, self).__init__() 20 | 21 | n_resblocks = args.n_resblocks 22 | n_feats = args.n_feats 23 | kernel_size = 3 24 | scale = args.scale[0] 25 | act = nn.ReLU(True) 26 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 27 | if url_name in url: 28 | self.url = url[url_name] 29 | else: 30 | self.url = None 31 | self.sub_mean = common.MeanShift(args.rgb_range) 32 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 33 | 34 | # use conv with weight normalization 35 | if args.wn: 36 | conv = common.wn_conv 37 | 38 | # define head module 39 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 40 | 41 | # define body module 42 | m_body = [ 43 | common.ResBlock( 44 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale, 45 | ) for _ in range(n_resblocks) 46 | ] 47 | m_body.append(conv(n_feats, n_feats, kernel_size)) 48 | 49 | # define tail module 50 | m_tail = [ 51 | common.Upsampler(conv, scale, n_feats, act=False), 52 | conv(n_feats, args.n_colors, kernel_size) 53 | ] 54 | 55 | self.head = nn.Sequential(*m_head) 56 | self.body = nn.Sequential(*m_body) 57 | self.tail = nn.Sequential(*m_tail) 58 | 59 | def forward(self, x): 60 | x = self.sub_mean(x) 61 | x = self.head(x) 62 | 63 | res = self.body(x) 64 | res += x 65 | 66 | x = self.tail(res) 67 | x = self.add_mean(x) 68 | 69 | return x 70 | 71 | def load_state_dict(self, state_dict, strict=True): 72 | own_state = self.state_dict() 73 | for name, param in state_dict.items(): 74 | if name in own_state: 75 | if isinstance(param, nn.Parameter): 76 | param = param.data 77 | try: 78 | own_state[name].copy_(param) 79 | except Exception: 80 | if name.find('tail') == -1: 81 | raise RuntimeError('While copying the parameter named {}, ' 82 | 'whose dimensions in the model are {} and ' 83 | 'whose dimensions in the checkpoint are {}.' 84 | .format(name, own_state[name].size(), param.size())) 85 | elif strict: 86 | if name.find('tail') == -1: 87 | raise KeyError('unexpected key "{}" in state_dict' 88 | .format(name)) 89 | 90 | -------------------------------------------------------------------------------- /src/model/ledsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | url = { 6 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 7 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 8 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 9 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 10 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 11 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 12 | } 13 | 14 | def make_model(args, parent=False): 15 | return LEDSR(args) 16 | 17 | class LEDSR(nn.Module): 18 | def __init__(self, args, conv=common.default_conv): 19 | super(LEDSR, self).__init__() 20 | 21 | n_resblocks = args.n_resblocks 22 | n_feats = args.n_feats 23 | kernel_size = 3 24 | scale = args.scale[0] 25 | act = nn.ReLU(True) 26 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 27 | if url_name in url: 28 | self.url = url[url_name] 29 | else: 30 | self.url = None 31 | self.sub_mean = common.MeanShift(args.rgb_range) 32 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 33 | 34 | # use conv with weight normalization 35 | if args.wn: 36 | conv = common.wn_conv 37 | 38 | # define head module 39 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 40 | 41 | # define body module 42 | m_body = [ 43 | common.ResBlock( 44 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 45 | ) for _ in range(n_resblocks) 46 | ] 47 | m_body.append(conv(n_feats, n_feats, kernel_size)) 48 | 49 | # define tail module 50 | # m_tail = [ 51 | # common.Upsampler(conv, scale, n_feats, act=False), 52 | # conv(n_feats, args.n_colors, kernel_size) 53 | # ] 54 | m_tail = [ 55 | common.LiteUpsampler(conv, scale, n_feats, args.n_colors, act=False) 56 | ] 57 | self.head = nn.Sequential(*m_head) 58 | self.body = nn.Sequential(*m_body) 59 | self.tail = nn.Sequential(*m_tail) 60 | 61 | def forward(self, x): 62 | x = self.sub_mean(x) 63 | x = self.head(x) 64 | 65 | res = self.body(x) 66 | res += x 67 | 68 | x = self.tail(res) 69 | x = self.add_mean(x) 70 | 71 | return x 72 | 73 | def load_state_dict(self, state_dict, strict=True): 74 | own_state = self.state_dict() 75 | for name, param in state_dict.items(): 76 | if name in own_state: 77 | if isinstance(param, nn.Parameter): 78 | param = param.data 79 | try: 80 | own_state[name].copy_(param) 81 | except Exception: 82 | if name.find('tail') == -1: 83 | raise RuntimeError('While copying the parameter named {}, ' 84 | 'whose dimensions in the model are {} and ' 85 | 'whose dimensions in the checkpoint are {}.' 86 | .format(name, own_state[name].size(), param.size())) 87 | elif strict: 88 | if name.find('tail') == -1: 89 | raise KeyError('unexpected key "{}" in state_dict' 90 | .format(name)) 91 | 92 | -------------------------------------------------------------------------------- /src/model/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return RDN(args) 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | r = args.scale[0] 49 | G0 = args.G0 50 | kSize = args.RDNkSize 51 | 52 | # number of RDB blocks, conv layers, out channels 53 | self.D, C, G = { 54 | 'A': (20, 6, 32), 55 | 'B': (16, 8, 64), 56 | }[args.RDNconfig] 57 | 58 | # Shallow feature extraction net 59 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 60 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self.RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self.RDBs.append( 66 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 73 | ]) 74 | 75 | # Up-sampling net 76 | if r == 2 or r == 3: 77 | self.UPNet = nn.Sequential(*[ 78 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 79 | nn.PixelShuffle(r), 80 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 81 | ]) 82 | elif r == 4: 83 | self.UPNet = nn.Sequential(*[ 84 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 85 | nn.PixelShuffle(2), 86 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 87 | nn.PixelShuffle(2), 88 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 89 | ]) 90 | else: 91 | raise ValueError("scale must be 2 or 3 or 4.") 92 | 93 | def forward(self, x): 94 | f__1 = self.SFENet1(x) 95 | x = self.SFENet2(f__1) 96 | 97 | RDBs_out = [] 98 | for i in range(self.D): 99 | x = self.RDBs[i](x) 100 | RDBs_out.append(x) 101 | 102 | x = self.GFF(torch.cat(RDBs_out,1)) 103 | x += f__1 104 | 105 | return self.UPNet(x) 106 | -------------------------------------------------------------------------------- /src/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import numpy as np 5 | from collections import OrderedDict 6 | tensor2list = lambda x: x.data.cpu().numpy().tolist() 7 | 8 | class Layer: 9 | def __init__(self, name, size, layer_index, module, res=False, layer_type=None): 10 | self.name = name 11 | self.module = module 12 | self.size = [] # deprecated in support of 'shape' 13 | for x in size: 14 | self.size.append(x) 15 | self.shape = self.size 16 | self.layer_index = layer_index # deprecated in support of 'index' 17 | self.index = layer_index 18 | self.layer_type = layer_type # deprecated in support of 'type' 19 | self.type = layer_type 20 | self.is_shortcut = True if "downsample" in name else False 21 | # if res: 22 | # self.stage, self.seq_index, self.block_index = self._get_various_index_by_name(name) 23 | 24 | def _get_various_index_by_name(self, name): 25 | '''Get the indeces including stage, seq_ix, blk_ix. 26 | Same stage means the same feature map size. 27 | ''' 28 | global lastest_stage # an awkward impel, just for now 29 | if name.startswith('module.'): 30 | name = name[7:] # remove the prefix caused by pytorch data parallel 31 | 32 | if "conv1" == name: # TODO: this might not be so safe 33 | lastest_stage = 0 34 | return 0, None, None 35 | if "linear" in name or 'fc' in name: # Note: this can be risky. Check it fully. TODO: @mingsun-tse 36 | return lastest_stage + 1, None, None # fc layer should always be the last layer 37 | else: 38 | try: 39 | stage = int(name.split(".")[0][-1]) # ONLY work for standard resnets. name example: layer2.2.conv1, layer4.0.downsample.0 40 | seq_ix = int(name.split(".")[1]) 41 | if 'conv' in name.split(".")[-1]: 42 | blk_ix = int(name[-1]) - 1 43 | else: 44 | blk_ix = -1 # shortcut layer 45 | lastest_stage = stage 46 | return stage, seq_ix, blk_ix 47 | except: 48 | print('! Parsing the layer name failed: %s. Please check.' % name) 49 | 50 | class LayerStruct: 51 | def __init__(self, model, LEARNABLES): 52 | self.model = model 53 | self.LEARNABLES = LEARNABLES 54 | self.register_layers() 55 | self.get_print_prefix() 56 | self.print_layer_stats() 57 | 58 | def register_layers(self): 59 | """This will maintain a data structure that can return some useful information by the name of a layer. 60 | TODO-@mst: Update this: https://github.com/MingSun-Tse/Pruning/blob/2bb7012d81e3c8326a2f756782b41c3d7ca9da21/pruner/meta_pruner.py#L65 61 | """ 62 | self.layers = OrderedDict() 63 | self._max_len_name = 0 64 | self._max_len_shape = 0 65 | 66 | ix = -1 # layer index, starts from 0 67 | for name, m in self.model.named_modules(): 68 | if isinstance(m, self.LEARNABLES): 69 | if "downsample" not in name: 70 | ix += 1 71 | self._max_len_name = max(self._max_len_name, len(name)) 72 | self.layers[name] = Layer(name, size=m.weight.size(), layer_index=ix, module=m, layer_type=m.__class__.__name__) 73 | self._max_len_shape = max(self._max_len_shape, len(str(self.layers[name].shape))) 74 | 75 | self._max_len_ix = len(str(ix)) 76 | self.num_layers = ix + 1 77 | 78 | def get_print_prefix(self): 79 | self.print_prefix = OrderedDict() 80 | for name, layer in self.layers.items(): 81 | format_str = f"[%{self._max_len_ix}d] %{self._max_len_name}s %{self._max_len_shape}s" 82 | self.print_prefix[name] = format_str % (layer.index, name, layer.shape) 83 | 84 | def print_layer_stats(self): 85 | print('************************ Layer Statistics ************************') 86 | for name, layer in self.layers.items(): 87 | print(f'{self.print_prefix[name]}') 88 | print('******************************************************************') -------------------------------------------------------------------------------- /src/model/ddbpn.py: -------------------------------------------------------------------------------- 1 | # Deep Back-Projection Networks For Super-Resolution 2 | # https://arxiv.org/abs/1803.02735 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return DDBPN(args) 12 | 13 | def projection_conv(in_channels, out_channels, scale, up=True): 14 | kernel_size, stride, padding = { 15 | 2: (6, 2, 2), 16 | 4: (8, 4, 2), 17 | 8: (12, 8, 2) 18 | }[scale] 19 | if up: 20 | conv_f = nn.ConvTranspose2d 21 | else: 22 | conv_f = nn.Conv2d 23 | 24 | return conv_f( 25 | in_channels, out_channels, kernel_size, 26 | stride=stride, padding=padding 27 | ) 28 | 29 | class DenseProjection(nn.Module): 30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 31 | super(DenseProjection, self).__init__() 32 | if bottleneck: 33 | self.bottleneck = nn.Sequential(*[ 34 | nn.Conv2d(in_channels, nr, 1), 35 | nn.PReLU(nr) 36 | ]) 37 | inter_channels = nr 38 | else: 39 | self.bottleneck = None 40 | inter_channels = in_channels 41 | 42 | self.conv_1 = nn.Sequential(*[ 43 | projection_conv(inter_channels, nr, scale, up), 44 | nn.PReLU(nr) 45 | ]) 46 | self.conv_2 = nn.Sequential(*[ 47 | projection_conv(nr, inter_channels, scale, not up), 48 | nn.PReLU(inter_channels) 49 | ]) 50 | self.conv_3 = nn.Sequential(*[ 51 | projection_conv(inter_channels, nr, scale, up), 52 | nn.PReLU(nr) 53 | ]) 54 | 55 | def forward(self, x): 56 | if self.bottleneck is not None: 57 | x = self.bottleneck(x) 58 | 59 | a_0 = self.conv_1(x) 60 | b_0 = self.conv_2(a_0) 61 | e = b_0.sub(x) 62 | a_1 = self.conv_3(e) 63 | 64 | out = a_0.add(a_1) 65 | 66 | return out 67 | 68 | class DDBPN(nn.Module): 69 | def __init__(self, args): 70 | super(DDBPN, self).__init__() 71 | scale = args.scale[0] 72 | 73 | n0 = 128 74 | nr = 32 75 | self.depth = 6 76 | 77 | rgb_mean = (0.4488, 0.4371, 0.4040) 78 | rgb_std = (1.0, 1.0, 1.0) 79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 80 | initial = [ 81 | nn.Conv2d(args.n_colors, n0, 3, padding=1), 82 | nn.PReLU(n0), 83 | nn.Conv2d(n0, nr, 1), 84 | nn.PReLU(nr) 85 | ] 86 | self.initial = nn.Sequential(*initial) 87 | 88 | self.upmodules = nn.ModuleList() 89 | self.downmodules = nn.ModuleList() 90 | channels = nr 91 | for i in range(self.depth): 92 | self.upmodules.append( 93 | DenseProjection(channels, nr, scale, True, i > 1) 94 | ) 95 | if i != 0: 96 | channels += nr 97 | 98 | channels = nr 99 | for i in range(self.depth - 1): 100 | self.downmodules.append( 101 | DenseProjection(channels, nr, scale, False, i != 0) 102 | ) 103 | channels += nr 104 | 105 | reconstruction = [ 106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 107 | ] 108 | self.reconstruction = nn.Sequential(*reconstruction) 109 | 110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 111 | 112 | def forward(self, x): 113 | x = self.sub_mean(x) 114 | x = self.initial(x) 115 | 116 | h_list = [] 117 | l_list = [] 118 | for i in range(self.depth - 1): 119 | if i == 0: 120 | l = x 121 | else: 122 | l = torch.cat(l_list, dim=1) 123 | h_list.append(self.upmodules[i](l)) 124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 125 | 126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 127 | out = self.reconstruction(torch.cat(h_list, dim=1)) 128 | out = self.add_mean(out) 129 | 130 | return out 131 | 132 | -------------------------------------------------------------------------------- /src/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from types import SimpleNamespace 3 | 4 | from model import common 5 | from loss import discriminator 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | 12 | class Adversarial(nn.Module): 13 | def __init__(self, args, gan_type): 14 | super(Adversarial, self).__init__() 15 | self.gan_type = gan_type 16 | self.gan_k = args.gan_k 17 | self.dis = discriminator.Discriminator(args) 18 | if gan_type == 'WGAN_GP': 19 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4 20 | optim_dict = { 21 | 'optimizer': 'ADAM', 22 | 'betas': (0, 0.9), 23 | 'epsilon': 1e-8, 24 | 'lr': 1e-5, 25 | 'weight_decay': args.weight_decay, 26 | 'decay': args.decay, 27 | 'gamma': args.gamma 28 | } 29 | optim_args = SimpleNamespace(**optim_dict) 30 | else: 31 | optim_args = args 32 | 33 | self.optimizer = utility.make_optimizer(optim_args, self.dis) 34 | 35 | def forward(self, fake, real): 36 | # updating discriminator... 37 | self.loss = 0 38 | fake_detach = fake.detach() # do not backpropagate through G 39 | for _ in range(self.gan_k): 40 | self.optimizer.zero_grad() 41 | # d: B x 1 tensor 42 | d_fake = self.dis(fake_detach) 43 | d_real = self.dis(real) 44 | retain_graph = False 45 | if self.gan_type == 'GAN': 46 | loss_d = self.bce(d_real, d_fake) 47 | elif self.gan_type.find('WGAN') >= 0: 48 | loss_d = (d_fake - d_real).mean() 49 | if self.gan_type.find('GP') >= 0: 50 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 51 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 52 | hat.requires_grad = True 53 | d_hat = self.dis(hat) 54 | gradients = torch.autograd.grad( 55 | outputs=d_hat.sum(), inputs=hat, 56 | retain_graph=True, create_graph=True, only_inputs=True 57 | )[0] 58 | gradients = gradients.view(gradients.size(0), -1) 59 | gradient_norm = gradients.norm(2, dim=1) 60 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 61 | loss_d += gradient_penalty 62 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks 63 | elif self.gan_type == 'RGAN': 64 | better_real = d_real - d_fake.mean(dim=0, keepdim=True) 65 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True) 66 | loss_d = self.bce(better_real, better_fake) 67 | retain_graph = True 68 | 69 | # Discriminator update 70 | self.loss += loss_d.item() 71 | loss_d.backward(retain_graph=retain_graph) 72 | self.optimizer.step() 73 | 74 | if self.gan_type == 'WGAN': 75 | for p in self.dis.parameters(): 76 | p.data.clamp_(-1, 1) 77 | 78 | self.loss /= self.gan_k 79 | 80 | # updating generator... 81 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is 82 | if self.gan_type == 'GAN': 83 | label_real = torch.ones_like(d_fake_bp) 84 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) 85 | elif self.gan_type.find('WGAN') >= 0: 86 | loss_g = -d_fake_bp.mean() 87 | elif self.gan_type == 'RGAN': 88 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) 89 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) 90 | loss_g = self.bce(better_fake, better_real) 91 | 92 | # Generator loss 93 | return loss_g 94 | 95 | def state_dict(self, *args, **kwargs): 96 | state_discriminator = self.dis.state_dict(*args, **kwargs) 97 | state_optimizer = self.optimizer.state_dict() 98 | 99 | return dict(**state_discriminator, **state_optimizer) 100 | 101 | def bce(self, real, fake): 102 | label_real = torch.ones_like(real) 103 | label_fake = torch.zeros_like(fake) 104 | bce_real = F.binary_cross_entropy_with_logits(real, label_real) 105 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) 106 | bce_loss = bce_real + bce_fake 107 | return bce_loss 108 | 109 | # Some references 110 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 111 | # OR 112 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 113 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | 5 | import utility 6 | 7 | import torch 8 | import torch.nn.utils as utils 9 | from tqdm import tqdm 10 | 11 | class Trainer(): 12 | def __init__(self, args, loader, my_model, my_loss, ckp): 13 | self.args = args 14 | self.scale = args.scale 15 | 16 | self.ckp = ckp 17 | self.loader_train = loader.loader_train 18 | self.loader_test = loader.loader_test 19 | self.model = my_model 20 | self.loss = my_loss 21 | self.optimizer = utility.make_optimizer(args, self.model) 22 | 23 | if self.args.load != '': 24 | self.optimizer.load(ckp.dir, epoch=len(ckp.log)) 25 | 26 | self.error_last = 1e8 27 | 28 | def train(self): 29 | self.loss.step() 30 | epoch = self.optimizer.get_last_epoch() + 1 31 | lr = self.optimizer.get_lr() 32 | 33 | self.ckp.write_log( 34 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 35 | ) 36 | self.loss.start_log() 37 | self.model.train() 38 | 39 | timer_data, timer_model = utility.timer(), utility.timer() 40 | # TEMP 41 | self.loader_train.dataset.set_scale(0) 42 | for batch, (lr, hr, _,) in enumerate(self.loader_train): 43 | lr, hr = self.prepare(lr, hr) 44 | timer_data.hold() 45 | timer_model.tic() 46 | 47 | self.optimizer.zero_grad() 48 | sr = self.model(lr, 0) 49 | loss = self.loss(sr, hr) 50 | loss.backward() 51 | if self.args.gclip > 0: 52 | utils.clip_grad_value_( 53 | self.model.parameters(), 54 | self.args.gclip 55 | ) 56 | self.optimizer.step() 57 | 58 | timer_model.hold() 59 | 60 | if (batch + 1) % self.args.print_every == 0: 61 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 62 | (batch + 1) * self.args.batch_size, 63 | len(self.loader_train.dataset), 64 | self.loss.display_loss(batch), 65 | timer_model.release(), 66 | timer_data.release())) 67 | 68 | timer_data.tic() 69 | 70 | self.loss.end_log(len(self.loader_train)) 71 | self.error_last = self.loss.log[-1, -1] 72 | self.optimizer.schedule() 73 | 74 | def test(self): 75 | torch.set_grad_enabled(False) 76 | 77 | epoch = self.optimizer.get_last_epoch() 78 | self.ckp.write_log('\nEvaluation:') 79 | self.ckp.add_log( 80 | torch.zeros(1, len(self.loader_test), len(self.scale)) 81 | ) 82 | self.model.eval() 83 | 84 | timer_test = utility.timer() 85 | if self.args.save_results: self.ckp.begin_background() 86 | for idx_data, d in enumerate(self.loader_test): 87 | for idx_scale, scale in enumerate(self.scale): 88 | d.dataset.set_scale(idx_scale) 89 | for lr, hr, filename in tqdm(d, ncols=80): 90 | lr, hr = self.prepare(lr, hr) 91 | sr = self.model(lr, idx_scale) 92 | sr = utility.quantize(sr, self.args.rgb_range) 93 | 94 | save_list = [sr] 95 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 96 | sr, hr, scale, self.args.rgb_range, dataset=d 97 | ) 98 | if self.args.save_gt: 99 | save_list.extend([lr, hr]) 100 | 101 | if self.args.save_results: 102 | self.ckp.save_results(d, filename[0], save_list, scale) 103 | 104 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 105 | best = self.ckp.log.max(0) 106 | self.ckp.write_log( 107 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 108 | d.dataset.name, 109 | scale, 110 | self.ckp.log[-1, idx_data, idx_scale], 111 | best[0][idx_data, idx_scale], 112 | best[1][idx_data, idx_scale] + 1 113 | ) 114 | ) 115 | 116 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) 117 | self.ckp.write_log('Saving...') 118 | 119 | if self.args.save_results: 120 | self.ckp.end_background() 121 | 122 | if not self.args.test_only: 123 | self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 124 | 125 | self.ckp.write_log( 126 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 127 | ) 128 | 129 | torch.set_grad_enabled(True) 130 | 131 | def prepare(self, *args): 132 | device = torch.device('cpu' if self.args.cpu else 'cuda') 133 | def _prepare(tensor): 134 | if self.args.precision == 'half': tensor = tensor.half() 135 | return tensor.to(device) 136 | 137 | return [_prepare(a) for a in args] 138 | 139 | def terminate(self): 140 | if self.args.test_only: 141 | self.test() 142 | return True 143 | else: 144 | epoch = self.optimizer.get_last_epoch() + 1 145 | return epoch >= self.args.epochs 146 | 147 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class Loss(nn.modules.loss._Loss): 15 | def __init__(self, args, ckp): 16 | super(Loss, self).__init__() 17 | print('Preparing loss function:') 18 | 19 | self.n_GPUs = args.n_GPUs 20 | self.loss = [] 21 | self.loss_module = nn.ModuleList() 22 | for loss in args.loss.split('+'): 23 | weight, loss_type = loss.split('*') 24 | if loss_type == 'MSE': 25 | loss_function = nn.MSELoss() 26 | elif loss_type == 'L1': 27 | loss_function = nn.L1Loss() 28 | elif loss_type.find('VGG') >= 0: 29 | module = import_module('loss.vgg') 30 | loss_function = getattr(module, 'VGG')( 31 | loss_type[3:], 32 | rgb_range=args.rgb_range 33 | ) 34 | elif loss_type.find('GAN') >= 0: 35 | module = import_module('loss.adversarial') 36 | loss_function = getattr(module, 'Adversarial')( 37 | args, 38 | loss_type 39 | ) 40 | 41 | self.loss.append({ 42 | 'type': loss_type, 43 | 'weight': float(weight), 44 | 'function': loss_function} 45 | ) 46 | if loss_type.find('GAN') >= 0: 47 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 48 | 49 | if len(self.loss) > 1: 50 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 51 | 52 | for l in self.loss: 53 | if l['function'] is not None: 54 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 55 | self.loss_module.append(l['function']) 56 | 57 | self.log = torch.Tensor() 58 | 59 | device = torch.device('cpu' if args.cpu else 'cuda') 60 | self.loss_module.to(device) 61 | if args.precision == 'half': self.loss_module.half() 62 | if not args.cpu and args.n_GPUs > 1: 63 | self.loss_module = nn.DataParallel( 64 | self.loss_module, range(args.n_GPUs) 65 | ) 66 | 67 | if args.load != '': self.load(ckp.dir, cpu=args.cpu) 68 | 69 | def forward(self, sr, hr): 70 | losses = [] 71 | for i, l in enumerate(self.loss): 72 | if l['function'] is not None: 73 | loss = l['function'](sr, hr) 74 | effective_loss = l['weight'] * loss 75 | losses.append(effective_loss) 76 | self.log[-1, i] += effective_loss.item() 77 | elif l['type'] == 'DIS': 78 | self.log[-1, i] += self.loss[i - 1]['function'].loss 79 | 80 | loss_sum = sum(losses) 81 | if len(self.loss) > 1: 82 | self.log[-1, -1] += loss_sum.item() 83 | 84 | return loss_sum 85 | 86 | def step(self): 87 | for l in self.get_loss_module(): 88 | if hasattr(l, 'scheduler'): 89 | l.scheduler.step() 90 | 91 | def start_log(self): 92 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 93 | 94 | def end_log(self, n_batches): 95 | self.log[-1].div_(n_batches) 96 | 97 | def display_loss(self, batch): 98 | n_samples = batch + 1 99 | log = [] 100 | for l, c in zip(self.loss, self.log[-1]): 101 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 102 | 103 | return ''.join(log) 104 | 105 | def plot_loss(self, apath, epoch): 106 | if epoch > 0: 107 | axis = np.linspace(1, epoch, epoch) 108 | for i, l in enumerate(self.loss): 109 | label = '{} Loss'.format(l['type']) 110 | fig = plt.figure() 111 | plt.title(label) 112 | plt.plot(axis, self.log[:, i].numpy(), label=label) 113 | plt.legend() 114 | plt.xlabel('Epochs') 115 | plt.ylabel('Loss') 116 | plt.grid(True) 117 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 118 | plt.close(fig) 119 | 120 | def get_loss_module(self): 121 | if self.n_GPUs == 1: 122 | return self.loss_module 123 | else: 124 | return self.loss_module.module 125 | 126 | def save(self, apath): 127 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 128 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 129 | 130 | def load(self, apath, cpu=False): 131 | if cpu: 132 | kwargs = {'map_location': lambda storage, loc: storage} 133 | else: 134 | kwargs = {} 135 | 136 | self.load_state_dict(torch.load( 137 | os.path.join(apath, 'loss.pt'), 138 | **kwargs 139 | )) 140 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 141 | for l in self.get_loss_module(): 142 | if hasattr(l, 'scheduler'): 143 | for _ in range(len(self.log)): l.scheduler.step() 144 | 145 | -------------------------------------------------------------------------------- /src/model/rcan.py: -------------------------------------------------------------------------------- 1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks 2 | ## https://arxiv.org/abs/1807.02758 3 | from model import common 4 | 5 | import torch.nn as nn 6 | 7 | def make_model(args, parent=False): 8 | return RCAN(args) 9 | 10 | ## Channel Attention (CA) Layer 11 | class CALayer(nn.Module): 12 | def __init__(self, channel, reduction=16): 13 | super(CALayer, self).__init__() 14 | # global average pooling: feature --> point 15 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 16 | # feature channel downscale and upscale --> channel weight 17 | self.conv_du = nn.Sequential( 18 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 21 | nn.Sigmoid() 22 | ) 23 | 24 | def forward(self, x): 25 | y = self.avg_pool(x) 26 | y = self.conv_du(y) 27 | return x * y 28 | 29 | ## Residual Channel Attention Block (RCAB) 30 | class RCAB(nn.Module): 31 | def __init__( 32 | self, conv, n_feat, kernel_size, reduction, 33 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 34 | 35 | super(RCAB, self).__init__() 36 | modules_body = [] 37 | for i in range(2): 38 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 39 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 40 | if i == 0: modules_body.append(act) 41 | modules_body.append(CALayer(n_feat, reduction)) 42 | self.body = nn.Sequential(*modules_body) 43 | self.res_scale = res_scale 44 | 45 | def forward(self, x): 46 | res = self.body(x) 47 | #res = self.body(x).mul(self.res_scale) 48 | res += x 49 | return res 50 | 51 | ## Residual Group (RG) 52 | class ResidualGroup(nn.Module): 53 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 54 | super(ResidualGroup, self).__init__() 55 | modules_body = [] 56 | modules_body = [ 57 | RCAB( 58 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 59 | for _ in range(n_resblocks)] 60 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 61 | self.body = nn.Sequential(*modules_body) 62 | 63 | def forward(self, x): 64 | res = self.body(x) 65 | res += x 66 | return res 67 | 68 | ## Residual Channel Attention Network (RCAN) 69 | class RCAN(nn.Module): 70 | def __init__(self, args, conv=common.default_conv): 71 | super(RCAN, self).__init__() 72 | 73 | n_resgroups = args.n_resgroups 74 | n_resblocks = args.n_resblocks 75 | n_feats = args.n_feats 76 | kernel_size = 3 77 | reduction = args.reduction 78 | scale = args.scale[0] 79 | act = nn.ReLU(True) 80 | 81 | # RGB mean for DIV2K 82 | self.sub_mean = common.MeanShift(args.rgb_range) 83 | 84 | # define head module 85 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 86 | 87 | # define body module 88 | modules_body = [ 89 | ResidualGroup( 90 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 91 | for _ in range(n_resgroups)] 92 | 93 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 94 | 95 | # define tail module 96 | modules_tail = [ 97 | common.Upsampler(conv, scale, n_feats, act=False), 98 | conv(n_feats, args.n_colors, kernel_size)] 99 | 100 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 101 | 102 | self.head = nn.Sequential(*modules_head) 103 | self.body = nn.Sequential(*modules_body) 104 | self.tail = nn.Sequential(*modules_tail) 105 | 106 | def forward(self, x): 107 | x = self.sub_mean(x) 108 | x = self.head(x) 109 | 110 | res = self.body(x) 111 | res += x 112 | 113 | x = self.tail(res) 114 | x = self.add_mean(x) 115 | 116 | return x 117 | 118 | def load_state_dict(self, state_dict, strict=False): 119 | own_state = self.state_dict() 120 | for name, param in state_dict.items(): 121 | if name in own_state: 122 | if isinstance(param, nn.Parameter): 123 | param = param.data 124 | try: 125 | own_state[name].copy_(param) 126 | except Exception: 127 | if name.find('tail') >= 0: 128 | print('Replace pre-trained upsampler to new one...') 129 | else: 130 | raise RuntimeError('While copying the parameter named {}, ' 131 | 'whose dimensions in the model are {} and ' 132 | 'whose dimensions in the checkpoint are {}.' 133 | .format(name, own_state[name].size(), param.size())) 134 | elif strict: 135 | if name.find('tail') == -1: 136 | raise KeyError('unexpected key "{}" in state_dict' 137 | .format(name)) 138 | 139 | if strict: 140 | missing = set(own_state.keys()) - set(state_dict.keys()) 141 | if len(missing) > 0: 142 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 143 | -------------------------------------------------------------------------------- /src/dataloader.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import random 3 | 4 | import torch 5 | import torch.multiprocessing as multiprocessing 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SequentialSampler 8 | from torch.utils.data import RandomSampler 9 | from torch.utils.data import BatchSampler 10 | from torch.utils.data import _utils 11 | from torch.utils.data.dataloader import _DataLoaderIter 12 | 13 | from torch.utils.data._utils import collate 14 | from torch.utils.data._utils import signal_handling 15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL 16 | from torch.utils.data._utils import ExceptionWrapper 17 | from torch.utils.data._utils import IS_WINDOWS 18 | from torch.utils.data._utils.worker import ManagerWatchdog 19 | 20 | from torch._six import queue 21 | 22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 23 | try: 24 | collate._use_shared_memory = True 25 | signal_handling._set_worker_signal_handlers() 26 | 27 | torch.set_num_threads(1) 28 | random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | data_queue.cancel_join_thread() 32 | 33 | if init_fn is not None: 34 | init_fn(worker_id) 35 | 36 | watchdog = ManagerWatchdog() 37 | 38 | while watchdog.is_alive(): 39 | try: 40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 41 | except queue.Empty: 42 | continue 43 | 44 | if r is None: 45 | assert done_event.is_set() 46 | return 47 | elif done_event.is_set(): 48 | continue 49 | 50 | idx, batch_indices = r 51 | try: 52 | idx_scale = 0 53 | if len(scale) > 1 and dataset.train: 54 | idx_scale = random.randrange(0, len(scale)) 55 | dataset.set_scale(idx_scale) 56 | 57 | samples = collate_fn([dataset[i] for i in batch_indices]) 58 | samples.append(idx_scale) 59 | except Exception: 60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 61 | else: 62 | data_queue.put((idx, samples)) 63 | del samples 64 | 65 | except KeyboardInterrupt: 66 | pass 67 | 68 | class _MSDataLoaderIter(_DataLoaderIter): 69 | 70 | def __init__(self, loader): 71 | self.dataset = loader.dataset 72 | self.scale = loader.scale 73 | self.collate_fn = loader.collate_fn 74 | self.batch_sampler = loader.batch_sampler 75 | self.num_workers = loader.num_workers 76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 77 | self.timeout = loader.timeout 78 | 79 | self.sample_iter = iter(self.batch_sampler) 80 | 81 | base_seed = torch.LongTensor(1).random_().item() 82 | 83 | if self.num_workers > 0: 84 | self.worker_init_fn = loader.worker_init_fn 85 | self.worker_queue_idx = 0 86 | self.worker_result_queue = multiprocessing.Queue() 87 | self.batches_outstanding = 0 88 | self.worker_pids_set = False 89 | self.shutdown = False 90 | self.send_idx = 0 91 | self.rcvd_idx = 0 92 | self.reorder_dict = {} 93 | self.done_event = multiprocessing.Event() 94 | 95 | base_seed = torch.LongTensor(1).random_()[0] 96 | 97 | self.index_queues = [] 98 | self.workers = [] 99 | for i in range(self.num_workers): 100 | index_queue = multiprocessing.Queue() 101 | index_queue.cancel_join_thread() 102 | w = multiprocessing.Process( 103 | target=_ms_loop, 104 | args=( 105 | self.dataset, 106 | index_queue, 107 | self.worker_result_queue, 108 | self.done_event, 109 | self.collate_fn, 110 | self.scale, 111 | base_seed + i, 112 | self.worker_init_fn, 113 | i 114 | ) 115 | ) 116 | w.daemon = True 117 | w.start() 118 | self.index_queues.append(index_queue) 119 | self.workers.append(w) 120 | 121 | if self.pin_memory: 122 | self.data_queue = queue.Queue() 123 | pin_memory_thread = threading.Thread( 124 | target=_utils.pin_memory._pin_memory_loop, 125 | args=( 126 | self.worker_result_queue, 127 | self.data_queue, 128 | torch.cuda.current_device(), 129 | self.done_event 130 | ) 131 | ) 132 | pin_memory_thread.daemon = True 133 | pin_memory_thread.start() 134 | self.pin_memory_thread = pin_memory_thread 135 | else: 136 | self.data_queue = self.worker_result_queue 137 | 138 | _utils.signal_handling._set_worker_pids( 139 | id(self), tuple(w.pid for w in self.workers) 140 | ) 141 | _utils.signal_handling._set_SIGCHLD_handler() 142 | self.worker_pids_set = True 143 | 144 | for _ in range(2 * self.num_workers): 145 | self._put_indices() 146 | 147 | 148 | class MSDataLoader(DataLoader): 149 | 150 | def __init__(self, cfg, *args, **kwargs): 151 | super(MSDataLoader, self).__init__( 152 | *args, **kwargs, num_workers=cfg.n_threads 153 | ) 154 | self.scale = cfg.scale 155 | 156 | def __iter__(self): 157 | return _MSDataLoaderIter(self) 158 | 159 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | import utility 5 | import data 6 | import model 7 | import loss 8 | from option import args 9 | from utils import get_n_flops_, get_n_params_ 10 | from torchsummaryX import summary 11 | 12 | torch.manual_seed(args.seed) 13 | checkpoint = utility.checkpoint(args) 14 | 15 | # @mst: select different trainers corresponding to different methods 16 | if args.method in ['']: 17 | from trainer import Trainer 18 | elif args.method in ['KD']: 19 | from trainer_kd import TrainerKD as Trainer 20 | elif args.method in ['L1', 'GReg-1', 'ASSL']: 21 | from trainer import Trainer 22 | from pruner import pruner_dict 23 | 24 | # @mst: KD 25 | def set_up_teacher(args, checkpoint, T_model, T_weights, T_n_resblocks, T_n_feats): 26 | # update args 27 | args = copy.deepcopy(args) # avoid modifying the original args 28 | args.model = T_model 29 | args.n_resblocks = T_n_resblocks 30 | args.n_feats = T_n_feats 31 | 32 | # set up model 33 | global model 34 | model = model.Model(args, ckp=None) 35 | 36 | # load pretraiend weights 37 | ckpt = torch.load(T_weights) 38 | model.model.load_state_dict(ckpt) 39 | checkpoint.write_log('==> Set up teacher successfully, pretrained weights: "%s"' % T_weights) 40 | return model 41 | 42 | def main(): 43 | global model, checkpoint 44 | if args.data_test == ['video']: 45 | from videotester import VideoTester 46 | model = model.Model(args, checkpoint) 47 | t = VideoTester(args, model, checkpoint) 48 | t.test() 49 | else: 50 | if checkpoint.ok: 51 | loader = data.Data(args) 52 | 53 | # @mst: different methods require different model settings 54 | if args.method in ['']: # original setting 55 | _model = model.Model(args, checkpoint) 56 | elif args.method in ['KD']: 57 | _model_S = model.Model(args, checkpoint) 58 | _model_T = set_up_teacher(args, checkpoint, args.T_model, args.T_weights, args.T_n_resblocks, args.T_n_feats) 59 | _model = [_model_T, _model_S] 60 | elif args.method in ['L1', 'GReg-1', 'ASSL']: 61 | _model = model.Model(args, checkpoint) 62 | class passer: pass 63 | passer.ckp = checkpoint 64 | passer.loss = loss.Loss(args, checkpoint) if not args.test_only else None 65 | passer.loader = loader 66 | pruner = pruner_dict[args.method].Pruner(_model, args, logger=None, passer=passer) 67 | 68 | # get the statistics of unpruned model 69 | # height_lr = 1280 // args.scale[0] 70 | # width_lr = 720 // args.scale[0] 71 | # n_params_original_v2 = get_n_params_(_model) 72 | # n_flops_original_v2 = get_n_flops_(_model, img_size=(height_lr, width_lr), n_channel=3, count_adds=False, idx_scale=args.scale[0]) 73 | 74 | _model = pruner.prune() # get the pruned model as initialization for later finetuning 75 | 76 | # get the statistics of pruned model and print 77 | height_lr = 1280 // args.scale[0] 78 | width_lr = 720 // args.scale[0] 79 | dummy_input = torch.zeros((1, 3, height_lr, width_lr)).cuda() 80 | 81 | # temporarily change the print fn 82 | import builtins, functools 83 | flops_f = open(checkpoint.get_path('model_complexity.txt'), 'w+') 84 | original_print = builtins.print 85 | builtins.print = functools.partial(print, file=flops_f, flush=True) 86 | summary(_model, dummy_input, {'idx_scale': args.scale[0]}) 87 | builtins.print = original_print 88 | 89 | # @mst: old code for printing FLOPs etc. Deprecated now. Will be removed. 90 | # n_params_now_v2 = get_n_params_(_model) 91 | # n_flops_now_v2 = get_n_flops_(_model, img_size=(height_lr, width_lr), n_channel=3, count_adds=False, idx_scale=args.scale[0]) 92 | # checkpoint.write_log_prune("==> n_params_original_v2: {:>7.4f}M, n_flops_original_v2: {:>7.4f}G".format(n_params_original_v2/1e6, n_flops_original_v2/1e9)) 93 | # checkpoint.write_log_prune("==> n_params_now_v2: {:>7.4f}M, n_flops_now_v2: {:>7.4f}G".format(n_params_now_v2/1e6, n_flops_now_v2/1e9)) 94 | # ratio_param = (n_params_original_v2 - n_params_now_v2) / n_params_original_v2 95 | # ratio_flops = (n_flops_original_v2 - n_flops_now_v2) / n_flops_original_v2 96 | # compression_ratio = 1.0 / (1 - ratio_param) 97 | # speedup_ratio = 1.0 / (1 - ratio_flops) 98 | # checkpoint.write_log_prune("==> reduction ratio -- params: {:>5.2f}% (compression {:>.2f}x), flops: {:>5.2f}% (speedup {:>.2f}x)".format(ratio_param*100, compression_ratio, ratio_flops*100, speedup_ratio)) 99 | 100 | # reset checkpoint and loss 101 | args.save = args.save + "_FT" 102 | checkpoint = utility.checkpoint(args) 103 | 104 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None 105 | t = Trainer(args, loader, _model, _loss, checkpoint) 106 | 107 | if args.greg_mode in ['all']: 108 | checkpoint.save(t, epoch=0, is_best=False) 109 | print(f'==> Regularizing all wn_scale parameters done. Checkpoint saved. Exit!') 110 | import shutil 111 | shutil.rmtree(checkpoint.dir) # this folder is not really used, so removed here 112 | exit(0) 113 | 114 | while not t.terminate(): 115 | t.train() 116 | t.test() 117 | 118 | checkpoint.done() 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Conv2D_WN(nn.Conv2d): 8 | '''Conv2D with weight normalization. 9 | ''' 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | kernel_size, 14 | stride=1, 15 | padding=0, 16 | dilation=1, 17 | groups=1, 18 | bias=True, 19 | padding_mode='zeros', # TODO: refine this type 20 | device=None, 21 | dtype=None 22 | ): 23 | super(Conv2D_WN, self).__init__(in_channels, out_channels, kernel_size, 24 | stride=stride, padding=padding, dilation=dilation, groups=groups, 25 | bias=bias, padding_mode=padding_mode) 26 | 27 | # set up the scale variable in weight normalization 28 | self.wn_scale = nn.Parameter(torch.ones(out_channels), requires_grad=True) 29 | self.init_wn() 30 | 31 | def init_wn(self): 32 | """initialize the wn parameters""" 33 | for i in range(self.weight.size(0)): 34 | self.wn_scale.data[i] = torch.norm(self.weight.data[i]) 35 | 36 | def forward(self, input): 37 | w = F.normalize(self.weight, dim=(1,2,3)) 38 | w = w * self.wn_scale.view(-1,1,1,1) 39 | return F.conv2d(input, w, self.bias, self.stride, 40 | self.padding, self.dilation, self.groups) 41 | 42 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 43 | return nn.Conv2d( 44 | in_channels, out_channels, kernel_size, 45 | padding=(kernel_size//2), bias=bias) 46 | 47 | def wn_conv(in_channels, out_channels, kernel_size, bias=True): 48 | return Conv2D_WN( 49 | in_channels, out_channels, kernel_size, 50 | padding=(kernel_size//2), bias=bias) 51 | 52 | class MeanShift(nn.Conv2d): 53 | def __init__( 54 | self, rgb_range, 55 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 56 | 57 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 58 | std = torch.Tensor(rgb_std) 59 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 60 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 61 | for p in self.parameters(): 62 | p.requires_grad = False 63 | 64 | class BasicBlock(nn.Sequential): 65 | def __init__( 66 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 67 | bn=True, act=nn.ReLU(True)): 68 | 69 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 70 | if bn: 71 | m.append(nn.BatchNorm2d(out_channels)) 72 | if act is not None: 73 | m.append(act) 74 | 75 | super(BasicBlock, self).__init__(*m) 76 | 77 | class ResBlock(nn.Module): 78 | def __init__( 79 | self, conv, n_feats, kernel_size, 80 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 81 | 82 | super(ResBlock, self).__init__() 83 | m = [] 84 | for i in range(2): 85 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 86 | if bn: 87 | m.append(nn.BatchNorm2d(n_feats)) 88 | if i == 0: 89 | m.append(act) 90 | 91 | self.body = nn.Sequential(*m) 92 | self.res_scale = res_scale 93 | 94 | def forward(self, x): 95 | res = self.body(x).mul(self.res_scale) 96 | res += x 97 | 98 | return res 99 | 100 | class Upsampler(nn.Sequential): 101 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 102 | 103 | m = [] 104 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 105 | for _ in range(int(math.log(scale, 2))): 106 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 107 | m.append(nn.PixelShuffle(2)) 108 | if bn: 109 | m.append(nn.BatchNorm2d(n_feats)) 110 | if act == 'relu': 111 | m.append(nn.ReLU(True)) 112 | elif act == 'prelu': 113 | m.append(nn.PReLU(n_feats)) 114 | 115 | elif scale == 3: 116 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 117 | m.append(nn.PixelShuffle(3)) 118 | if bn: 119 | m.append(nn.BatchNorm2d(n_feats)) 120 | if act == 'relu': 121 | m.append(nn.ReLU(True)) 122 | elif act == 'prelu': 123 | m.append(nn.PReLU(n_feats)) 124 | else: 125 | raise NotImplementedError 126 | 127 | super(Upsampler, self).__init__(*m) 128 | 129 | 130 | 131 | 132 | class LiteUpsampler(nn.Sequential): 133 | def __init__(self, conv, scale, n_feats, n_out=3, bn=False, act=False, bias=True): 134 | 135 | m = [] 136 | m.append(conv(n_feats, n_out*(scale ** 2), 3, bias)) 137 | m.append(nn.PixelShuffle(scale)) 138 | # if (scale & (scale - 1)) == 0: # Is scale = 2^n? 139 | # for _ in range(int(math.log(scale, 2))): 140 | # m.append(conv(n_feats, 4 * n_out, 3, bias)) 141 | # m.append(nn.PixelShuffle(2)) 142 | # if bn: 143 | # m.append(nn.BatchNorm2d(n_out)) 144 | # if act == 'relu': 145 | # m.append(nn.ReLU(True)) 146 | # elif act == 'prelu': 147 | # m.append(nn.PReLU(n_out)) 148 | 149 | # elif scale == 3: 150 | # m.append(conv(n_feats, 9 * n_out, 3, bias)) 151 | # m.append(nn.PixelShuffle(3)) 152 | # if bn: 153 | # m.append(nn.BatchNorm2d(n_out)) 154 | # if act == 'relu': 155 | # m.append(nn.ReLU(True)) 156 | # elif act == 'prelu': 157 | # m.append(nn.PReLU(n_out)) 158 | # else: 159 | # raise NotImplementedError 160 | 161 | super(LiteUpsampler, self).__init__(*m) 162 | 163 | -------------------------------------------------------------------------------- /src/pruner/meta_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import time 5 | import numpy as np 6 | from math import ceil, sqrt 7 | from collections import OrderedDict 8 | from utils import strdict_to_dict 9 | from fnmatch import fnmatch, fnmatchcase 10 | from layer import LayerStruct 11 | from .utils import get_pr_model, get_constrained_layers, pick_pruned_model, get_kept_filter_channel, replace_module, get_next_bn 12 | from .utils import get_masks 13 | 14 | class MetaPruner: 15 | def __init__(self, model, args, logger, passer): 16 | self.model = model 17 | self.args = args 18 | self.logger = logger 19 | self.logprint = logger.log_printer.logprint if logger else print 20 | self.netprint = logger.log_printer.netprint if logger else print 21 | 22 | # set up layers 23 | self.LEARNABLES = (nn.Conv2d, nn.Linear) # the layers we focus on for pruning 24 | layer_struct = LayerStruct(model, self.LEARNABLES) 25 | self.layers = layer_struct.layers 26 | self._max_len_ix = layer_struct._max_len_ix 27 | self._max_len_name = layer_struct._max_len_name 28 | self.layer_print_prefix = layer_struct.print_prefix 29 | 30 | # set up pr for each layer 31 | self.raw_pr = get_pr_model(self.layers, args.stage_pr, skip=args.skip_layers, compare_mode=args.compare_mode) 32 | self.pr = copy.deepcopy(self.raw_pr) 33 | 34 | # pick pruned and kept weight groups 35 | self.constrained_layers = get_constrained_layers(self.layers, self.args.same_pruned_wg_layers) 36 | print(f'Constrained layers: {self.constrained_layers}') 37 | 38 | def _get_kept_wg_L1(self, align_constrained=False): 39 | # ************************* core pruning function ************************** 40 | self.pr, self.pruned_wg, self.kept_wg = pick_pruned_model(self.model, self.layers, self.raw_pr, 41 | wg=self.args.wg, 42 | criterion=self.args.prune_criterion, 43 | compare_mode=self.args.compare_mode, 44 | sort_mode=self.args.pick_pruned, 45 | constrained=self.constrained_layers, 46 | align_constrained=align_constrained) 47 | # *************************************************************************** 48 | 49 | # print 50 | print(f'*********** Get pruned wg ***********') 51 | for name, layer in self.layers.items(): 52 | logtmp = f'{self.layer_print_prefix[name]} -- Got pruned wg by L1 sorting ({self.args.pick_pruned}), pr {self.pr[name]}' 53 | ext = f' -- This is a constrained layer. Its pruned/kept indices have been adjusted.' if name in self.constrained_layers else '' 54 | self.netprint(logtmp + ext) 55 | print(f'*************************************') 56 | 57 | def _prune_and_build_new_model(self): 58 | if self.args.wg == 'weight': 59 | self.masks = get_masks(self.layers, self.pruned_wg) 60 | return 61 | 62 | new_model = copy.deepcopy(self.model) 63 | for name, m in self.model.named_modules(): 64 | if isinstance(m, self.LEARNABLES): 65 | kept_filter, kept_chl = get_kept_filter_channel(self.layers, name, pr=self.pr, kept_wg=self.kept_wg, wg=self.args.wg) 66 | 67 | # decide if renit the current layer 68 | reinit = False 69 | for rl in self.args.reinit_layers: 70 | if fnmatch(name, rl): 71 | reinit = True 72 | break 73 | 74 | # get number of channels (can be manually assigned) 75 | num_chl = self.args.layer_chl[name] if name in self.args.layer_chl else len(kept_chl) 76 | 77 | # copy weight and bias 78 | bias = False if isinstance(m.bias, type(None)) else True 79 | if isinstance(m, nn.Conv2d): 80 | new_layer = nn.Conv2d(num_chl, len(kept_filter), m.kernel_size, 81 | m.stride, m.padding, m.dilation, m.groups, bias).cuda() 82 | if not reinit: 83 | kept_weights = m.weight.data[kept_filter][:, kept_chl, :, :] 84 | 85 | elif isinstance(m, nn.Linear): 86 | kept_weights = m.weight.data[kept_filter][:, kept_chl] 87 | new_layer = nn.Linear(in_features=len(kept_chl), out_features=len(kept_filter), bias=bias).cuda() 88 | 89 | if not reinit: 90 | new_layer.weight.data.copy_(kept_weights) # load weights into the new module 91 | if bias: 92 | kept_bias = m.bias.data[kept_filter] 93 | new_layer.bias.data.copy_(kept_bias) 94 | 95 | # load the new conv 96 | replace_module(new_model, name, new_layer) 97 | 98 | # get the corresponding bn (if any) for later use 99 | next_bn = get_next_bn(self.model, m) 100 | 101 | elif isinstance(m, nn.BatchNorm2d) and m == next_bn: 102 | new_bn = nn.BatchNorm2d(len(kept_filter), eps=m.eps, momentum=m.momentum, 103 | affine=m.affine, track_running_stats=m.track_running_stats).cuda() 104 | 105 | # copy bn weight and bias 106 | if self.args.copy_bn_w: 107 | weight = m.weight.data[kept_filter] 108 | new_bn.weight.data.copy_(weight) 109 | if self.args.copy_bn_b: 110 | bias = m.bias.data[kept_filter] 111 | new_bn.bias.data.copy_(bias) 112 | 113 | # copy bn running stats 114 | new_bn.running_mean.data.copy_(m.running_mean[kept_filter]) 115 | new_bn.running_var.data.copy_(m.running_var[kept_filter]) 116 | new_bn.num_batches_tracked.data.copy_(m.num_batches_tracked) 117 | 118 | # load the new bn 119 | replace_module(new_model, name, new_bn) 120 | self.model = new_model 121 | 122 | # print the layer shape of pruned model 123 | LayerStruct(new_model, self.LEARNABLES) 124 | return new_model -------------------------------------------------------------------------------- /src/data/srdata_no_bin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('img') >= 0 or benchmark: 32 | self.images_hr, self.images_lr = list_hr, list_lr 33 | elif args.ext.find('sep') >= 0: 34 | os.makedirs( 35 | self.dir_hr.replace(self.apath, path_bin), 36 | exist_ok=True 37 | ) 38 | for s in self.scale: 39 | os.makedirs( 40 | os.path.join( 41 | self.dir_lr.replace(self.apath, path_bin), 42 | 'X{}'.format(s) 43 | ), 44 | exist_ok=True 45 | ) 46 | 47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 48 | for h in list_hr: 49 | b = h.replace(self.apath, path_bin) 50 | b = b.replace(self.ext[0], '.pt') 51 | self.images_hr.append(b) 52 | self._check_and_load(args.ext, h, b, verbose=True) 53 | for i, ll in enumerate(list_lr): 54 | for l in ll: 55 | b = l.replace(self.apath, path_bin) 56 | b = b.replace(self.ext[1], '.pt') 57 | self.images_lr[i].append(b) 58 | self._check_and_load(args.ext, l, b, verbose=True) 59 | if train: 60 | n_patches = args.batch_size * args.test_every 61 | n_images = len(args.data_train) * len(self.images_hr) 62 | if n_images == 0: 63 | self.repeat = 0 64 | else: 65 | self.repeat = max(n_patches // n_images, 1) 66 | 67 | # # Below functions as used to prepare images 68 | # def _scan(self): 69 | # names_hr = sorted( 70 | # glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 71 | # ) 72 | # names_lr = [[] for _ in self.scale] 73 | # for f in names_hr: 74 | # filename, _ = os.path.splitext(os.path.basename(f)) 75 | # for si, s in enumerate(self.scale): 76 | # names_lr[si].append(os.path.join( 77 | # self.dir_lr, 'X{}/{}x{}{}'.format( 78 | # s, filename, s, self.ext[1] 79 | # ) 80 | # )) 81 | 82 | # return names_hr, names_lr 83 | 84 | def _scan(self): 85 | list_hr = [] 86 | list_lr = [[] for _ in self.scale] 87 | 88 | for i in range(self.begin, self.end + 1): 89 | filename = '{:0>4}'.format(i) 90 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext[0])) 91 | for si, s in enumerate(self.scale): 92 | list_lr[si].append(os.path.join( 93 | self.dir_lr, 94 | 'X{}/{}x{}{}'.format(s, filename, s, self.ext[1]) 95 | )) 96 | 97 | return list_hr, list_lr 98 | 99 | def _set_filesystem(self, dir_data): 100 | self.apath = os.path.join(dir_data, self.name) 101 | self.dir_hr = os.path.join(self.apath, 'HR') 102 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 103 | if self.input_large: self.dir_lr += 'L' 104 | self.ext = ('.png', '.png') 105 | 106 | def _check_and_load(self, ext, img, f, verbose=True): 107 | if not os.path.isfile(f) or ext.find('reset') >= 0: 108 | if verbose: 109 | print('Making a binary: {}'.format(f)) 110 | with open(f, 'wb') as _f: 111 | pickle.dump(imageio.imread(img), _f) 112 | 113 | def __getitem__(self, idx): 114 | lr, hr, filename = self._load_file(idx) 115 | pair = self.get_patch(lr, hr) 116 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 117 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 118 | 119 | return pair_t[0], pair_t[1], filename 120 | 121 | def __len__(self): 122 | if self.train: 123 | return len(self.images_hr) * self.repeat 124 | else: 125 | return len(self.images_hr) 126 | 127 | def _get_index(self, idx): 128 | if self.train: 129 | return idx % len(self.images_hr) 130 | else: 131 | return idx 132 | 133 | def _load_file(self, idx): 134 | idx = self._get_index(idx) 135 | f_hr = self.images_hr[idx] 136 | f_lr = self.images_lr[self.idx_scale][idx] 137 | 138 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 139 | if self.args.ext == 'img' or self.benchmark: 140 | hr = imageio.imread(f_hr) 141 | lr = imageio.imread(f_lr) 142 | elif self.args.ext.find('sep') >= 0: 143 | with open(f_hr, 'rb') as _f: 144 | hr = pickle.load(_f) 145 | with open(f_lr, 'rb') as _f: 146 | lr = pickle.load(_f) 147 | 148 | return lr, hr, filename 149 | 150 | def get_patch(self, lr, hr): 151 | scale = self.scale[self.idx_scale] 152 | if self.train: 153 | lr, hr = common.get_patch( 154 | lr, hr, 155 | patch_size=self.args.patch_size, 156 | scale=scale, 157 | multi=(len(self.scale) > 1), 158 | input_large=self.input_large 159 | ) 160 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 161 | else: 162 | ih, iw = lr.shape[:2] 163 | hr = hr[0:ih * scale, 0:iw * scale] 164 | 165 | return lr, hr 166 | 167 | def set_scale(self, idx_scale): 168 | if not self.input_large: 169 | self.idx_scale = idx_scale 170 | else: 171 | self.idx_scale = random.randint(0, len(self.scale) - 1) 172 | 173 | -------------------------------------------------------------------------------- /src/model/rirsr.py: -------------------------------------------------------------------------------- 1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks 2 | ## https://arxiv.org/abs/1807.02758 3 | from model import common 4 | 5 | import torch.nn as nn 6 | # residual in residual (RIR) from RCAN. 7 | # we remove channel attention for easy implementation, when we conduct online pruning. 8 | def make_model(args, parent=False): 9 | return RIRSR(args) 10 | 11 | ## Channel Attention (CA) Layer 12 | class CALayer(nn.Module): 13 | def __init__(self, channel, reduction=16): 14 | super(CALayer, self).__init__() 15 | # global average pooling: feature --> point 16 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 17 | # feature channel downscale and upscale --> channel weight 18 | self.conv_du = nn.Sequential( 19 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 22 | nn.Sigmoid() 23 | ) 24 | 25 | def forward(self, x): 26 | y = self.avg_pool(x) 27 | y = self.conv_du(y) 28 | return x * y 29 | 30 | ## Residual Channel Attention Block (RCAB) 31 | class RCAB(nn.Module): 32 | def __init__( 33 | self, conv, n_feat, kernel_size, reduction, 34 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 35 | 36 | super(RCAB, self).__init__() 37 | modules_body = [] 38 | for i in range(2): 39 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 40 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 41 | if i == 0: modules_body.append(act) 42 | modules_body.append(CALayer(n_feat, reduction)) 43 | self.body = nn.Sequential(*modules_body) 44 | self.res_scale = res_scale 45 | 46 | def forward(self, x): 47 | res = self.body(x) 48 | #res = self.body(x).mul(self.res_scale) 49 | res += x 50 | return res 51 | 52 | ## Residual Group (RG) 53 | # class ResidualGroup(nn.Module): 54 | # def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 55 | # super(ResidualGroup, self).__init__() 56 | # modules_body = [] 57 | # modules_body = [ 58 | # RCAB( 59 | # conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 60 | # for _ in range(n_resblocks)] 61 | # modules_body.append(conv(n_feat, n_feat, kernel_size)) 62 | # self.body = nn.Sequential(*modules_body) 63 | 64 | # def forward(self, x): 65 | # res = self.body(x) 66 | # res += x 67 | # return res 68 | 69 | class ResidualGroup(nn.Module): 70 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 71 | super(ResidualGroup, self).__init__() 72 | modules_body = [] 73 | # for RCAN 74 | # modules_body = [ 75 | # RCAB( 76 | # conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 77 | # for _ in range(n_resblocks)] 78 | modules_body = [ 79 | common.ResBlock( 80 | conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 81 | for _ in range(n_resblocks)] 82 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 83 | self.body = nn.Sequential(*modules_body) 84 | 85 | def forward(self, x): 86 | res = self.body(x) 87 | res += x 88 | return res 89 | 90 | ## Residual Channel Attention Network (RCAN) 91 | ## Residual in Residual Super-Resolution (RIRSR) 92 | class RIRSR(nn.Module): 93 | def __init__(self, args, conv=common.default_conv): 94 | super(RIRSR, self).__init__() 95 | 96 | n_resgroups = args.n_resgroups 97 | n_resblocks = args.n_resblocks 98 | n_feats = args.n_feats 99 | kernel_size = 3 100 | reduction = args.reduction 101 | scale = args.scale[0] 102 | act = nn.ReLU(True) 103 | 104 | # RGB mean for DIV2K 105 | self.sub_mean = common.MeanShift(args.rgb_range) 106 | 107 | # define head module 108 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 109 | 110 | # define body module 111 | modules_body = [ 112 | ResidualGroup( 113 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 114 | for _ in range(n_resgroups)] 115 | 116 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 117 | 118 | # define tail module 119 | modules_tail = [ 120 | common.Upsampler(conv, scale, n_feats, act=False), 121 | conv(n_feats, args.n_colors, kernel_size)] 122 | 123 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 124 | 125 | self.head = nn.Sequential(*modules_head) 126 | self.body = nn.Sequential(*modules_body) 127 | self.tail = nn.Sequential(*modules_tail) 128 | 129 | def forward(self, x): 130 | x = self.sub_mean(x) 131 | x = self.head(x) 132 | 133 | res = self.body(x) 134 | res += x 135 | 136 | x = self.tail(res) 137 | x = self.add_mean(x) 138 | 139 | return x 140 | 141 | def load_state_dict(self, state_dict, strict=False): 142 | own_state = self.state_dict() 143 | for name, param in state_dict.items(): 144 | if name in own_state: 145 | if isinstance(param, nn.Parameter): 146 | param = param.data 147 | try: 148 | own_state[name].copy_(param) 149 | except Exception: 150 | if name.find('tail') >= 0: 151 | print('Replace pre-trained upsampler to new one...') 152 | else: 153 | raise RuntimeError('While copying the parameter named {}, ' 154 | 'whose dimensions in the model are {} and ' 155 | 'whose dimensions in the checkpoint are {}.' 156 | .format(name, own_state[name].size(), param.size())) 157 | elif strict: 158 | if name.find('tail') == -1: 159 | raise KeyError('unexpected key "{}" in state_dict' 160 | .format(name)) 161 | 162 | if strict: 163 | missing = set(own_state.keys()) - set(state_dict.keys()) 164 | if len(missing) > 0: 165 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 166 | 167 | -------------------------------------------------------------------------------- /src/pruner/l1_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | import time 5 | import numpy as np 6 | import utility 7 | from tqdm import tqdm 8 | from utils import _weights_init, _weights_init_orthogonal, orthogonalize_weights 9 | from .meta_pruner import MetaPruner 10 | 11 | 12 | # refer to: A Signal Propagation Perspective for Pruning Neural Networks at Initialization (ICLR 2020). 13 | # https://github.com/namhoonlee/spp-public 14 | def approximate_isometry_optimize(model, mask, lr, n_iter, wg='weight', print=print): 15 | def optimize(w): 16 | '''Approximate Isometry for sparse weights by iterative optimization 17 | ''' 18 | flattened = w.view(w.size(0), -1) # [n_filter, -1] 19 | identity = torch.eye(w.size(0)).cuda() # identity matrix 20 | w_ = torch.autograd.Variable(flattened, requires_grad=True) 21 | optim = torch.optim.Adam([w_], lr) 22 | for i in range(n_iter): 23 | loss = nn.MSELoss()(torch.matmul(w_, w_.t()), identity) 24 | optim.zero_grad() 25 | loss.backward() 26 | optim.step() 27 | if not isinstance(mask, type(None)): 28 | w_ = torch.mul(w_, mask[name]) # not update the pruned params 29 | w_ = torch.autograd.Variable(w_, requires_grad=True) 30 | optim = torch.optim.Adam([w_], lr) 31 | if i % 10 == 0: 32 | print('[%d/%d] approximate_isometry_optimize for layer "%s", loss %.6f' % (i, n_iter, name, loss.item())) 33 | return w_.view(m.weight.shape) 34 | 35 | for name, m in model.named_modules(): 36 | if isinstance(m, (nn.Conv2d, nn.Linear)): 37 | w_ = optimize(m.weight) 38 | m.weight.data.copy_(w_) 39 | print('Finished approximate_isometry_optimize for layer "%s"' % name) 40 | 41 | def exact_isometry_based_on_existing_weights(model, print=print): 42 | for name, m in model.named_modules(): 43 | if isinstance(m, (nn.Conv2d, nn.Linear)): 44 | w_ = orthogonalize_weights(m.weight) 45 | m.weight.data.copy_(w_) 46 | print('Finished exact_isometry for layer "%s"' % name) 47 | 48 | class Pruner(MetaPruner): 49 | def __init__(self, model, args, logger, passer): 50 | super(Pruner, self).__init__(model, args, logger, passer) 51 | ckp = passer.ckp 52 | self.logprint = ckp.write_log_prune # use another log file specifically for pruning logs 53 | self.netprint = ckp.write_log_prune 54 | 55 | # ************************** variables from RCAN ************************** 56 | loader = passer.loader 57 | self.scale = args.scale 58 | 59 | self.ckp = ckp 60 | self.loader_train = loader.loader_train 61 | self.loader_test = loader.loader_test 62 | self.model = model 63 | 64 | self.error_last = 1e8 65 | # ************************************************************************** 66 | 67 | def prune(self): 68 | self._get_kept_wg_L1() 69 | self.logprint(f"==> Before _prune_and_build_new_model. Testing...") 70 | self.test() 71 | self._prune_and_build_new_model() 72 | self.logprint(f"==> Pruned and built a new model. Testing...") 73 | self.test() 74 | mask = self.mask if self.args.wg == 'weight' else None 75 | 76 | if self.args.reinit: 77 | if self.args.reinit in ['default', 'kaiming_normal']: 78 | self.model.apply(_weights_init) # completely reinit weights via 'kaiming_normal' 79 | self.logprint("==> Reinit model: default ('kaiming_normal' for Conv/FC; 0 mean, 1 std for BN)") 80 | 81 | elif self.args.reinit in ['orth', 'exact_isometry_from_scratch']: 82 | self.model.apply(lambda m: _weights_init_orthogonal(m, act=self.args.activation)) # reinit weights via 'orthogonal_' from scratch 83 | self.logprint("==> Reinit model: exact_isometry ('orthogonal_' for Conv/FC; 0 mean, 1 std for BN)") 84 | 85 | elif self.args.reinit == 'exact_isometry_based_on_existing': 86 | exact_isometry_based_on_existing_weights(self.model, print=self.logprint) # orthogonalize weights based on existing weights 87 | self.logprint("==> Reinit model: exact_isometry (orthogonalize Conv/FC weights based on existing weights)") 88 | 89 | elif self.args.reinit == 'approximate_isometry': # A Signal Propagation Perspective for Pruning Neural Networks at Initialization (ICLR 2020) 90 | approximate_isometry_optimize(self.model, mask=mask, lr=self.args.lr_AI, n_iter=10000, print=self.logprint) # 10000 refers to the paper above; lr in the paper is 0.1, but not converged here 91 | self.logprint("==> Reinit model: approximate_isometry") 92 | 93 | else: 94 | raise NotImplementedError 95 | 96 | return copy.deepcopy(self.model) 97 | 98 | def test(self): 99 | is_train = self.model.training 100 | torch.set_grad_enabled(False) 101 | 102 | self.ckp.write_log('Evaluation:') 103 | self.ckp.add_log( 104 | torch.zeros(1, len(self.loader_test), len(self.scale)) 105 | ) 106 | self.model.eval() 107 | 108 | timer_test = utility.timer() 109 | if self.args.save_results: self.ckp.begin_background() 110 | for idx_data, d in enumerate(self.loader_test): 111 | for idx_scale, scale in enumerate(self.scale): 112 | d.dataset.set_scale(idx_scale) 113 | for lr, hr, filename in tqdm(d, ncols=80): 114 | lr, hr = self.prepare(lr, hr) 115 | sr = self.model(lr, idx_scale) 116 | sr = utility.quantize(sr, self.args.rgb_range) 117 | 118 | save_list = [sr] 119 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 120 | sr, hr, scale, self.args.rgb_range, dataset=d 121 | ) 122 | if self.args.save_gt: 123 | save_list.extend([lr, hr]) 124 | 125 | if self.args.save_results: 126 | self.ckp.save_results(d, filename[0], save_list, scale) 127 | 128 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 129 | best = self.ckp.log.max(0) 130 | logstr = '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {}) [method: {} compare_mode: {}]'.format( 131 | d.dataset.name, 132 | scale, 133 | self.ckp.log[-1, idx_data, idx_scale], 134 | best[0][idx_data, idx_scale], 135 | best[1][idx_data, idx_scale] + 1, 136 | self.args.method, 137 | self.args.compare_mode, 138 | ) 139 | self.ckp.write_log(logstr) 140 | self.logprint(logstr) 141 | 142 | self.ckp.write_log('Forward: {:.2f}s'.format(timer_test.toc())) 143 | self.ckp.write_log('Saving...') 144 | 145 | if self.args.save_results: 146 | self.ckp.end_background() 147 | 148 | # if not self.args.test_only: 149 | # self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 150 | 151 | self.ckp.write_log( 152 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 153 | ) 154 | 155 | torch.set_grad_enabled(True) 156 | 157 | if is_train: 158 | self.model.train() 159 | 160 | def prepare(self, *args): 161 | device = torch.device('cpu' if self.args.cpu else 'cuda') 162 | def _prepare(tensor): 163 | if self.args.precision == 'half': tensor = tensor.half() 164 | return tensor.to(device) 165 | 166 | return [_prepare(a) for a in args] -------------------------------------------------------------------------------- /src/demo.sh: -------------------------------------------------------------------------------- 1 | # EDSR baseline model (x2) + JPEG augmentation 2 | python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset 3 | #python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75 4 | 5 | # EDSR baseline model (x3) - from EDSR baseline model (x2) 6 | #python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] 7 | 8 | # EDSR baseline model (x4) - from EDSR baseline model (x2) 9 | #python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir] 10 | 11 | # EDSR in the paper (x2) 12 | #python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset 13 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.2.0 --n_resblocks 80 --n_feats 64 14 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.2.0_2 --n_resblocks 80 --n_feats 64 15 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0 --n_resblocks 80 --n_feats 64 16 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0_2 --n_resblocks 80 --n_feats 64 17 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0_3 --n_resblocks 80 --n_feats 64 18 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --save EDSR_R80F64_BIx2_Xp_pt1.6.0_4 --n_resblocks 80 --n_feats 64 19 | # EDSR in the paper (x3) - from EDSR (x2) 20 | #python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir] 21 | 22 | # EDSR in the paper (x4) - from EDSR (x2) 23 | #python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir] 24 | 25 | # MDSR baseline model 26 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models 27 | 28 | # MDSR in the paper 29 | #python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models 30 | 31 | # Standard benchmarks (Ex. EDSR_baseline_x4) 32 | #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble 33 | 34 | #python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble 35 | 36 | # Test your own images 37 | #python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results 38 | 39 | # Advanced - Test with JPEG images 40 | #python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results 41 | 42 | # Advanced - Training with adversarial loss 43 | #python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download 44 | 45 | # RDN BI model (x2) 46 | #python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset 47 | # RDN BI model (x3) 48 | #python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset 49 | # RDN BI model (x4) 50 | #python3.6 main.py --scale 4 --save RDN_D16C8G64_BIx4 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 128 --reset 51 | 52 | # RCAN_BIX2_G10R20P48, input=48x48, output=96x96 53 | # pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0 54 | #python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96 55 | # RCAN_BIX3_G10R20P48, input=48x48, output=144x144 56 | #python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt 57 | # RCAN_BIX4_G10R20P48, input=48x48, output=192x192 58 | #python main.py --template RCAN --save RCAN_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt 59 | # RCAN_BIX8_G10R20P48, input=48x48, output=384x384 60 | #python main.py --template RCAN --save RCAN_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt 61 | 62 | 63 | 64 | 65 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R16F64P48B16_DF2KBIX2 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 66 | 67 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R16F64P48B16_DF2KBIX3 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 68 | 69 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 4 --patch_size 192 --save EDSR_R16F64P48B16_DF2KBIX4 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 70 | 71 | # debug 72 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R16F64P48B16_DF2KBIX2_dtep_decay --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/img/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3551-3555 --test_every 100 --lr_decay 5 73 | 74 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R16F64P48B16_DF2KBDX3_dtep_decay --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BDX3 --data_train DF2K --data_test DF2K --data_range 1-3450/3551-3555 --test_every 100 --lr_decay 5 --chop 75 | 76 | 77 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R16F64P48B16_DF2KDNX3_dtep_decay --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/DNX3 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --test_every 100 --lr_decay 5 --chop --save_results --save_gt 78 | 79 | 80 | 81 | #### 82 | CUDA_VISIBLE_DEVICES=1 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R16F64P48B16_DF2K_BIX2 --n_resblocks 16 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 83 | 84 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R80F64P48B16_DIV2K_BIX2 --n_resblocks 80 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DIV2K --data_test DIV2K --data_range 1-800/801-810 --chop 85 | 86 | 87 | CUDA_VISIBLE_DEVICES=3 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R80F64P48B16_DF2K_BIX2 --n_resblocks 80 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 88 | 89 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 2 --patch_size 96 --save EDSR_R32F256P48B16_DF2K_BIX2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 90 | 91 | CUDA_VISIBLE_DEVICES=2 python main.py --model EDSR --scale 3 --patch_size 144 --save EDSR_R80F64P48B16_DF2K_BIX3 --n_resblocks 80 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 92 | 93 | # RCAN 94 | CUDA_VISIBLE_DEVICES=0 python main.py --model RCAN --scale 2 --patch_size 96 --save RCAN_G10R20P48B16_DF2K_BIX2 --n_resgroups 10 --n_resblocks 20 --n_feats 64 --dir_data /home/yulun/data/SR/RGB/pt/BIX2X3X4 --data_train DF2K --data_test DF2K --data_range 1-3450/3451-3460 --chop 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = args.scale 15 | self.idx_scale = 0 16 | self.input_large = (args.model == 'VDSR') 17 | self.self_ensemble = args.self_ensemble 18 | self.chop = args.chop 19 | self.precision = args.precision 20 | self.cpu = args.cpu 21 | self.device = torch.device('cpu' if args.cpu else 'cuda') 22 | self.n_GPUs = args.n_GPUs 23 | self.save_models = args.save_models 24 | 25 | module = import_module('model.' + args.model.lower()) 26 | self.model = module.make_model(args).to(self.device) 27 | if args.precision == 'half': 28 | self.model.half() 29 | 30 | if ckp: 31 | load_from = self.load( 32 | ckp.get_path('model'), 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | if load_from and args.wn: # @mst 38 | for _, module in self.model.named_modules(): 39 | if hasattr(module, 'wn_scale'): 40 | module.init_wn() 41 | print(f'==> Weights are reloaded, reinit wn_scale.') 42 | 43 | 44 | def forward(self, x, idx_scale): 45 | self.idx_scale = idx_scale 46 | if hasattr(self.model, 'set_scale'): 47 | self.model.set_scale(idx_scale) 48 | 49 | if self.training: 50 | if self.n_GPUs > 1: 51 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 52 | else: 53 | return self.model(x) 54 | else: 55 | if self.chop: 56 | forward_function = self.forward_chop 57 | else: 58 | forward_function = self.model.forward 59 | 60 | if self.self_ensemble: 61 | return self.forward_x8(x, forward_function=forward_function) 62 | else: 63 | return forward_function(x) 64 | 65 | def save(self, apath, epoch, is_best=False): 66 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 67 | 68 | if is_best: 69 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 70 | if self.save_models: 71 | save_dirs.append( 72 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 73 | ) 74 | 75 | for s in save_dirs: 76 | to_save = { 77 | 'state_dict': self.model.state_dict(), 78 | 'arch': self.model, 79 | } 80 | torch.save(to_save, s) 81 | 82 | def load(self, apath, pre_train='', resume=-1, cpu=False): 83 | load_from = None 84 | kwargs = {} 85 | if cpu: 86 | kwargs = {'map_location': lambda storage, loc: storage} 87 | 88 | if resume == -1: 89 | load_from = torch.load( 90 | os.path.join(apath, 'model_latest.pt'), 91 | **kwargs 92 | ) 93 | elif resume == 0: 94 | if pre_train == 'download': 95 | print('Download the model') 96 | dir_model = os.path.join('..', 'models') 97 | os.makedirs(dir_model, exist_ok=True) 98 | load_from = torch.utils.model_zoo.load_url( 99 | self.model.url, 100 | model_dir=dir_model, 101 | **kwargs 102 | ) 103 | elif pre_train: 104 | print('Load the model from {}'.format(pre_train)) 105 | load_from = torch.load(pre_train, **kwargs) 106 | else: 107 | load_from = torch.load( 108 | os.path.join(apath, 'model_{}.pt'.format(resume)), 109 | **kwargs 110 | ) 111 | 112 | if load_from: 113 | if 'state_dict' in load_from: # @mst: for pruned models, load the pruned model arch as 'self.model' 114 | self.model = load_from['arch'] 115 | self.model.load_state_dict(load_from['state_dict'], strict=False) 116 | else: 117 | self.model.load_state_dict(load_from, strict=False) 118 | return load_from 119 | 120 | # shave = 10, min_size=160000 121 | def forward_chop(self, x, shave=10, min_size=160000): 122 | scale = self.scale[self.idx_scale] 123 | n_GPUs = min(self.n_GPUs, 4) 124 | b, c, h, w = x.size() 125 | h_half, w_half = h // 2, w // 2 126 | h_size, w_size = h_half + shave, w_half + shave 127 | lr_list = [ 128 | x[:, :, 0:h_size, 0:w_size], 129 | x[:, :, 0:h_size, (w - w_size):w], 130 | x[:, :, (h - h_size):h, 0:w_size], 131 | x[:, :, (h - h_size):h, (w - w_size):w]] 132 | 133 | if w_size * h_size < min_size: 134 | sr_list = [] 135 | for i in range(0, 4, n_GPUs): 136 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 137 | sr_batch = self.model(lr_batch) 138 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 139 | else: 140 | sr_list = [ 141 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 142 | for patch in lr_list 143 | ] 144 | 145 | h, w = scale * h, scale * w 146 | h_half, w_half = scale * h_half, scale * w_half 147 | h_size, w_size = scale * h_size, scale * w_size 148 | shave *= scale 149 | 150 | output = x.new(b, c, h, w) 151 | output[:, :, 0:h_half, 0:w_half] \ 152 | = sr_list[0][:, :, 0:h_half, 0:w_half] 153 | output[:, :, 0:h_half, w_half:w] \ 154 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 155 | output[:, :, h_half:h, 0:w_half] \ 156 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 157 | output[:, :, h_half:h, w_half:w] \ 158 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 159 | 160 | return output 161 | 162 | def forward_x8(self, *args, forward_function=None): 163 | def _transform(v, op): 164 | if self.precision != 'single': v = v.float() 165 | 166 | v2np = v.data.cpu().numpy() 167 | if op == 'v': 168 | tfnp = v2np[:, :, :, ::-1].copy() 169 | elif op == 'h': 170 | tfnp = v2np[:, :, ::-1, :].copy() 171 | elif op == 't': 172 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 173 | 174 | ret = torch.Tensor(tfnp).to(self.device) 175 | if self.precision == 'half': ret = ret.half() 176 | 177 | return ret 178 | 179 | list_x = [] 180 | for a in args: 181 | x = [a] 182 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 183 | 184 | list_x.append(x) 185 | 186 | list_y = [] 187 | for x in zip(*list_x): 188 | y = forward_function(*x) 189 | if not isinstance(y, list): y = [y] 190 | if not list_y: 191 | list_y = [[_y] for _y in y] 192 | else: 193 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 194 | 195 | for _list_y in list_y: 196 | for i in range(len(_list_y)): 197 | if i > 3: 198 | _list_y[i] = _transform(_list_y[i], 't') 199 | if i % 4 > 1: 200 | _list_y[i] = _transform(_list_y[i], 'h') 201 | if (i % 4) % 2 == 1: 202 | _list_y[i] = _transform(_list_y[i], 'v') 203 | 204 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 205 | if len(y) == 1: y = y[0] 206 | 207 | return y 208 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASSL (NeurIPS'21 Spotlight) 2 | 3 |
4 | 5 |     6 | 7 |
8 | 9 | This repository is for a new network pruning method (`Aligned Structured Sparsity Learning, ASSL`) for efficient single image super-resolution (SR), introduced in our NeurIPS 2021 **Spotlight** paper: 10 | > **Aligned Structured Sparsity Learning for Efficient Image Super-Resolution [[Camera Ready](https://papers.nips.cc/paper/2021/file/15de21c670ae7c3f6f3f1f37029303c9-Paper.pdf)] [[Visual Results](https://github.com/MingSun-Tse/ASSL/releases)]** \ 11 | > [Yulun Zhang*](http://yulunzhang.com/), [Huan Wang*](http://huanwang.tech/), [Can Qin](http://canqin.tech/), and [Yun Fu](http://www1.ece.neu.edu/~yunfu/) (*equal contribution) \ 12 | > Northeastern University, Boston, MA, USA 13 | 14 | 15 | ## Introduction 16 |
17 | 18 |
19 | Lightweight image super-resolution (SR) networks have obtained promising results with moderate model size. Many SR methods have focused on designing lightweight architectures, which neglect to further reduce the redundancy of network parameters. On the other hand, model compression techniques, like neural architecture search and knowledge distillation, typically consume considerable memory and computation resources. In contrast, network pruning is a cheap and effective model compression technique. However, it is hard to be applied to SR networks directly, because filter pruning for residual blocks is well-known tricky. To address the above issues, we propose aligned structured sparsity learning (ASSL), which introduces a weight normalization layer and applies L2 regularization to the scale parameters for sparsity. To align the pruned filter locations across different layers, we propose a sparsity structure alignment penalty term, which minimizes the norm of soft mask gram matrix. We apply aligned structured sparsity learning strategy to train efficient image SR network, named as ASSLN, with smaller model size and lower computation than state-of-the-art methods. We conduct extensive comparisons with lightweight SR networks. Our ASSLN achieves superior performance gains over recent methods quantitatively and visually. 20 | 21 | ## Install 22 | ```python 23 | git clone git@github.com:mingsun-tse/ASSL.git -b master 24 | cd ASSL/src 25 | 26 | # install dependencies (PyTorch 1.2.0 used), Anaconda is strongly recommended 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | 31 | ## Train 32 | ### Prepare training data 33 | 34 | 1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) and [Flickr2K dataset](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) from SNU_CVLab. 35 | 36 | 2. Specify '--dir_data' based on the HR and LR images path. In option.py, '--ext' is set as 'sep_reset', which first convert .png to .npy. If all the training images (.png) are converted to .npy files, then set '--ext sep' to skip converting files. 37 | 38 | For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch). 39 | 40 | ### Prepare pretrained dense model 41 | Neural network pruning is typically conducted on a *pretrained* model. Our method also follows this common practice. Before we run the pruning scripts next, here we set up the pretrained dense models. Download the `pretrained_models.zip` from our [releases](https://github.com/MingSun-Tse/ASSL/releases), and unzip it as follows: 42 | ```python 43 | wget https://github.com/MingSun-Tse/ASSL/releases/download/v0.1/pretrained_models.zip 44 | unzip pretrained_models.zip 45 | mv pretrained_models .. 46 | ``` 47 | 48 | ### Run 49 | ```python 50 | # Prune from 256 to 49, pr=0.80859375, x2 51 | python main.py --model LEDSR --scale 2 --patch_size 96 --ext sep --dir_data --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 --chop --save_results --n_resblocks 16 --n_feats 256 --method ASSL --wn --stage_pr [0-1000:0.80859375] --skip_layers *mean*,*tail* --same_pruned_wg_layers model.head.0,model.body.16,*body.2 --reg_upper_limit 0.5 --reg_granularity_prune 0.0001 --update_reg_interval 20 --stabilize_reg_interval 43150 --pre_train ../pretrained_models/LEDSR_F256R16BIX2_DF2K_M311.pt --same_pruned_wg_criterion reg --save main/SR/LEDSR_F256R16BIX2_DF2K_ASSL0.80859375_RGP0.0001_RUL0.5_Pretrain 52 | 53 | # Prune from 256 to 49, pr=0.80859375, x3 54 | python main.py --model LEDSR --scale 3 --patch_size 144 --ext sep --dir_data --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 --chop --save_results --n_resblocks 16 --n_feats 256 --method ASSL --wn --stage_pr [0-1000:0.80859375] --skip_layers *mean*,*tail* --same_pruned_wg_layers model.head.0,model.body.16,*body.2 --reg_upper_limit 0.5 --reg_granularity_prune 0.0001 --update_reg_interval 20 --stabilize_reg_interval 43150 --pre_train ../pretrained_models/LEDSR_F256R16BIX3_DF2K_M230.pt --same_pruned_wg_criterion reg --save main/SR/LEDSR_F256R16BIX3_DF2K_ASSL0.80859375_RGP0.0001_RUL0.5_Pretrain 55 | 56 | # Prune from 256 to 49, pr=0.80859375, x4 57 | python main.py --model LEDSR --scale 4 --patch_size 192 --ext sep --dir_data --data_train DF2K --data_test DF2K --data_range 1-3550/3551-3555 --chop --save_results --n_resblocks 16 --n_feats 256 --method ASSL --wn --stage_pr [0-1000:0.80859375] --skip_layers *mean*,*tail* --same_pruned_wg_layers model.head.0,model.body.16,*body.2 --reg_upper_limit 0.5 --reg_granularity_prune 0.0001 --update_reg_interval 20 --stabilize_reg_interval 43150 --pre_train ../pretrained_models/LEDSR_F256R16BIX4_DF2K_M231.pt --same_pruned_wg_criterion reg --save main/SR/LEDSR_F256R16BIX4_DF2K_ASSL0.80859375_RGP0.0001_RUL0.5_Pretrain 58 | ``` 59 | where `` refers to the data directory path. One example on our PC is: `/home/yulun/data/SR/RGB/BIX2X3X4/pt_bin`. 60 | 61 | 62 | ## Test 63 | After training, to use the trained models to generate HR images, you may use the following snippet. Currectly, you can use our [final models](https://github.com/MingSun-Tse/ASSL/releases) to test first: 64 | ``` 65 | wget https://github.com/MingSun-Tse/ASSL/releases/download/v0.1/final_models.zip 66 | unzip final_models.zip 67 | mv final_models .. 68 | python main.py --data_test Demo --scale 4 --dir_demo --test_only --save_results --pre_train ../final_models/ASSLN_F49_X4.pt --save Test_ASSLN_F49_X4 69 | ``` 70 | where `` refers to the test data path on your computer. One example on our PC is: `/media/yulun/10THD1/data/super-resolution/LRBI/Set5/x4`. 71 | 72 | 73 | ## Results 74 | ### Quantitative Results 75 | PSNR/SSIM comparison on popular SR benchmark datasets is shown below (best in red, second best in blue). 76 |
77 | 78 |
79 | 80 | ### Visual Results 81 | Visual comparison (x4) among lightweight SR approaches on the Urban100 dataset is shown below. Please see our [releases](https://github.com/MingSun-Tse/ASSL/releases) for the complete visual results on Set5/Set14/B100/Urban100/Manga109. 82 |
83 | 84 |
85 | 86 | ## Citation 87 | If you find the code helpful in your resarch or work, please cite the following papers. 88 | ``` 89 | @InProceedings{Lim_2017_CVPR_Workshops, 90 | author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu}, 91 | title = {Enhanced Deep Residual Networks for Single Image Super-Resolution}, 92 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 93 | month = {July}, 94 | year = {2017} 95 | } 96 | 97 | @inproceedings{zhang2021aligned, 98 | title={Aligned Structured Sparsity Learning for Efficient Image Super-Resolution}, 99 | author={Zhang, Yulun and Wang, Huan and Qin, Can and Fu, Yun}, 100 | booktitle={NeurIPS}, 101 | year={2021} 102 | } 103 | ``` 104 | 105 | ## Acknowledgements 106 | We refer to the following implementations when we develop this code: [EDSR-PyTorch](https://github.com/thstkdgus35/EDSR-PyTorch), [RCAN](https://github.com/yulunzhang/RCAN), [Regularization-Pruning](https://github.com/MingSun-Tse/Regularization-Pruning). Great thanks to them! 107 | -------------------------------------------------------------------------------- /src/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('bin') >= 0: 32 | # Binary files are stored in 'bin' folder 33 | # If the binary file exists, load it. If not, make it. 34 | list_hr, list_lr = self._scan() 35 | self.images_hr = self._check_and_load( 36 | args.ext, list_hr, self._name_hrbin() 37 | ) 38 | self.images_lr = [ 39 | self._check_and_load(args.ext, l, self._name_lrbin(s)) \ 40 | for s, l in zip(self.scale, list_lr) 41 | ] 42 | else: 43 | if args.ext.find('img') >= 0 or benchmark: 44 | self.images_hr, self.images_lr = list_hr, list_lr 45 | elif args.ext.find('sep') >= 0: 46 | os.makedirs( 47 | self.dir_hr.replace(self.apath, path_bin), 48 | exist_ok=True 49 | ) 50 | for s in self.scale: 51 | os.makedirs( 52 | os.path.join( 53 | self.dir_lr.replace(self.apath, path_bin), 54 | 'X{}'.format(s) 55 | ), 56 | exist_ok=True 57 | ) 58 | 59 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 60 | for h in list_hr: 61 | b = h.replace(self.apath, path_bin) 62 | b = b.replace(self.ext[0], '.pt') 63 | self.images_hr.append(b) 64 | self._check_and_load( 65 | args.ext, [h], b, verbose=True, load=False 66 | ) 67 | 68 | for i, ll in enumerate(list_lr): 69 | for l in ll: 70 | b = l.replace(self.apath, path_bin) 71 | b = b.replace(self.ext[1], '.pt') 72 | self.images_lr[i].append(b) 73 | self._check_and_load( 74 | args.ext, [l], b, verbose=True, load=False 75 | ) 76 | 77 | if train: 78 | n_patches = args.batch_size * args.test_every 79 | n_images = len(args.data_train) * len(self.images_hr) 80 | if n_images == 0: 81 | self.repeat = 0 82 | else: 83 | self.repeat = max(n_patches // n_images, 1) 84 | 85 | # # Below functions as used to prepare images 86 | # def _scan(self): 87 | # names_hr = sorted( 88 | # glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 89 | # ) 90 | # names_lr = [[] for _ in self.scale] 91 | # for f in names_hr: 92 | # filename, _ = os.path.splitext(os.path.basename(f)) 93 | # for si, s in enumerate(self.scale): 94 | # names_lr[si].append(os.path.join( 95 | # self.dir_lr, 'X{}/{}x{}{}'.format( 96 | # s, filename, s, self.ext[1] 97 | # ) 98 | # )) 99 | 100 | # return names_hr, names_lr 101 | 102 | def _scan(self): 103 | list_hr = [] 104 | list_lr = [[] for _ in self.scale] 105 | 106 | for i in range(self.begin, self.end + 1): 107 | filename = '{:0>4}'.format(i) 108 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext[0])) 109 | for si, s in enumerate(self.scale): 110 | list_lr[si].append(os.path.join( 111 | self.dir_lr, 112 | 'X{}/{}x{}{}'.format(s, filename, s, self.ext[1]) 113 | )) 114 | 115 | return list_hr, list_lr 116 | 117 | def _set_filesystem(self, dir_data): 118 | self.apath = os.path.join(dir_data, self.name) 119 | self.dir_hr = os.path.join(self.apath, 'HR') 120 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 121 | if self.input_large: self.dir_lr += 'L' 122 | self.ext = ('.png', '.png') 123 | 124 | def _name_hrbin(self): 125 | return os.path.join( 126 | self.apath, 127 | 'bin', 128 | '{}_bin_HR.pt'.format(self.split) 129 | ) 130 | 131 | def _name_lrbin(self, scale): 132 | return os.path.join( 133 | self.apath, 134 | 'bin', 135 | '{}_bin_LR_X{}.pt'.format(self.split, scale) 136 | ) 137 | 138 | def _check_and_load(self, ext, l, f, verbose=True, load=True): 139 | if os.path.isfile(f) and ext.find('reset') < 0: 140 | if load: 141 | if verbose: print('Loading {}...'.format(f)) 142 | with open(f, 'rb') as _f: ret = pickle.load(_f) 143 | return ret 144 | else: 145 | return None 146 | else: 147 | if verbose: 148 | if ext.find('reset') >= 0: 149 | print('Making a new binary: {}'.format(f)) 150 | else: 151 | print('{} does not exist. Now making binary...'.format(f)) 152 | if ext.find('bin') >= 0: 153 | print('Bin pt file with name and image') 154 | b = [{ 155 | 'name': os.path.splitext(os.path.basename(_l))[0], 156 | 'image': imageio.imread(_l) 157 | } for _l in l] 158 | with open(f, 'wb') as _f: pickle.dump(b, _f) 159 | 160 | return b 161 | else: 162 | print('Direct pt file without name or image') 163 | # import pdb 164 | # pdb.set_trace() 165 | b = imageio.imread(l[0]) 166 | with open(f, 'wb') as _f: pickle.dump(b, _f) 167 | 168 | # return b 169 | 170 | def __getitem__(self, idx): 171 | lr, hr, filename = self._load_file(idx) 172 | pair = self.get_patch(lr, hr) 173 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 174 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 175 | 176 | return pair_t[0], pair_t[1], filename 177 | 178 | def __len__(self): 179 | if self.train: 180 | return len(self.images_hr) * self.repeat 181 | else: 182 | return len(self.images_hr) 183 | 184 | def _get_index(self, idx): 185 | if self.train: 186 | return idx % len(self.images_hr) 187 | else: 188 | return idx 189 | 190 | def _load_file(self, idx): 191 | idx = self._get_index(idx) 192 | f_hr = self.images_hr[idx] 193 | f_lr = self.images_lr[self.idx_scale][idx] 194 | 195 | if self.args.ext.find('bin') >= 0: 196 | filename = f_hr['name'] 197 | hr = f_hr['image'] 198 | lr = f_lr['image'] 199 | else: 200 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 201 | if self.args.ext == 'img' or self.benchmark: 202 | hr = imageio.imread(f_hr) 203 | lr = imageio.imread(f_lr) 204 | elif self.args.ext.find('sep') >= 0: 205 | # For each pt file, use 'image' to load it 206 | # with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image'] 207 | # with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image'] 208 | # For each pt file, directly load it 209 | with open(f_hr, 'rb') as _f: hr = pickle.load(_f) 210 | with open(f_lr, 'rb') as _f: lr = pickle.load(_f) 211 | 212 | return lr, hr, filename 213 | 214 | def get_patch(self, lr, hr): 215 | scale = self.scale[self.idx_scale] 216 | if self.train: 217 | lr, hr = common.get_patch( 218 | lr, hr, 219 | patch_size=self.args.patch_size, 220 | scale=scale, 221 | multi=(len(self.scale) > 1), 222 | input_large=self.input_large 223 | ) 224 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 225 | else: 226 | ih, iw = lr.shape[:2] 227 | hr = hr[0:ih * scale, 0:iw * scale] 228 | 229 | return lr, hr 230 | 231 | def set_scale(self, idx_scale): 232 | if not self.input_large: 233 | self.idx_scale = idx_scale 234 | else: 235 | self.idx_scale = random.randint(0, len(self.scale) - 1) 236 | 237 | -------------------------------------------------------------------------------- /src/data/srdata_bin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | 25 | self._set_filesystem(args.dir_data) 26 | if args.ext.find('img') < 0: 27 | path_bin = os.path.join(self.apath, 'bin') 28 | os.makedirs(path_bin, exist_ok=True) 29 | 30 | list_hr, list_lr = self._scan() 31 | if args.ext.find('bin') >= 0: 32 | # Binary files are stored in 'bin' folder 33 | # If the binary file exists, load it. If not, make it. 34 | list_hr, list_lr = self._scan() 35 | self.images_hr = self._check_and_load( 36 | args.ext, list_hr, self._name_hrbin() 37 | ) 38 | self.images_lr = [ 39 | self._check_and_load(args.ext, l, self._name_lrbin(s)) \ 40 | for s, l in zip(self.scale, list_lr) 41 | ] 42 | else: 43 | if args.ext.find('img') >= 0 or benchmark: 44 | self.images_hr, self.images_lr = list_hr, list_lr 45 | elif args.ext.find('sep') >= 0: 46 | os.makedirs( 47 | self.dir_hr.replace(self.apath, path_bin), 48 | exist_ok=True 49 | ) 50 | for s in self.scale: 51 | os.makedirs( 52 | os.path.join( 53 | self.dir_lr.replace(self.apath, path_bin), 54 | 'X{}'.format(s) 55 | ), 56 | exist_ok=True 57 | ) 58 | 59 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 60 | for h in list_hr: 61 | b = h.replace(self.apath, path_bin) 62 | b = b.replace(self.ext[0], '.pt') 63 | self.images_hr.append(b) 64 | self._check_and_load( 65 | args.ext, [h], b, verbose=True, load=False 66 | ) 67 | 68 | for i, ll in enumerate(list_lr): 69 | for l in ll: 70 | b = l.replace(self.apath, path_bin) 71 | b = b.replace(self.ext[1], '.pt') 72 | self.images_lr[i].append(b) 73 | self._check_and_load( 74 | args.ext, [l], b, verbose=True, load=False 75 | ) 76 | 77 | if train: 78 | n_patches = args.batch_size * args.test_every 79 | n_images = len(args.data_train) * len(self.images_hr) 80 | if n_images == 0: 81 | self.repeat = 0 82 | else: 83 | self.repeat = max(n_patches // n_images, 1) 84 | 85 | # # Below functions as used to prepare images 86 | # def _scan(self): 87 | # names_hr = sorted( 88 | # glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 89 | # ) 90 | # names_lr = [[] for _ in self.scale] 91 | # for f in names_hr: 92 | # filename, _ = os.path.splitext(os.path.basename(f)) 93 | # for si, s in enumerate(self.scale): 94 | # names_lr[si].append(os.path.join( 95 | # self.dir_lr, 'X{}/{}x{}{}'.format( 96 | # s, filename, s, self.ext[1] 97 | # ) 98 | # )) 99 | 100 | # return names_hr, names_lr 101 | 102 | def _scan(self): 103 | list_hr = [] 104 | list_lr = [[] for _ in self.scale] 105 | 106 | for i in range(self.begin, self.end + 1): 107 | filename = '{:0>4}'.format(i) 108 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext[0])) 109 | for si, s in enumerate(self.scale): 110 | list_lr[si].append(os.path.join( 111 | self.dir_lr, 112 | 'X{}/{}x{}{}'.format(s, filename, s, self.ext[1]) 113 | )) 114 | 115 | return list_hr, list_lr 116 | 117 | def _set_filesystem(self, dir_data): 118 | self.apath = os.path.join(dir_data, self.name) 119 | self.dir_hr = os.path.join(self.apath, 'HR') 120 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 121 | if self.input_large: self.dir_lr += 'L' 122 | self.ext = ('.png', '.png') 123 | 124 | def _name_hrbin(self): 125 | return os.path.join( 126 | self.apath, 127 | 'bin', 128 | '{}_bin_HR.pt'.format(self.split) 129 | ) 130 | 131 | def _name_lrbin(self, scale): 132 | return os.path.join( 133 | self.apath, 134 | 'bin', 135 | '{}_bin_LR_X{}.pt'.format(self.split, scale) 136 | ) 137 | 138 | def _check_and_load(self, ext, l, f, verbose=True, load=True): 139 | if os.path.isfile(f) and ext.find('reset') < 0: 140 | if load: 141 | if verbose: print('Loading {}...'.format(f)) 142 | with open(f, 'rb') as _f: ret = pickle.load(_f) 143 | return ret 144 | else: 145 | return None 146 | else: 147 | if verbose: 148 | if ext.find('reset') >= 0: 149 | print('Making a new binary: {}'.format(f)) 150 | else: 151 | print('{} does not exist. Now making binary...'.format(f)) 152 | if ext.find('bin') >= 0: 153 | print('Bin pt file with name and image') 154 | b = [{ 155 | 'name': os.path.splitext(os.path.basename(_l))[0], 156 | 'image': imageio.imread(_l) 157 | } for _l in l] 158 | with open(f, 'wb') as _f: pickle.dump(b, _f) 159 | 160 | return b 161 | else: 162 | print('Direct pt file without name or image') 163 | # import pdb 164 | # pdb.set_trace() 165 | b = imageio.imread(l[0]) 166 | with open(f, 'wb') as _f: pickle.dump(b, _f) 167 | 168 | # return b 169 | 170 | def __getitem__(self, idx): 171 | lr, hr, filename = self._load_file(idx) 172 | pair = self.get_patch(lr, hr) 173 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 174 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 175 | 176 | return pair_t[0], pair_t[1], filename 177 | 178 | def __len__(self): 179 | if self.train: 180 | return len(self.images_hr) * self.repeat 181 | else: 182 | return len(self.images_hr) 183 | 184 | def _get_index(self, idx): 185 | if self.train: 186 | return idx % len(self.images_hr) 187 | else: 188 | return idx 189 | 190 | def _load_file(self, idx): 191 | idx = self._get_index(idx) 192 | f_hr = self.images_hr[idx] 193 | f_lr = self.images_lr[self.idx_scale][idx] 194 | 195 | if self.args.ext.find('bin') >= 0: 196 | filename = f_hr['name'] 197 | hr = f_hr['image'] 198 | lr = f_lr['image'] 199 | else: 200 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 201 | if self.args.ext == 'img' or self.benchmark: 202 | hr = imageio.imread(f_hr) 203 | lr = imageio.imread(f_lr) 204 | elif self.args.ext.find('sep') >= 0: 205 | # For each pt file, use 'image' to load it 206 | # with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image'] 207 | # with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image'] 208 | # For each pt file, directly load it 209 | with open(f_hr, 'rb') as _f: hr = pickle.load(_f) 210 | with open(f_lr, 'rb') as _f: lr = pickle.load(_f) 211 | 212 | return lr, hr, filename 213 | 214 | def get_patch(self, lr, hr): 215 | scale = self.scale[self.idx_scale] 216 | if self.train: 217 | lr, hr = common.get_patch( 218 | lr, hr, 219 | patch_size=self.args.patch_size, 220 | scale=scale, 221 | multi=(len(self.scale) > 1), 222 | input_large=self.input_large 223 | ) 224 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 225 | else: 226 | ih, iw = lr.shape[:2] 227 | hr = hr[0:ih * scale, 0:iw * scale] 228 | 229 | return lr, hr 230 | 231 | def set_scale(self, idx_scale): 232 | if not self.input_large: 233 | self.idx_scale = idx_scale 234 | else: 235 | self.idx_scale = random.randint(0, len(self.scale) - 1) 236 | 237 | -------------------------------------------------------------------------------- /src/utility.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import math 3 | import time 4 | import datetime 5 | from multiprocessing import Process 6 | from multiprocessing import Queue 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import numpy as np 13 | import imageio 14 | 15 | import torch 16 | import torch.optim as optim 17 | import torch.optim.lr_scheduler as lrs 18 | 19 | class timer(): 20 | def __init__(self): 21 | self.acc = 0 22 | self.tic() 23 | 24 | def tic(self): 25 | self.t0 = time.time() 26 | 27 | def toc(self, restart=False): 28 | diff = time.time() - self.t0 29 | if restart: self.t0 = time.time() 30 | return diff 31 | 32 | def hold(self): 33 | self.acc += self.toc() 34 | 35 | def release(self): 36 | ret = self.acc 37 | self.acc = 0 38 | 39 | return ret 40 | 41 | def reset(self): 42 | self.acc = 0 43 | 44 | class checkpoint(): 45 | def __init__(self, args): 46 | self.args = args 47 | self.ok = True 48 | self.log = torch.Tensor() 49 | now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') 50 | 51 | if not args.load: 52 | if not args.save: 53 | args.save = now 54 | self.dir = os.path.join('..', 'experiment', args.save) 55 | else: 56 | self.dir = os.path.join('..', 'experiment', args.load) 57 | if os.path.exists(self.dir): 58 | self.log = torch.load(self.get_path('psnr_log.pt')) 59 | print('Continue from epoch {}...'.format(len(self.log))) 60 | else: 61 | args.load = '' 62 | 63 | # # add a reminder 64 | # if os.path.exists(self.dir): 65 | # val = input('Warning: directory "%s" already exists, is there any potential problem with this? Type [yes/no] to continue: ' % self.dir) 66 | # if val.lower() == 'no': 67 | # val = input('Warning: Are you sure? We cannot be too careful. Type [yes/no] to continue: ') 68 | # if val.lower() == 'yes': 69 | # print("You've responded with 'yes'. This program is about to terminate. Please check and run again.") 70 | # exit(1) 71 | # else: 72 | # print("You are very positive that there is NO problem. The program continues running. Have a nice day!") 73 | 74 | if args.reset: 75 | os.system('rm -rf ' + self.dir) 76 | args.load = '' 77 | 78 | os.makedirs(self.dir, exist_ok=True) 79 | os.makedirs(self.get_path('model'), exist_ok=True) 80 | for d in args.data_test: 81 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 82 | 83 | open_type = 'a' if os.path.exists(self.get_path('log.txt')) else 'w' 84 | self.log_file = open(self.get_path('log.txt'), open_type) 85 | self.log_file_prune = open(self.get_path('log_prune.txt'), open_type) # @mst: use another log file specifically for pruning logs 86 | with open(self.get_path('config.txt'), open_type) as f: 87 | f.write(now + '\n\n') 88 | for arg in vars(args): 89 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 90 | f.write('\n') 91 | self.print_script() 92 | 93 | self.n_processes = 8 94 | 95 | def get_path(self, *subdir): 96 | return os.path.join(self.dir, *subdir) 97 | 98 | def print_script(self): 99 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 100 | gpu_id = os.environ['CUDA_VISIBLE_DEVICES'] 101 | script = ' '.join(['CUDA_VISIBLE_DEVICES=%s python' % gpu_id, *sys.argv]) 102 | else: 103 | script = ' '.join(['python', *sys.argv]) 104 | with open(self.get_path('config.txt'), 'a+') as f: 105 | print(script, file=f, flush=True) 106 | print(script, file=sys.stdout, flush=True) 107 | 108 | def save(self, trainer, epoch, is_best=False): 109 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 110 | trainer.loss.save(self.dir) 111 | trainer.loss.plot_loss(self.dir, epoch) 112 | 113 | self.plot_psnr(epoch) 114 | trainer.optimizer.save(self.dir) 115 | torch.save(self.log, self.get_path('psnr_log.pt')) 116 | 117 | def add_log(self, log): 118 | self.log = torch.cat([self.log, log]) 119 | 120 | def write_log(self, log, refresh=False): 121 | print(log) 122 | self.log_file.write(log + '\n') 123 | self.log_file.flush() 124 | if refresh: 125 | self.log_file.close() 126 | self.log_file = open(self.get_path('log.txt'), 'a') 127 | 128 | # @mst: use another log file specifically for pruning logs 129 | def write_log_prune(self, log, refresh=False): 130 | print(log) 131 | self.log_file_prune.write(log + '\n') 132 | self.log_file_prune.flush() 133 | if refresh: 134 | self.log_file_prune.close() 135 | self.log_file_prune = open(self.get_path('log_prune.txt'), 'a') 136 | 137 | def done(self): 138 | self.log_file.close() 139 | 140 | def plot_psnr(self, epoch): 141 | if epoch > 0: 142 | axis = np.linspace(1, epoch, epoch) 143 | for idx_data, d in enumerate(self.args.data_test): 144 | label = 'SR on {}'.format(d) 145 | fig = plt.figure() 146 | plt.title(label) 147 | for idx_scale, scale in enumerate(self.args.scale): 148 | plt.plot( 149 | axis, 150 | self.log[:, idx_data, idx_scale].numpy(), 151 | label='Scale {}'.format(scale) 152 | ) 153 | plt.legend() 154 | plt.xlabel('Epochs') 155 | plt.ylabel('PSNR') 156 | plt.grid(True) 157 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 158 | plt.close(fig) 159 | 160 | def begin_background(self): 161 | self.queue = Queue() 162 | 163 | def bg_target(queue): 164 | while True: 165 | if not queue.empty(): 166 | filename, tensor = queue.get() 167 | if filename is None: break 168 | imageio.imwrite(filename, tensor.numpy()) 169 | 170 | self.process = [ 171 | Process(target=bg_target, args=(self.queue,)) \ 172 | for _ in range(self.n_processes) 173 | ] 174 | 175 | for p in self.process: p.start() 176 | 177 | def end_background(self): 178 | for _ in range(self.n_processes): self.queue.put((None, None)) 179 | while not self.queue.empty(): time.sleep(1) 180 | for p in self.process: p.join() 181 | 182 | def save_results(self, dataset, filename, save_list, scale): 183 | if self.args.save_results: 184 | filename = self.get_path( 185 | 'results-{}'.format(dataset.dataset.name), 186 | '{}_x{}_'.format(filename, scale) 187 | ) 188 | 189 | postfix = ('SR', 'LR', 'HR') 190 | for v, p in zip(save_list, postfix): 191 | normalized = v[0].mul(255 / self.args.rgb_range) 192 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 193 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 194 | 195 | def quantize(img, rgb_range): 196 | pixel_range = 255 / rgb_range 197 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 198 | 199 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 200 | if hr.nelement() == 1: return 0 201 | 202 | diff = (sr - hr) / rgb_range 203 | # if dataset and dataset.dataset.benchmark: 204 | # shave = scale 205 | # if diff.size(1) > 1: 206 | # gray_coeffs = [65.738, 129.057, 25.064] 207 | # convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 208 | # diff = diff.mul(convert).sum(dim=1) 209 | # else: 210 | # shave = scale + 6 211 | 212 | if diff.size(1) > 1: 213 | gray_coeffs = [65.738, 129.057, 25.064] 214 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 215 | diff = diff.mul(convert).sum(dim=1) 216 | 217 | shave = scale 218 | valid = diff[..., shave:-shave, shave:-shave] 219 | mse = valid.pow(2).mean() 220 | 221 | return -10 * math.log10(mse) 222 | 223 | def make_optimizer(args, target): 224 | ''' 225 | make optimizer and scheduler together 226 | ''' 227 | # optimizer 228 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 229 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 230 | 231 | if args.optimizer == 'SGD': 232 | optimizer_class = optim.SGD 233 | kwargs_optimizer['momentum'] = args.momentum 234 | elif args.optimizer == 'ADAM': 235 | optimizer_class = optim.Adam 236 | kwargs_optimizer['betas'] = args.betas 237 | kwargs_optimizer['eps'] = args.epsilon 238 | elif args.optimizer == 'RMSprop': 239 | optimizer_class = optim.RMSprop 240 | kwargs_optimizer['eps'] = args.epsilon 241 | 242 | # scheduler 243 | # milestones = list(map(lambda x: int(x), args.decay.split('-'))) 244 | # kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 245 | # scheduler_class = lrs.MultiStepLR 246 | 247 | if args.decay_type == 'step': 248 | # step 249 | kwargs_scheduler = {'step_size': args.lr_decay, 'gamma': args.gamma} 250 | scheduler_class = lrs.StepLR 251 | else: 252 | # multi_step, specified by args.decay 253 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 254 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 255 | scheduler_class = lrs.MultiStepLR 256 | 257 | class CustomOptimizer(optimizer_class): 258 | def __init__(self, *args, **kwargs): 259 | super(CustomOptimizer, self).__init__(*args, **kwargs) 260 | 261 | def _register_scheduler(self, scheduler_class, **kwargs): 262 | self.scheduler = scheduler_class(self, **kwargs) 263 | 264 | def save(self, save_dir): 265 | torch.save(self.state_dict(), self.get_dir(save_dir)) 266 | 267 | def load(self, load_dir, epoch=1): 268 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 269 | if epoch > 1: 270 | for _ in range(epoch): self.scheduler.step() 271 | 272 | def get_dir(self, dir_path): 273 | return os.path.join(dir_path, 'optimizer.pt') 274 | 275 | def schedule(self): 276 | self.scheduler.step() 277 | 278 | def get_lr(self): 279 | return self.scheduler.get_lr()[0] 280 | 281 | def get_last_epoch(self): 282 | return self.scheduler.last_epoch 283 | 284 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 285 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 286 | return optimizer -------------------------------------------------------------------------------- /src/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | import glob 4 | from utils import strlist_to_list, strdict_to_dict, check_path, parse_prune_ratio_vgg 5 | 6 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 7 | 8 | parser.add_argument('--debug', action='store_true', 9 | help='Enables debug mode') 10 | parser.add_argument('--template', default='.', 11 | help='You can set various templates in option.py') 12 | 13 | # Hardware specifications 14 | parser.add_argument('--n_threads', type=int, default=6, 15 | help='number of threads for data loading') 16 | parser.add_argument('--cpu', action='store_true', 17 | help='use cpu only') 18 | parser.add_argument('--n_GPUs', type=int, default=1, 19 | help='number of GPUs') 20 | parser.add_argument('--seed', type=int, default=1, 21 | help='random seed') 22 | 23 | # Data specifications 24 | parser.add_argument('--dir_data', type=str, default='/home/yulun/data/SR/RGB/BIX2X3X4/pt_bin', 25 | help='dataset directory') 26 | parser.add_argument('--dir_demo', type=str, default='../test', 27 | help='demo image directory') 28 | parser.add_argument('--data_train', type=str, default='DF2K', 29 | help='train dataset name') 30 | parser.add_argument('--data_test', type=str, default='DF2K', 31 | help='test dataset name') 32 | parser.add_argument('--data_range', type=str, default='1-3550/3551-3555', 33 | help='train/test data range') 34 | parser.add_argument('--ext', type=str, default='sep', 35 | help='dataset file extension') 36 | parser.add_argument('--scale', type=str, default='4', 37 | help='super resolution scale') 38 | parser.add_argument('--patch_size', type=int, default=192, 39 | help='output patch size') 40 | parser.add_argument('--rgb_range', type=int, default=255, 41 | help='maximum value of RGB') 42 | parser.add_argument('--n_colors', type=int, default=3, 43 | help='number of color channels to use') 44 | parser.add_argument('--chop', action='store_true', 45 | help='enable memory-efficient forward') 46 | parser.add_argument('--no_augment', action='store_true', 47 | help='do not use data augmentation') 48 | 49 | # Model specifications 50 | parser.add_argument('--model', default='EDSR', 51 | help='model name') 52 | 53 | parser.add_argument('--act', type=str, default='relu', 54 | help='activation function') 55 | parser.add_argument('--pre_train', type=str, default='', 56 | help='pre-trained model directory') 57 | parser.add_argument('--extend', type=str, default='.', 58 | help='pre-trained model directory') 59 | parser.add_argument('--n_resblocks', type=int, default=16, 60 | help='number of residual blocks') 61 | parser.add_argument('--n_feats', type=int, default=64, 62 | help='number of feature maps') 63 | parser.add_argument('--res_scale', type=float, default=1, 64 | help='residual scaling') 65 | parser.add_argument('--shift_mean', default=True, 66 | help='subtract pixel mean from the input') 67 | parser.add_argument('--dilation', action='store_true', 68 | help='use dilated convolution') 69 | parser.add_argument('--precision', type=str, default='single', 70 | choices=('single', 'half'), 71 | help='FP precision for test (single | half)') 72 | 73 | # Option for Residual dense network (RDN) 74 | parser.add_argument('--G0', type=int, default=64, 75 | help='default number of filters. (Use in RDN)') 76 | parser.add_argument('--RDNkSize', type=int, default=3, 77 | help='default kernel size. (Use in RDN)') 78 | parser.add_argument('--RDNconfig', type=str, default='B', 79 | help='parameters config of RDN. (Use in RDN)') 80 | 81 | # Option for Residual channel attention network (RCAN) 82 | parser.add_argument('--n_resgroups', type=int, default=10, 83 | help='number of residual groups') 84 | parser.add_argument('--reduction', type=int, default=16, 85 | help='number of feature maps reduction') 86 | 87 | # Training specifications 88 | parser.add_argument('--reset', action='store_true', 89 | help='reset the training') 90 | parser.add_argument('--test_every', type=int, default=1000, 91 | help='do test per every N batches') 92 | parser.add_argument('--epochs', type=int, default=5000, 93 | help='number of epochs to train') 94 | parser.add_argument('--batch_size', type=int, default=16, 95 | help='input batch size for training') 96 | parser.add_argument('--split_batch', type=int, default=1, 97 | help='split the batch into smaller chunks') 98 | parser.add_argument('--self_ensemble', action='store_true', 99 | help='use self-ensemble method for test') 100 | parser.add_argument('--test_only', action='store_true', 101 | help='set this option to test the model') 102 | parser.add_argument('--gan_k', type=int, default=1, 103 | help='k value for adversarial loss') 104 | 105 | # Optimization specifications 106 | parser.add_argument('--lr', type=float, default=1e-4, 107 | help='learning rate') 108 | parser.add_argument('--lr_decay', type=int, default=200, 109 | help='learning rate decay per N epochs') 110 | parser.add_argument('--decay_type', type=str, default='step', 111 | help='learning rate decay type') 112 | parser.add_argument('--decay', type=str, default='200', 113 | help='learning rate decay type, multiple_step, 200-400-600-800-1000') 114 | parser.add_argument('--gamma', type=float, default=0.5, 115 | help='learning rate decay factor for step decay') 116 | parser.add_argument('--optimizer', default='ADAM', 117 | choices=('SGD', 'ADAM', 'RMSprop'), 118 | help='optimizer to use (SGD | ADAM | RMSprop)') 119 | parser.add_argument('--momentum', type=float, default=0.9, 120 | help='SGD momentum') 121 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 122 | help='ADAM beta') 123 | parser.add_argument('--epsilon', type=float, default=1e-8, 124 | help='ADAM epsilon for numerical stability') 125 | parser.add_argument('--weight_decay', type=float, default=0, 126 | help='weight decay') 127 | parser.add_argument('--gclip', type=float, default=0, 128 | help='gradient clipping threshold (0 = no clipping)') 129 | 130 | # Loss specifications 131 | parser.add_argument('--loss', type=str, default='1*L1', 132 | help='loss function configuration') 133 | parser.add_argument('--skip_threshold', type=float, default='1e8', 134 | help='skipping batch that has large error') 135 | 136 | # Log specifications 137 | parser.add_argument('--save', type=str, default='test', 138 | help='file name to save') 139 | parser.add_argument('--load', type=str, default='', 140 | help='file name to load') 141 | parser.add_argument('--resume', type=int, default=0, 142 | help='resume from specific checkpoint') 143 | parser.add_argument('--save_models', action='store_true', 144 | help='save all intermediate models') 145 | parser.add_argument('--print_every', type=int, default=100, 146 | help='how many batches to wait before logging training status') 147 | parser.add_argument('--save_results', action='store_true', 148 | help='save output results') 149 | parser.add_argument('--save_gt', action='store_true', 150 | help='save low-resolution and high-resolution images together') 151 | 152 | # Routine arguments to set up experiment dir 153 | parser.add_argument('--project_name', type=str, default="") 154 | parser.add_argument('--screen_print', action="store_true") 155 | parser.add_argument('--print_interval', type=int, default=100) 156 | 157 | # Lightweight SR 158 | parser.add_argument('--method', type=str, default='', choices=['', 'ASSL', 'L1'], help='method name') 159 | parser.add_argument('--wg', type=str, default='filter', choices=['filter', 'weight'], help='weight group to prune') 160 | parser.add_argument('--stage_pr', type=str, default="", help='to appoint layer-wise pruning ratio') 161 | parser.add_argument('--skip_layers', type=str, default="", help='layers to skip when pruning') 162 | parser.add_argument('--reinit_layers', type=str, default="", help='layers to reinit (not inherit weights)') 163 | parser.add_argument('--same_pruned_wg_layers', type=str, default='', help='layers to be set with the same pruned weight group') 164 | parser.add_argument('--same_pruned_wg_criterion', type=str, default='rand', choices=['rand', 'reg'], help='use which criterion to select pruned wg') 165 | parser.add_argument('--num_layers', type=int, default=1000, help='num of layers in the network') 166 | parser.add_argument('--resume_path', type=str, default='', help='path of the checkpoint to resume') 167 | 168 | # ASSL 169 | parser.add_argument('--update_reg_interval', type=int, default=5) 170 | parser.add_argument('--stabilize_reg_interval', type=int, default=40000) 171 | parser.add_argument('--reg_upper_limit', type=float, default=1) 172 | parser.add_argument('--reg_granularity_prune', type=float, default=1e-4) 173 | parser.add_argument('--pick_pruned', type=str, default='min', choices=['min', 'max', 'rand']) 174 | parser.add_argument('--not_apply_reg', dest='apply_reg', action='store_false', default=True) 175 | parser.add_argument('--layer_chl', type=str, default='', help='manually assign the number of channels for some layers. A not so beautiful scheme.') 176 | parser.add_argument('--greg_mode', type=str, default='part', choices=['part', 'all']) 177 | parser.add_argument('--compare_mode', type=str, default='local', choices=['local', 'global']) 178 | parser.add_argument('--prune_criterion', type=str, default='wn_scale', choices=['l1-norm', 'wn_scale']) 179 | parser.add_argument('--wn', action='store_true', help='if use weight normalization') 180 | parser.add_argument('--lw_spr', type=float, default=1, help='lw for loss of sparsity pattern regularization') 181 | parser.add_argument('--iter_finish_spr', '--iter_ssa', dest='iter_ssa', type=int, default=17260, help='863x20 = 20 epochs') 182 | parser.add_argument('--lr_prune', type=float, default=0.0002) 183 | 184 | args = parser.parse_args() 185 | template.set_template(args) 186 | 187 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 188 | args.data_train = args.data_train.split('+') 189 | args.data_test = args.data_test.split('+') 190 | 191 | if args.epochs == 0: 192 | args.epochs = 1e8 193 | 194 | for arg in vars(args): 195 | if vars(args)[arg] == 'True': 196 | vars(args)[arg] = True 197 | elif vars(args)[arg] == 'False': 198 | vars(args)[arg] = False 199 | 200 | # parse for layer-wise prune ratio 201 | # stage_pr is a list of float, skip_layers is a list of strings 202 | if args.method in ['L1', 'ASSL']: 203 | assert args.stage_pr 204 | if glob.glob(args.stage_pr): # 'stage_pr' is a path 205 | args.stage_pr = check_path(args.stage_pr) 206 | else: 207 | if args.compare_mode in ['global']: # 'stage_pr' is a float 208 | args.stage_pr = float(args.stage_pr) 209 | elif args.compare_mode in ['local']: # 'stage_pr' is a list 210 | args.stage_pr = parse_prune_ratio_vgg(args.stage_pr, num_layers=args.num_layers) 211 | args.skip_layers = strlist_to_list(args.skip_layers, str) 212 | args.reinit_layers = strlist_to_list(args.reinit_layers, str) 213 | args.same_pruned_wg_layers = strlist_to_list(args.same_pruned_wg_layers, str) 214 | args.layer_chl = strdict_to_dict(args.layer_chl, int) 215 | 216 | # directly appoint some values to maintain compatibility 217 | args.reinit = False 218 | args.project_name = args.save -------------------------------------------------------------------------------- /src/pruner/utils.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn 2 | from collections import OrderedDict 3 | from fnmatch import fnmatch, fnmatchcase 4 | import math, numpy as np, copy 5 | tensor2list = lambda x: x.data.cpu().numpy().tolist() 6 | tensor2array = lambda x: x.data.cpu().numpy() 7 | totensor = lambda x: torch.Tensor(x) 8 | 9 | def get_pr_layer(base_pr, layer_name, layer_index, skip=[], compare_mode='local'): 10 | """ 'base_pr' example: '[0-4:0.5, 5:0.6, 8-10:0.2]', 6, 7 not mentioned, default value is 0 11 | """ 12 | if compare_mode in ['global']: 13 | pr = 1e-20 # a small positive value to indicate this layer will be considered for pruning, will be replaced 14 | elif compare_mode in ['local']: 15 | pr = base_pr[layer_index] 16 | 17 | # if layer name matchs the pattern pre-specified in 'skip', skip it (i.e., pr = 0) 18 | for p in skip: 19 | if fnmatch(layer_name, p): 20 | pr = 0 21 | return pr 22 | 23 | def get_pr_model(layers, base_pr, skip=[], compare_mode='local'): 24 | """Get layer-wise pruning ratio for a model. 25 | """ 26 | pr = OrderedDict() 27 | if isinstance(base_pr, str): 28 | ckpt = torch.load(base_pr) 29 | pruned, kept = ckpt['pruned_wg'], ckpt['kept_wg'] 30 | for name in pruned: 31 | num_pruned, num_kept = len(pruned[name]), len(kept[name]) 32 | pr[name] = float(num_pruned) / (num_pruned + num_kept) 33 | print(f"==> Load base_pr model successfully and inherit its pruning ratio: '{base_pr}'.") 34 | elif isinstance(base_pr, (float, list)): 35 | if compare_mode in ['global']: 36 | assert isinstance(base_pr, float) 37 | pr['model'] = base_pr 38 | for name, layer in layers.items(): 39 | pr[name] = get_pr_layer(base_pr, name, layer.index, skip=skip, compare_mode=compare_mode) 40 | print(f"==> Get pr (pruning ratio) for pruning the model, done (pr may be updated later).") 41 | else: 42 | raise NotImplementedError 43 | return pr 44 | 45 | def get_constrained_layers(layers, constrained_pattern): 46 | constrained_layers = [] 47 | for name, _ in layers.items(): 48 | for p in constrained_pattern: 49 | if fnmatch(name, p): 50 | constrained_layers += [name] 51 | return constrained_layers 52 | 53 | def adjust_pr(layers, pr, pruned, kept, num_pruned_constrained, constrained): 54 | """The real pr of a layer may not be exactly equal to the assigned one (i.e., raw pr) due to various reasons (e.g., constrained layers). 55 | Adjust it here, e.g., averaging the prs for all constrained layers. 56 | """ 57 | pr, pruned, kept = copy.deepcopy(pr), copy.deepcopy(pruned), copy.deepcopy(kept) 58 | for name, layer in layers.items(): 59 | if name in constrained: 60 | # -- averaging within all constrained layers to keep the total num of pruned weight groups still the same 61 | num_pruned = int(num_pruned_constrained / len(constrained)) 62 | # -- 63 | pr[name] = num_pruned / len(layer.score) 64 | order = pruned[name] + kept[name] 65 | pruned[name], kept[name] = order[:num_pruned], order[num_pruned:] 66 | else: 67 | num_pruned = len(pruned[name]) 68 | pr[name] = num_pruned / len(layer.score) 69 | return pr, pruned, kept 70 | 71 | def set_same_pruned(model, pr, pruned_wg, kept_wg, constrained, wg='filter', criterion='l1-norm', sort_mode='min'): 72 | """Set pruned wgs of some layers to the same indices. 73 | """ 74 | pruned_wg, kept_wg = copy.deepcopy(pruned_wg), copy.deepcopy(kept_wg) 75 | pruned = None 76 | for name, m in model.named_modules(): 77 | if name in constrained: 78 | if pruned is None: 79 | score = get_score_layer(m, wg=wg, criterion=criterion)['score'] 80 | pruned, kept = pick_pruned_layer(score=score, pr=pr[name], sort_mode=sort_mode) 81 | pr_first_constrained = pr[name] 82 | assert pr[name] == pr_first_constrained 83 | pruned_wg[name], kept_wg[name] = pruned, kept 84 | return pruned_wg, kept_wg 85 | 86 | def get_score_layer(module, wg='filter', criterion='l1-norm'): 87 | r"""Get importance score for a layer. 88 | 89 | Return: 90 | out (dict): A dict that has key 'score', whose value is a numpy array 91 | """ 92 | # -- define any scoring scheme here as you like 93 | shape = module.weight.data.shape 94 | if wg == "channel": 95 | l1 = module.weight.abs().mean(dim=[0, 2, 3]) if len(shape) == 4 else module.weight.abs().mean(dim=0) 96 | elif wg == "filter": 97 | l1 = module.weight.abs().mean(dim=[1, 2, 3]) if len(shape) == 4 else module.weight.abs().mean(dim=1) 98 | elif wg == "weight": 99 | l1 = module.weight.abs().flatten() 100 | # -- 101 | 102 | out = {} 103 | out['l1-norm'] = tensor2array(l1) 104 | out['wn_scale'] = tensor2array(module.wn_scale.abs()) if hasattr(module, 'wn_scale') else [1e30] * module.weight.size(0) 105 | # 1e30 to indicate this layer will not be pruned because of its unusually high scores 106 | out['score'] = out[criterion] 107 | return out 108 | 109 | def pick_pruned_layer(score, pr=None, threshold=None, sort_mode='min'): 110 | r"""Get the indices of pruned weight groups in a layer. 111 | 112 | Return: 113 | pruned (list) 114 | kept (list) 115 | """ 116 | assert sort_mode in ['min', 'rand', 'max'] 117 | score = np.array(score) 118 | num_total = len(score) 119 | if sort_mode in ['rand']: 120 | assert pr is not None 121 | num_pruned = min(math.ceil(pr * num_total), num_total - 1) # do not prune all 122 | order = np.random.permutation(num_total).tolist() 123 | else: 124 | num_pruned = math.ceil(pr * num_total) if threshold is None else len(np.where(score < threshold)[0]) 125 | num_pruned = min(num_pruned, num_total - 1) # do not prune all 126 | if sort_mode in ['min', 'ascending']: 127 | order = np.argsort(score).tolist() 128 | elif sort_mode in ['max', 'descending']: 129 | order = np.argsort(score)[::-1].tolist() 130 | pruned, kept = order[:num_pruned], order[num_pruned:] 131 | return pruned, kept 132 | 133 | def pick_pruned_model(model, layers, raw_pr, wg='filter', criterion='l1-norm', compare_mode='local', sort_mode='min', constrained=[], align_constrained=False): 134 | r"""Pick pruned weight groups for a model. 135 | Args: 136 | layers: an OrderedDict, key is layer name 137 | 138 | Return: 139 | pruned (OrderedDict): key is layer name, value is the pruned indices for the layer 140 | kept (OrderedDict): key is layer name, value is the kept indices for the layer 141 | """ 142 | assert sort_mode in ['rand', 'min', 'max'] and compare_mode in ['global', 'local'] 143 | pruned_wg, kept_wg = OrderedDict(), OrderedDict() 144 | all_scores, num_pruned_constrained = [], 0 145 | 146 | # iter to get importance score for each layer 147 | for name, module in model.named_modules(): 148 | if name in layers: 149 | layer = layers[name] 150 | out = get_score_layer(module, wg=wg, criterion=criterion) 151 | score = out['score'] 152 | layer.score = score 153 | if raw_pr[name] > 0: # pr > 0 indicates we want to prune this layer so its score will be included in the 154 | all_scores = np.append(all_scores, score) 155 | 156 | # local pruning 157 | if compare_mode in ['local']: 158 | assert isinstance(raw_pr, dict) 159 | pruned_wg[name], kept_wg[name] = pick_pruned_layer(score, raw_pr[name], sort_mode=sort_mode) 160 | if name in constrained: 161 | num_pruned_constrained += len(pruned_wg[name]) 162 | 163 | # global pruning 164 | if compare_mode in ['global']: 165 | num_total = len(all_scores) 166 | num_pruned = min(math.ceil(raw_pr['model'] * num_total), num_total - 1) # do not prune all 167 | if sort_mode == 'min': 168 | threshold = sorted(all_scores)[num_pruned] # in ascending order 169 | elif sort_mode == 'max': 170 | threshold = sorted(all_scores)[::-1][num_pruned] # in decending order 171 | print(f'#all_scores: {len(all_scores)} threshold:{threshold:.6f}') 172 | 173 | for name, layer in layers.items(): 174 | if raw_pr[name] > 0: 175 | if sort_mode in ['rand']: 176 | pass 177 | elif sort_mode in ['min', 'max']: 178 | pruned_wg[name], kept_wg[name] = pick_pruned_layer(layer.score, pr=None, threshold=threshold, sort_mode=sort_mode) 179 | else: 180 | pruned_wg[name], kept_wg[name] = [], list(range(len(layer.score))) 181 | if name in constrained: 182 | num_pruned_constrained += len(pruned_wg[name]) 183 | 184 | # adjust pr/pruned/kept 185 | pr, pruned_wg, kept_wg = adjust_pr(layers, raw_pr, pruned_wg, kept_wg, num_pruned_constrained, constrained) 186 | print(f'==> Adjust pr/pruned/kept, done.') 187 | 188 | if align_constrained: 189 | pruned_wg, kept_wg = set_same_pruned(model, pr, pruned_wg, kept_wg, constrained, 190 | wg=wg, criterion=criterion, sort_mode=sort_mode) 191 | 192 | return pr, pruned_wg, kept_wg 193 | 194 | def get_next_learnable(layers, layer_name, n_conv_within_block=3): 195 | r"""Get the next learnable layer for the layer of 'layer_name', chosen from 'layers'. 196 | """ 197 | current_layer = layers[layer_name] 198 | 199 | # for standard ResNets on ImageNet 200 | if hasattr(current_layer, 'block_index'): 201 | block_index = current_layer.block_index 202 | if block_index == n_conv_within_block - 1: 203 | return None 204 | 205 | for name, layer in layers.items(): 206 | if layer.type == current_layer.type and layer.index == current_layer.index + 1: 207 | return name 208 | return None 209 | 210 | def get_prev_learnable(layers, layer_name): 211 | r"""Get the previous learnable layer for the layer of 'layer_name', chosen from 'layers'. 212 | """ 213 | current_layer = layers[layer_name] 214 | 215 | # for standard ResNets on ImageNet 216 | if hasattr(current_layer, 'block_index'): 217 | block_index = current_layer.block_index 218 | if block_index in [None, 0, -1]: # 1st conv, 1st conv in a block, 1x1 shortcut layer 219 | return None 220 | 221 | for name, layer in layers.items(): 222 | if layer.index == current_layer.index - 1: 223 | return name 224 | return None 225 | 226 | def get_next_bn(model, layer_name): 227 | r"""Get the next bn layer for the layer of 'layer_name', chosen from 'model'. 228 | Return the bn module instead of its name. 229 | """ 230 | just_passed = False 231 | for name, module in model.named_modules(): 232 | if name == layer_name: 233 | just_passed = True 234 | if just_passed and isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)): 235 | return module 236 | return None 237 | 238 | def replace_module(model, name, new_m): 239 | """Replace the module in with 240 | E.g., 'module.layer1.0.conv1' ==> model.__getattr__('module').__getattr__("layer1").__getitem__(0).__setattr__('conv1', new_m) 241 | """ 242 | obj = model 243 | segs = name.split(".") 244 | for ix in range(len(segs)): 245 | s = segs[ix] 246 | if ix == len(segs) - 1: # the last one 247 | if s.isdigit(): 248 | obj.__setitem__(int(s), new_m) 249 | else: 250 | obj.__setattr__(s, new_m) 251 | return 252 | if s.isdigit(): 253 | obj = obj.__getitem__(int(s)) 254 | else: 255 | obj = obj.__getattr__(s) 256 | 257 | def get_kept_filter_channel(layers, layer_name, pr, kept_wg, wg='filter'): 258 | """Considering layer dependency, get the kept filters and channels for the layer of 'layer_name'. 259 | """ 260 | current_layer = layers[layer_name] 261 | if wg in ["channel"]: 262 | kept_chl = kept_wg[layer_name] 263 | next_learnable = get_next_learnable(layers, layer_name) 264 | kept_filter = list(range(current_layer.module.weight.size(0))) if next_learnable is None else kept_wg[next_learnable] 265 | elif wg in ["filter"]: 266 | kept_filter = kept_wg[layer_name] 267 | prev_learnable = get_prev_learnable(layers, layer_name) 268 | if (prev_learnable is None) or pr[prev_learnable] == 0: 269 | # In the case of SR networks, tail, there is an upsampling via sub-pixel. 'self.pr[prev_learnable_layer] == 0' can help avoid it. 270 | # Not using this, the code will report error. 271 | kept_chl = list(range(current_layer.module.weight.size(1))) 272 | else: 273 | kept_chl = kept_wg[prev_learnable] 274 | 275 | # sort to make the indices be in ascending order 276 | kept_filter.sort() 277 | kept_chl.sort() 278 | return kept_filter, kept_chl 279 | 280 | def get_masks(layers, pruned_wg): 281 | """Get masks for unstructured pruning. 282 | """ 283 | masks = OrderedDict() 284 | for name, layer in layers.items(): 285 | mask = torch.ones(layer.shape).cuda().flatten() 286 | mask[pruned_wg[name]] = 0 287 | masks[name] = mask.view(layer.shape) 288 | return masks -------------------------------------------------------------------------------- /src/pruner/assl_pruner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import torch.nn.utils as utils 6 | from decimal import Decimal 7 | import os, copy, time, pickle, numpy as np, math 8 | from .meta_pruner import MetaPruner 9 | import utility 10 | import matplotlib.pyplot as plt 11 | from tqdm import tqdm 12 | from fnmatch import fnmatch, fnmatchcase 13 | from .utils import get_score_layer, pick_pruned_layer 14 | pjoin = os.path.join 15 | tensor2list = lambda x: x.data.cpu().numpy().tolist() 16 | tensor2array = lambda x: x.data.cpu().numpy() 17 | totensor = lambda x: torch.Tensor(x) 18 | 19 | class Pruner(MetaPruner): 20 | def __init__(self, model, args, logger, passer): 21 | super(Pruner, self).__init__(model, args, logger, passer) 22 | loader = passer.loader 23 | ckp = passer.ckp 24 | loss = passer.loss 25 | self.logprint = ckp.write_log_prune # use another log file specifically for pruning logs 26 | self.netprint = ckp.write_log_prune 27 | 28 | # ************************** variables from RCAN ************************** 29 | self.scale = args.scale 30 | 31 | self.ckp = ckp 32 | self.loader_train = loader.loader_train 33 | self.loader_test = loader.loader_test 34 | self.model = model 35 | self.loss = loss 36 | self.optimizer = utility.make_optimizer(args, self.model) 37 | 38 | if self.args.load != '': 39 | self.optimizer.load(ckp.dir, epoch=len(ckp.log)) 40 | 41 | self.error_last = 1e8 42 | # ************************************************************************** 43 | 44 | # Reg related variables 45 | self.reg = {} 46 | self.delta_reg = {} 47 | self._init_reg() 48 | self.iter_update_reg_finished = {} 49 | self.iter_finish_pick = {} 50 | self.iter_stabilize_reg = math.inf 51 | self.hist_mag_ratio = {} 52 | self.w_abs = {} 53 | self.wn_scale = {} 54 | 55 | # init prune_state 56 | self.prune_state = 'update_reg' 57 | if args.greg_mode in ['part'] and args.same_pruned_wg_layers and args.same_pruned_wg_criterion in ['reg']: 58 | self.prune_state = "ssa" # sparsity structure alignment 59 | self._get_kept_wg_L1(align_constrained=True) 60 | 61 | # init pruned_wg/kept_wg if they can be determined right at the begining 62 | if args.greg_mode in ['part'] and self.prune_state in ['update_reg']: 63 | self._get_kept_wg_L1(align_constrained=True) # this will update the 'self.kept_wg', 'self.pruned_wg', 'self.pr' 64 | 65 | def _init_reg(self): 66 | for name, m in self.model.named_modules(): 67 | if name in self.layers: 68 | if self.args.wg == 'weight': 69 | self.reg[name] = torch.zeros_like(m.weight.data).flatten().cuda() 70 | else: 71 | shape = m.weight.data.shape 72 | self.reg[name] = torch.zeros(shape[0], shape[1]).cuda() 73 | 74 | def _greg_1(self, m, name): 75 | if self.pr[name] == 0: 76 | return True 77 | 78 | pruned = self.pruned_wg[name] 79 | if self.args.wg == "channel": 80 | self.reg[name][:, pruned] += self.args.reg_granularity_prune 81 | elif self.args.wg == "filter": 82 | self.reg[name][pruned, :] += self.args.reg_granularity_prune 83 | elif self.args.wg == 'weight': 84 | self.reg[name][pruned] += self.args.reg_granularity_prune 85 | else: 86 | raise NotImplementedError 87 | 88 | # when all layers are pushed hard enough, stop 89 | return self.reg[name].max() > self.args.reg_upper_limit 90 | 91 | def _greg_penalize_all(self, m, name): 92 | if self.pr[name] == 0: 93 | return True 94 | 95 | if self.args.wg == "channel": 96 | self.reg[name] += self.args.reg_granularity_prune 97 | elif self.args.wg == "filter": 98 | self.reg[name] += self.args.reg_granularity_prune 99 | elif self.args.wg == 'weight': 100 | self.reg[name] += self.args.reg_granularity_prune 101 | else: 102 | raise NotImplementedError 103 | 104 | # when all layers are pushed hard enough, stop 105 | return self.reg[name].max() > self.args.reg_upper_limit 106 | 107 | def _update_reg(self, skip=[]): 108 | for name, m in self.model.named_modules(): 109 | if name in self.layers: 110 | if name in self.iter_update_reg_finished.keys(): 111 | continue 112 | if name in skip: 113 | continue 114 | 115 | # get the importance score (L1-norm in this case) 116 | out = get_score_layer(m, wg='filter', criterion='wn_scale') 117 | self.w_abs[name], self.wn_scale[name] = out['l1-norm'], out['wn_scale'] 118 | 119 | # update reg functions, two things: 120 | # (1) update reg of this layer (2) determine if it is time to stop update reg 121 | if self.args.greg_mode in ['part']: 122 | finish_update_reg = self._greg_1(m, name) 123 | elif self.args.greg_mode in ['all']: 124 | finish_update_reg = self._greg_penalize_all(m, name) 125 | 126 | # check prune state 127 | if finish_update_reg: 128 | # after 'update_reg' stage, keep the reg to stabilize weight magnitude 129 | self.iter_update_reg_finished[name] = self.total_iter 130 | self.logprint(f"==> {self.layer_print_prefix[name]} -- Just finished 'update_reg'. Iter {self.total_iter}. pr {self.pr[name]}") 131 | 132 | # check if all layers finish 'update_reg' 133 | prune_state = "stabilize_reg" 134 | for n, mm in self.model.named_modules(): 135 | if isinstance(mm, self.LEARNABLES): 136 | if n not in self.iter_update_reg_finished: 137 | prune_state = '' 138 | break 139 | if prune_state == "stabilize_reg": 140 | self.prune_state = 'stabilize_reg' 141 | self.iter_stabilize_reg = self.total_iter 142 | self.logprint("==> All layers just finished 'update_reg', go to 'stabilize_reg'. Iter = %d" % self.total_iter) 143 | 144 | def _apply_reg(self): 145 | for name, m in self.model.named_modules(): 146 | if name in self.layers and self.pr[name] > 0: 147 | reg = self.reg[name] # [N, C] 148 | m.wn_scale.grad += reg[:, 0] * m.wn_scale 149 | bias = False if isinstance(m.bias, type(None)) else True 150 | if bias: 151 | m.bias.grad += reg[:, 0] * m.bias 152 | 153 | def _merge_wn_scale_to_weights(self): 154 | '''Merge the learned weight normalization scale to the weights. 155 | ''' 156 | for name, m in self.model.named_modules(): 157 | if name in self.layers and hasattr(m, 'wn_scale'): 158 | m.weight.data = F.normalize(m.weight.data, dim=(1,2,3)) * m.wn_scale.view(-1,1,1,1) 159 | self.logprint(f'Merged weight normalization scale to weights: {name}') 160 | 161 | def _resume_prune_status(self, ckpt_path): 162 | raise NotImplementedError 163 | 164 | def _save_model(self, filename): 165 | savepath = f'{self.ckp.dir}/model/{filename}' 166 | ckpt = { 167 | 'pruned_wg': self.pruned_wg, 168 | 'kept_wg': self.kept_wg, 169 | 'model': self.model, 170 | 'state_dict': self.model.state_dict(), 171 | } 172 | torch.save(ckpt, savepath) 173 | return savepath 174 | 175 | def prune(self): 176 | self.total_iter = 0 177 | if self.args.resume_path: 178 | self._resume_prune_status(self.args.resume_path) 179 | self._get_kept_wg_L1() # get pruned and kept wg from the resumed model 180 | self.model = self.model.train() 181 | self.logprint("Resume model successfully: '{}'. Iter = {}. prune_state = {}".format( 182 | self.args.resume_path, self.total_iter, self.prune_state)) 183 | 184 | while True: 185 | finish_prune = self.train() # there will be a break condition to get out of the infinite loop 186 | if finish_prune: 187 | return copy.deepcopy(self.model) 188 | self.test() 189 | 190 | # ************************************************ The code below refers to RCAN ************************************************ # 191 | def train(self): 192 | self.loss.step() 193 | for param_group in self.optimizer.param_groups: 194 | param_group['lr'] = self.args.lr_prune # use fixed LR in pruning 195 | 196 | epoch = self.optimizer.get_last_epoch() + 1 197 | for param_group in self.optimizer.param_groups: 198 | learning_rate = param_group['lr'] 199 | 200 | self.ckp.write_log( 201 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(learning_rate)) 202 | ) 203 | self.loss.start_log() 204 | self.model.train() 205 | 206 | timer_data, timer_model = utility.timer(), utility.timer() 207 | # TEMP 208 | self.loader_train.dataset.set_scale(0) 209 | for batch, (lr, hr, _,) in enumerate(self.loader_train): 210 | self.total_iter += 1 211 | 212 | lr, hr = self.prepare(lr, hr) 213 | timer_data.hold() 214 | timer_model.tic() 215 | 216 | self.optimizer.zero_grad() 217 | sr = self.model(lr, 0) 218 | loss = self.loss(sr, hr) 219 | 220 | # @mst: print 221 | if self.total_iter % self.args.print_interval == 0: 222 | self.logprint("") 223 | self.logprint(f"Iter {self.total_iter} [prune_state: {self.prune_state} method: {self.args.method} compare_mode: {self.args.compare_mode} greg_mode: {self.args.greg_mode}] LR: {learning_rate} " + "-"*40) 224 | 225 | # @mst: regularization loss: sparsity structure alignment (SSA) 226 | if self.prune_state in ['ssa']: 227 | n = len(self.constrained_layers) 228 | soft_masks = torch.zeros(n, self.args.n_feats, requires_grad=True).cuda() 229 | hard_masks = torch.zeros(n, self.args.n_feats, requires_grad=False).cuda() 230 | cnt = -1 231 | for name, m in self.model.named_modules(): 232 | if name in self.constrained_layers: 233 | cnt += 1 234 | _, indices = torch.sort(m.wn_scale.data) 235 | n_wg = m.weight.size(0) 236 | n_pruned = n_pruned = min(math.ceil(self.pr[name] * n_wg), n_wg - 1) # do not prune all 237 | thre = m.wn_scale[indices[n_pruned]] 238 | soft_masks[cnt] = torch.sigmoid(m.wn_scale - thre) 239 | hard_masks[cnt] = m.wn_scale >= thre 240 | loss_reg = -torch.mm(soft_masks, soft_masks.t()).mean() 241 | loss_reg_hard = -torch.mm(hard_masks, hard_masks.t()).mean().data # only as an analysis metric, not optimized 242 | if self.total_iter % self.args.print_interval == 0: 243 | logstr = f'Iter {self.total_iter} loss_recon {loss.item():.4f} loss_reg (*{self.args.lw_spr}) {loss_reg.item():6f} (loss_reg_hard {loss_reg_hard.item():.6f})' 244 | self.logprint(logstr) 245 | loss += loss_reg * self.args.lw_spr 246 | 247 | # for constrained Conv layers, at prune_state 'ssa', do not update their regularization co-efficients 248 | if self.total_iter % self.args.update_reg_interval == 0: 249 | self._update_reg(skip=self.constrained_layers) 250 | 251 | loss.backward() 252 | if self.args.gclip > 0: 253 | utils.clip_grad_value_( 254 | self.model.parameters(), 255 | self.args.gclip 256 | ) 257 | 258 | # @mst: update reg factors and apply them before optimizer updates 259 | if self.prune_state in ['update_reg'] and self.total_iter % self.args.update_reg_interval == 0: 260 | self._update_reg() 261 | 262 | # after reg is updated, print to check 263 | if self.total_iter % self.args.print_interval == 0: 264 | self._print_reg_status() 265 | 266 | if self.args.apply_reg: # reg can also be not applied, as a baseline for comparison 267 | self._apply_reg() 268 | # -- 269 | 270 | self.optimizer.step() 271 | timer_model.hold() 272 | 273 | if (batch + 1) % self.args.print_every == 0: 274 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 275 | (batch + 1) * self.args.batch_size, 276 | len(self.loader_train.dataset), 277 | self.loss.display_loss(batch), 278 | timer_model.release(), 279 | timer_data.release())) 280 | timer_data.tic() 281 | 282 | # @mst: at the end of 'ssa', switch prune_state to 'update_reg' 283 | if self.prune_state in ['ssa'] and self.total_iter == self.args.iter_ssa: 284 | self._get_kept_wg_L1(align_constrained=True) # this will update the pruned_wg/kept_wg for constrained Conv layers 285 | self.prune_state = 'update_reg' 286 | self.logprint(f'==> Iter {self.total_iter} prune_state "ssa" is done, get pruned_wg/kept_wg, switch to {self.prune_state}.') 287 | 288 | # @mst: exit of reg pruning loop 289 | if self.prune_state in ["stabilize_reg"] and self.total_iter - self.iter_stabilize_reg == self.args.stabilize_reg_interval: 290 | self.logprint(f"==> 'stabilize_reg' is done. Iter {self.total_iter}.About to prune and build new model. Testing...") 291 | self.test() 292 | 293 | if self.args.greg_mode in ['all']: 294 | self._get_kept_wg_L1(align_constrained=True) 295 | self.logprint(f'==> Get pruned_wg/kept_wg.') 296 | 297 | self._merge_wn_scale_to_weights() 298 | self._prune_and_build_new_model() 299 | path = self._save_model('model_just_finished_prune.pt') 300 | self.logprint(f"==> Pruned and built a new model. Ckpt saved: '{path}'. Testing...") 301 | self.test() 302 | return True 303 | 304 | self.loss.end_log(len(self.loader_train)) 305 | self.error_last = self.loss.log[-1, -1] 306 | # self.optimizer.schedule() # use fixed LR in pruning 307 | 308 | def _print_reg_status(self): 309 | self.logprint('************* Regularization Status *************') 310 | for name, m in self.model.named_modules(): 311 | if name in self.layers and self.pr[name] > 0: 312 | logstr = [self.layer_print_prefix[name]] 313 | logstr += [f"reg_status: min {self.reg[name].min():.5f} ave {self.reg[name].mean():.5f} max {self.reg[name].max():.5f}"] 314 | out = get_score_layer(m, wg='filter', criterion='wn_scale') 315 | w_abs, wn_scale = out['l1-norm'], out['wn_scale'] 316 | pruned, kept = pick_pruned_layer(score=wn_scale, pr=self.pr[name], sort_mode='min') 317 | avg_mag_pruned, avg_mag_kept = np.mean(w_abs[pruned]), np.mean(w_abs[kept]) 318 | avg_scale_pruned, avg_scale_kept = np.mean(wn_scale[pruned]), np.mean(wn_scale[kept]) 319 | logstr += ["average w_mag: pruned %.6f kept %.6f" % (avg_mag_pruned, avg_mag_kept)] 320 | logstr += ["average wn_scale: pruned %.6f kept %.6f" % (avg_scale_pruned, avg_scale_kept)] 321 | logstr += [f'Iter {self.total_iter}'] 322 | logstr += [f'cstn' if name in self.constrained_layers else 'free'] 323 | logstr += [f'pr {self.pr[name]}'] 324 | self.logprint(' | '.join(logstr)) 325 | self.logprint('*************************************************') 326 | 327 | def test(self): 328 | is_train = self.model.training 329 | torch.set_grad_enabled(False) 330 | 331 | epoch = self.optimizer.get_last_epoch() 332 | self.ckp.write_log('Evaluation:') 333 | self.ckp.add_log( 334 | torch.zeros(1, len(self.loader_test), len(self.scale)) 335 | ) 336 | self.model.eval() 337 | 338 | timer_test = utility.timer() 339 | if self.args.save_results: self.ckp.begin_background() 340 | for idx_data, d in enumerate(self.loader_test): 341 | for idx_scale, scale in enumerate(self.scale): 342 | d.dataset.set_scale(idx_scale) 343 | for lr, hr, filename in tqdm(d, ncols=80): 344 | lr, hr = self.prepare(lr, hr) 345 | sr = self.model(lr, idx_scale) 346 | sr = utility.quantize(sr, self.args.rgb_range) 347 | 348 | save_list = [sr] 349 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( 350 | sr, hr, scale, self.args.rgb_range, dataset=d 351 | ) 352 | if self.args.save_gt: 353 | save_list.extend([lr, hr]) 354 | 355 | if self.args.save_results: 356 | self.ckp.save_results(d, filename[0], save_list, scale) 357 | 358 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 359 | best = self.ckp.log.max(0) 360 | logstr = '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {}) [prune_state: {} method: {} compare_mode: {} greg_mode: {}]'.format( 361 | d.dataset.name, 362 | scale, 363 | self.ckp.log[-1, idx_data, idx_scale], 364 | best[0][idx_data, idx_scale], 365 | best[1][idx_data, idx_scale] + 1, 366 | self.prune_state, 367 | self.args.method, 368 | self.args.compare_mode, 369 | self.args.greg_mode, 370 | ) 371 | self.ckp.write_log(logstr) 372 | self.logprint(logstr) 373 | 374 | self.ckp.write_log('Forward: {:.2f}s'.format(timer_test.toc())) 375 | self.ckp.write_log('Saving...') 376 | 377 | if self.args.save_results: 378 | self.ckp.end_background() 379 | 380 | # if not self.args.test_only: 381 | # self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) 382 | 383 | self.ckp.write_log( 384 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 385 | ) 386 | 387 | torch.set_grad_enabled(True) 388 | 389 | if is_train: 390 | self.model.train() 391 | 392 | def prepare(self, *args): 393 | device = torch.device('cpu' if self.args.cpu else 'cuda') 394 | def _prepare(tensor): 395 | if self.args.precision == 'half': tensor = tensor.half() 396 | return tensor.to(device) 397 | 398 | return [_prepare(a) for a in args] --------------------------------------------------------------------------------