├── README.md ├── __init__.py ├── data ├── __init__.py ├── benchmark.py ├── common.py ├── div2k.py ├── div2k_valid.py ├── srdata.py ├── test2k.py └── test4k.py ├── main_cadyq.py ├── main_getconfig.py ├── main_inference.py ├── main_org.py ├── main_pams.py ├── metrics ├── calculate_PSNR_SSIM.m └── run.sh ├── model ├── __init__.py ├── cadyq.py ├── carn.py ├── carn_cadyq.py ├── carn_cadyq_inference.py ├── carn_pams.py ├── common.py ├── edge.py ├── edsr.py ├── edsr_cadyq.py ├── edsr_cadyq_inference.py ├── edsr_pams.py ├── idn.py ├── idn_cadyq.py ├── idn_cadyq_inference.py ├── idn_pams.py ├── quant_ops.py ├── srresnet.py ├── srresnet_cadyq.py ├── srresnet_cadyq_inference.py └── srresnet_pams.py ├── option.py ├── test_edsrbaseline_cabm_simple.sh ├── test_edsrbaseline_get_cabm_config.sh ├── train_edsrbaseline_cabm_simple.sh ├── train_edsrbaseline_cadyq.sh ├── train_edsrbaseline_org.sh ├── train_edsrbaseline_pams.sh ├── utility.py └── utils ├── __init__.py ├── common.py ├── logger.py └── utility.py /README.md: -------------------------------------------------------------------------------- 1 | # CABM: Content-Aware Bit Mapping for Single Image Super-Resolution Network with Large Input (CVPR 2023) 2 | 3 | This repository is the official implementation of our CVPR2023 paper. 4 | [paper](https://arxiv.org/abs/2304.06454). 5 | 6 | 7 | Our implementation is based on [CADyQ(PyTorch)](https://github.com/Cheeun/CADyQ) and [PAMS(PyTorch)](https://github.com/colorjam/PAMS). 8 | 9 | Due to the numerous settings in our paper, we have only provided a simplified version of training and testing code here. 10 | 11 | 12 | ### Dependencies 13 | * kornia (pip install kornia) 14 | * Python >= 3.6 15 | * PyTorch >= 1.10.0 16 | * other packages used in our code 17 | 18 | 19 | ### Datasets 20 | * For training, we use [DIV2K datasets](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar). 21 | 22 | * For testing, we use [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) and [Test2K,4K.8K](https://github.com/Cheeun/CADyQ). 23 | 24 | ``` 25 | # for training 26 | DIV2K 27 | 28 | # for testing 29 | benchmark 30 | Test2K 31 | Test4K 32 | ``` 33 | 34 | 35 | ### How to train CABM step by step 36 | ``` 37 | # Taking EDSR as an example 38 | 39 | # Step-1 40 | # Train full-precision models 41 | sh train_edsrbaseline_org.sh 42 | 43 | # Step-2 44 | # Train 8-bit PAMS models 45 | sh train_edsrbaseline_pams.sh 46 | 47 | # Step-3 48 | # Train CADyQ models 49 | sh train_edsrbaseline_cadyq.sh 50 | 51 | # Step-4 52 | # Get edge-to-bit tables 53 | sh test_edsrbaseline_get_cabm_config.sh 54 | 55 | # Step-5 56 | # Get CABM models 57 | sh train_edsrbaseline_cabm_simple.sh 58 | ``` 59 | 60 | ### How to test CABM 61 | ``` 62 | test_edsrbaseline_cabm_simple.sh 63 | ``` 64 | 65 | ### How to sample patches while training 66 | You may refer to [SamplingAUG](https://github.com/littlepure2333/SamplingAug). 67 | 68 | 69 | ### Citation 70 | ``` 71 | @article{Tian2023CABMCB, 72 | title={CABM: Content-Aware Bit Mapping for Single Image Super-Resolution Network with Large Input}, 73 | author={Senmao Tian and Ming Lu and Jiaming Liu and Yandong Guo and Yurong Chen and Shunli Zhang}, 74 | journal={ArXiv}, 75 | year={2023}, 76 | volume={abs/2304.06454} 77 | } 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sheldon04/CABM-pytorch/0634f7e9539fba97f094d172b65651ea14c5c4f8/__init__.py -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | import sys 3 | sys.path.append('./') 4 | from torch.utils.data import dataloader 5 | from torch.utils.data import ConcatDataset 6 | 7 | # This is a simple wrapper function for ConcatDataset 8 | class MyConcatDataset(ConcatDataset): 9 | def __init__(self, datasets): 10 | super(MyConcatDataset, self).__init__(datasets) 11 | self.train = datasets[0].train 12 | 13 | def set_scale(self, idx_scale): 14 | for d in self.datasets: 15 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 16 | 17 | class Data: 18 | def __init__(self, args): 19 | self.loader_train = None 20 | if not args.test_only: 21 | datasets = [] 22 | for d in args.data_train: 23 | if d == 'DIV2K_partion': 24 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 25 | m = import_module('data.' + module_name.lower()) 26 | datasets.append(getattr(m, module_name)(args, name='DIV2K')) 27 | else: 28 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 29 | m = import_module('data.' + module_name.lower()) 30 | datasets.append(getattr(m, module_name)(args, name=d)) 31 | 32 | # self.loader_train = MSDataLoader( 33 | self.loader_train = dataloader.DataLoader( 34 | # args, 35 | MyConcatDataset(datasets), 36 | batch_size=args.batch_size, 37 | shuffle=True, 38 | pin_memory=not args.cpu, 39 | num_workers=args.n_threads, 40 | ) 41 | self.loader_test = [] 42 | for d in args.data_test: 43 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 44 | m = import_module('data.benchmark') 45 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 46 | else: 47 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 48 | m = import_module('data.' + module_name.lower()) 49 | print(m) 50 | testset = getattr(m, module_name)(args, train=False, name=d) 51 | 52 | self.loader_test.append( 53 | dataloader.DataLoader( 54 | testset, 55 | batch_size=1, 56 | shuffle=False, 57 | pin_memory=not args.cpu, 58 | num_workers=args.n_threads, 59 | ) 60 | ) 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | if self.input_large: 21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') 22 | else: 23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 24 | self.ext = ('', '.png') 25 | 26 | -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def random_pick(some_list, probabilities): 9 | x = random.uniform(0,1) 10 | cumulative_probability = 0.0 11 | for item, item_probability in zip(some_list, probabilities): 12 | cumulative_probability += item_probability 13 | if x < cumulative_probability: 14 | break 15 | return item 16 | 17 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 18 | # print(args[0][0]['image'].shape[:2]) 19 | # print("new") 20 | 21 | # lr shape 22 | ih, iw = args[0].shape[:2] # original 23 | # ih, iw = args[0][0]['image'].shape[:2] # fixed 210916 24 | # print(args[0].shape) 25 | 26 | 27 | if not input_large: 28 | p = scale if multi else 1 29 | tp = p * patch_size 30 | ip = tp // scale 31 | else: 32 | tp = patch_size 33 | ip = patch_size 34 | 35 | ix = random.randrange(0, iw - ip + 1) 36 | iy = random.randrange(0, ih - ip + 1) 37 | 38 | if not input_large: 39 | tx, ty = scale * ix, scale * iy 40 | else: 41 | tx, ty = ix, iy 42 | 43 | ret = [ 44 | args[0][iy:iy + ip, ix:ix + ip, :], 45 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 46 | ] 47 | return ret 48 | 49 | def set_channel(*args, n_channels=3): 50 | def _set_channel(img): 51 | if img.ndim == 2: 52 | img = np.expand_dims(img, axis=2) 53 | 54 | c = img.shape[2] 55 | if n_channels == 1 and c == 3: 56 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 57 | elif n_channels == 3 and c == 1: 58 | img = np.concatenate([img] * n_channels, 2) 59 | 60 | return img 61 | 62 | return [_set_channel(a) for a in args] 63 | 64 | def np2Tensor(*args, rgb_range=255): 65 | def _np2Tensor(img): 66 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 67 | tensor = torch.from_numpy(np_transpose).float() 68 | tensor.mul_(rgb_range / 255) 69 | 70 | return tensor 71 | 72 | return [_np2Tensor(a) for a in args] 73 | 74 | def augment(*args, hflip=True, rot=True): 75 | hflip = hflip and random.random() < 0.5 76 | vflip = rot and random.random() < 0.5 77 | rot90 = rot and random.random() < 0.5 78 | 79 | def _augment(img): 80 | if hflip: img = img[:, ::-1, :] 81 | if vflip: img = img[::-1, :, :] 82 | if rot90: img = img.transpose(1, 0, 2) 83 | 84 | return img 85 | 86 | return [_augment(a) for a in args] 87 | 88 | -------------------------------------------------------------------------------- /data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | # import pdb; pdb.set_trace() 8 | if train: 9 | data_range = data_range[0] 10 | else: 11 | if args.test_only and len(data_range) == 1: 12 | data_range = data_range[0] 13 | else: 14 | data_range = data_range[1] 15 | 16 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 17 | super(DIV2K, self).__init__( 18 | args, name=name, train=train, benchmark=benchmark 19 | ) 20 | 21 | def _scan(self): 22 | names_hr, names_lr = super(DIV2K, self)._scan() 23 | names_hr = names_hr[self.begin - 1:self.end] 24 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 25 | 26 | return names_hr, names_lr 27 | 28 | def _set_filesystem(self, dir_data): 29 | super(DIV2K, self)._set_filesystem(dir_data) 30 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 31 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 32 | if self.input_large: self.dir_lr += 'L' 33 | 34 | -------------------------------------------------------------------------------- /data/div2k_valid.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class div2k_valid(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(div2k_valid, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'DIV2K' ) 19 | self.dir_hr = os.path.join(self.apath, 'DIV2K_valid_HR') 20 | self.dir_lr = os.path.join(self.apath, 'DIV2K_valid_LR_bicubic') 21 | self.ext = ('', '.png') -------------------------------------------------------------------------------- /data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | 6 | from data import common 7 | import pdb 8 | import numpy as np 9 | import imageio 10 | import torch 11 | import torch.utils.data as data 12 | 13 | class SRData(data.Dataset): 14 | def __init__(self, args, name='', train=True, benchmark=False): 15 | self.args = args 16 | self.name = name 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.do_eval = True 20 | self.benchmark = benchmark 21 | self.input_large = (args.model == 'VDSR') 22 | self.scale = args.scale 23 | self.idx_scale = 0 24 | self._set_filesystem(args.dir_data) 25 | if args.ext.find('img') < 0: 26 | path_bin = os.path.join(self.apath, 'bin') 27 | os.makedirs(path_bin, exist_ok=True) 28 | 29 | list_hr, list_lr = self._scan() 30 | if args.ext.find('img') >= 0 or benchmark: 31 | self.images_hr, self.images_lr = list_hr, list_lr 32 | elif args.ext.find('sep') >= 0: 33 | os.makedirs( 34 | self.dir_hr.replace(self.apath, path_bin), 35 | exist_ok=True 36 | ) 37 | for s in self.scale: 38 | os.makedirs( 39 | os.path.join( 40 | self.dir_lr.replace(self.apath, path_bin), 41 | 'X{}'.format(s) 42 | ), 43 | exist_ok=True 44 | ) 45 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 46 | for h in list_hr: 47 | b = h.replace(self.apath, path_bin) 48 | b = b.replace(self.ext[0], '.pt') 49 | self.images_hr.append(b) 50 | self._check_and_load(args.ext, h, b, verbose=True) 51 | for i, ll in enumerate(list_lr): 52 | for l in ll: 53 | b = l.replace(self.apath, path_bin) 54 | b = b.replace(self.ext[1], '.pt') 55 | self.images_lr[i].append(b) 56 | self._check_and_load(args.ext, l, b, verbose=True) 57 | if train: 58 | n_patches = args.batch_size * args.test_every 59 | n_images = len(args.data_train) * len(self.images_hr) # 800 60 | if n_images == 0: 61 | self.repeat = 0 62 | else: 63 | self.repeat = max(n_patches // n_images, 1) 64 | 65 | # Below functions as used to prepare images 66 | def _scan(self): 67 | names_hr = sorted( 68 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 69 | ) 70 | names_lr = [[] for _ in self.scale] 71 | for f in names_hr: 72 | filename, _ = os.path.splitext(os.path.basename(f)) 73 | for si, s in enumerate(self.scale): 74 | #''' 75 | if 'sub' in self.dir_hr: 76 | names_lr[si].append(os.path.join( 77 | self.dir_lr, '{}{}'.format( 78 | filename, self.ext[1] 79 | ) 80 | )) 81 | # for test2k, test4k, test8k 82 | elif 'test' in self.dir_hr: 83 | names_lr[si].append(os.path.join( 84 | self.dir_lr, '{}{}'.format( 85 | filename, self.ext[1] 86 | ) 87 | )) 88 | # 0800x4.png for DIV2K 89 | elif 'DIV2K_valid' in self.dir_hr: 90 | names_lr[si].append(os.path.join( 91 | self.dir_lr, 'X{}/{}x{}{}'.format( 92 | s, filename, s, self.ext[1] 93 | #self.dir_lr, 'X{}/{}{}'.format( 94 | # s, filename, self.ext[1] 95 | ) 96 | )) 97 | elif 'DIV2K' in self.dir_hr or 'DIV2K_partion' in self.dir_hr: 98 | names_lr[si].append(os.path.join( 99 | self.dir_lr, 'X{}/{}x{}{}'.format( 100 | s, filename, s, self.ext[1] 101 | #self.dir_lr, 'X{}/{}{}'.format( 102 | # s, filename, self.ext[1] 103 | ) 104 | )) 105 | # baboon.png for benchmark 106 | else: 107 | names_lr[si].append(os.path.join( 108 | self.dir_lr, 'X{}/{}x{}{}'.format( 109 | s, filename, s, self.ext[1] 110 | #self.dir_lr, 'X{}/{}{}'.format( 111 | # s, filename, self.ext[1] 112 | ) 113 | )) 114 | 115 | return names_hr, names_lr 116 | 117 | def _set_filesystem(self, dir_data): 118 | self.apath = os.path.join(dir_data, self.name) 119 | self.dir_hr = os.path.join(self.apath, 'HR') 120 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 121 | if self.input_large: self.dir_lr += 'L' 122 | self.ext = ('.png', '.png') 123 | 124 | def _check_and_load(self, ext, img, f, verbose=True): 125 | if not os.path.isfile(f) or ext.find('reset') >= 0: 126 | if verbose: 127 | print('Making a binary: {}'.format(f)) 128 | with open(f, 'wb') as _f: 129 | pickle.dump(imageio.imread(img), _f) 130 | 131 | def __getitem__(self, idx): 132 | lr, hr, filename = self._load_file(idx) 133 | pair = self.get_patch(lr, hr) 134 | 135 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 136 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 137 | 138 | return pair_t[0], pair_t[1], filename 139 | 140 | def __len__(self): 141 | if self.train: 142 | return len(self.images_hr) * self.repeat 143 | else: 144 | return len(self.images_hr) 145 | 146 | def _get_index(self, idx): 147 | if self.train: 148 | return idx % len(self.images_hr) 149 | else: 150 | return idx 151 | 152 | def _load_file(self, idx): 153 | idx = self._get_index(idx) 154 | f_hr = self.images_hr[idx] 155 | f_lr = self.images_lr[self.idx_scale][idx] 156 | 157 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 158 | 159 | if self.args.ext == 'img' or self.benchmark: 160 | hr = imageio.imread(f_hr) 161 | lr = imageio.imread(f_lr) 162 | elif self.args.ext.find('sep') >= 0: 163 | 164 | with open(f_hr, 'rb') as _f: 165 | hr = pickle.load(_f) 166 | with open(f_lr, 'rb') as _f: 167 | lr = pickle.load(_f) 168 | 169 | return lr, hr, filename 170 | 171 | # new version from EDSR srdata.py 210916 172 | # def _load_file(self, idx): 173 | # idx = self._get_index(idx) 174 | # f_hr = self.images_hr[idx] 175 | # f_lr = self.images_lr[self.idx_scale][idx] 176 | # 177 | # if self.args.ext.find('bin') >= 0: 178 | # filename = f_hr['name'] 179 | # hr = f_hr['image'] 180 | # lr = f_lr['image'] 181 | # else: 182 | # filename, _ = os.path.splitext(os.path.basename(f_hr)) 183 | # if self.args.ext == 'img' or self.benchmark: 184 | # hr = imageio.imread(f_hr) 185 | # lr = imageio.imread(f_lr) 186 | # elif self.args.ext.find('sep') >= 0: 187 | # with open(f_hr, 'rb') as _f: hr = pickle.load(_f)[0]['image'] 188 | # with open(f_lr, 'rb') as _f: lr = pickle.load(_f)[0]['image'] 189 | # # print(lr.shape, hr.shape, filename) 190 | # return lr, hr, filename 191 | 192 | def get_patch(self, lr, hr): 193 | scale = self.scale[self.idx_scale] 194 | # pdb.set_trace() 195 | if self.train: 196 | 197 | lr, hr = common.get_patch( 198 | lr, hr, 199 | patch_size=self.args.patch_size, 200 | scale=scale, 201 | multi=(len(self.scale) > 1), 202 | input_large=self.input_large 203 | ) 204 | if not self.args.no_augment: lr, hr = common.augment(lr, hr) 205 | else: 206 | ih, iw = lr.shape[:2] 207 | hr = hr[0:ih * scale, 0:iw * scale] 208 | 209 | return lr, hr 210 | 211 | def set_scale(self, idx_scale): 212 | if not self.input_large: 213 | self.idx_scale = idx_scale 214 | else: 215 | self.idx_scale = random.randint(0, len(self.scale) - 1) 216 | 217 | -------------------------------------------------------------------------------- /data/test2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class test2k(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(test2k, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'test2k') 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | self.dir_lr = os.path.join(self.apath, 'LR') 21 | self.ext = ('', '.png') 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /data/test4k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class test4k(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(test4k, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'test4k') 19 | print(self.apath) 20 | self.dir_hr = os.path.join(self.apath, 'HR') 21 | self.dir_lr = os.path.join(self.apath, 'LR') 22 | self.ext = ('', '.png') 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /main_org.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | from decimal import Decimal 4 | 5 | import cv2 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.optim.lr_scheduler import StepLR 9 | from tqdm import tqdm 10 | 11 | import data 12 | import utility 13 | 14 | from model.carn import CARN 15 | from model.edsr import EDSR 16 | from model.idn import IDN 17 | from model.srresnet import SRResNet 18 | 19 | import torch.nn as nn 20 | from option import args 21 | from utils import common as util 22 | from utils.common import AverageMeter 23 | import torch.nn.parallel as P 24 | import numpy as np 25 | 26 | import time 27 | from torchvision.utils import save_image 28 | 29 | torch.manual_seed(args.seed) 30 | checkpoint = utility.checkpoint(args) 31 | device = torch.device('cpu' if args.cpu else f'cuda:{args.gpu_id}') 32 | 33 | 34 | class Trainer(): 35 | def __init__(self, args, loader, t_model, s_model, ckp): 36 | self.args = args 37 | self.scale = args.scale 38 | 39 | self.epoch = 0 40 | self.ckp = ckp 41 | self.loader_train = loader.loader_train 42 | self.loader_test = loader.loader_test 43 | self.t_model = t_model 44 | self.s_model = s_model 45 | 46 | 47 | if args.model == 'EDSR' or args.model == 'SRResNet': 48 | arch_param = [v for k, v in self.t_model.named_parameters( 49 | ) if 'alpha' not in k and 'net' not in k] 50 | alpha_param = [ 51 | v for k, v in self.t_model.named_parameters() if 'alpha' in k] 52 | 53 | else: 54 | alpha_param = [ 55 | v for k, v in self.t_model.named_parameters() if 'alpha' in k] 56 | arch_param = [v for k, v in self.t_model.named_parameters( 57 | ) if 'alpha' not in k and 'net' not in k] 58 | 59 | params = [{'params': arch_param}, {'params': alpha_param, 'lr': 1e-2}] 60 | self.optimizer = torch.optim.Adam( 61 | params, lr=args.lr, betas=args.betas, eps=args.epsilon) 62 | self.scheduler = StepLR( 63 | self.optimizer, step_size=int(args.decay), gamma=args.gamma) 64 | 65 | self.resume_epoch = 0 66 | if args.resume is not None: 67 | ckpt = torch.load(args.resume) 68 | self.epoch = ckpt['epoch'] 69 | print(f"Continue from {self.epoch}") 70 | self.t_model.load_state_dict(ckpt['state_dict']) 71 | self.optimizer.load_state_dict(ckpt['optimizer']) 72 | self.scheduler.load_state_dict(ckpt['scheduler']) 73 | self.resume_epoch = ckpt['epoch'] 74 | # self.epoch -= self.resume_epoch 75 | 76 | # --------------- Print Model --------------------- 77 | if args.test_only: 78 | self.ckp.write_log('Test on {}'.format(args.teacher_weights)) 79 | 80 | # --------------- Print # Params --------------------- 81 | n_params = 0 82 | for p in list(s_model.parameters()): 83 | n_p = 1 84 | for s in list(p.size()): 85 | n_p = n_p*s 86 | n_params += n_p 87 | self.ckp.write_log('Parameters: {:.1f}K'.format(n_params/(1e+3))) 88 | 89 | self.losses = AverageMeter() 90 | self.att_losses = AverageMeter() 91 | self.nor_losses = AverageMeter() 92 | self.bit_losses = AverageMeter() 93 | self.avg_bit = AverageMeter() 94 | 95 | self.test_patch_size = args.patch_size 96 | self.step_size = args.step_size 97 | 98 | self.mse_loss = nn.MSELoss() 99 | 100 | self.losses_list = [] 101 | self.bit_list = [] 102 | self.valpsnr_list = [] 103 | self.valbit_list = [] 104 | 105 | def train(self): 106 | self.scheduler.step() 107 | # self.cadyq_scheduler.step() 108 | 109 | self.epoch = self.epoch + 1 110 | lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 111 | # bitsel_lr = self.cadyq_optimizer.state_dict()['param_groups'][0]['lr'] 112 | 113 | self.w_bit = self.epoch*self.args.w_bit_decay + \ 114 | self.args.w_bit if self.args.cadyq else self.args.w_bit 115 | 116 | self.ckp.write_log( 117 | '[Epoch {}]\tLearning rate: {:.2e}'.format( 118 | self.epoch, Decimal(lr)) 119 | ) 120 | 121 | self.t_model.train() 122 | # self.s_model.train() 123 | 124 | self.t_model.apply(lambda m: setattr(m, 'epoch', self.epoch)) 125 | 126 | num_iterations = len(self.loader_train) 127 | timer_data, timer_model = utility.timer(), utility.timer() 128 | 129 | 130 | for batch, (lr, hr, idx_scale, ) in enumerate(self.loader_train): 131 | num_iters = num_iterations * (self.epoch-1) + batch 132 | 133 | lr, hr = self.prepare(lr, hr) 134 | 135 | data_size = lr.size(0) 136 | 137 | timer_data.hold() 138 | timer_model.tic() 139 | 140 | self.optimizer.zero_grad() 141 | 142 | if hasattr(self.t_model, 'set_scale'): 143 | self.t_model.set_scale(idx_scale) 144 | if hasattr(self.s_model, 'set_scale'): 145 | self.s_model.set_scale(idx_scale) 146 | 147 | # Teacher 148 | if self.args.model == 'CARN': 149 | t_sr = self.t_model(lr/255., self.scale[0]) 150 | t_sr *=255. 151 | else: 152 | # t_sr, t_res, _, t_feat, _ = self.t_model(lr) 153 | t_sr = self.t_model(lr) 154 | 155 | # 1. Pixel-wise L1 loss 156 | if self.args.model=='FSRCNN': 157 | nor_loss = self.mse_loss(t_sr, hr) 158 | else: 159 | nor_loss = args.w_l1 * F.l1_loss(t_sr, hr) 160 | loss = nor_loss 161 | 162 | loss.backward() 163 | self.optimizer.step() 164 | 165 | timer_model.hold() 166 | 167 | self.losses.update(loss.item(), data_size) 168 | self.nor_losses.update(nor_loss.item(), data_size) 169 | 170 | display_loss = f'Loss: {self.losses.avg: .3f}' 171 | display_loss_nor = f'L_1: {self.nor_losses.avg: .3f}' 172 | 173 | if (batch + 1) % self.args.print_every == 0: 174 | self.ckp.write_log('[{}/{}] \t{:.1f}+{:.1f}s+ \t{}'.format( 175 | (batch + 1) * self.args.batch_size, 176 | len(self.loader_train.dataset), 177 | timer_model.release(), 178 | timer_data.release(), 179 | display_loss_nor 180 | )) 181 | self.losses_list.append(self.losses.avg) 182 | 183 | timer_data.tic() 184 | 185 | def test(self, is_teacher=True): 186 | torch.set_grad_enabled(False) 187 | epoch = self.epoch 188 | self.ckp.write_log('\nEvaluation:') 189 | self.ckp.add_log( 190 | torch.zeros(1, len(self.loader_test), len(self.scale)) 191 | ) 192 | if is_teacher: 193 | model = self.t_model 194 | else: 195 | model = self.s_model 196 | model.eval() 197 | timer_test = utility.timer() 198 | 199 | if self.args.save_results: 200 | self.ckp.begin_background() 201 | 202 | for idx_data, d in enumerate(self.loader_test): 203 | for idx_scale, scale in enumerate(self.scale): 204 | if self.args.test_patch: 205 | # ------------------------Test patch-wise------------------------------ 206 | # Check options : --test_patch --patch_size 128 --step_size 16 --student_weights STUDENT_MODEL_DIRECTORY 207 | d.dataset.set_scale(idx_scale) 208 | i = 0 209 | tot_bits = 0 210 | for lr, hr, filename in tqdm(d, ncols=80): 211 | i += 1 212 | lr, hr = self.prepare(lr, hr) 213 | 214 | print(lr.size()) 215 | print(lr[0].size()) 216 | 217 | lr_list, num_h, num_w, h, w = self.crop( 218 | lr[0], self.test_patch_size, self.step_size) 219 | hr_list = self.crop( 220 | hr[0], self.test_patch_size*self.args.scale[0], self.step_size*self.args.scale[0])[0] 221 | sr_list = [] 222 | 223 | p = 0 224 | tot_bits_image = 0 225 | psnrs = [] 226 | 227 | for lr_sub_img, hr_sub_img in zip(lr_list, hr_list): 228 | time_start = time.time() 229 | # --------------------select which quantization to pass through--------------------- 230 | if self.args.model == 'CARN': 231 | sr_sub = model( 232 | lr_sub_img.unsqueeze(0)/255., scale) 233 | sr_sub *= 255. 234 | else: 235 | sr_sub = model( 236 | lr_sub_img.unsqueeze(0)) 237 | time_end = time.time() 238 | # print('\n') 239 | # print(get_bit_config(model)) 240 | 241 | avg_bit = 32.00 242 | tot_bits_image += avg_bit 243 | 244 | patch_psnr = utility.calc_psnr( 245 | sr_sub, hr_sub_img, scale, self.args.rgb_range, dataset=d) 246 | psnrs.append(patch_psnr) 247 | self.ckp.write_log( 248 | '{}-{:3d}: {:.2f} dB, {:.2f} avg bits'.format(filename[0], p, patch_psnr, avg_bit)) 249 | 250 | if self.args.save_patch: 251 | save_image(sr_sub[0]/255, './experiment/'+self.args.save+'/results-'+self.args.data_test[0] + 252 | '/{}_{}_{:.2f}_{:.2f}.png'.format(filename[0], p, patch_psnr, avg_bit)) 253 | 254 | sr_sub = utility.quantize( 255 | sr_sub, self.args.rgb_range) 256 | sr_list.append(sr_sub) 257 | p += 1 258 | 259 | sr = self.combine( 260 | sr_list, num_h, num_w, h, w, self.test_patch_size, self.step_size) 261 | sr = sr.unsqueeze(0) 262 | 263 | save_list = [sr] 264 | if self.args.add_mask: 265 | sr_mask = util.add_mask_psnr(sr.cpu(), scale, num_h, num_w, h*scale, w*scale, self.test_patch_size, self.step_size, psnrs) 266 | save_list.append(sr_mask) 267 | cur_psnr = utility.calc_psnr( 268 | sr, hr, scale, self.args.rgb_range, dataset=d) 269 | cur_ssim = utility.calc_ssim( 270 | sr, hr, scale, benchmark=d.dataset.benchmark) 271 | 272 | self.ckp.log[-1, idx_data, idx_scale] += cur_psnr 273 | self.ckp.bit_log[-1, idx_data, 274 | idx_scale] += tot_bits_image/p 275 | self.ckp.ssim_log[-1, idx_data, idx_scale] += cur_ssim 276 | 277 | tot_bits += tot_bits_image/p 278 | # per image 279 | self.ckp.write_log( 280 | '\n[{}] PSNR: {:.3f} dB; SSIM: {:.3f}; Avg_bit: {:.2f}; Num_patch: {}'.format( 281 | filename[0], 282 | cur_psnr, 283 | cur_ssim, 284 | tot_bits_image/p, 285 | p 286 | ) 287 | ) 288 | 289 | if self.args.save_gt: 290 | save_list.extend([lr, hr]) 291 | 292 | if self.args.save_results: 293 | save_name = '{}_{:.2f}'.format( 294 | filename[0], cur_psnr) 295 | self.ckp.save_results( 296 | d, save_name, save_list, scale) 297 | 298 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 299 | self.ckp.ssim_log[-1, idx_data, idx_scale] /= len(d) 300 | 301 | best_psnr = self.ckp.log.max(0) 302 | 303 | self.ckp.write_log( 304 | '[{} x{}] PSNR: {:.3f} SSIM:{:.3f} (Best: {:.3f} @epoch {})'.format( 305 | d.dataset.name, 306 | scale, 307 | self.ckp.log[-1, idx_data, idx_scale], 308 | self.ckp.ssim_log[-1, idx_data, idx_scale], 309 | best_psnr[0][idx_data, idx_scale], 310 | best_psnr[1][idx_data, idx_scale] + 1, 311 | ) 312 | ) 313 | 314 | else: 315 | # ------------------------Test image-wise------------------------------ 316 | d.dataset.set_scale(idx_scale) 317 | i = 0 318 | 319 | tot_bits = 0 320 | pbar = tqdm(d, ncols=80) 321 | for lr, hr, filename in pbar: 322 | i += 1 323 | lr, hr = self.prepare(lr, hr) 324 | 325 | if self.args.precision == 'half': 326 | model = model.half() 327 | if self.args.chop: 328 | sr, s_res = self.forward_chop(lr) 329 | else: 330 | if self.args.model.lower() == 'fsrcnn': 331 | sr = model(lr) 332 | elif self.args.model == 'IDN': 333 | sr = model(lr) 334 | elif self.args.model == 'CARN': 335 | sr = model( 336 | lr/255., scale) # for CARN 337 | sr *= 255. # for CARN 338 | else: 339 | # EDSR, SRResNet 340 | sr = model(lr) 341 | 342 | 343 | sr = utility.quantize(sr, self.args.rgb_range) 344 | save_list = [sr] 345 | 346 | cur_psnr = utility.calc_psnr( 347 | sr, hr, scale, self.args.rgb_range, dataset=d) 348 | if self.args.test_only: 349 | cur_ssim = utility.calc_ssim( 350 | sr, hr, scale, benchmark=d.dataset.benchmark) 351 | else: 352 | cur_ssim = 0 353 | 354 | self.ckp.log[-1, idx_data, idx_scale] += cur_psnr 355 | self.ckp.ssim_log[-1, idx_data, idx_scale] += cur_ssim 356 | 357 | if self.args.save_gt: 358 | save_list.extend([lr, hr]) 359 | 360 | if self.args.save_results: 361 | save_name = f'{filename[0]}_{args.k_bits}bit' + \ 362 | '_{:.2f}'.format(cur_psnr) 363 | self.ckp.save_results( 364 | d, save_name, save_list, scale) 365 | 366 | self.ckp.log[-1, idx_data, idx_scale] /= len(d) 367 | self.ckp.ssim_log[-1, idx_data, idx_scale] /= len(d) 368 | 369 | best_psnr = self.ckp.log.max(0) 370 | self.ckp.write_log( 371 | '[{} x{}] PSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 372 | d.dataset.name, 373 | scale, 374 | self.ckp.log[-1, idx_data, idx_scale], 375 | best_psnr[0][idx_data, idx_scale], 376 | best_psnr[1][idx_data, idx_scale] + 1, 377 | ) 378 | ) 379 | 380 | if self.args.save_results: 381 | self.ckp.end_background() 382 | 383 | if not self.args.test_only: 384 | is_best_psnr = (best_psnr[1][0, 0] + 1 == epoch) 385 | 386 | state = { 387 | 'epoch': epoch, 388 | 'state_dict': self.s_model.state_dict(), 389 | 'optimizer': self.optimizer.state_dict(), 390 | 'scheduler': self.scheduler.state_dict(), 391 | } 392 | util.save_checkpoint(state, is_best_psnr, 393 | checkpoint=self.ckp.dir + '/model') 394 | util.plot_psnr(self.args, self.ckp.dir, self.epoch - 395 | self.resume_epoch, self.ckp.log) 396 | util.plot_bit(self.args, self.ckp.dir, self.epoch - 397 | self.resume_epoch, self.ckp.bit_log) # in utils/common.py 398 | 399 | self.ckp.write_log( 400 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True 401 | ) 402 | 403 | torch.set_grad_enabled(True) 404 | 405 | def prepare(self, *args): 406 | # device = torch.device('cpu' if self.args.cpu else 'cuda') 407 | def _prepare(tensor): 408 | if self.args.precision == 'half': 409 | tensor = tensor.half() 410 | return tensor.to(device) 411 | 412 | return [_prepare(a) for a in args] 413 | 414 | def terminate(self): 415 | if self.args.test_only: 416 | self.test() 417 | return True 418 | else: 419 | return self.epoch >= self.args.epochs 420 | 421 | def forward_chop(self, *args, shave=10, min_size=160000): 422 | # min_size : 400 x 400 423 | scale = self.scale[0] 424 | n_GPUs = min(self.args.n_GPUs, 4) 425 | # height, width 426 | h, w = args[0].size()[-2:] 427 | 428 | top = slice(0, h//2 + shave) 429 | bottom = slice(h - h//2 - shave, h) 430 | left = slice(0, w//2 + shave) 431 | right = slice(w - w//2 - shave, w) 432 | x_chops = [torch.cat([ 433 | a[..., top, left], 434 | a[..., top, right], 435 | a[..., bottom, left], 436 | a[..., bottom, right] 437 | ]) for a in args] 438 | 439 | y_chops = [] 440 | if h * w < 4 * min_size: 441 | for i in range(0, 4, n_GPUs): 442 | 443 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 444 | y, y_res = P.data_parallel(self.s_model, *x, range(n_GPUs)) 445 | 446 | if not isinstance(y, list): 447 | y = [y] 448 | if not y_chops: 449 | y_chops = [[c for c in _y.chunk( 450 | n_GPUs, dim=0)] for _y in y] 451 | else: 452 | for y_chop, _y in zip(y_chops, y): 453 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 454 | 455 | else: 456 | for p in zip(*x_chops): 457 | 458 | y, y_res = self.forward_chop( 459 | p[0].unsqueeze(0), shave=shave, min_size=min_size) 460 | 461 | if not isinstance(y, list): 462 | y = [y] 463 | if not y_chops: 464 | y_chops = [[_y] for _y in y] 465 | else: 466 | for y_chop, _y in zip(y_chops, y): 467 | y_chop.append(_y) 468 | 469 | h *= scale 470 | w *= scale 471 | top = slice(0, h//2) 472 | bottom = slice(h - h//2, h) 473 | bottom_r = slice(h//2 - h, None) 474 | left = slice(0, w//2) 475 | right = slice(w - w//2, w) 476 | right_r = slice(w//2 - w, None) 477 | 478 | b, c = y_chops[0][0].size()[:-2] 479 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 480 | for y_chop, _y in zip(y_chops, y): 481 | _y[..., top, left] = y_chop[0][..., top, left] 482 | _y[..., top, right] = y_chop[1][..., top, right_r] 483 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 484 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 485 | 486 | if len(y) == 1: 487 | y = y[0] 488 | 489 | return y, y_res 490 | 491 | def crop(self, img, crop_sz, step): 492 | n_channels = len(img.shape) 493 | if n_channels == 2: 494 | h, w = img.shape 495 | elif n_channels == 3: 496 | c, h, w = img.shape 497 | else: 498 | raise ValueError('Wrong image shape - {}'.format(n_channels)) 499 | 500 | h_space = np.arange(0, max(h - crop_sz, 0) + 1, step) 501 | w_space = np.arange(0, max(w - crop_sz, 0) + 1, step) 502 | index = 0 503 | num_h = 0 504 | lr_list = [] 505 | for x in h_space: 506 | num_h += 1 507 | num_w = 0 508 | for y in w_space: 509 | num_w += 1 510 | index += 1 511 | if n_channels == 2: 512 | crop_img = img[x:x + crop_sz, y:y + crop_sz] 513 | else: 514 | if x == h_space[-1]: 515 | if y == w_space[-1]: 516 | crop_img = img[:, x:h, y:w] 517 | else: 518 | crop_img = img[:, x:h, y:y + crop_sz] 519 | elif y == w_space[-1]: 520 | crop_img = img[:, x:x + crop_sz, y:w] 521 | else: 522 | crop_img = img[:, x:x + crop_sz, y:y + crop_sz] 523 | lr_list.append(crop_img) 524 | return lr_list, num_h, num_w, h, w 525 | 526 | def combine(self, sr_list, num_h, num_w, h, w, patch_size, step): 527 | index = 0 528 | 529 | sr_img = torch.zeros((3, h*self.scale[0], w*self.scale[0])).to(device) 530 | s = int(((patch_size - step) / 2)*self.scale[0]) 531 | index1 = 0 532 | index2 = 0 533 | if num_h == 1: 534 | if num_w == 1: 535 | sr_img[:, :h*self.scale[0], :w * 536 | self.scale[0]] += sr_list[index][0] 537 | else: 538 | for j in range(num_w): 539 | y0 = j*step*self.scale[0] 540 | if j == 0: 541 | sr_img[:, :, y0:y0+s+step*self.scale[0] 542 | ] += sr_list[index1][0][:, :, :s+step*self.scale[0]] 543 | elif j == num_w-1: 544 | sr_img[:, :, y0+s:w*self.scale[0] 545 | ] += sr_list[index1][0][:, :, s:] 546 | else: 547 | sr_img[:, :, y0+s:y0+s+step*self.scale[0] 548 | ] += sr_list[index1][0][:, :, s:s+step*self.scale[0]] 549 | index1 += 1 550 | 551 | elif num_w == 1: 552 | for i in range(num_h): 553 | x0 = i*step*self.scale[0] 554 | if i == 0: 555 | sr_img[:, x0:x0+s+step*self.scale[0], 556 | :] += sr_list[index2][0][:, :s+step*self.scale[0], :] 557 | elif i == num_h-1: 558 | sr_img[:, x0+s:h*self.scale[0], 559 | :] += sr_list[index2][0][:, s:, :] 560 | else: 561 | sr_img[:, x0+s:x0+s+step*self.scale[0], 562 | :] += sr_list[index2][0][:, s:s+step*self.scale[0], :] 563 | index2 += 1 564 | 565 | else: 566 | for i in range(num_h): 567 | for j in range(num_w): 568 | x0 = i*step*self.scale[0] 569 | y0 = j*step*self.scale[0] 570 | 571 | if i == 0: 572 | if j == 0: 573 | sr_img[:, x0:x0+s+step*self.scale[0], y0:y0+s+step*self.scale[0] 574 | ] += sr_list[index][0][:, :s+step*self.scale[0], :s+step*self.scale[0]] 575 | elif j == num_w-1: 576 | sr_img[:, x0:x0+s+step*self.scale[0], y0+s:w*self.scale[0] 577 | ] += sr_list[index][0][:, :s+step*self.scale[0], s:] 578 | else: 579 | sr_img[:, x0:x0+s+step*self.scale[0], y0+s:y0+s+step*self.scale[0] 580 | ] += sr_list[index][0][:, :s+step*self.scale[0], s:s+step*self.scale[0]] 581 | elif j == 0: 582 | if i == num_h-1: 583 | sr_img[:, x0+s:h*self.scale[0], y0:y0+s+step*self.scale[0] 584 | ] += sr_list[index][0][:, s:, :s+step*self.scale[0]] 585 | else: 586 | sr_img[:, x0+s:x0+s+step*self.scale[0], y0:y0+s+step*self.scale[0] 587 | ] += sr_list[index][0][:, s:s+step*self.scale[0], :s+step*self.scale[0]] 588 | elif i == num_h-1: 589 | if j == num_w-1: 590 | sr_img[:, x0+s:h*self.scale[0], y0+s:w * 591 | self.scale[0]] += sr_list[index][0][:, s:, s:] 592 | else: 593 | sr_img[:, x0+s:h*self.scale[0], y0+s:y0+s+step*self.scale[0] 594 | ] += sr_list[index][0][:, s:, s:s+step*self.scale[0]] 595 | elif j == num_w-1: 596 | sr_img[:, x0+s:x0+s+step*self.scale[0], y0+s:w*self.scale[0] 597 | ] += sr_list[index][0][:, s:s+step*self.scale[0], s:] 598 | else: 599 | sr_img[:, x0+s:x0+s+step*self.scale[0], y0+s:y0+s+step*self.scale[0] 600 | ] += sr_list[index][0][:, s:s+step*self.scale[0], s:s+step*self.scale[0]] 601 | 602 | index += 1 603 | 604 | return sr_img 605 | 606 | 607 | def main(): 608 | if checkpoint.ok: 609 | loader = data.Data(args) 610 | if args.model == 'CARN': 611 | t_model = CARN(args, multi_scale=True, is_teacher=False).to(device) 612 | elif args.model == 'EDSR': 613 | t_model = EDSR(args,is_teacher=False).to(device) 614 | elif args.model == 'IDN': 615 | t_model = IDN(args, is_teacher=False).to(device) 616 | 617 | elif args.model == 'SRResNet': 618 | t_model = SRResNet( 619 | args, is_teacher=False).to(device) 620 | 621 | else: 622 | raise ValueError('not expected model = {}'.format(args.model)) 623 | 624 | if args.teacher_weights is not None: 625 | if args.test_only: 626 | ckpt = torch.load(f'{args.teacher_weights}') 627 | t_checkpoint = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt 628 | t_model.load_state_dict(t_checkpoint) 629 | 630 | t = Trainer(args, loader, t_model, t_model, checkpoint) 631 | 632 | print(f'{args.save} start!') 633 | while not t.terminate(): 634 | t.train() 635 | t.test() 636 | checkpoint.done() 637 | print(f'{args.save} done!') 638 | 639 | 640 | if __name__ == '__main__': 641 | main() 642 | -------------------------------------------------------------------------------- /metrics/calculate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | function calculate_PSNR_SSIM(dataset, scale, bit) 2 | 3 | fprintf('%s', dataset) 4 | folder_GT = ['/home/pengjun/sda4/data/sr_data/benchmark/', dataset, '/HR']; 5 | root_sr = sprintf('/home/pengjun/sda4/ycq/workspace/torch-sisr/experiment/2.0.3/rdn_x%d/%dbit/1e+3*at-2/', scale, bit) 6 | folder_SR = sprintf('%s/results-%s', root_sr, dataset); 7 | filename = sprintf('/home/yanchenqian/pretrained/rdn_x%d_%dbit.txt', scale, bit); 8 | 9 | %suffix = '' 10 | suffix = sprintf('_x%d_SR', scale) 11 | test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels 12 | if test_Y 13 | fprintf('Tesing Y channel.\n'); 14 | else 15 | fprintf('Tesing RGB channels.\n'); 16 | end 17 | filepaths = dir(fullfile(folder_GT, '*.png')); 18 | PSNR_all = zeros(1, length(filepaths)); 19 | SSIM_all = zeros(1, length(filepaths)); 20 | 21 | fileID = fopen(filename,'a'); 22 | fprintf(fileID, '%s \n\n', folder_SR); 23 | for idx_im = 1:length(filepaths) 24 | im_name = filepaths(idx_im).name; 25 | im_GT = imread(fullfile(folder_GT, im_name)); 26 | [rows, columns, c_gt] = size(im_GT); 27 | % lr_name = [im_name(1:end-4), '.png'] 28 | lr_name = sprintf('%dbit_%s%s.png', bit, im_name(1:end-4), suffix) 29 | % im_SR = imread(fullfile(folder_SR, [bit 'bit_', im_name(1:end-4), suffix '.png'])); 30 | im_SR = imread(fullfile(folder_SR, lr_name)); 31 | [rows, columns, c_sr] = size(im_SR); 32 | 33 | if (c_gt == 1 && c_sr == 3) 34 | im_GT = cat(3, im_GT, im_GT, im_GT); 35 | end 36 | 37 | im_GT = imcrop(im_GT, [0, 0, columns, rows]); 38 | 39 | if test_Y % evaluate on Y channel in YCbCr color space 40 | if size(im_GT, 3) == 3 41 | im_GT_YCbCr = rgb2ycbcr(im2double(im_GT)); 42 | im_GT_in = im_GT_YCbCr(:,:,1); 43 | im_SR_YCbCr = rgb2ycbcr(im2double(im_SR)); 44 | im_SR_in = im_SR_YCbCr(:,:,1); 45 | else 46 | im_GT_in = im2double(im_GT); 47 | im_SR_in = im2double(im_SR); 48 | end 49 | else % evaluate on RGB channels 50 | im_GT_in = im2double(im_GT); 51 | im_SR_in = im2double(im_SR); 52 | end 53 | 54 | % calculate PSNR and SSIM 55 | PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale); 56 | SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale); 57 | fprintf(fileID, '%d.(X%d)%20s: \t(PSNR/SSIM) = %.3f/%.4f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im)); 58 | end 59 | 60 | fprintf(fileID, '\n%26s: \t(PSNR/SSIM) = %.3f/%.4f\n\n\n', dataset, mean(PSNR_all), mean(SSIM_all)); 61 | fclose(fileID); 62 | %dlmwrite(xlsfile, [mean(PSNR_all) mean(SSIM_all)], '-append'); 63 | end 64 | 65 | function res = calculate_PSNR(GT, SR, border) 66 | % remove border 67 | GT = GT(border+1:end-border, border+1:end-border, :); 68 | SR = SR(border+1:end-border, border+1:end-border, :); 69 | % calculate PNSR (assume in [0,255]) 70 | error = GT(:) - SR(:); 71 | mse = mean(error.^2); 72 | res = 10 * log10(255^2/mse); 73 | end 74 | 75 | function res = calculate_SSIM(GT, SR, border) 76 | GT = GT(border+1:end-border, border+1:end-border, :); 77 | SR = SR(border+1:end-border, border+1:end-border, :); 78 | % calculate SSIM 79 | mssim = zeros(1, size(SR, 3)); 80 | for i = 1:size(SR,3) 81 | [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i)); 82 | end 83 | res = mean(mssim); 84 | end 85 | 86 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L) 87 | 88 | %======================================================================== 89 | %SSIM Index, Version 1.0 90 | %Copyright(c) 2003 Zhou Wang 91 | %All Rights Reserved. 92 | % 93 | %The author is with Howard Hughes Medical Institute, and Laboratory 94 | %for Computational Vision at Center for Neural Science and Courant 95 | %Institute of Mathematical Sciences, New York University. 96 | % 97 | %---------------------------------------------------------------------- 98 | %Permission to use, copy, or modify this software and its documentation 99 | %for educational and research purposes only and without fee is hereby 100 | %granted, provided that this copyright notice and the original authors' 101 | %names appear on all copies and supporting documentation. This program 102 | %shall not be used, rewritten, or adapted as the basis of a commercial 103 | %software or hardware product without first obtaining permission of the 104 | %authors. The authors make no representations about the suitability of 105 | %this software for any purpose. It is provided "as is" without express 106 | %or implied warranty. 107 | %---------------------------------------------------------------------- 108 | % 109 | %This is an implementation of the algorithm for calculating the 110 | %Structural SIMilarity (SSIM) index between two images. Please refer 111 | %to the following paper: 112 | % 113 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 114 | %quality assessment: From error measurement to structural similarity" 115 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 116 | % 117 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 118 | % 119 | %---------------------------------------------------------------------- 120 | % 121 | %Input : (1) img1: the first image being compared 122 | % (2) img2: the second image being compared 123 | % (3) K: constants in the SSIM index formula (see the above 124 | % reference). defualt value: K = [0.01 0.03] 125 | % (4) window: local window for statistics (see the above 126 | % reference). default widnow is Gaussian given by 127 | % window = fspecial('gaussian', 11, 1.5); 128 | % (5) L: dynamic range of the images. default: L = 255 129 | % 130 | %Output: (1) mssim: the mean SSIM index value between 2 images. 131 | % If one of the images being compared is regarded as 132 | % perfect quality, then mssim can be considered as the 133 | % quality measure of the other image. 134 | % If img1 = img2, then mssim = 1. 135 | % (2) ssim_map: the SSIM index map of the test image. The map 136 | % has a smaller size than the input images. The actual size: 137 | % size(img1) - size(window) + 1. 138 | % 139 | %Default Usage: 140 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 141 | % 142 | % [mssim ssim_map] = ssim_index(img1, img2); 143 | % 144 | %Advanced Usage: 145 | % User defined parameters. For example 146 | % 147 | % K = [0.05 0.05]; 148 | % window = ones(8); 149 | % L = 100; 150 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 151 | % 152 | %See the results: 153 | % 154 | % mssim %Gives the mssim value 155 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 156 | % 157 | %======================================================================== 158 | 159 | 160 | if (nargin < 2 || nargin > 5) 161 | ssim_index = -Inf; 162 | ssim_map = -Inf; 163 | return; 164 | end 165 | 166 | if (size(img1) ~= size(img2)) 167 | ssim_index = -Inf; 168 | ssim_map = -Inf; 169 | return; 170 | end 171 | 172 | [M, N] = size(img1); 173 | 174 | if (nargin == 2) 175 | if ((M < 11) || (N < 11)) 176 | ssim_index = -Inf; 177 | ssim_map = -Inf; 178 | return 179 | end 180 | window = fspecial('gaussian', 11, 1.5); % 181 | K(1) = 0.01; % default settings 182 | K(2) = 0.03; % 183 | L = 255; % 184 | end 185 | 186 | if (nargin == 3) 187 | if ((M < 11) || (N < 11)) 188 | ssim_index = -Inf; 189 | ssim_map = -Inf; 190 | return 191 | end 192 | window = fspecial('gaussian', 11, 1.5); 193 | L = 255; 194 | if (length(K) == 2) 195 | if (K(1) < 0 || K(2) < 0) 196 | ssim_index = -Inf; 197 | ssim_map = -Inf; 198 | return; 199 | end 200 | else 201 | ssim_index = -Inf; 202 | ssim_map = -Inf; 203 | return; 204 | end 205 | end 206 | 207 | if (nargin == 4) 208 | [H, W] = size(window); 209 | if ((H*W) < 4 || (H > M) || (W > N)) 210 | ssim_index = -Inf; 211 | ssim_map = -Inf; 212 | return 213 | end 214 | L = 255; 215 | if (length(K) == 2) 216 | if (K(1) < 0 || K(2) < 0) 217 | ssim_index = -Inf; 218 | ssim_map = -Inf; 219 | return; 220 | end 221 | else 222 | ssim_index = -Inf; 223 | ssim_map = -Inf; 224 | return; 225 | end 226 | end 227 | 228 | if (nargin == 5) 229 | [H, W] = size(window); 230 | if ((H*W) < 4 || (H > M) || (W > N)) 231 | ssim_index = -Inf; 232 | ssim_map = -Inf; 233 | return 234 | end 235 | if (length(K) == 2) 236 | if (K(1) < 0 || K(2) < 0) 237 | ssim_index = -Inf; 238 | ssim_map = -Inf; 239 | return; 240 | end 241 | else 242 | ssim_index = -Inf; 243 | ssim_map = -Inf; 244 | return; 245 | end 246 | end 247 | 248 | C1 = (K(1)*L)^2; 249 | C2 = (K(2)*L)^2; 250 | window = window/sum(sum(window)); 251 | img1 = double(img1); 252 | img2 = double(img2); 253 | 254 | mu1 = filter2(window, img1, 'valid'); 255 | mu2 = filter2(window, img2, 'valid'); 256 | mu1_sq = mu1.*mu1; 257 | mu2_sq = mu2.*mu2; 258 | mu1_mu2 = mu1.*mu2; 259 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 260 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 261 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 262 | 263 | if (C1 > 0 && C2 > 0) 264 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 265 | else 266 | numerator1 = 2*mu1_mu2 + C1; 267 | numerator2 = 2*sigma12 + C2; 268 | denominator1 = mu1_sq + mu2_sq + C1; 269 | denominator2 = sigma1_sq + sigma2_sq + C2; 270 | ssim_map = ones(size(mu1)); 271 | index = (denominator1.*denominator2 > 0); 272 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 273 | index = (denominator1 ~= 0) & (denominator2 == 0); 274 | ssim_map(index) = numerator1(index)./denominator1(index); 275 | end 276 | 277 | mssim = mean2(ssim_map); 278 | 279 | end 280 | -------------------------------------------------------------------------------- /metrics/run.sh: -------------------------------------------------------------------------------- 1 | # for bit in "${bits[@]}" 2 | scale=4 3 | bit=6 4 | datasets=('Set5' 'Set14' 'B100' 'Urban100') 5 | 6 | for dataset in ${datasets[@]} 7 | do 8 | matlab -nodesktop -nosplash -r "calculate_PSNR_SSIM('$dataset',$scale,$bit);quit" 9 | done 10 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel as P 7 | import torch.utils.model_zoo 8 | 9 | class Model(nn.Module): 10 | def __init__(self, args, ckp): 11 | super(Model, self).__init__() 12 | print('Making model...') 13 | 14 | self.scale = args.scale 15 | self.idx_scale = 0 16 | self.input_large = (args.model == 'VDSR') 17 | self.self_ensemble = args.self_ensemble 18 | self.chop = args.chop 19 | self.precision = args.precision 20 | self.cpu = args.cpu 21 | self.n_GPUs = args.n_GPUs 22 | self.save_models = args.save_models 23 | 24 | module = import_module('model.' + args.model.lower()) 25 | self.model = module.make_model(args).cuda() 26 | if args.precision == 'half': 27 | self.model.half() 28 | self.load( 29 | ckp.get_path('model'), 30 | pre_train=args.pre_train, 31 | resume=args.resume, 32 | cpu=args.cpu 33 | ) 34 | 35 | def forward(self, x, idx_scale): 36 | self.idx_scale = idx_scale 37 | if hasattr(self.model, 'set_scale'): 38 | self.model.set_scale(idx_scale) 39 | 40 | if self.training: 41 | if self.n_GPUs > 1: 42 | return P.data_parallel(self.model, x, range(self.n_GPUs)) 43 | else: 44 | return self.model(x) 45 | else: 46 | if self.chop: 47 | forward_function = self.forward_chop 48 | else: 49 | forward_function = self.model.forward 50 | 51 | if self.self_ensemble: 52 | return self.forward_x8(x, forward_function=forward_function) 53 | else: 54 | return forward_function(x) 55 | 56 | def save(self, apath, epoch, is_best=False): 57 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 58 | 59 | if is_best: 60 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 61 | if self.save_models: 62 | save_dirs.append( 63 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 64 | ) 65 | 66 | for s in save_dirs: 67 | torch.save(self.model.state_dict(), s) 68 | 69 | def load(self, apath, pre_train='', resume=-1, cpu=False): 70 | load_from = None 71 | kwargs = {} 72 | if cpu: 73 | kwargs = {'map_location': lambda storage, loc: storage} 74 | 75 | if resume == -1: 76 | load_from = torch.load( 77 | os.path.join(apath, 'model_latest.pt'), 78 | **kwargs 79 | ) 80 | 81 | elif resume is None: 82 | if pre_train == 'download': 83 | print('Download the model') 84 | dir_model = os.path.join('..', 'models') 85 | os.makedirs(dir_model, exist_ok=True) 86 | load_from = torch.utils.model_zoo.load_url( 87 | self.model.url, 88 | model_dir=dir_model, 89 | **kwargs 90 | ) 91 | elif pre_train: 92 | # print(pre_train) 93 | print('Load the model from {}'.format(pre_train)) 94 | load_from = torch.load(pre_train, **kwargs) 95 | else: 96 | load_from = torch.load(resume) 97 | 98 | if load_from: 99 | print('strcit is False') 100 | self.model.load_state_dict(load_from, strict=False) 101 | 102 | def forward_chop(self, *args, shave=10, min_size=160000): 103 | print("chopped") 104 | scale = 1 if self.input_large else self.scale[self.idx_scale] 105 | n_GPUs = min(self.n_GPUs, 4) 106 | # height, width 107 | h, w = args[0].size()[-2:] 108 | 109 | top = slice(0, h//2 + shave) 110 | bottom = slice(h - h//2 - shave, h) 111 | left = slice(0, w//2 + shave) 112 | right = slice(w - w//2 - shave, w) 113 | x_chops = [torch.cat([ 114 | a[..., top, left], 115 | a[..., top, right], 116 | a[..., bottom, left], 117 | a[..., bottom, right] 118 | ]) for a in args] 119 | 120 | y_chops = [] 121 | if h * w < 4 * min_size: 122 | for i in range(0, 4, n_GPUs): 123 | x = [x_chop[i:(i + n_GPUs)] for x_chop in x_chops] 124 | y = P.data_parallel(self.model, *x, range(n_GPUs)) 125 | if not isinstance(y, list): y = [y] 126 | if not y_chops: 127 | y_chops = [[c for c in _y.chunk(n_GPUs, dim=0)] for _y in y] 128 | else: 129 | for y_chop, _y in zip(y_chops, y): 130 | y_chop.extend(_y.chunk(n_GPUs, dim=0)) 131 | else: 132 | for p in zip(*x_chops): 133 | y = self.forward_chop(*p, shave=shave, min_size=min_size) 134 | if not isinstance(y, list): y = [y] 135 | if not y_chops: 136 | y_chops = [[_y] for _y in y] 137 | else: 138 | for y_chop, _y in zip(y_chops, y): y_chop.append(_y) 139 | 140 | h *= scale 141 | w *= scale 142 | top = slice(0, h//2) 143 | bottom = slice(h - h//2, h) 144 | bottom_r = slice(h//2 - h, None) 145 | left = slice(0, w//2) 146 | right = slice(w - w//2, w) 147 | right_r = slice(w//2 - w, None) 148 | 149 | # batch size, number of color channels 150 | b, c = y_chops[0][0].size()[:-2] 151 | y = [y_chop[0].new(b, c, h, w) for y_chop in y_chops] 152 | for y_chop, _y in zip(y_chops, y): 153 | _y[..., top, left] = y_chop[0][..., top, left] 154 | _y[..., top, right] = y_chop[1][..., top, right_r] 155 | _y[..., bottom, left] = y_chop[2][..., bottom_r, left] 156 | _y[..., bottom, right] = y_chop[3][..., bottom_r, right_r] 157 | 158 | if len(y) == 1: y = y[0] 159 | 160 | return y 161 | 162 | def forward_x8(self, *args, forward_function=None): 163 | def _transform(v, op): 164 | if self.precision != 'single': v = v.float() 165 | 166 | v2np = v.data.cpu().numpy() 167 | if op == 'v': 168 | tfnp = v2np[:, :, :, ::-1].copy() 169 | elif op == 'h': 170 | tfnp = v2np[:, :, ::-1, :].copy() 171 | elif op == 't': 172 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 173 | 174 | ret = torch.Tensor(tfnp).cuda() 175 | if self.precision == 'half': ret = ret.half() 176 | 177 | return ret 178 | 179 | list_x = [] 180 | for a in args: 181 | x = [a] 182 | for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x]) 183 | 184 | list_x.append(x) 185 | 186 | list_y = [] 187 | for x in zip(*list_x): 188 | y = forward_function(*x) 189 | if not isinstance(y, list): y = [y] 190 | if not list_y: 191 | list_y = [[_y] for _y in y] 192 | else: 193 | for _list_y, _y in zip(list_y, y): _list_y.append(_y) 194 | 195 | for _list_y in list_y: 196 | for i in range(len(_list_y)): 197 | if i > 3: 198 | _list_y[i] = _transform(_list_y[i], 't') 199 | if i % 4 > 1: 200 | _list_y[i] = _transform(_list_y[i], 'h') 201 | if (i % 4) % 2 == 1: 202 | _list_y[i] = _transform(_list_y[i], 'v') 203 | 204 | y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y] 205 | if len(y) == 1: y = y[0] 206 | 207 | return y -------------------------------------------------------------------------------- /model/cadyq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.quant_ops import quant_act_pams 5 | 6 | 7 | class BitSelector(nn.Module): 8 | def __init__(self, n_feats, bias=False, ema_epoch=1, search_space=[2,4,6,8], linq=False): 9 | super(BitSelector, self).__init__() 10 | 11 | self.quant_bit1 = quant_act_pams(k_bits=search_space[0], ema_epoch=ema_epoch) 12 | self.quant_bit2 = quant_act_pams(k_bits=search_space[1], ema_epoch=ema_epoch) 13 | self.quant_bit3 = quant_act_pams(k_bits=search_space[2], ema_epoch=ema_epoch) 14 | # self.quant_bit4 = quant_act_pams(k_bits=search_space[3], ema_epoch=ema_epoch) 15 | 16 | self.search_space =search_space 17 | 18 | self.bits_out = 8 19 | 20 | self.net_small = nn.Sequential( 21 | nn.Linear(n_feats+2, len(search_space)) 22 | ) 23 | nn.init.ones_(self.net_small[0].weight) 24 | nn.init.zeros_(self.net_small[0].bias) 25 | nn.init.ones_(self.net_small[0].bias[-1]) 26 | 27 | def forward(self, x): 28 | weighted_bits = x[3] 29 | bits = x[2] 30 | grad = x[0] 31 | x = x[1] 32 | # print('x: ' + str(x.size())) 33 | layer_std_s = torch.std(x, (2,3)).detach() # remove dim 2,3 34 | # print('layer_std: ' + str(layer_std_s.size())) 35 | # print('grad: ' + str(grad.size())) 36 | x_embed = torch.cat([grad, layer_std_s], dim=1) #[B, C+2] 37 | # print('x_embed: ' + str(x_embed.size())) 38 | 39 | bit_type = self.net_small(x_embed) 40 | # print('bit_type: ' + str(bit_type)) 41 | flag = torch.argmax(bit_type, dim=1) 42 | #[1,0] 43 | p = F.softmax(bit_type, dim=1) 44 | # print('flag: ' + str(flag.size())) 45 | if len(self.search_space)== 4: 46 | p1 = p[:,0] 47 | # print("p1: " + str(p1)) 48 | # print("p1.view: " + str(p1.view(p1.size(0),1,1,1))) 49 | p2 = p[:,1] 50 | p3 = p[:,2] 51 | p4 = p[:,3] 52 | bits_hard = (flag==0)*self.search_space[0] + (flag==1)*self.search_space[1] + (flag==2)*self.search_space[2] + (flag==3)*self.search_space[3] 53 | bits_soft = p1*self.search_space[0]+p2*self.search_space[1]+ p3*self.search_space[2] + p4*self.search_space[3] 54 | bits_out = bits_hard.detach() - bits_soft.detach() + bits_soft 55 | self.bits_out = bits_out 56 | bits += bits_out 57 | weighted_bits += bits_out / (self.search_space[0]*p1.detach()+self.search_space[1]*p2.detach()+self.search_space[2]*p3.detach()+self.search_space[3]*p4.detach()) 58 | 59 | q_bit1 = self.quant_bit1(x) 60 | q_bit2 = self.quant_bit2(x) 61 | q_bit3 = self.quant_bit3(x) 62 | q_bit4 = self.quant_bit4(x) 63 | out_soft = p1.view(p1.size(0),1,1,1)*q_bit1 + p2.view(p2.size(0),1,1,1)*q_bit2 + p3.view(p3.size(0),1,1,1)*q_bit3 + p4.view(p4.size(0),1,1,1)*q_bit4 64 | out_hard = (flag==0).view(flag.size(0),1,1,1)*q_bit1 + (flag==1).view(flag.size(0),1,1,1)*q_bit2 + (flag==2).view(flag.size(0),1,1,1)*q_bit3 + (flag==3).view(flag.size(0),1,1,1)*q_bit4 65 | residual = out_hard.detach() - out_soft.detach() + out_soft 66 | elif len(self.search_space)== 3: 67 | p1 = p[:,0] 68 | # print("p1: " + str(p1)) 69 | # print("p1.view: " + str(p1.view(p1.size(0),1,1,1))) 70 | p2 = p[:,1] 71 | p3 = p[:,2] 72 | bits_hard = (flag==0)*self.search_space[0] + (flag==1)*self.search_space[1] + (flag==2)*self.search_space[2] 73 | # print(bits_hard) 74 | bits_soft = p1*self.search_space[0]+p2*self.search_space[1]+ p3*self.search_space[2] 75 | # print(bits_soft) 76 | bits_out = bits_hard.detach() - bits_soft.detach() + bits_soft 77 | self.bits_out = bits_out 78 | bits += bits_out 79 | weighted_bits += bits_out / (self.search_space[0]*p1.detach()+self.search_space[1]*p2.detach()+self.search_space[2]*p3.detach()) 80 | # print(weighted_bits) 81 | q_bit1 = self.quant_bit1(x) 82 | q_bit2 = self.quant_bit2(x) 83 | q_bit3 = self.quant_bit3(x) 84 | out_soft = p1.view(p1.size(0),1,1,1)*q_bit1 + p2.view(p2.size(0),1,1,1)*q_bit2 + p3.view(p3.size(0),1,1,1)*q_bit3 85 | out_hard = (flag==0).view(flag.size(0),1,1,1)*q_bit1 + (flag==1).view(flag.size(0),1,1,1)*q_bit2 + (flag==2).view(flag.size(0),1,1,1)*q_bit3 86 | residual = out_hard.detach() - out_soft.detach() + out_soft 87 | 88 | elif len(self.search_space)== 2: 89 | p1 = p[:,0] 90 | p2 = p[:,1] 91 | bits_hard = (flag==0)*self.search_space[0] + (flag==1)*self.search_space[1] 92 | bits_soft = p1*self.search_space[0]+p2*self.search_space[1] 93 | bits_out = bits_hard.detach() - bits_soft.detach() + bits_soft 94 | self.bits_out = bits_out 95 | bits += bits_out 96 | weighted_bits += bits_out / (self.search_space[0]*p1.detach()+self.search_space[1]*p2.detach()) 97 | q_bit1 =self.quant_bit1(x) 98 | q_bit2 = self.quant_bit2(x) 99 | out_soft = p1.view(p1.size(0),1,1,1)*q_bit1 + p2.view(p2.size(0),1,1,1)*q_bit2 100 | out_hard = (flag==0).view(flag.size(0),1,1,1)*q_bit1 + (flag==1).view(flag.size(0),1,1,1)*q_bit2 101 | residual = out_hard.detach() - out_soft.detach() + out_soft 102 | 103 | # return [grad, residual, bits, weighted_bits] 104 | return [grad, residual, bits, weighted_bits] -------------------------------------------------------------------------------- /model/carn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from model.quant_ops import quant_act_pams 6 | from option import args 7 | 8 | def init_weights(modules): 9 | pass 10 | 11 | class MeanShift(nn.Module): 12 | def __init__(self, mean_rgb, sub): 13 | super(MeanShift, self).__init__() 14 | 15 | sign = -1 if sub else 1 16 | r = mean_rgb[0] * sign 17 | g = mean_rgb[1] * sign 18 | b = mean_rgb[2] * sign 19 | 20 | self.shifter = nn.Conv2d(3, 3, 1, 1, 0) 21 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 22 | self.shifter.bias.data = torch.Tensor([r, g, b]) 23 | 24 | # Freeze the mean shift layer 25 | for params in self.shifter.parameters(): 26 | params.requires_grad = False 27 | 28 | def forward(self, x): 29 | x = self.shifter(x) 30 | return x 31 | 32 | 33 | 34 | class UpsampleBlock(nn.Module): 35 | def __init__(self, 36 | n_channels, scale=4, multi_scale=True, group=1, fully=False, k_bits =32): 37 | super(UpsampleBlock, self).__init__() 38 | 39 | if multi_scale: 40 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group, fully=fully, k_bits=k_bits) 41 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group, fully=fully, k_bits=k_bits) 42 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group, fully=fully, k_bits=k_bits) 43 | else: 44 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group, fully=fully, k_bits=k_bits) 45 | # self.up4 = _UpsampleBlock(n_channels, scale=scale, group=group) 46 | 47 | 48 | self.multi_scale = multi_scale 49 | 50 | def forward(self, x, scale=args.scale[0]): 51 | # def forward(self, x, scale=4): 52 | 53 | if self.multi_scale: 54 | if scale == 2: 55 | return self.up2(x) 56 | elif scale == 3: 57 | return self.up3(x) 58 | elif scale == 4: 59 | return self.up4(x) 60 | else: 61 | return self.up(x) 62 | 63 | 64 | class _UpsampleBlock(nn.Module): 65 | def __init__(self, 66 | n_channels, scale, 67 | group=1, fully=False, k_bits =32): 68 | super(_UpsampleBlock, self).__init__() 69 | 70 | modules = [] 71 | if scale == 2 or scale == 4 or scale == 8: 72 | for _ in range(int(math.log(scale, 2))): 73 | if fully: 74 | # modules += [quant_act_lin(k_bits)] 75 | modules += [pams_quant_act(k_bits, ema_epoch=1)] 76 | 77 | modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 78 | modules += [nn.PixelShuffle(2)] 79 | elif scale == 3: 80 | modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 81 | modules += [nn.PixelShuffle(3)] 82 | 83 | self.body = nn.Sequential(*modules) 84 | init_weights(self.modules) 85 | 86 | def forward(self, x): 87 | out = self.body(x) 88 | return out 89 | 90 | class BasicBlock(nn.Module): 91 | def __init__(self, 92 | in_channels, out_channels, 93 | ksize=3, stride=1, pad=1): 94 | super(BasicBlock, self).__init__() 95 | 96 | self.body = nn.Sequential( 97 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad), 98 | nn.ReLU(inplace=True) 99 | ) 100 | 101 | init_weights(self.modules) 102 | 103 | def forward(self, x): 104 | out = self.body(x) 105 | return out 106 | 107 | class ResidualBlock(nn.Module): 108 | def __init__(self, 109 | in_channels, out_channels): 110 | super(ResidualBlock, self).__init__() 111 | 112 | self.body = nn.Sequential( 113 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 116 | ) 117 | 118 | init_weights(self.modules) 119 | 120 | def forward(self, x): 121 | f = x[1] 122 | x = x[0] 123 | 124 | 125 | out = self.body(x) 126 | f1 = out 127 | out = F.relu(out + x) 128 | if f is None: 129 | f = f1.unsqueeze(0) 130 | else: 131 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 132 | # return out 133 | return [out, f] 134 | 135 | class CARNBlock(nn.Module): 136 | def __init__(self, 137 | in_channels, out_channels, 138 | group=1): 139 | super(CARNBlock, self).__init__() 140 | 141 | self.b1 = ResidualBlock(64, 64) 142 | self.b2 = ResidualBlock(64, 64) 143 | self.b3 = ResidualBlock(64, 64) 144 | self.c1 = BasicBlock(64*2, 64, 1, 1, 0) 145 | self.c2 = BasicBlock(64*3, 64, 1, 1, 0) 146 | self.c3 = BasicBlock(64*4, 64, 1, 1, 0) 147 | 148 | def forward(self, x): 149 | # added for teacher 150 | f = x[1] 151 | x = x[0] 152 | 153 | c0 = o0 = x 154 | b1,f = self.b1([o0,f]) 155 | # b1 = self.b1(o0) 156 | c1 = torch.cat([c0, b1], dim=1) 157 | o1 = self.c1(c1) 158 | 159 | b2,f = self.b2([o1,f]) 160 | # b2 = self.b2(o1) 161 | c2 = torch.cat([c1, b2], dim=1) 162 | o2 = self.c2(c2) 163 | 164 | b3,f = self.b3([o2,f]) 165 | # b3 = self.b3(o2) 166 | c3 = torch.cat([c2, b3], dim=1) 167 | o3 = self.c3(c3) 168 | 169 | return [o3, f] 170 | # return o3 171 | 172 | class CARN(nn.Module): 173 | def __init__(self, args, is_teacher=False, multi_scale=False): 174 | super(CARN, self).__init__() 175 | 176 | scale = args.scale[0] 177 | # multi_scale = args.multi_scale 178 | 179 | group = args.group 180 | 181 | self.is_teacher =is_teacher 182 | 183 | self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True) 184 | self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False) 185 | 186 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 187 | 188 | self.b1 = CARNBlock(64, 64) 189 | self.b2 = CARNBlock(64, 64) 190 | self.b3 = CARNBlock(64, 64) 191 | self.c1 = BasicBlock(64*2, 64, 1, 1, 0) 192 | self.c2 = BasicBlock(64*3, 64, 1, 1, 0) 193 | self.c3 = BasicBlock(64*4, 64, 1, 1, 0) 194 | 195 | self.upsample = UpsampleBlock(64, scale=scale, multi_scale=multi_scale, group=group) 196 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 197 | 198 | def forward(self, x, scale=args.scale[0]): 199 | # def forward(self, x, scale=4): 200 | # import pdb; pdb.set_trace() 201 | f = None 202 | x = self.sub_mean(x) 203 | 204 | x = self.entry(x) 205 | c0 = o0 = x 206 | 207 | b1,f = self.b1([o0,f]) 208 | c1 = torch.cat([c0, b1], dim=1) 209 | o1 = self.c1(c1) 210 | 211 | b2,f = self.b2([o1,f]) 212 | c2 = torch.cat([c1, b2], dim=1) 213 | o2 = self.c2(c2) 214 | 215 | b3,f = self.b3([o2,f]) 216 | c3 = torch.cat([c2, b3], dim=1) 217 | o3 = self.c3(c3) 218 | 219 | feat = o3 220 | 221 | out = self.upsample(o3) 222 | 223 | out = self.exit(out) 224 | out = self.add_mean(out) 225 | if self.is_teacher: 226 | return out, feat 227 | else: 228 | return out 229 | -------------------------------------------------------------------------------- /model/carn_cadyq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import carn, carn_pams 7 | from model.quant_ops import quant_conv3x3 8 | from model.cadyq import BitSelector 9 | 10 | class ResidualBlock_CADyQ(nn.Module): 11 | def __init__(self, 12 | in_channels, out_channels, conv, k_bits=32, bias=False, ema_epoch=1,search_space=[4,6,8],loss_kdf=False, linq=False): 13 | super(ResidualBlock_CADyQ, self).__init__() 14 | 15 | self.body = nn.Sequential( 16 | BitSelector(in_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space), 17 | conv(in_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 18 | nn.ReLU(inplace=True), 19 | BitSelector(out_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space), 20 | conv(out_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 21 | ) 22 | self.loss_kdf= loss_kdf 23 | 24 | def forward(self, x): 25 | weighted_bits = x[4] 26 | f = x[3] 27 | bits = x[2] 28 | grad = x[0] 29 | x = x[1] 30 | 31 | residual = x 32 | grad,x,bits,weighted_bits= self.body[0]([grad,x,bits,weighted_bits]) # cadyq 33 | x = self.body[1:3](x) # conv-relu 34 | grad,x,bits,weighted_bits= self.body[3]([grad,x,bits,weighted_bits]) # cadyq 35 | out = self.body[4](x) # conv 36 | f1 = out 37 | if self.loss_kdf: 38 | if f is None: 39 | f = f1.unsqueeze(0) 40 | else: 41 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 42 | else: 43 | f = None 44 | 45 | out = F.relu(out + residual) 46 | return [grad, out, bits, f, weighted_bits] 47 | 48 | 49 | class CARNBlock_CADyQ(nn.Module): 50 | def __init__(self, 51 | in_channels, out_channels, conv, k_bits=32, bias=False, ema_epoch=1, group=1, search_space=[4,6,8],loss_kdf=False, linq=False, fully=False): 52 | super(CARNBlock_CADyQ, self).__init__() 53 | 54 | self.b1 = ResidualBlock_CADyQ(in_channels, out_channels, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,search_space=search_space,loss_kdf=loss_kdf,linq=linq) 55 | self.b2 = ResidualBlock_CADyQ(in_channels, out_channels, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,search_space=search_space,loss_kdf=loss_kdf,linq=linq) 56 | self.b3 = ResidualBlock_CADyQ(in_channels, out_channels, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,search_space=search_space,loss_kdf=loss_kdf,linq=linq) 57 | 58 | if fully: 59 | self.c1 = carn_pams.PAMS_BasicBlock(in_channels*2, out_channels, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 60 | self.c2 = carn_pams.PAMS_BasicBlock(in_channels*3, out_channels, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 61 | self.c3 = carn_pams.PAMS_BasicBlock(in_channels*4, out_channels, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 62 | else: 63 | self.c1 = carn.BasicBlock(in_channels*2, out_channels, 1, 1, 0) 64 | self.c2 = carn.BasicBlock(in_channels*3, out_channels, 1, 1, 0) 65 | self.c3 = carn.BasicBlock(in_channels*4, out_channels, 1, 1, 0) 66 | 67 | 68 | def forward(self, x): 69 | weighted_bits = x[4] 70 | f = x[3] 71 | bits = x[2] 72 | grad = x[0] 73 | x = x[1] 74 | 75 | c0 = o0 = x 76 | 77 | grad, b1, bits, f, weighted_bits = self.b1([grad, o0, bits, f, weighted_bits]) 78 | c1 = torch.cat([c0, b1], dim=1) 79 | o1 = self.c1(c1) 80 | 81 | grad, b2, bits, f, weighted_bits = self.b2([grad, o1, bits, f, weighted_bits]) 82 | c2 = torch.cat([c1, b2], dim=1) 83 | o2 = self.c2(c2) 84 | 85 | grad, b3, bits, f, weighted_bits = self.b3([grad, o2, bits, f, weighted_bits]) 86 | c3 = torch.cat([c2, b3], dim=1) 87 | o3 = self.c3(c3) 88 | 89 | return [grad, o3, bits, f, weighted_bits] 90 | 91 | class CARN_CADyQ(nn.Module): 92 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits=32, multi_scale=False): 93 | super(CARN_CADyQ, self).__init__() 94 | 95 | scale = args.scale[0] 96 | group = args.group 97 | n_feats = args.n_feats 98 | self.fully = args.fully 99 | 100 | self.k_bits = args.k_bits 101 | 102 | self.sub_mean = carn.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 103 | self.add_mean = carn.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 104 | 105 | 106 | self.entry = nn.Conv2d(3, n_feats, 3, 1, 1) 107 | 108 | self.b1 = CARNBlock_CADyQ(n_feats, n_feats, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf,linq=args.linq,fully=args.fully) 109 | self.b2 = CARNBlock_CADyQ(n_feats, n_feats, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf,linq=args.linq,fully=args.fully) 110 | self.b3 = CARNBlock_CADyQ(n_feats, n_feats, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf,linq=args.linq,fully=args.fully) 111 | 112 | self.c1 = carn.BasicBlock(n_feats*2, n_feats, 1, 1, 0) 113 | self.c2 = carn.BasicBlock(n_feats*3, n_feats, 1, 1, 0) 114 | self.c3 = carn.BasicBlock(n_feats*4, n_feats, 1, 1, 0) 115 | 116 | self.upsample = carn.UpsampleBlock(n_feats, scale=scale, multi_scale=multi_scale, group=group, fully=args.fully, k_bits=args.k_bits) 117 | self.exit = nn.Conv2d(n_feats, 3, 3, 1, 1) 118 | 119 | 120 | 121 | def forward(self, x, scale): 122 | image = x 123 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image), order=1) 124 | grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) # [16,2] # CARN 125 | 126 | f= None; weighted_bits=0; bits=0 127 | 128 | 129 | x = self.sub_mean(x) 130 | # if self.fully: x = self.quant_head(x) 131 | 132 | x = self.entry(x) 133 | 134 | c0 = o0 = x 135 | grad, b1, bits, f, weighted_bits = self.b1([grad, o0, bits, f, weighted_bits]) 136 | c1 = torch.cat([c0, b1], dim=1) 137 | o1 = self.c1(c1) 138 | 139 | grad, b2, bits, f, weighted_bits = self.b2([grad, o1, bits, f, weighted_bits]) 140 | c2 = torch.cat([c1, b2], dim=1) 141 | o2 = self.c2(c2) 142 | 143 | grad, b3, bits, f, weighted_bits = self.b3([grad, o2, bits, f, weighted_bits]) 144 | c3 = torch.cat([c2, b3], dim=1) 145 | o3 = self.c3(c3) 146 | 147 | feat = o3 148 | 149 | out = self.upsample(o3, scale=scale) 150 | 151 | # if self.fully: out = self.quant_tail(out) 152 | 153 | out = self.exit(out) 154 | out = self.add_mean(out) 155 | 156 | 157 | return out, feat, bits, f, weighted_bits 158 | 159 | -------------------------------------------------------------------------------- /model/carn_cadyq_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import carn, carn_pams 7 | from model.quant_ops import quant_conv3x3 8 | from model.edge import BitSelector 9 | 10 | class ResidualBlock_CADyQ(nn.Module): 11 | def __init__(self, 12 | in_channels, out_channels, conv, k_bits=32, bias=False, ema_epoch=1,search_space=[4,6,8],loss_kdf=False, linq=False): 13 | super(ResidualBlock_CADyQ, self).__init__() 14 | 15 | self.body = nn.Sequential( 16 | BitSelector(in_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space), 17 | conv(in_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 18 | nn.ReLU(inplace=True), 19 | BitSelector(out_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space), 20 | conv(out_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 21 | ) 22 | self.loss_kdf= loss_kdf 23 | 24 | def forward(self, x): 25 | weighted_bits = x[4] 26 | f = x[3] 27 | bits = x[2] 28 | grad = x[0] 29 | x = x[1] 30 | 31 | residual = x 32 | grad,x,bits= self.body[0]([grad,x,bits]) # cadyq 33 | x = self.body[1:3](x) # conv-relu 34 | grad,x,bits= self.body[3]([grad,x,bits]) # cadyq 35 | out = self.body[4](x) # conv 36 | f1 = out 37 | if self.loss_kdf: 38 | if f is None: 39 | f = f1.unsqueeze(0) 40 | else: 41 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 42 | else: 43 | f = None 44 | 45 | out = F.relu(out + residual) 46 | return [grad, out, bits, f, weighted_bits] 47 | 48 | 49 | class CARNBlock_CADyQ(nn.Module): 50 | def __init__(self, 51 | in_channels, out_channels, conv, k_bits=32, bias=False, ema_epoch=1, group=1, search_space=[4,6,8],loss_kdf=False, linq=False, fully=False): 52 | super(CARNBlock_CADyQ, self).__init__() 53 | 54 | self.b1 = ResidualBlock_CADyQ(in_channels, out_channels, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,search_space=search_space,loss_kdf=loss_kdf,linq=linq) 55 | self.b2 = ResidualBlock_CADyQ(in_channels, out_channels, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,search_space=search_space,loss_kdf=loss_kdf,linq=linq) 56 | self.b3 = ResidualBlock_CADyQ(in_channels, out_channels, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,search_space=search_space,loss_kdf=loss_kdf,linq=linq) 57 | 58 | if fully: 59 | self.c1 = carn_pams.PAMS_BasicBlock(in_channels*2, out_channels, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 60 | self.c2 = carn_pams.PAMS_BasicBlock(in_channels*3, out_channels, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 61 | self.c3 = carn_pams.PAMS_BasicBlock(in_channels*4, out_channels, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 62 | else: 63 | self.c1 = carn.BasicBlock(in_channels*2, out_channels, 1, 1, 0) 64 | self.c2 = carn.BasicBlock(in_channels*3, out_channels, 1, 1, 0) 65 | self.c3 = carn.BasicBlock(in_channels*4, out_channels, 1, 1, 0) 66 | 67 | 68 | def forward(self, x): 69 | weighted_bits = x[4] 70 | f = x[3] 71 | bits = x[2] 72 | grad = x[0] 73 | x = x[1] 74 | 75 | c0 = o0 = x 76 | 77 | grad, b1, bits, f, weighted_bits = self.b1([grad, o0, bits, f, weighted_bits]) 78 | c1 = torch.cat([c0, b1], dim=1) 79 | o1 = self.c1(c1) 80 | 81 | grad, b2, bits, f, weighted_bits = self.b2([grad, o1, bits, f, weighted_bits]) 82 | c2 = torch.cat([c1, b2], dim=1) 83 | o2 = self.c2(c2) 84 | 85 | grad, b3, bits, f, weighted_bits = self.b3([grad, o2, bits, f, weighted_bits]) 86 | c3 = torch.cat([c2, b3], dim=1) 87 | o3 = self.c3(c3) 88 | 89 | return [grad, o3, bits, f, weighted_bits] 90 | 91 | class CARN_CADyQ_I(nn.Module): 92 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits=32, multi_scale=False): 93 | super(CARN_CADyQ_I, self).__init__() 94 | 95 | scale = args.scale[0] 96 | group = args.group 97 | n_feats = args.n_feats 98 | self.fully = args.fully 99 | 100 | self.k_bits = args.k_bits 101 | 102 | self.sub_mean = carn.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 103 | self.add_mean = carn.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 104 | 105 | 106 | self.entry = nn.Conv2d(3, n_feats, 3, 1, 1) 107 | 108 | self.b1 = CARNBlock_CADyQ(n_feats, n_feats, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf,linq=args.linq,fully=args.fully) 109 | self.b2 = CARNBlock_CADyQ(n_feats, n_feats, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf,linq=args.linq,fully=args.fully) 110 | self.b3 = CARNBlock_CADyQ(n_feats, n_feats, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf,linq=args.linq,fully=args.fully) 111 | 112 | self.c1 = carn.BasicBlock(n_feats*2, n_feats, 1, 1, 0) 113 | self.c2 = carn.BasicBlock(n_feats*3, n_feats, 1, 1, 0) 114 | self.c3 = carn.BasicBlock(n_feats*4, n_feats, 1, 1, 0) 115 | 116 | self.upsample = carn.UpsampleBlock(n_feats, scale=scale, multi_scale=multi_scale, group=group, fully=args.fully, k_bits=args.k_bits) 117 | self.exit = nn.Conv2d(n_feats, 3, 3, 1, 1) 118 | 119 | 120 | 121 | def forward(self, x, scale): 122 | image = x 123 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image), order=1) 124 | grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) # [16,2] # CARN 125 | 126 | f= None; weighted_bits=0; bits=0 127 | 128 | 129 | x = self.sub_mean(x) 130 | # if self.fully: x = self.quant_head(x) 131 | 132 | x = self.entry(x) 133 | 134 | c0 = o0 = x 135 | grad, b1, bits, f, weighted_bits = self.b1([grad, o0, bits, f, weighted_bits]) 136 | c1 = torch.cat([c0, b1], dim=1) 137 | o1 = self.c1(c1) 138 | 139 | grad, b2, bits, f, weighted_bits = self.b2([grad, o1, bits, f, weighted_bits]) 140 | c2 = torch.cat([c1, b2], dim=1) 141 | o2 = self.c2(c2) 142 | 143 | grad, b3, bits, f, weighted_bits = self.b3([grad, o2, bits, f, weighted_bits]) 144 | c3 = torch.cat([c2, b3], dim=1) 145 | o3 = self.c3(c3) 146 | 147 | feat = o3 148 | 149 | out = self.upsample(o3, scale=scale) 150 | 151 | # if self.fully: out = self.quant_tail(out) 152 | 153 | out = self.exit(out) 154 | out = self.add_mean(out) 155 | 156 | 157 | return out, feat, bits, f 158 | 159 | -------------------------------------------------------------------------------- /model/carn_pams.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model import carn 6 | from model.quant_ops import quant_act_pams, quant_conv3x3 7 | 8 | class BasicBlock_PAMS(nn.Module): 9 | def __init__(self, 10 | in_channels, out_channels, conv, k_bits=32, bias=False, ema_epoch=1, 11 | ksize=3, stride=1, pad=1): 12 | super(BasicBlock_PAMS, self).__init__() 13 | 14 | self.quant_act = quant_act_pams(k_bits, ema_epoch) 15 | 16 | 17 | self.body = nn.Sequential( 18 | self.quant_act, 19 | conv(in_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=ksize, stride=stride, padding=pad), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | carn.init_weights(self.modules) 24 | 25 | def forward(self, x): 26 | out = self.body(x) 27 | return out 28 | 29 | class ResidualBlock_PAMS(nn.Module): 30 | def __init__(self, 31 | in_channels, out_channels, conv, k_bits=32, bias=False, ema_epoch=1,loss_kdf= False, linq=False): 32 | super(ResidualBlock_PAMS, self).__init__() 33 | 34 | self.quant_act1 = quant_act_pams(k_bits,ema_epoch=ema_epoch) 35 | self.quant_act2 = quant_act_pams(k_bits,ema_epoch=ema_epoch) 36 | 37 | self.body = nn.Sequential( 38 | self.quant_act1, 39 | conv(in_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 40 | nn.ReLU(inplace=True), 41 | self.quant_act2, 42 | conv(out_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 43 | ) 44 | self.loss_kdf= loss_kdf 45 | 46 | carn.init_weights(self.modules) 47 | 48 | def forward(self, x): 49 | f = x[1] 50 | x = x[0] 51 | 52 | out = self.body(x) 53 | 54 | f1 = out 55 | out = F.relu(out + x) 56 | 57 | if self.loss_kdf: 58 | if f is None: 59 | f = f1.unsqueeze(0) 60 | else: 61 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 62 | else: 63 | f = None 64 | 65 | 66 | # return out 67 | return [out, f] 68 | 69 | class CARNBlock_PAMS(nn.Module): 70 | def __init__(self, 71 | in_channels, out_channels, conv, k_bits=32, bias=False, ema_epoch=1, loss_kdf=False, group=1, linq=False, fully=False): 72 | super(CARNBlock_PAMS, self).__init__() 73 | 74 | self.b1 = ResidualBlock_PAMS(64, 64, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,loss_kdf=loss_kdf,linq=linq) 75 | self.b2 = ResidualBlock_PAMS(64, 64, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,loss_kdf=loss_kdf,linq=linq) 76 | self.b3 = ResidualBlock_PAMS(64, 64, conv, k_bits=k_bits, bias=bias, ema_epoch=ema_epoch,loss_kdf=loss_kdf,linq=linq) 77 | 78 | if fully: 79 | self.c1 = BasicBlock_PAMS(64*2, 64, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 80 | self.c2 = BasicBlock_PAMS(64*3, 64, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 81 | self.c3 = BasicBlock_PAMS(64*4, 64, conv, k_bits=k_bits, bias=bias,ema_epoch=ema_epoch, ksize=1, stride=1, pad=0) 82 | else: 83 | self.c1 = carn.BasicBlock(64*2, 64, 1, 1, 0) 84 | self.c2 = carn.BasicBlock(64*3, 64, 1, 1, 0) 85 | self.c3 = carn.BasicBlock(64*4, 64, 1, 1, 0) 86 | 87 | def forward(self, x): 88 | f = x[1] 89 | x = x[0] 90 | 91 | c0 = o0 = x 92 | 93 | b1,f = self.b1([o0,f]) 94 | c1 = torch.cat([c0, b1], dim=1) 95 | o1 = self.c1(c1) 96 | 97 | b2,f = self.b2([o1,f]) 98 | c2 = torch.cat([c1, b2], dim=1) 99 | o2 = self.c2(c2) 100 | 101 | b3,f = self.b3([o2,f]) 102 | c3 = torch.cat([c2, b3], dim=1) 103 | o3 = self.c3(c3) 104 | 105 | return [o3, f] 106 | 107 | 108 | class CARN_PAMS(nn.Module): 109 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits=32, multi_scale=False, linq=False, fully=False): 110 | super(CARN_PAMS, self).__init__() 111 | 112 | scale = args.scale[0] 113 | # multi_scale = args.multi_scale 114 | # multi_scale = False 115 | group = args.group 116 | self.fully = fully 117 | 118 | self.k_bits = args.k_bits 119 | 120 | self.sub_mean = carn.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 121 | self.add_mean = carn.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 122 | 123 | if self.fully: 124 | self.entry = conv(3, 64, k_bits=args.k_bits, bias=bias, kernel_size=3, stride=1, padding=1) 125 | self.quant_head = quant_act_pams(k_bits=args.k_bits, ema_epoch=args.ema_epoch) 126 | 127 | else: 128 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 129 | 130 | self.b1 = CARNBlock_PAMS(64, 64, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,loss_kdf=args.loss_kdf,linq=linq,fully=fully) 131 | self.b2 = CARNBlock_PAMS(64, 64, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,loss_kdf=args.loss_kdf,linq=linq,fully=fully) 132 | self.b3 = CARNBlock_PAMS(64, 64, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch,loss_kdf=args.loss_kdf,linq=linq,fully=fully) 133 | 134 | if self.fully: 135 | self.c1 = BasicBlock_PAMS(64*2, 64, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch, ksize=1, stride=1, pad=0) 136 | self.c2 = BasicBlock_PAMS(64*3, 64, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch, ksize=1, stride=1, pad=0) 137 | self.c3 = BasicBlock_PAMS(64*4, 64, conv, k_bits=args.k_bits, bias=bias,ema_epoch=args.ema_epoch, ksize=1, stride=1, pad=0) 138 | else: 139 | self.c1 = carn.BasicBlock(64*2, 64, 1, 1, 0) 140 | self.c2 = carn.BasicBlock(64*3, 64, 1, 1, 0) 141 | self.c3 = carn.BasicBlock(64*4, 64, 1, 1, 0) 142 | 143 | self.upsample = carn.UpsampleBlock(64, scale=scale, multi_scale=multi_scale, group=group, fully=fully, k_bits=args.k_bits) 144 | 145 | if self.fully: 146 | self.exit = conv(64, 3, k_bits=args.k_bits, bias=bias, kernel_size=3, stride=1, padding=1) 147 | # self.quant_tail = quant_act_lin(k_bits=args.k_bits) 148 | self.quant_tail = quant_act_pams(k_bits=args.k_bits, ema_epoch=args.ema_epoch) 149 | else: 150 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 151 | 152 | def forward(self, x, scale): 153 | # def forward(self, x, scale=4): 154 | f = None 155 | x = self.sub_mean(x) 156 | if self.fully: x = self.quant_head(x) 157 | 158 | x = self.entry(x) 159 | c0 = o0 = x 160 | 161 | b1,f = self.b1([o0,f]) 162 | c1 = torch.cat([c0, b1], dim=1) 163 | o1 = self.c1(c1) 164 | 165 | b2,f = self.b2([o1,f]) 166 | c2 = torch.cat([c1, b2], dim=1) 167 | o2 = self.c2(c2) 168 | 169 | b3,f = self.b3([o2,f]) 170 | c3 = torch.cat([c2, b3], dim=1) 171 | o3 = self.c3(c3) 172 | 173 | feat = o3 174 | 175 | out = self.upsample(o3, scale=scale) 176 | 177 | if self.fully: out = self.quant_tail(out) 178 | out = self.exit(out) 179 | out = self.add_mean(out) 180 | 181 | return out, feat, f -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from model.quant_ops import quant_act_pams, quant_act_lin 7 | 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size//2), bias=bias) 13 | 14 | class ShortCut(nn.Module): 15 | def __init__(self): 16 | super(ShortCut, self).__init__() 17 | 18 | def forward(self, input): 19 | return input 20 | 21 | class MeanShift(nn.Conv2d): 22 | def __init__( 23 | self, rgb_range, 24 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 25 | 26 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 27 | std = torch.Tensor(rgb_std).cuda() 28 | # self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 29 | self.weight.data = torch.eye(3).view(3, 3, 1, 1).cuda() / std.view(3, 1, 1, 1) 30 | 31 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean).cuda() / std 32 | for p in self.parameters(): 33 | p.requires_grad = False 34 | 35 | class BasicBlock(nn.Sequential): 36 | def __init__( 37 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 38 | bn=True, act=nn.ReLU(True)): 39 | 40 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 41 | if bn: 42 | m.append(nn.BatchNorm2d(out_channels)) 43 | if act is not None: 44 | m.append(act) 45 | 46 | super(BasicBlock, self).__init__(*m) 47 | 48 | class ResBlock(nn.Module): 49 | def __init__( 50 | self, conv, n_feats, kernel_size, 51 | bias=True, bn=False, inn=False, act=nn.ReLU(True), res_scale=1): 52 | 53 | super(ResBlock, self).__init__() 54 | m = [] 55 | for i in range(2): 56 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 57 | if bn: 58 | m.append(nn.BatchNorm2d(n_feats)) 59 | elif inn: 60 | m.append(nn.InstanceNorm2d(n_feats, affine=True)) 61 | if i == 0: 62 | m.append(act) 63 | 64 | self.body = nn.Sequential(*m) 65 | self.res_scale = res_scale 66 | self.shortcut = ShortCut() 67 | 68 | def forward(self, x): 69 | residual = self.shortcut(x) 70 | res = self.body(x).mul(self.res_scale) 71 | res += residual 72 | 73 | return res 74 | 75 | class Upsampler(nn.Sequential): 76 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 77 | 78 | m = [] 79 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 80 | for _ in range(int(math.log(scale, 2))): 81 | m.append(conv(n_feats, 4 * n_feats, 3, bias=bias)) 82 | m.append(nn.PixelShuffle(2)) 83 | if bn: 84 | m.append(nn.BatchNorm2d(n_feats)) 85 | if act == 'relu': 86 | m.append(nn.ReLU(True)) 87 | elif act == 'prelu': 88 | m.append(nn.PReLU(n_feats)) 89 | elif act == 'lrelu': 90 | m.append(nn.LeakyReLU(0.2, inplace=True)) 91 | 92 | elif scale == 3: 93 | m.append(conv(n_feats, 9 * n_feats, 3, bias=bias)) 94 | m.append(nn.PixelShuffle(3)) 95 | if bn: 96 | m.append(nn.BatchNorm2d(n_feats)) 97 | if act == 'relu': 98 | m.append(nn.ReLU(True)) 99 | elif act == 'prelu': 100 | m.append(nn.PReLU(n_feats)) 101 | elif act == 'lrelu': 102 | m.append(nn.LeakyReLU(0.2, inplace=True)) 103 | else: 104 | raise NotImplementedError 105 | 106 | super(Upsampler, self).__init__(*m) 107 | 108 | 109 | 110 | class ResBlock_srresnet(nn.Module): 111 | def __init__( 112 | self, conv, n_feats, kernel_size, 113 | bias=False, bn=False, act=nn.ReLU(True), res_scale=1): 114 | 115 | super(ResBlock_srresnet, self).__init__() 116 | 117 | self.conv1 = conv(n_feats, n_feats, kernel_size, bias=bias) 118 | self.conv2 = conv(n_feats, n_feats, kernel_size, bias=bias) 119 | 120 | self.bn1 = nn.BatchNorm2d(n_feats) 121 | self.act = act 122 | self.bn2 = nn.BatchNorm2d(n_feats) 123 | self.res_scale = res_scale 124 | 125 | 126 | self.res_scale = res_scale 127 | self.shortcut = ShortCut() 128 | 129 | def forward(self, x): 130 | residual = self.shortcut(x) 131 | res = self.act(self.bn1(self.conv1(x))) 132 | res = self.bn2(self.conv2(res)).mul(self.res_scale) 133 | res += residual 134 | 135 | 136 | # residual = self.shortcut(x) 137 | # res = self.body(x).mul(self.res_scale) 138 | # res += residual 139 | 140 | return res 141 | 142 | class Upsampler_srresnet(nn.Sequential): 143 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 144 | # scale = 4 # for SRResNet 145 | m = [] 146 | if scale == 4: 147 | m.append(conv(n_feats, 4 * n_feats, 3, bias=False)) 148 | m.append(nn.PixelShuffle(2)) 149 | m.append(nn.PReLU()) 150 | # m.append(nn.LeakyReLU(0.2, inplace=True)) 151 | m.append(conv(n_feats, 4 * n_feats, 3, bias=False)) 152 | m.append(nn.PixelShuffle(2)) 153 | # m.append(nn.LeakyReLU(0.2, inplace=True)) 154 | m.append(nn.PReLU()) 155 | elif scale ==2 : 156 | m.append(conv(n_feats, 4 * n_feats, 3, bias=False)) 157 | m.append(nn.PixelShuffle(2)) 158 | m.append(nn.PReLU()) 159 | else: 160 | print("not implemented") 161 | 162 | 163 | super(Upsampler_srresnet, self).__init__(*m) 164 | 165 | 166 | -------------------------------------------------------------------------------- /model/edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from model.quant_ops import quant_act_pams 5 | 6 | from option import args 7 | device = torch.device('cpu' if args.cpu else f'cuda:{args.gpu_id}') 8 | class BitSelector(nn.Module): 9 | def __init__(self, n_feats, bias=False, ema_epoch=1, search_space=[4,6,8], linq=False): 10 | super(BitSelector, self).__init__() 11 | 12 | self.quant_bit1 = quant_act_pams(k_bits=search_space[0], ema_epoch=ema_epoch) 13 | self.quant_bit2 = quant_act_pams(k_bits=search_space[1], ema_epoch=ema_epoch) 14 | self.quant_bit3 = quant_act_pams(k_bits=search_space[2], ema_epoch=ema_epoch) 15 | 16 | self.search_space =search_space 17 | self.take_index = None 18 | 19 | self.flag = None 20 | 21 | self.net_small = nn.Sequential( 22 | nn.Linear(n_feats+2, len(search_space)) 23 | ) 24 | nn.init.ones_(self.net_small[0].weight) 25 | nn.init.zeros_(self.net_small[0].bias) 26 | nn.init.ones_(self.net_small[0].bias[-1]) 27 | 28 | def forward(self, x): 29 | if len(x) >= 4 and x[3] is not None: 30 | self.take_index = x[3] 31 | else: 32 | self.take_index = torch.arange(x[1].shape[0], device=x[1].device) 33 | bits = x[2] 34 | grad = x[0] 35 | x = x[1] 36 | 37 | self.flag = self.flag.to(x.device) 38 | 39 | if len(self.search_space)== 4: 40 | bits_hard = (self.flag[self.take_index]==0)*self.search_space[0] + (self.flag[self.take_index]==1)*self.search_space[1] + (self.flag[self.take_index]==2)*self.search_space[2] + (self.flag[self.take_index]==3)*self.search_space[3] 41 | bits_out = bits_hard.detach() 42 | if not isinstance(bits, int): 43 | bits[self.take_index] += bits_out 44 | else: 45 | bits += bits_out 46 | 47 | q_bit1 = self.quant_bit1(x) 48 | q_bit2 = self.quant_bit2(x) 49 | q_bit3 = self.quant_bit3(x) 50 | q_bit4 = self.quant_bit4(x) 51 | out_hard = (self.flag[self.take_index]==0).view(self.flag[self.take_index].size(0),1,1,1)*q_bit1 + (self.flag[self.take_index]==1).view(self.flag[self.take_index].size(0),1,1,1)*q_bit2 + (self.flag[self.take_index]==2).view(self.flag[self.take_index].size(0),1,1,1)*q_bit3 + (self.flag[self.take_index]==3).view(self.flag[self.take_index].size(0),1,1,1)*q_bit4 52 | 53 | if args.test_only: 54 | residual = out_hard.detach() 55 | else: 56 | residual = out_hard 57 | 58 | elif len(self.search_space)== 3: 59 | bits_hard = (self.flag[self.take_index]==0)*self.search_space[0] + (self.flag[self.take_index]==1)*self.search_space[1] + (self.flag[self.take_index]==2)*self.search_space[2] 60 | bits_out = bits_hard.detach() 61 | if not isinstance(bits, int): 62 | bits[self.take_index] += bits_out 63 | else: 64 | bits += bits_out 65 | 66 | q_bit1 = self.quant_bit1(x) 67 | q_bit2 = self.quant_bit2(x) 68 | q_bit3 = self.quant_bit3(x) 69 | out_hard = (self.flag[self.take_index]==0).view(self.flag[self.take_index].size(0),1,1,1)*q_bit1 + (self.flag[self.take_index]==1).view(self.flag[self.take_index].size(0),1,1,1)*q_bit2 + (self.flag[self.take_index]==2).view(self.flag[self.take_index].size(0),1,1,1)*q_bit3 70 | 71 | if args.test_only: 72 | residual = out_hard.detach() 73 | else: 74 | residual = out_hard 75 | 76 | elif len(self.search_space)== 2: 77 | bits_hard = (self.flag[self.take_index]==0)*self.search_space[0] + (self.flag[self.take_index]==1)*self.search_space[1] 78 | bits_out = bits_hard.detach() 79 | if not isinstance(bits, int): 80 | bits[self.take_index] += bits_out 81 | else: 82 | bits += bits_out 83 | q_bit1 =self.quant_bit1(x) 84 | q_bit2 = self.quant_bit2(x) 85 | out_hard = (self.flag[self.take_index]==0).view(self.flag[self.take_index].size(0),1,1,1)*q_bit1 + (self.flag[self.take_index]==1).view(self.flag[self.take_index].size(0),1,1,1)*q_bit2 86 | if args.test_only: 87 | residual = out_hard.detach() 88 | else: 89 | residual = out_hard 90 | 91 | return [grad, residual, bits] -------------------------------------------------------------------------------- /model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | import torch.nn as nn 3 | 4 | url = { 5 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 6 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 7 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 8 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 9 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 10 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 11 | } 12 | 13 | class EDSR(nn.Module): 14 | def __init__(self, args, is_teacher, conv=common.default_conv): 15 | super(EDSR, self).__init__() 16 | 17 | n_resblocks = args.n_resblocks 18 | n_feats = args.n_feats 19 | kernel_size = 3 20 | scale = args.scale[0] 21 | act = nn.ReLU(True) 22 | self.is_teacher = is_teacher 23 | 24 | self.sub_mean = common.MeanShift(args.rgb_range) 25 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 26 | 27 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 28 | m_body = [ 29 | common.ResBlock( 30 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 31 | ) for _ in range(n_resblocks) 32 | ] 33 | m_body.append(conv(n_feats, n_feats, kernel_size)) 34 | 35 | m_tail = [ 36 | common.Upsampler(conv, scale, n_feats, act=False), 37 | conv(n_feats, args.n_colors, kernel_size) 38 | ] 39 | 40 | self.head = nn.Sequential(*m_head) 41 | self.body = nn.Sequential(*m_body) 42 | self.tail = nn.Sequential(*m_tail) 43 | 44 | def forward(self, x): 45 | x = self.sub_mean(x) 46 | x = self.head(x) 47 | res = self.body(x) 48 | 49 | res += x 50 | 51 | out = res 52 | x = self.tail(res) 53 | x = self.add_mean(x) 54 | 55 | if self.is_teacher: 56 | return x, out 57 | else: 58 | return x 59 | 60 | -------------------------------------------------------------------------------- /model/edsr_cadyq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import common 7 | from model.quant_ops import conv3x3, quant_conv3x3, quant_act_pams 8 | from model.cadyq import BitSelector 9 | 10 | 11 | class ResidualBlock_CADyQ(nn.Module): 12 | def __init__(self, 13 | in_channels, out_channels, conv, act, kernel_size, res_scale, k_bits=32, bias=False, ema_epoch=1,search_space=[4,6,8],loss_kdf=False, linq=False): 14 | super(ResidualBlock_CADyQ, self).__init__() 15 | 16 | self.bitsel1 = BitSelector(in_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space) 17 | 18 | self.body = nn.Sequential( 19 | conv(in_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=kernel_size, stride=1, padding=1), 20 | # conv(in_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=1, padding=1), 21 | act, 22 | BitSelector(out_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space), 23 | conv(out_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=kernel_size, stride=1, padding=1), 24 | # conv(out_channels, out_channels, bias=bias, kernel_size=kernel_size, stride=1, padding=1) 25 | ) 26 | self.loss_kdf= loss_kdf 27 | self.res_scale = res_scale 28 | 29 | self.quant_act3 = quant_act_pams(k_bits, ema_epoch=ema_epoch) 30 | self.shortcut = common.ShortCut() 31 | 32 | 33 | def forward(self, x): 34 | weighted_bits = x[4] 35 | f = x[3] 36 | bits = x[2] 37 | grad = x[0] 38 | x = x[1] 39 | 40 | x = self.shortcut(x) 41 | grad,x,bits,weighted_bits = self.bitsel1([grad,x,bits,weighted_bits]) # cadyq 42 | residual = x 43 | # grad,x,bits,weighted_bits= self.body[0]() # cadyq 44 | # x = self.body[1:3](x) # conv-relu 45 | x = self.body[0:2](x) # conv-relu 46 | # grad,x,bits,weighted_bits= self.body[3]([grad,x,bits,weighted_bits]) # cadyq 47 | grad,x,bits,weighted_bits= self.body[2]([grad,x,bits,weighted_bits]) # cadyq 48 | # out = self.body[4](x) # conv 49 | out = self.body[3](x) # conv 50 | f1 = out 51 | out = out.mul(self.res_scale) 52 | out = self.quant_act3(out) 53 | out += residual 54 | if self.loss_kdf: 55 | if f is None: 56 | f = f1.unsqueeze(0) 57 | else: 58 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 59 | else: 60 | f = None 61 | 62 | 63 | return [grad, out, bits, f, weighted_bits] 64 | 65 | 66 | 67 | 68 | class EDSR_CADyQ(nn.Module): 69 | def __init__(self, args, conv=quant_conv3x3, bias = False, k_bits = 32): 70 | super(EDSR_CADyQ, self).__init__() 71 | 72 | n_resblock = args.n_resblocks 73 | n_feats = args.n_feats 74 | kernel_size = 3 75 | scale = args.scale[0] 76 | act = nn.ReLU(True) 77 | self.k_bits = k_bits 78 | 79 | self.sub_mean = common.MeanShift(args.rgb_range) 80 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 81 | device = torch.device('cpu' if args.cpu else f'cuda:{args.gpu_id}') 82 | 83 | 84 | m_head = [conv3x3(args.n_colors, n_feats, kernel_size, bias=bias)] 85 | 86 | # baseline= (n_resblock == 16 ) 87 | m_body = [ 88 | ResidualBlock_CADyQ(n_feats, n_feats, quant_conv3x3, act, kernel_size, res_scale=args.res_scale, k_bits=self.k_bits, bias=bias, ema_epoch=args.ema_epoch, search_space=args.search_space, loss_kdf=args.loss_kdf, linq=args.linq 89 | ) for i in range(n_resblock) 90 | ] 91 | m_body.append(conv3x3(n_feats, n_feats, kernel_size, bias= bias)) 92 | 93 | 94 | m_tail = [ 95 | common.Upsampler(conv3x3, scale, n_feats, act=False), 96 | nn.Conv2d(n_feats, args.n_colors, kernel_size,padding=(kernel_size//2)) 97 | ] 98 | 99 | self.head = nn.Sequential(*m_head) 100 | self.body = nn.Sequential(*m_body) 101 | self.tail = nn.Sequential(*m_tail) 102 | 103 | def forward(self, x): 104 | 105 | image = x 106 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 107 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) *1e+3 # [16,2] 108 | 109 | f=None; weighted_bits = 0; bits=0; 110 | 111 | x = self.sub_mean(x) 112 | 113 | x = self.head(x) 114 | res = x 115 | 116 | image_grad, res, bits, f, weighted_bits = self.body[:-1]([image_grad, res, bits, f,weighted_bits]) 117 | res = self.body[-1](res) 118 | 119 | res += x 120 | out = res 121 | x = self.tail(res) 122 | 123 | x = self.add_mean(x) 124 | 125 | return x, out, bits, f, weighted_bits 126 | 127 | -------------------------------------------------------------------------------- /model/edsr_cadyq_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import common 7 | from model.quant_ops import conv3x3, quant_conv3x3, quant_act_pams 8 | from model.edge import BitSelector 9 | 10 | 11 | class ResidualBlock_CADyQ(nn.Module): 12 | def __init__(self, 13 | in_channels, out_channels, conv, act, kernel_size, res_scale, k_bits=32, bias=False, ema_epoch=1,search_space=[4,6,8],loss_kdf=False, linq=False): 14 | super(ResidualBlock_CADyQ, self).__init__() 15 | 16 | 17 | self.bitsel1 = BitSelector(in_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space) 18 | 19 | self.body = nn.Sequential( 20 | conv(in_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=kernel_size, stride=1, padding=1), 21 | act, 22 | BitSelector(out_channels, bias=bias, ema_epoch=ema_epoch, search_space=search_space), 23 | conv(out_channels, out_channels, k_bits=k_bits, bias=bias, kernel_size=kernel_size, stride=1, padding=1), 24 | ) 25 | self.loss_kdf= loss_kdf 26 | self.res_scale = res_scale 27 | 28 | self.quant_act3 = quant_act_pams(k_bits, ema_epoch=ema_epoch) 29 | self.shortcut = common.ShortCut() 30 | 31 | 32 | def forward(self, x): 33 | f = x[3] 34 | bits = x[2] 35 | grad = x[0] 36 | x = x[1] 37 | 38 | x = self.shortcut(x) 39 | tmp_bits = bits 40 | grad,x,bits = self.bitsel1([grad,x,bits]) # cadyq 41 | residual = x 42 | # grad,x,bits,weighted_bits= self.body[0]() # cadyq 43 | # x = self.body[1:3](x) # conv-relu 44 | x = self.body[0:2](x) # conv-relu 45 | # grad,x,bits,weighted_bits= self.body[3]([grad,x,bits,weighted_bits]) # cadyq 46 | grad,x,bits= self.body[2]([grad,x,bits]) # cadyq 47 | # out = self.body[4](x) # conv 48 | out = self.body[3](x) # conv 49 | f1 = out 50 | out = out.mul(self.res_scale) 51 | out = self.quant_act3(out) 52 | out += residual 53 | if self.loss_kdf: 54 | if f is None: 55 | f = f1.unsqueeze(0) 56 | else: 57 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 58 | else: 59 | f = None 60 | 61 | 62 | return [grad, out, bits, f] 63 | 64 | 65 | 66 | 67 | class EDSR_CADyQ_I(nn.Module): 68 | def __init__(self, args, conv=quant_conv3x3, bias = False, k_bits = 32): 69 | super(EDSR_CADyQ_I, self).__init__() 70 | 71 | n_resblock = args.n_resblocks 72 | n_feats = args.n_feats 73 | kernel_size = 3 74 | scale = args.scale[0] 75 | act = nn.ReLU(True) 76 | self.k_bits = k_bits 77 | 78 | self.sub_mean = common.MeanShift(args.rgb_range) 79 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 80 | device = torch.device('cpu' if args.cpu else f'cuda:{args.gpu_id}') 81 | 82 | 83 | m_head = [conv3x3(args.n_colors, n_feats, kernel_size, bias=bias)] 84 | 85 | # baseline= (n_resblock == 16 ) 86 | m_body = [ 87 | ResidualBlock_CADyQ(n_feats, n_feats, quant_conv3x3, act, kernel_size, res_scale=args.res_scale, k_bits=self.k_bits, bias=bias, ema_epoch=args.ema_epoch, search_space=args.search_space, loss_kdf=args.loss_kdf, linq=args.linq 88 | ) for i in range(n_resblock) 89 | ] 90 | m_body.append(conv3x3(n_feats, n_feats, kernel_size, bias= bias)) 91 | 92 | 93 | m_tail = [ 94 | common.Upsampler(conv3x3, scale, n_feats, act=False), 95 | nn.Conv2d(n_feats, args.n_colors, kernel_size,padding=(kernel_size//2)) 96 | ] 97 | 98 | self.head = nn.Sequential(*m_head) 99 | self.body = nn.Sequential(*m_body) 100 | self.tail = nn.Sequential(*m_tail) 101 | 102 | def forward(self, x): 103 | 104 | image = x 105 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 106 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) *1e+3 # [16,2] 107 | 108 | f=None; weighted_bits = 0; bits=0; 109 | 110 | x = self.sub_mean(x) 111 | 112 | x = self.head(x) 113 | res = x 114 | 115 | image_grad, res, bits, f = self.body[:-1]([image_grad, res, bits, f]) 116 | res = self.body[-1](res) 117 | 118 | res += x 119 | out = res 120 | x = self.tail(res) 121 | 122 | x = self.add_mean(x) 123 | 124 | return x, out, bits, f 125 | 126 | -------------------------------------------------------------------------------- /model/edsr_pams.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.quant_ops import quant_act_pams, quant_conv3x3, conv3x3 6 | from model import common 7 | 8 | class ResBlock_PAMS(nn.Module): 9 | def __init__(self, conv, n_feats, kernel_size, bias=False, 10 | bn=False, act=nn.ReLU(False), res_scale=1, k_bits = 32, ema_epoch=1, loss_kdf=False): 11 | 12 | super(ResBlock_PAMS, self).__init__() 13 | self.k_bits = k_bits 14 | 15 | self.quant_act1 = quant_act_pams(self.k_bits,ema_epoch=ema_epoch) 16 | self.quant_act2 = quant_act_pams(self.k_bits, ema_epoch=ema_epoch) 17 | self.quant_act3 = quant_act_pams(self.k_bits, ema_epoch=ema_epoch) 18 | 19 | self.shortcut = common.ShortCut() 20 | 21 | 22 | self.body = nn.Sequential( 23 | # self.quant_act1, 24 | conv(n_feats, n_feats, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 25 | act, 26 | self.quant_act2, 27 | conv(n_feats, n_feats, k_bits=k_bits, bias=bias, kernel_size=3, stride=1, padding=1), 28 | ) 29 | 30 | self.res_scale = res_scale 31 | self.loss_kdf = loss_kdf 32 | 33 | def forward(self, x): 34 | f = x[1] 35 | x = x[0] 36 | 37 | # residual = self.body[0](self.shortcut(x)) 38 | # body = self.body[1:](residual) 39 | residual = self.quant_act1(self.shortcut(x)) 40 | body = self.body(residual) 41 | 42 | f2 = body 43 | body = body.mul(self.res_scale) 44 | res = self.quant_act3(body) 45 | # res = body 46 | res += residual 47 | 48 | if self.loss_kdf: 49 | new_f = f2.unsqueeze(0) 50 | if f is None: 51 | f = new_f 52 | else: 53 | f = torch.cat([f, new_f], dim=0) 54 | else: 55 | f = None 56 | 57 | return [res, f] 58 | 59 | class EDSR_PAMS(nn.Module): 60 | def __init__(self, args, conv=quant_conv3x3, bias = False, k_bits = 32, mixed=False,linq=False,fully=False): 61 | super(EDSR_PAMS, self).__init__() 62 | 63 | n_resblock = args.n_resblocks 64 | n_feats = args.n_feats 65 | kernel_size = 3 66 | scale = args.scale[0] 67 | act = nn.ReLU(True) 68 | self.k_bits = args.k_bits 69 | self.fully = fully 70 | 71 | self.sub_mean = common.MeanShift(args.rgb_range) 72 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 73 | 74 | m_head = [conv3x3(args.n_colors, n_feats, kernel_size, bias=bias)] 75 | 76 | # baseline = (n_resblock == 16) 77 | m_body = [ 78 | ResBlock_PAMS( 79 | quant_conv3x3, n_feats, kernel_size, act=act, res_scale=args.res_scale, k_bits=self.k_bits, bias = bias, ema_epoch=args.ema_epoch, loss_kdf=args.loss_kdf, 80 | ) for i in range(n_resblock) 81 | ] 82 | m_body.append(conv3x3(n_feats, n_feats, kernel_size, bias= bias)) 83 | 84 | 85 | m_tail = [ 86 | common.Upsampler(conv3x3, scale, n_feats, act=False), 87 | conv3x3(n_feats, args.n_colors, kernel_size, bias=bias) 88 | # nn.Conv2d(n_feats, args.n_colors, kernel_size, padding=(kernel_size//2)) 89 | ] 90 | 91 | self.head = nn.Sequential(*m_head) 92 | self.body = nn.Sequential(*m_body) 93 | self.tail = nn.Sequential(*m_tail) 94 | 95 | def forward(self, x): 96 | x = self.sub_mean(x) 97 | x = self.head(x) 98 | 99 | f=None 100 | 101 | res,f = self.body[:-1]([x,f]) 102 | res = self.body[-1](res) 103 | res += x 104 | 105 | out = res 106 | x = self.tail(res) 107 | x = self.add_mean(x) 108 | 109 | return x, out, f 110 | 111 | 112 | @property 113 | def name(self): 114 | return 'edsr' -------------------------------------------------------------------------------- /model/idn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FBlock(nn.Module): 7 | def __init__(self, num_features): 8 | super(FBlock, self).__init__() 9 | self.module = nn.Sequential( 10 | nn.Conv2d(3, num_features, kernel_size=3, padding=1), 11 | nn.LeakyReLU(0.05), 12 | nn.Conv2d(num_features, num_features, kernel_size=3, padding=1), 13 | nn.LeakyReLU(0.05) 14 | ) 15 | 16 | def forward(self, x): 17 | return self.module(x) 18 | 19 | 20 | class DBlock(nn.Module): 21 | def __init__(self, num_features, d, s): 22 | super(DBlock, self).__init__() 23 | self.num_features = num_features 24 | self.s = s 25 | self.enhancement_top = nn.Sequential( 26 | nn.Conv2d(num_features, num_features - d, kernel_size=3, padding=1), 27 | nn.LeakyReLU(0.05), 28 | nn.Conv2d(num_features - d, num_features - 2 * d, kernel_size=3, padding=1, groups=4), 29 | nn.LeakyReLU(0.05), 30 | nn.Conv2d(num_features - 2 * d, num_features, kernel_size=3, padding=1), 31 | nn.LeakyReLU(0.05) 32 | ) 33 | self.enhancement_bottom = nn.Sequential( 34 | nn.Conv2d(num_features - d, num_features, kernel_size=3, padding=1), 35 | nn.LeakyReLU(0.05), 36 | nn.Conv2d(num_features, num_features - d, kernel_size=3, padding=1, groups=4), 37 | nn.LeakyReLU(0.05), 38 | nn.Conv2d(num_features - d, num_features + d, kernel_size=3, padding=1), 39 | nn.LeakyReLU(0.05) 40 | ) 41 | self.compression = nn.Conv2d(num_features + d, num_features, kernel_size=1) 42 | 43 | def forward(self, x): 44 | f = x[1] 45 | x = x[0] 46 | 47 | residual = x 48 | x = self.enhancement_top(x) 49 | slice_1 = x[:, :int((self.num_features - self.num_features/self.s)), :, :] 50 | slice_2 = x[:, int((self.num_features - self.num_features/self.s)):, :, :] 51 | x = self.enhancement_bottom(slice_1) 52 | 53 | f1 = x 54 | 55 | x = x + torch.cat((residual, slice_2), 1) 56 | x = self.compression(x) 57 | if f is None: 58 | f = f1.unsqueeze(0) 59 | else: 60 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 61 | return [x, f] 62 | 63 | 64 | 65 | class IDN(nn.Module): 66 | def __init__(self, args, is_teacher): 67 | super(IDN, self).__init__() 68 | self.scale = args.scale[0] 69 | num_features = args.n_feats 70 | d = args.idn_d 71 | s = args.idn_s 72 | self.is_teacher =is_teacher 73 | self.fblock = FBlock(num_features) 74 | self.dblocks = nn.Sequential(*[DBlock(num_features, d, s) for _ in range(4)]) 75 | self.deconv = nn.ConvTranspose2d(num_features, 3, kernel_size=17, stride=self.scale, padding=8, output_padding=1) 76 | 77 | self._initialize_weights() 78 | 79 | def _initialize_weights(self): 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | nn.init.kaiming_normal_(m.weight) 83 | nn.init.zeros_(m.bias) 84 | if isinstance(m, nn.ConvTranspose2d): 85 | nn.init.kaiming_normal_(m.weight) 86 | nn.init.zeros_(m.bias) 87 | 88 | def forward(self, x): 89 | bicubic = F.interpolate(x, scale_factor=self.scale, mode='bicubic', align_corners=False) 90 | 91 | x = self.fblock(x) 92 | f = None 93 | x,f = self.dblocks([x,f]) 94 | 95 | out = x 96 | x = self.deconv(x, output_size=bicubic.size()) 97 | 98 | if self.is_teacher: 99 | return bicubic + x, out 100 | 101 | else: 102 | return bicubic + x -------------------------------------------------------------------------------- /model/idn_cadyq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import idn_pams 7 | from model.quant_ops import quant_conv3x3 8 | from model.cadyq import BitSelector 9 | 10 | class DBlock_CADyQ(nn.Module): 11 | def __init__(self, conv, num_features, d, s, bias=False, k_bits=32, ema_epoch=1, search_space=[2,4,8],loss_kdf=False, linq=False): 12 | super(DBlock_CADyQ, self).__init__() 13 | self.num_features = num_features 14 | self.s = s 15 | self.k_bits = k_bits 16 | 17 | self.enhancement_top = nn.Sequential( 18 | BitSelector(num_features, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 19 | conv(num_features, num_features - d, kernel_size=3, k_bits=self.k_bits, bias=bias), 20 | nn.LeakyReLU(0.05), 21 | BitSelector(num_features - d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 22 | conv(num_features - d, num_features - 2 * d, kernel_size=3, k_bits=self.k_bits, bias=bias, groups=4), 23 | nn.LeakyReLU(0.05), 24 | BitSelector(num_features - 2 * d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 25 | conv(num_features - 2 * d, num_features, kernel_size=3, k_bits=self.k_bits, bias=bias), 26 | nn.LeakyReLU(0.05) 27 | ) 28 | self.enhancement_bottom = nn.Sequential( 29 | BitSelector(num_features - d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 30 | conv(num_features - d, num_features, kernel_size=3, k_bits=self.k_bits, bias=bias), 31 | nn.LeakyReLU(0.05), 32 | BitSelector(num_features, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 33 | conv(num_features, num_features - d, kernel_size=3, k_bits=self.k_bits, bias=bias, groups=4), 34 | nn.LeakyReLU(0.05), 35 | BitSelector(num_features - d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 36 | conv(num_features - d, num_features + d, kernel_size=3, k_bits=self.k_bits, bias=bias), 37 | nn.LeakyReLU(0.05) 38 | ) 39 | self.compression = nn.Conv2d(num_features + d, num_features, kernel_size=1) 40 | self.loss_kdf= loss_kdf 41 | 42 | 43 | def forward(self, x): 44 | weighted_bits = x[4] 45 | f = x[3] 46 | bits = x[2] 47 | grad = x[0] 48 | x = x[1] 49 | 50 | residual = x 51 | 52 | # x = self.enhancement_top(x) 53 | flops=0 54 | 55 | grad,x,bits,weighted_bits= self.enhancement_top[0]([grad,x,bits,weighted_bits]) # BitSelector 56 | 57 | x = self.enhancement_top[1:3](x) 58 | grad,x,bits,weighted_bits = self.enhancement_top[3]([grad,x,bits,weighted_bits]) 59 | 60 | x = self.enhancement_top[4:6](x) 61 | grad,x,bits,weighted_bits = self.enhancement_top[6]([grad,x,bits,weighted_bits]) 62 | 63 | x = self.enhancement_top[7:9](x) 64 | 65 | 66 | slice_1 = x[:, :int((self.num_features - self.num_features/self.s)), :, :] 67 | slice_2 = x[:, int((self.num_features - self.num_features/self.s)):, :, :] 68 | 69 | grad,x,bits,weighted_bits = self.enhancement_bottom[0]([grad,slice_1,bits,weighted_bits]) 70 | 71 | x = self.enhancement_bottom[1:3](x) 72 | grad,x,bits,weighted_bits = self.enhancement_bottom[3]([grad,x,bits,weighted_bits]) 73 | 74 | x = self.enhancement_bottom[4:6](x) 75 | grad,x,bits,weighted_bits = self.enhancement_bottom[6]([grad,x,bits,weighted_bits]) 76 | 77 | x = self.enhancement_bottom[7:9](x) 78 | f1 = x 79 | 80 | x = x + torch.cat((residual, slice_2), 1) 81 | 82 | x = self.compression(x) 83 | 84 | 85 | if self.loss_kdf: 86 | if f is None: 87 | f = f1.unsqueeze(0) 88 | else: 89 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 90 | else: 91 | f = None 92 | 93 | # return x 94 | return [grad, x, bits, f, weighted_bits] 95 | 96 | 97 | class IDN_CADyQ(nn.Module): 98 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits=32): 99 | super(IDN_CADyQ, self).__init__() 100 | self.scale = args.scale[0] 101 | num_features = args.n_feats 102 | d = args.idn_d 103 | s = args.idn_s 104 | 105 | self.fblock = idn_pams.FBlock_PAMS(conv, num_features, bias=bias, k_bits=args.k_bits, ema_epoch=args.ema_epoch ) 106 | 107 | m_dblocks = [ 108 | DBlock_CADyQ(conv, num_features, d, s, bias=bias, k_bits=args.k_bits, ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf, linq=args.linq) for _ in range(args.n_resblocks) 109 | 110 | ] 111 | # args.n_resblocks should be 4 112 | self.dblocks = nn.Sequential(*m_dblocks) 113 | self.deconv = nn.ConvTranspose2d(num_features, 3, kernel_size=17, stride=self.scale, padding=8, output_padding=1) 114 | 115 | 116 | def forward(self, x): 117 | image = x 118 | bicubic = F.interpolate(x, scale_factor=self.scale, mode='bicubic', align_corners=False) 119 | 120 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 121 | grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) *1e+3 # [16,2] 122 | 123 | x= self.fblock(x) 124 | 125 | f = None 126 | bits = 0; weighted_bits = 0 127 | 128 | grad, x,bits, f, weighted_bits= self.dblocks([grad, x, bits, f, weighted_bits]) 129 | out = x 130 | 131 | x = self.deconv(x, output_size=bicubic.size()) 132 | 133 | return bicubic + x, out, bits, f, weighted_bits 134 | -------------------------------------------------------------------------------- /model/idn_cadyq_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import idn_pams 7 | from model.quant_ops import quant_conv3x3 8 | from model.edge import BitSelector 9 | 10 | class DBlock_CADyQ(nn.Module): 11 | def __init__(self, conv, num_features, d, s, bias=False, k_bits=32, ema_epoch=1, search_space=[2,4,8],loss_kdf=False, linq=False): 12 | super(DBlock_CADyQ, self).__init__() 13 | self.num_features = num_features 14 | self.s = s 15 | self.k_bits = k_bits 16 | 17 | self.enhancement_top = nn.Sequential( 18 | BitSelector(num_features, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 19 | conv(num_features, num_features - d, kernel_size=3, k_bits=self.k_bits, bias=bias), 20 | nn.LeakyReLU(0.05), 21 | BitSelector(num_features - d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 22 | conv(num_features - d, num_features - 2 * d, kernel_size=3, k_bits=self.k_bits, bias=bias, groups=4), 23 | nn.LeakyReLU(0.05), 24 | BitSelector(num_features - 2 * d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 25 | conv(num_features - 2 * d, num_features, kernel_size=3, k_bits=self.k_bits, bias=bias), 26 | nn.LeakyReLU(0.05) 27 | ) 28 | self.enhancement_bottom = nn.Sequential( 29 | BitSelector(num_features - d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 30 | conv(num_features - d, num_features, kernel_size=3, k_bits=self.k_bits, bias=bias), 31 | nn.LeakyReLU(0.05), 32 | BitSelector(num_features, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 33 | conv(num_features, num_features - d, kernel_size=3, k_bits=self.k_bits, bias=bias, groups=4), 34 | nn.LeakyReLU(0.05), 35 | BitSelector(num_features - d, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq), 36 | conv(num_features - d, num_features + d, kernel_size=3, k_bits=self.k_bits, bias=bias), 37 | nn.LeakyReLU(0.05) 38 | ) 39 | self.compression = nn.Conv2d(num_features + d, num_features, kernel_size=1) 40 | self.loss_kdf= loss_kdf 41 | 42 | 43 | def forward(self, x): 44 | weighted_bits = x[4] 45 | f = x[3] 46 | bits = x[2] 47 | grad = x[0] 48 | x = x[1] 49 | 50 | residual = x 51 | 52 | # x = self.enhancement_top(x) 53 | flops=0 54 | 55 | grad,x,bits= self.enhancement_top[0]([grad,x,bits]) # BitSelector 56 | 57 | x = self.enhancement_top[1:3](x) 58 | grad,x,bits = self.enhancement_top[3]([grad,x,bits]) 59 | 60 | x = self.enhancement_top[4:6](x) 61 | grad,x,bits = self.enhancement_top[6]([grad,x,bits]) 62 | 63 | x = self.enhancement_top[7:9](x) 64 | 65 | 66 | slice_1 = x[:, :int((self.num_features - self.num_features/self.s)), :, :] 67 | slice_2 = x[:, int((self.num_features - self.num_features/self.s)):, :, :] 68 | 69 | grad,x,bits = self.enhancement_bottom[0]([grad,slice_1,bits]) 70 | 71 | x = self.enhancement_bottom[1:3](x) 72 | grad,x,bits = self.enhancement_bottom[3]([grad,x,bits]) 73 | 74 | x = self.enhancement_bottom[4:6](x) 75 | grad,x,bits = self.enhancement_bottom[6]([grad,x,bits]) 76 | 77 | x = self.enhancement_bottom[7:9](x) 78 | f1 = x 79 | 80 | x = x + torch.cat((residual, slice_2), 1) 81 | 82 | x = self.compression(x) 83 | 84 | 85 | if self.loss_kdf: 86 | if f is None: 87 | f = f1.unsqueeze(0) 88 | else: 89 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 90 | else: 91 | f = None 92 | 93 | # return x 94 | return [grad, x, bits, f, weighted_bits] 95 | 96 | 97 | class IDN_CADyQ_I(nn.Module): 98 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits=32): 99 | super(IDN_CADyQ_I, self).__init__() 100 | self.scale = args.scale[0] 101 | num_features = args.n_feats 102 | d = args.idn_d 103 | s = args.idn_s 104 | 105 | self.fblock = idn_pams.FBlock_PAMS(conv, num_features, bias=bias, k_bits=args.k_bits, ema_epoch=args.ema_epoch ) 106 | 107 | m_dblocks = [ 108 | DBlock_CADyQ(conv, num_features, d, s, bias=bias, k_bits=args.k_bits, ema_epoch=args.ema_epoch,search_space=args.search_space,loss_kdf=args.loss_kdf, linq=args.linq) for _ in range(args.n_resblocks) 109 | 110 | ] 111 | # args.n_resblocks should be 4 112 | self.dblocks = nn.Sequential(*m_dblocks) 113 | self.deconv = nn.ConvTranspose2d(num_features, 3, kernel_size=17, stride=self.scale, padding=8, output_padding=1) 114 | 115 | 116 | def forward(self, x): 117 | image = x 118 | bicubic = F.interpolate(x, scale_factor=self.scale, mode='bicubic', align_corners=False) 119 | 120 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 121 | grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) *1e+3 # [16,2] 122 | 123 | x= self.fblock(x) 124 | 125 | f = None 126 | bits = 0; weighted_bits = 0 127 | 128 | grad, x,bits, f, weighted_bits= self.dblocks([grad, x, bits, f, weighted_bits]) 129 | out = x 130 | 131 | x = self.deconv(x, output_size=bicubic.size()) 132 | 133 | return bicubic + x, out, bits, f 134 | -------------------------------------------------------------------------------- /model/idn_pams.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.quant_ops import quant_act_pams, quant_conv3x3 6 | 7 | class FBlock_PAMS(nn.Module): 8 | def __init__(self, conv, num_features, k_bits =32, bias=False, ema_epoch=1, linq=False): 9 | super(FBlock_PAMS, self).__init__() 10 | 11 | self.quant_actf1 = quant_act_pams(k_bits, ema_epoch=ema_epoch) 12 | self.quant_actf2 = quant_act_pams(k_bits, ema_epoch=ema_epoch) 13 | 14 | self.module = nn.Sequential( 15 | self.quant_actf1, 16 | conv(3, num_features, kernel_size=3, k_bits=k_bits, bias=bias), 17 | nn.LeakyReLU(0.05), 18 | self.quant_actf2, 19 | conv(num_features, num_features, kernel_size=3, k_bits=k_bits, bias=bias), 20 | nn.LeakyReLU(0.05) 21 | ) 22 | 23 | def forward(self, x): 24 | return self.module(x) 25 | 26 | class DBlock_PAMS(nn.Module): 27 | def __init__(self, conv, num_features, d, s, bias=False, k_bits=32, ema_epoch=1, loss_kdf=False, linq=False): 28 | super(DBlock_PAMS, self).__init__() 29 | self.num_features = num_features 30 | self.s = s 31 | self.k_bits = k_bits 32 | 33 | self.quant_act1 = quant_act_pams(self.k_bits,ema_epoch=ema_epoch) 34 | self.quant_act2 = quant_act_pams(self.k_bits,ema_epoch=ema_epoch) 35 | self.quant_act3 = quant_act_pams(self.k_bits,ema_epoch=ema_epoch) 36 | self.quant_act4 = quant_act_pams(self.k_bits,ema_epoch=ema_epoch) 37 | self.quant_act5 = quant_act_pams(self.k_bits,ema_epoch=ema_epoch) 38 | self.quant_act6 = quant_act_pams(self.k_bits,ema_epoch=ema_epoch) 39 | 40 | 41 | self.enhancement_top = nn.Sequential( 42 | self.quant_act1, 43 | conv(num_features, num_features - d, kernel_size=3, k_bits=self.k_bits, bias=bias), 44 | nn.LeakyReLU(0.05), 45 | self.quant_act2, 46 | conv(num_features - d, num_features - 2 * d, kernel_size=3, k_bits=self.k_bits, bias=bias, groups=4), 47 | nn.LeakyReLU(0.05), 48 | self.quant_act3, 49 | conv(num_features - 2 * d, num_features, kernel_size=3, k_bits=self.k_bits, bias=bias), 50 | nn.LeakyReLU(0.05) 51 | ) 52 | self.enhancement_bottom = nn.Sequential( 53 | self.quant_act4, 54 | conv(num_features - d, num_features, kernel_size=3, k_bits=self.k_bits, bias=bias), 55 | nn.LeakyReLU(0.05), 56 | self.quant_act5, 57 | conv(num_features, num_features - d, kernel_size=3, k_bits=self.k_bits, bias=bias, groups=4), 58 | nn.LeakyReLU(0.05), 59 | self.quant_act6, 60 | conv(num_features - d, num_features + d, kernel_size=3, k_bits=self.k_bits, bias=bias), 61 | nn.LeakyReLU(0.05) 62 | ) 63 | self.compression = nn.Conv2d(num_features + d, num_features, kernel_size=1) 64 | self.loss_kdf= loss_kdf 65 | 66 | def forward(self, x): 67 | f = x[1] 68 | x = x[0] 69 | 70 | residual = x 71 | x = self.enhancement_top(x) 72 | 73 | slice_1 = x[:, :int((self.num_features - self.num_features/self.s)), :, :] 74 | slice_2 = x[:, int((self.num_features - self.num_features/self.s)):, :, :] 75 | 76 | x = self.enhancement_bottom(slice_1) 77 | f1 = x 78 | 79 | x = x + torch.cat((residual, slice_2), 1) 80 | x = self.compression(x) 81 | 82 | if self.loss_kdf: 83 | if f is None: 84 | f = f1.unsqueeze(0) 85 | else: 86 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 87 | else: 88 | f = None 89 | 90 | return [x, f] 91 | 92 | 93 | class IDN_PAMS(nn.Module): 94 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits=32, linq=False): 95 | super(IDN_PAMS, self).__init__() 96 | self.scale = args.scale[0] 97 | num_features = args.n_feats 98 | d = args.idn_d 99 | s = args.idn_s 100 | 101 | self.fblock = FBlock_PAMS(conv, num_features, bias=bias, k_bits=args.k_bits, ema_epoch=args.ema_epoch, linq=linq) 102 | 103 | 104 | m_dblocks = [ 105 | DBlock_PAMS(conv, num_features, d, s, bias=bias, k_bits=args.k_bits, ema_epoch=args.ema_epoch,loss_kdf=args.loss_kdf, linq=linq) for _ in range(4) 106 | 107 | ] 108 | self.dblocks = nn.Sequential(*m_dblocks) 109 | self.deconv = nn.ConvTranspose2d(num_features, 3, kernel_size=17, stride=self.scale, padding=8, output_padding=1) 110 | 111 | 112 | def forward(self, x): 113 | bicubic = F.interpolate(x, scale_factor=self.scale, mode='bicubic', align_corners=False) 114 | x= self.fblock(x) 115 | f = None 116 | x,f = self.dblocks([x,f]) 117 | out = x 118 | x = self.deconv(x, output_size=bicubic.size()) 119 | 120 | return bicubic + x, out, f 121 | 122 | -------------------------------------------------------------------------------- /model/quant_ops.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import pdb 4 | import random 5 | import time 6 | from itertools import repeat 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Function as F 12 | 13 | 14 | def _ntuple(n): 15 | def parse(x): 16 | if isinstance(x, collections.Iterable): 17 | return x 18 | return tuple(repeat(x, n)) 19 | return parse 20 | 21 | _pair = _ntuple(2) 22 | 23 | def quant_max(tensor): 24 | """ 25 | Returns the max value for symmetric quantization. 26 | """ 27 | return torch.abs(tensor.detach()).max() + 1e-8 28 | 29 | def TorchRound(): 30 | """ 31 | Apply STE to clamp function. 32 | """ 33 | class identity_quant(torch.autograd.Function): 34 | @staticmethod 35 | def forward(ctx, input): 36 | out = torch.round(input) 37 | return out 38 | 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | return grad_output 42 | 43 | return identity_quant().apply 44 | 45 | 46 | class quant_weight(nn.Module): 47 | """ 48 | Quantization function for quantize weight with maximum. 49 | """ 50 | 51 | def __init__(self, k_bits): 52 | super(quant_weight, self).__init__() 53 | self.k_bits = k_bits 54 | self.qmax = 2. ** (k_bits -1) - 1. 55 | self.round = TorchRound() 56 | 57 | def forward(self, input): 58 | # no learning 59 | max_val = quant_max(input) 60 | weight = input * self.qmax / max_val 61 | q_weight = self.round(weight) 62 | q_weight = q_weight * max_val / self.qmax 63 | return q_weight 64 | 65 | class quant_act_pams(nn.Module): 66 | """ 67 | Quantization function for quantize activation with parameterized max scale. 68 | """ 69 | def __init__(self, k_bits, ema_epoch=1, decay=0.9997, is_teacher=False, rel_shift=False): 70 | super(quant_act_pams, self).__init__() 71 | self.decay = decay 72 | self.k_bits = k_bits 73 | self.qmax = 2. ** (self.k_bits -1) -1. 74 | self.qmax_shift = 2. ** (self.k_bits) -1 75 | 76 | self.round = TorchRound() 77 | self.alpha = nn.Parameter(torch.Tensor(1)) 78 | 79 | self.ema_epoch = ema_epoch 80 | self.epoch = 1 81 | 82 | self.is_teacher = is_teacher 83 | self.rel_shift= rel_shift 84 | 85 | self.register_buffer('max_val', torch.tensor(1.)) 86 | 87 | 88 | self.reset_parameter() 89 | 90 | 91 | def reset_parameter(self): 92 | nn.init.constant_(self.alpha, 10) 93 | 94 | def _ema(self, x): 95 | max_val = torch.mean(torch.max(torch.max(torch.max(abs(x),dim=1)[0],dim=1)[0],dim=1)[0]) 96 | 97 | if self.epoch == 1: 98 | self.max_val = max_val 99 | else: 100 | self.max_val = (1.0-self.decay) * max_val + self.decay * self.max_val 101 | 102 | 103 | def forward(self, x): 104 | self.qmax = 2. ** (self.k_bits -1) -1. 105 | self.qmax_shift = 2. ** (self.k_bits) -1 106 | if self.epoch > self.ema_epoch or not self.training: 107 | # act = torch.max(torch.min(x, self.alpha), -self.alpha) 108 | if x.min()>=0: 109 | if self.rel_shift: 110 | act = torch.max(torch.min(x, self.alpha), 0.*self.alpha) 111 | else: 112 | act = torch.max(torch.min(x, self.alpha), -self.alpha) # for prelu (e.g., fsrcnn, srresnet) 113 | 114 | 115 | else: 116 | act = torch.max(torch.min(x, self.alpha), -self.alpha) 117 | 118 | elif self.epoch <= self.ema_epoch and self.training: 119 | act = x 120 | self._ema(x) 121 | self.alpha.data = self.max_val.unsqueeze(0) 122 | 123 | 124 | act = act / self.alpha 125 | # print(x.shape) 126 | if x.min()>=0 and self.rel_shift: 127 | qmax = self.qmax_shift 128 | else: 129 | qmax = self.qmax 130 | q_act = self.round(act*qmax) / qmax 131 | q_act = q_act *self.alpha 132 | 133 | 134 | 135 | return q_act 136 | 137 | 138 | 139 | class quant_act_lin(nn.Module): 140 | def __init__(self, k_bits): 141 | super(quant_act_lin, self).__init__() 142 | self.k_bits = k_bits 143 | self.qmax = 2. ** (self.k_bits -1) -1. 144 | self.round = TorchRound() 145 | 146 | def forward(self, x): 147 | max_val = quant_max(x) 148 | x = x * self.qmax / max_val 149 | x_q = self.round(x) 150 | x_q = x_q * max_val / self.qmax 151 | return x_q 152 | 153 | class QuantConv2d(nn.Module): 154 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 155 | padding=0, dilation=1, groups=1, bias=False,k_bits=32,): 156 | super(QuantConv2d, self).__init__() 157 | # self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels,kernel_size,kernel_size)) 158 | self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels//groups,kernel_size,kernel_size)) 159 | 160 | # self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels,kernel_size,kernel_size)).cuda() 161 | 162 | self.stride = stride 163 | self.padding = padding 164 | self.dilation = dilation 165 | self.groups = groups 166 | self.in_channels = in_channels 167 | self.kernel_size = _pair(kernel_size) 168 | self.bias_flag = bias 169 | if self.bias_flag: 170 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 171 | # self.bias = nn.Parameter(torch.Tensor(out_channels)).cuda() 172 | 173 | else: 174 | self.register_parameter('bias',None) 175 | self.k_bits = k_bits 176 | self.quant_weight = quant_weight(k_bits = k_bits) 177 | self.output = None 178 | self.reset_parameters() 179 | 180 | def reset_parameters(self): 181 | n = self.in_channels 182 | for k in self.kernel_size: 183 | n *= k 184 | stdv = 1. / math.sqrt(n) 185 | self.weight.data.uniform_(-stdv, stdv) 186 | if self.bias is not None: 187 | self.bias.data.uniform_(-stdv, stdv) 188 | 189 | def reset_parameter(self): 190 | stdv = 1.0/ math.sqrt(self.weight.size(0)) 191 | self.weight.data.uniform_(-stdv,stdv) 192 | if self.bias_flag: 193 | nn.init.constant_(self.bias,0.0) 194 | 195 | def forward(self, input, bits=None, order=None): 196 | if bits is not None: 197 | if input.size(0)!= 1: 198 | for i in range (input.size(0)): 199 | self.quant_weight = quant_weight(k_bits = bits[i]) 200 | weight_q = self.quant_weight(self.weight) 201 | out= nn.functional.conv2d(input[i].unsqueeze(0), weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) 202 | if i==0: 203 | out_stacked = out 204 | else: 205 | out_stacked = torch.cat([out_stacked, out], dim=0) 206 | return out_stacked 207 | else: 208 | self.quant_weight = quant_weight(k_bits = bits) 209 | # this works for weight during inference (batch=1) but not for training 210 | # for training, use group conv to take different bit for different batch index 211 | 212 | return nn.functional.conv2d(input, self.quant_weight(self.weight), self.bias, self.stride, self.padding, self.dilation, self.groups) 213 | 214 | def conv3x3(in_channels, out_channels,kernel_size=3,stride=1,padding =1,bias= True): 215 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias) 216 | 217 | def quant_conv3x3(in_channels, out_channels,kernel_size=3,padding = 1,stride=1,k_bits=32,bias = False,groups=1): 218 | return QuantConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride = stride,padding=padding,k_bits=k_bits,bias = bias,groups=groups) 219 | 220 | def conv9x9(in_channels, out_channels,kernel_size=9,stride=1,padding =4,bias= False): 221 | return nn.Conv2d(in_channels, out_channels, kernel_size=9, stride=stride, padding=padding, bias=bias) 222 | 223 | 224 | def quant_conv9x9(in_channels, out_channels,kernel_size=9,stride=1,padding =4,bias= False, k_bits=32): 225 | return QuantConv2d(in_channels, out_channels, kernel_size=9, stride=stride, padding=padding, bias=bias, k_bits=k_bits) 226 | 227 | -------------------------------------------------------------------------------- /model/srresnet.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | from model.quant_ops import conv9x9, conv3x3 7 | 8 | 9 | 10 | class SRResNet(nn.Module): 11 | def __init__(self, args, is_teacher=False, conv=common.default_conv): 12 | super(SRResNet, self).__init__() 13 | 14 | 15 | n_resblocks = args.n_resblocks 16 | n_feats = args.n_feats 17 | kernel_size = 3 18 | scale = args.scale[0] 19 | # act = nn.LeakyReLU(0.2, inplace=True) 20 | act = nn.PReLU() 21 | 22 | self.is_teacher = is_teacher 23 | 24 | self.sub_mean = common.MeanShift(args.rgb_range) 25 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 26 | 27 | # define head module 28 | m_head = [conv9x9(args.n_colors, n_feats, kernel_size=9, bias=False)] 29 | m_head.append(act) 30 | 31 | m_body = [ 32 | common.ResBlock_srresnet( 33 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 34 | ) for _ in range(n_resblocks) 35 | ] 36 | 37 | m_body.append( conv3x3(n_feats, n_feats, kernel_size, bias=False)) 38 | m_body.append( nn.BatchNorm2d(n_feats)) 39 | # m_body.append(nn.InstanceNorm2d(n_feats, affine=True)) 40 | 41 | 42 | 43 | 44 | m_tail = [ 45 | common.Upsampler_srresnet(conv3x3, scale, n_feats, act=False), 46 | conv9x9(n_feats, args.n_colors, kernel_size=9, bias=False) 47 | ] 48 | 49 | self.head = nn.Sequential(*m_head) 50 | self.body = nn.Sequential(*m_body) 51 | self.tail = nn.Sequential(*m_tail) 52 | 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | 60 | def forward(self, x): 61 | 62 | # x = self.sub_mean(x) 63 | x = self.head(x) 64 | 65 | res = self.body(x) 66 | res += x 67 | 68 | out = res 69 | 70 | x = self.tail(res) 71 | # x = self.add_mean(x) 72 | if self.is_teacher: 73 | return x, out 74 | else: 75 | return x 76 | 77 | 78 | -------------------------------------------------------------------------------- /model/srresnet_cadyq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import common 7 | from model.quant_ops import conv3x3, conv9x9, quant_conv3x3, quant_conv9x9 8 | from model.cadyq import BitSelector 9 | 10 | class ResBlock_CADyQ(nn.Module): 11 | def __init__(self, conv, n_feats, kernel_size, bias=False, 12 | bn=False, inn=False, act=nn.PReLU(), res_scale=1, k_bits = 32, ema_epoch=1, search_space=[2,4,8], loss_kdf=False, linq=False): 13 | 14 | super(ResBlock_CADyQ, self).__init__() 15 | self.k_bits = k_bits 16 | 17 | self.classify1 = BitSelector(n_feats, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq) 18 | self.classify2 = BitSelector(n_feats, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq) 19 | 20 | self.conv1 = conv(n_feats, n_feats, kernel_size, k_bits=self.k_bits, bias=bias) 21 | self.conv2 = conv(n_feats, n_feats, kernel_size, k_bits=self.k_bits, bias=bias) 22 | self.bn1 = nn.BatchNorm2d(n_feats) 23 | # self.bn1 = nn.InstanceNorm2d(n_feats, affine=True) 24 | 25 | self.act = act 26 | self.bn2 = nn.BatchNorm2d(n_feats) 27 | # self.bn2 = nn.InstanceNorm2d(n_feats, affine=True) 28 | 29 | self.res_scale = res_scale 30 | 31 | self.shortcut = common.ShortCut() 32 | self.loss_kdf = loss_kdf 33 | 34 | 35 | def forward(self, x): 36 | weighted_bits = x[4] 37 | f = x[3] 38 | bits = x[2] 39 | grad = x[0] 40 | x = x[1] 41 | 42 | grad,residual,bits,weighted_bits = self.classify1([grad, self.shortcut(x),bits,weighted_bits]) 43 | res = self.act(self.bn1(self.conv1(x))) 44 | 45 | grad, res, bits, weighted_bits = self.classify2([grad, res, bits, weighted_bits]) 46 | res = self.conv2(res) 47 | f1 = res 48 | res = self.bn2(res).mul(self.res_scale) 49 | res += residual 50 | if self.loss_kdf: 51 | if f is None: 52 | f = f1.unsqueeze(0) 53 | else: 54 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 55 | else: 56 | f = None 57 | 58 | return [grad, res, bits, f, weighted_bits] 59 | 60 | 61 | 62 | class SRResNet_CADyQ(nn.Module): 63 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits =32 ): 64 | super(SRResNet_CADyQ, self).__init__() 65 | 66 | 67 | n_resblocks = args.n_resblocks 68 | n_feats = args.n_feats 69 | kernel_size = 3 70 | scale = args.scale[0] 71 | act = nn.PReLU() 72 | # act = nn.LeakyReLU(0.2, inplace=True) 73 | 74 | self.fully = args.fully 75 | self.k_bits = k_bits 76 | 77 | self.sub_mean = common.MeanShift(args.rgb_range) 78 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 79 | 80 | 81 | m_head = [conv9x9(args.n_colors, n_feats, kernel_size=9, bias=False)] 82 | m_head.append(act) 83 | 84 | m_body = [ 85 | ResBlock_CADyQ( 86 | quant_conv3x3, n_feats, kernel_size, bn=True, act=act, res_scale=args.res_scale, k_bits=self.k_bits, bias=bias, ema_epoch=args.ema_epoch, search_space=args.search_space, loss_kdf=args.loss_kdf, linq=args.linq 87 | ) for _ in range(n_resblocks) 88 | ] 89 | m_body.append( conv3x3(n_feats, n_feats, kernel_size, bias=False)) 90 | m_body.append( nn.BatchNorm2d(n_feats)) 91 | 92 | m_tail = [ 93 | common.Upsampler_srresnet(conv3x3, scale, n_feats, act=False), 94 | conv9x9(n_feats, args.n_colors, kernel_size=9, bias=False) 95 | ] 96 | 97 | self.head = nn.Sequential(*m_head) 98 | self.body = nn.Sequential(*m_body) 99 | self.tail = nn.Sequential(*m_tail) 100 | 101 | 102 | def forward(self, x): 103 | 104 | # x = self.sub_mean(x) 105 | image = x 106 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 107 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) *1e+3 # abs version 108 | if self.fully: x = self.quant_head(x) 109 | 110 | x = self.head(x) 111 | res = x 112 | bits = 0; weighted_bits=0; f=None 113 | 114 | image_grad, res, bits, f, weighted_bits = self.body[0:-2]([image_grad, res, bits, f, weighted_bits]) 115 | 116 | res = self.body[-2:](res) 117 | res += x 118 | out = res 119 | x = self.tail(res) 120 | # x = self.add_mean(x) 121 | 122 | return x, out, bits, f, weighted_bits 123 | 124 | 125 | -------------------------------------------------------------------------------- /model/srresnet_cadyq_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import kornia as K 5 | 6 | from model import common 7 | from model.quant_ops import conv3x3, conv9x9, quant_conv3x3, quant_conv9x9 8 | from model.edge import BitSelector 9 | 10 | class ResBlock_CADyQ(nn.Module): 11 | def __init__(self, conv, n_feats, kernel_size, bias=False, 12 | bn=False, inn=False, act=nn.PReLU(), res_scale=1, k_bits = 32, ema_epoch=1, search_space=[2,4,8], loss_kdf=False, linq=False): 13 | 14 | super(ResBlock_CADyQ, self).__init__() 15 | self.k_bits = k_bits 16 | 17 | self.classify1 = BitSelector(n_feats, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq) 18 | self.classify2 = BitSelector(n_feats, bias=bias, ema_epoch=ema_epoch, search_space=search_space, linq=linq) 19 | 20 | self.conv1 = conv(n_feats, n_feats, kernel_size, k_bits=self.k_bits, bias=bias) 21 | self.conv2 = conv(n_feats, n_feats, kernel_size, k_bits=self.k_bits, bias=bias) 22 | self.bn1 = nn.BatchNorm2d(n_feats) 23 | # self.bn1 = nn.InstanceNorm2d(n_feats, affine=True) 24 | 25 | self.act = act 26 | self.bn2 = nn.BatchNorm2d(n_feats) 27 | # self.bn2 = nn.InstanceNorm2d(n_feats, affine=True) 28 | 29 | self.res_scale = res_scale 30 | 31 | self.shortcut = common.ShortCut() 32 | self.loss_kdf = loss_kdf 33 | 34 | 35 | def forward(self, x): 36 | weighted_bits = x[4] 37 | f = x[3] 38 | bits = x[2] 39 | grad = x[0] 40 | x = x[1] 41 | 42 | grad,residual,bits = self.classify1([grad, self.shortcut(x),bits]) 43 | res = self.act(self.bn1(self.conv1(x))) 44 | 45 | grad, res, bits = self.classify2([grad, res, bits]) 46 | res = self.conv2(res) 47 | f1 = res 48 | res = self.bn2(res).mul(self.res_scale) 49 | res += residual 50 | if self.loss_kdf: 51 | if f is None: 52 | f = f1.unsqueeze(0) 53 | else: 54 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 55 | else: 56 | f = None 57 | 58 | return [grad, res, bits, f, weighted_bits] 59 | 60 | 61 | 62 | class SRResNet_CADyQ_I(nn.Module): 63 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits =32 ): 64 | super(SRResNet_CADyQ_I, self).__init__() 65 | 66 | 67 | n_resblocks = args.n_resblocks 68 | n_feats = args.n_feats 69 | kernel_size = 3 70 | scale = args.scale[0] 71 | act = nn.PReLU() 72 | # act = nn.LeakyReLU(0.2, inplace=True) 73 | 74 | self.fully = args.fully 75 | self.k_bits = k_bits 76 | 77 | self.sub_mean = common.MeanShift(args.rgb_range) 78 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 79 | 80 | 81 | m_head = [conv9x9(args.n_colors, n_feats, kernel_size=9, bias=False)] 82 | m_head.append(act) 83 | 84 | m_body = [ 85 | ResBlock_CADyQ( 86 | quant_conv3x3, n_feats, kernel_size, bn=True, act=act, res_scale=args.res_scale, k_bits=self.k_bits, bias=bias, ema_epoch=args.ema_epoch, search_space=args.search_space, loss_kdf=args.loss_kdf, linq=args.linq 87 | ) for _ in range(n_resblocks) 88 | ] 89 | m_body.append( conv3x3(n_feats, n_feats, kernel_size, bias=False)) 90 | m_body.append( nn.BatchNorm2d(n_feats)) 91 | 92 | m_tail = [ 93 | common.Upsampler_srresnet(conv3x3, scale, n_feats, act=False), 94 | conv9x9(n_feats, args.n_colors, kernel_size=9, bias=False) 95 | ] 96 | 97 | self.head = nn.Sequential(*m_head) 98 | self.body = nn.Sequential(*m_body) 99 | self.tail = nn.Sequential(*m_tail) 100 | 101 | 102 | def forward(self, x): 103 | 104 | # x = self.sub_mean(x) 105 | image = x 106 | grads: torch.Tensor = K.filters.spatial_gradient(K.color.rgb_to_grayscale(image/255.), order=1) 107 | image_grad = torch.mean(torch.abs(grads.squeeze(1)),(2,3)) *1e+3 # abs version 108 | if self.fully: x = self.quant_head(x) 109 | 110 | x = self.head(x) 111 | res = x 112 | bits = 0; weighted_bits=0; f=None 113 | 114 | image_grad, res, bits, f, weighted_bits = self.body[0:-2]([image_grad, res, bits, f, weighted_bits]) 115 | 116 | res = self.body[-2:](res) 117 | res += x 118 | out = res 119 | x = self.tail(res) 120 | # x = self.add_mean(x) 121 | 122 | return x, out, bits, f 123 | 124 | 125 | -------------------------------------------------------------------------------- /model/srresnet_pams.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model import common 6 | 7 | from model.quant_ops import quant_act_pams 8 | from model.quant_ops import conv3x3, conv9x9, quant_conv3x3, quant_conv9x9 9 | 10 | class ResBlock_PAMS(nn.Module): 11 | def __init__(self, conv, n_feats, kernel_size, bias=False, 12 | bn=False, inn=False, act=nn.PReLU(), res_scale=1, k_bits = 32, ema_epoch=1, name=None, loss_kdf=False, linq=False): 13 | 14 | super(ResBlock_PAMS, self).__init__() 15 | self.k_bits = k_bits 16 | 17 | self.quant_act1 = quant_act_pams(k_bits,ema_epoch=ema_epoch) 18 | self.quant_act2 = quant_act_pams(k_bits,ema_epoch=ema_epoch) 19 | 20 | self.conv1 = conv(n_feats, n_feats, kernel_size, k_bits=self.k_bits, bias=bias) 21 | self.conv2 = conv(n_feats, n_feats, kernel_size, k_bits=self.k_bits, bias=bias) 22 | self.bn1 = nn.BatchNorm2d(n_feats) 23 | 24 | self.act = act 25 | self.bn2 = nn.BatchNorm2d(n_feats) 26 | 27 | self.res_scale = res_scale 28 | 29 | self.shortcut = common.ShortCut() 30 | self.loss_kdf = loss_kdf 31 | 32 | 33 | 34 | def forward(self, x): 35 | f = x[1] 36 | x = x[0] 37 | 38 | residual = self.quant_act1(self.shortcut(x)) 39 | res = self.act(self.bn1(self.conv1(x))) 40 | 41 | res = self.quant_act2(res) 42 | res = self.conv2(res) 43 | f1 = res 44 | res = self.bn2(res).mul(self.res_scale) 45 | 46 | res += residual 47 | if self.loss_kdf: 48 | if f is None: 49 | f = f1.unsqueeze(0) 50 | else: 51 | f = torch.cat([f, f1.unsqueeze(0)], dim=0) 52 | else: 53 | f = None 54 | 55 | return [res, f] 56 | 57 | class SRResNet_PAMS(nn.Module): 58 | def __init__(self, args, conv=quant_conv3x3, bias=False, k_bits =32, linq=False, fully=False): 59 | super(SRResNet_PAMS, self).__init__() 60 | 61 | n_resblocks = args.n_resblocks 62 | n_feats = args.n_feats 63 | kernel_size = 3 64 | scale = args.scale[0] 65 | act = nn.PReLU() 66 | self.fully = fully 67 | 68 | self.k_bits = k_bits 69 | 70 | self.sub_mean = common.MeanShift(args.rgb_range) 71 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 72 | 73 | m_head = [conv9x9(args.n_colors, n_feats, kernel_size=9, bias=False)] 74 | m_head.append(act) 75 | 76 | m_body = [ 77 | ResBlock_PAMS( 78 | quant_conv3x3, n_feats, kernel_size, bn=True, act=act, res_scale=args.res_scale, k_bits=self.k_bits, bias=bias, ema_epoch=args.ema_epoch, loss_kdf=args.loss_kdf, linq=linq 79 | ) for _ in range(n_resblocks) 80 | ] 81 | m_body.append( conv3x3(n_feats, n_feats, kernel_size, bias=False)) 82 | m_body.append( nn.BatchNorm2d(n_feats)) 83 | 84 | m_tail = [ 85 | common.Upsampler_srresnet(conv3x3, scale, n_feats, act=False), 86 | conv9x9(n_feats, args.n_colors, kernel_size=9, bias=False) 87 | ] 88 | 89 | self.head = nn.Sequential(*m_head) 90 | self.body = nn.Sequential(*m_body) 91 | self.tail = nn.Sequential(*m_tail) 92 | 93 | 94 | def forward(self, x): 95 | 96 | # x = self.sub_mean(x) 97 | if self.fully: x = self.quant_head(x) 98 | x = self.head(x) 99 | if self.fully: x = self.quant_head2(x) 100 | f= None 101 | 102 | res,f = self.body[0:-2]([x,f]) 103 | res = self.body[-2:](res) 104 | 105 | res += x 106 | out = res 107 | x = self.tail(res) 108 | # x = self.add_mean(x) 109 | 110 | return x, out, f 111 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 4 | 5 | parser.add_argument('--debug', action='store_true', 6 | help='Enables debug mode') 7 | parser.add_argument('--template', default='.', 8 | help='You can set various templates in option.py') 9 | parser.add_argument('--show_params', action='store_true', 10 | help='You can see the parameters of the model') 11 | parser.add_argument('--select_bit', type=int, default=-1, 12 | help='number of threads for data loading') 13 | parser.add_argument('--select_float', type=int, default=2, 14 | help='number of threads for data loading') 15 | parser.add_argument('--calibration', type=float, default=6.0, 16 | help='number of threads for data loading') 17 | # Hardware specifications 18 | parser.add_argument('--n_threads', type=int, default=6, 19 | help='number of threads for data loading') 20 | parser.add_argument('--cpu', action='store_true', 21 | help='use cpu only') 22 | parser.add_argument('--gpu_id', default="0",type=str, 23 | help='the gpu id of using') 24 | parser.add_argument('--n_GPUs', type=int, default=1, 25 | help='number of GPUs') 26 | parser.add_argument('--seed', type=int, default=1, 27 | help='random seed') 28 | parser.add_argument('--threshld_ratio', type=float, default=1.0, 29 | help='random seed') 30 | # Data specifications 31 | parser.add_argument('--dir_data', type=str, default='', 32 | help='dataset image directory') 33 | parser.add_argument('--sample_config', type=str, default='', 34 | help='dataset image directory') 35 | parser.add_argument('--dir_demo', type=str, default='../test', 36 | help='demo image directory') 37 | parser.add_argument('--data_train', type=str, default='DIV2K', 38 | help='train dataset name') 39 | parser.add_argument('--data_test', type=str, default='Set5', 40 | help='test dataset name') 41 | parser.add_argument('--data_range', type=str, default='1-800/801-810', 42 | help='train/test data range') 43 | parser.add_argument('--ext', type=str, default='sep', 44 | help='dataset file extension') 45 | parser.add_argument('--scale', type=str, default='4', 46 | help='super resolution scale') 47 | parser.add_argument('--patch_size', type=int, default=96, 48 | help='output patch size') 49 | parser.add_argument('--rgb_range', type=int, default=255, 50 | help='maximum value of RGB') 51 | parser.add_argument('--n_colors', type=int, default=3, 52 | help='number of color channels to use') 53 | parser.add_argument('--chop', action='store_true', 54 | help='enable memory-efficient forward') 55 | parser.add_argument('--no_augment', action='store_true', 56 | help='do not use data augmentation') 57 | parser.add_argument('--kl', action='store_true', 58 | help='use kl') 59 | parser.add_argument('--conv_idx', type=str, default='22', 60 | help='vgg index') 61 | 62 | # Model specifications 63 | parser.add_argument('--model', default='EDSR', 64 | help='model name') 65 | parser.add_argument('--pix_type', default='l1', 66 | help='model name') 67 | parser.add_argument('--act', type=str, default='relu', 68 | help='activation function') 69 | parser.add_argument('--teacher_weights', type=str, default=None, 70 | help='pretrained model directory for teacher initialization') 71 | parser.add_argument('--student_weights', type=str, default=None, 72 | help='pretrained model directory for student initialization') 73 | parser.add_argument('--extend', type=str, default='.', 74 | help='pre-trained model directory') 75 | parser.add_argument('--n_resblocks', type=int, default=16, 76 | help='number of residual blocks') 77 | parser.add_argument('--n_feats', type=int, default=64, 78 | help='number of feature maps') 79 | parser.add_argument('--res_scale', type=float, default=1, 80 | help='residual scaling') 81 | parser.add_argument('--shift_mean', default=True, 82 | help='subtract pixel mean from the input') 83 | parser.add_argument('--dilation', action='store_true', 84 | help='use dilated convolution') 85 | parser.add_argument('--precision', type=str, default='single', 86 | choices=('single', 'half'), 87 | help='FP precision for test (single | half)') 88 | parser.add_argument('--k_bits', type=int, default=32, 89 | help='The k_bits of the quantize') 90 | 91 | # Option for Residual dense network (RDN) 92 | parser.add_argument('--G0', type=int, default=64, 93 | help='default number of filters. (Use in RDN)') 94 | parser.add_argument('--RDNkSize', type=int, default=3, 95 | help='default kernel size. (Use in RDN)') 96 | parser.add_argument('--RDNconfig', type=str, default='B', 97 | help='parameters config of RDN. (Use in RDN)') 98 | 99 | # Option for Residual channel attention network (RCAN) 100 | parser.add_argument('--n_resgroups', type=int, default=10, 101 | help='number of residual groups') 102 | parser.add_argument('--reduction', type=int, default=16, 103 | help='number of feature maps reduction') 104 | 105 | #---------------------IDN------------------------- 106 | parser.add_argument('--idn_d', type=int, default=16) 107 | parser.add_argument('--idn_s', type=int, default=4) 108 | # use n_feats of above : default = 64 109 | 110 | #---------------------CARN------------------------- 111 | parser.add_argument('--multi_scale',action='store_true') 112 | parser.add_argument('--group', type=int, default=1) 113 | 114 | 115 | # Training specifications 116 | parser.add_argument('--reset', action='store_true', 117 | help='reset the training') 118 | parser.add_argument('--test_every', type=int, default=1000, 119 | help='do test per every N batches') 120 | parser.add_argument('--epochs', type=int, default=30, 121 | help='number of epochs to train') 122 | parser.add_argument('--ema_epoch', type=int, default=1, 123 | help='number of epochs to train') 124 | parser.add_argument('--batch_size', type=int, default=4, 125 | help='input batch size for training') 126 | parser.add_argument('--split_batch', type=int, default=1, 127 | help='split the batch into smaller chunks') 128 | parser.add_argument('--self_ensemble', action='store_true', 129 | help='use self-ensemble method for test') 130 | parser.add_argument('--test_only', action='store_true', 131 | help='set this option to test the model') 132 | parser.add_argument('--gan_k', type=int, default=1, 133 | help='k value for adversarial loss') 134 | 135 | # Optimization specifications 136 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 137 | parser.add_argument('--w_l1', type=float, default=1.0, help='learning rate for L1') 138 | parser.add_argument('--w_at', type=float, default=1e+3, help='learning rate for distillation loss') 139 | parser.add_argument('--w_bit', type=float, default=0.5, help='learning rate for bit regularization loss') 140 | 141 | parser.add_argument('--decay', type=str, default='150', 142 | help='learning rate decay type') 143 | parser.add_argument('--gamma', type=float, default=0.5, 144 | help='learning rate decay factor for step decay') 145 | parser.add_argument('--optimizer', default='ADAM', 146 | choices=('SGD', 'ADAM', 'RMSprop'), 147 | help='optimizer to use (SGD | ADAM | RMSprop)') 148 | parser.add_argument('--momentum', type=float, default=0.9, 149 | help='SGD momentum') 150 | parser.add_argument('--nesterov', type=bool, default=False, 151 | help='nesterov') 152 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), 153 | help='ADAM beta') 154 | parser.add_argument('--epsilon', type=float, default=1e-8, 155 | help='ADAM epsilon for numerical stability') 156 | parser.add_argument('--weight_decay', type=float, default=0, 157 | help='weight decay') 158 | parser.add_argument('--gclip', type=float, default=5, 159 | help='gradient clipping threshold (0 = no clipping)') 160 | 161 | # Loss specifications 162 | parser.add_argument('--loss', type=str, default='1*L1', 163 | help='loss function configuration') 164 | parser.add_argument('--skip_threshold', type=float, default='1e8', 165 | help='skipping batch that has large error') 166 | parser.add_argument('--loss_kd', action='store_true', help='trainning with knowledge distillation loss') 167 | parser.add_argument('--loss_kdf', action='store_true', help='trainning with feature knowledge distillation loss') 168 | 169 | 170 | # Log specifications 171 | parser.add_argument('--suffix', default=None, type=str, 172 | help='suffix to help you remember what experiment you ran') 173 | parser.add_argument('--save', type=str, default='test', 174 | help='file name to save') 175 | parser.add_argument('--load', type=str, default='', 176 | help='file name to load') 177 | parser.add_argument('--resume', type=str, default=None, 178 | help='resume from specific checkpoint') 179 | parser.add_argument('--save_models', action='store_true', 180 | help='save all intermediate models') 181 | parser.add_argument('--print_every', type=int, default=100, 182 | help='how many batches to wait before logging training status') 183 | parser.add_argument('--save_results', action='store_true', 184 | help='save output results') 185 | parser.add_argument('--save_gt', action='store_true', 186 | help='save low-resolution and high-resolution images together') 187 | 188 | 189 | parser.add_argument('--cadyq', action='store_true', help='mixed precision for layer') 190 | parser.add_argument('--is_teacher', action='store_true', help='pams test') 191 | parser.add_argument('--search_space', type=str, default='32', help='bit search space') 192 | 193 | 194 | parser.add_argument('--bitsel_lr', type=float, default=1e-4) 195 | parser.add_argument('--bitsel_decay', type=str, default=150) 196 | parser.add_argument('--w_bit_decay', type=float, default=2e-6) 197 | 198 | parser.add_argument('--test_patch', action='store_true', help='testing patch-wise') 199 | parser.add_argument('--step_size', type=int, default=28, help='step size for combining patches') 200 | parser.add_argument('--save_patch', action='store_true',help='save patch results') 201 | 202 | parser.add_argument('--linq', action='store_true', help='linq') 203 | parser.add_argument('--fully', action='store_true', help='full quantization') 204 | # parser.add_argument('--train_full', action='store_true', help='pretraining 32 bit model') 205 | # parser.add_argument('--lpips', action='store_true', help='use lpips loss for optimization') 206 | 207 | 208 | # FSRCNN 209 | parser.add_argument('--m', type=int, default=4, help='m') 210 | 211 | args = parser.parse_args() 212 | 213 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 214 | args.data_train = args.data_train.split('+') 215 | args.data_test = args.data_test.split('+') 216 | 217 | # args.search_space = list(map(lambda x: int(x), args.search_space.split('+'))) 218 | args.search_space = [4,6,8] 219 | # [2,4,8] 220 | if args.epochs == 0: 221 | args.epochs = 1e8 222 | 223 | for arg in vars(args): 224 | if vars(args)[arg] == 'True': 225 | vars(args)[arg] = True 226 | elif vars(args)[arg] == 'False': 227 | vars(args)[arg] = False 228 | 229 | -------------------------------------------------------------------------------- /test_edsrbaseline_cabm_simple.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main_inference.py \ 2 | --test_only \ 3 | --data_test Urban100 --dir_data --n_GPUs 1 \ 4 | --scale 4 --k_bits 8 --model EDSR \ 5 | --search_space 4+6+8 --save edsrbaseline_test_x4 \ 6 | --n_feats 64 --n_resblocks 16 --res_scale 1 \ 7 | --patch_size 96 --select_bit 0 --select_float 2 --calibration 100 \ 8 | --student_weights \ 9 | --test_patch --step_size 96 \ 10 | # -------------------------------------------------------------------------------- /test_edsrbaseline_get_cabm_config.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main_getconfig.py \ 2 | --test_only \ 3 | --data_test div2k_valid --dir_data --n_GPUs 1 \ 4 | --scale 4 --k_bits 8 --model EDSR \ 5 | --cadyq --search_space 4+6+8 --save test_edsr_cadyq_patch \ 6 | --n_feats 64 --n_resblocks 16 --res_scale 1 \ 7 | --patch_size 96 --batch_size 16 --step_size 94 \ 8 | --student_weights \ 9 | --test_patch \ 10 | # -------------------------------------------------------------------------------- /train_edsrbaseline_cabm_simple.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main_inference.py \ 2 | --data_test Set5 --dir_data --n_GPUs 1 \ 3 | --scale 4 --k_bits 8 --model EDSR \ 4 | --search_space 4+6+8 --save edsrbaseline_cabm_simple_x4 \ 5 | --n_feats 64 --n_resblocks 16 --res_scale 1 \ 6 | --patch_size 96 --batch_size 16 --test_patch --step_size 94 \ 7 | --epochs 300 --decay 150 --lr 1e-5 \ 8 | --loss_kd --loss_kdf --w_bit_decay 1e-6 --select_bit 0 --select_float 2 \ 9 | --teacher_weights \ 10 | --student_weights \ 11 | # 12 | -------------------------------------------------------------------------------- /train_edsrbaseline_cadyq.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main_cadyq.py \ 2 | --data_test Set5 --dir_data --n_GPUs 1 \ 3 | --scale 4 --k_bits 8 --model EDSR \ 4 | --cadyq --search_space 4+6+8 --save edsrbaseline_cadyq_x4 \ 5 | --n_feats 64 --n_resblocks 16 --res_scale 1 \ 6 | --patch_size 96 --batch_size 16 \ 7 | --epochs 300 --decay 150 --lr 1e-4 --bitsel_lr 1e-4 --bitsel_decay 150 \ 8 | --loss_kd --loss_kdf --w_bit 1e-4 --w_bit_decay 1e-6 \ 9 | --teacher_weights \ 10 | --student_weights \ 11 | # -------------------------------------------------------------------------------- /train_edsrbaseline_org.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main_org.py \ 2 | --data_test Set5 --dir_data --n_GPUs 1 \ 3 | --scale 4 --model EDSR \ 4 | --cadyq --save edsrbaseline_org_x4 \ 5 | --n_feats 64 --n_resblocks 16 --res_scale 1 \ 6 | --patch_size 192 --batch_size 16 \ 7 | --epochs 600 --decay 300 --lr 1e-4 \ 8 | # -------------------------------------------------------------------------------- /train_edsrbaseline_pams.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main_pams.py \ 2 | --data_test Set5 --dir_data --n_GPUs 1 \ 3 | --scale 4 --k_bits 8 --model EDSR \ 4 | --save edsrbaseline_pams_x4 \ 5 | --n_feats 64 --n_resblocks 16 --res_scale 1 \ 6 | --patch_size 96 --batch_size 16 \ 7 | --epochs 100 --lr 1e-5 --decay 50 \ 8 | --teacher_weights \ 9 | --student_weights \ 10 | # -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from multiprocessing import Process 6 | from multiprocessing import Queue 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | 12 | import numpy as np 13 | import imageio 14 | import cv2 15 | 16 | import pdb 17 | import torch 18 | import torch.optim as optim 19 | import torch.optim.lr_scheduler as lrs 20 | 21 | class timer(): 22 | def __init__(self): 23 | self.acc = 0 24 | self.tic() 25 | 26 | def tic(self): 27 | self.t0 = time.time() 28 | 29 | def toc(self, restart=False): 30 | diff = time.time() - self.t0 31 | if restart: self.t0 = time.time() 32 | return diff 33 | 34 | def hold(self): 35 | self.acc += self.toc() 36 | 37 | def release(self): 38 | ret = self.acc 39 | self.acc = 0 40 | 41 | return ret 42 | 43 | def reset(self): 44 | self.acc = 0 45 | 46 | class checkpoint(): 47 | def __init__(self, args): 48 | self.args = args 49 | self.ok = True 50 | self.log = torch.Tensor() 51 | 52 | self.lpips_log = torch.Tensor() 53 | self.bit_log = torch.Tensor() 54 | 55 | 56 | self.ssim_log = torch.Tensor() 57 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 58 | 59 | if not args.load: 60 | if not args.save: 61 | args.save = now 62 | # self.dir = os.path.join('..', 'experiment', args.save) 63 | self.dir = os.path.join( 'experiment', args.save) 64 | else: 65 | # self.dir = os.path.join('..', 'experiment', args.load) 66 | self.dir = os.path.join( 'experiment', args.load) 67 | 68 | if os.path.exists(self.dir): 69 | self.log = torch.load(self.get_path('psnr_log.pt')) 70 | print('Continue from epoch {}...'.format(len(self.log))) 71 | else: 72 | args.load = '' 73 | 74 | if args.reset: 75 | os.system('rm -rf ' + self.dir) 76 | args.load = '' 77 | 78 | os.makedirs(self.dir, exist_ok=True) 79 | os.makedirs(self.get_path('model'), exist_ok=True) 80 | for d in args.data_test: 81 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 82 | os.makedirs(self.get_path('run'), exist_ok=True) 83 | 84 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' 85 | self.log_file = open(self.get_path('log.txt'), open_type) 86 | self.eval_file = open(self.get_path('eval.txt'), open_type) 87 | 88 | with open(self.get_path('config.txt'), open_type) as f: 89 | f.write(now + '\n\n') 90 | for arg in vars(args): 91 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 92 | f.write('\n') 93 | 94 | self.n_processes = 8 95 | 96 | def get_path(self, *subdir): 97 | return os.path.join(self.dir, *subdir) 98 | 99 | def save(self, trainer, epoch, is_best=False): 100 | # this is not used : utils/common.py save_checkpoin is used instead 101 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 102 | trainer.loss.save(self.dir) 103 | trainer.loss.plot_loss(self.dir, epoch) 104 | 105 | self.plot_psnr(epoch) 106 | trainer.optimizer.save(self.dir) 107 | torch.save(self.log, self.get_path('psnr_log.pt')) 108 | 109 | def add_log(self, log): 110 | self.log = torch.cat([self.log, log]) 111 | self.lpips_log = torch.cat([self.lpips_log, log]) 112 | self.bit_log = torch.cat([self.bit_log, log]) 113 | self.ssim_log = torch.cat([self.ssim_log, log]) 114 | 115 | 116 | 117 | 118 | def write_log(self, log, refresh=False): 119 | print(log) 120 | self.log_file.write(log + '\n') 121 | if refresh: 122 | self.log_file.close() 123 | self.log_file = open(self.get_path('log.txt'), 'a') 124 | 125 | def write_eval(self, log): 126 | self.eval_file.write(log + '\n') 127 | 128 | def done(self): 129 | self.log_file.close() 130 | self.eval_file.close() 131 | 132 | def plot_psnr(self, epoch): 133 | axis = np.linspace(1, epoch, epoch) 134 | for idx_data, d in enumerate(self.args.data_test): 135 | label = 'SR on {}'.format(d) 136 | fig = plt.figure() 137 | plt.title(label) 138 | for idx_scale, scale in enumerate(self.args.scale): 139 | plt.plot( 140 | axis, 141 | self.log[:, idx_data, idx_scale].numpy(), 142 | label='Scale {}'.format(scale) 143 | ) 144 | plt.legend() 145 | plt.xlabel('Epochs') 146 | plt.ylabel('PSNR') 147 | plt.grid(True) 148 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 149 | plt.close(fig) 150 | 151 | def begin_background(self): 152 | self.queue = Queue() 153 | 154 | def bg_target(queue): 155 | while True: 156 | if not queue.empty(): 157 | filename, tensor = queue.get() 158 | if filename is None: break 159 | imageio.imwrite(filename, tensor.numpy()) 160 | 161 | self.process = [ 162 | Process(target=bg_target, args=(self.queue,)) \ 163 | for _ in range(self.n_processes) 164 | ] 165 | 166 | for p in self.process: p.start() 167 | 168 | def end_background(self): 169 | for _ in range(self.n_processes): self.queue.put((None, None)) 170 | while not self.queue.empty(): time.sleep(1) 171 | for p in self.process: p.join() 172 | 173 | def save_results(self, dataset, filename, save_list, scale): 174 | if self.args.save_results: 175 | filename = self.get_path( 176 | 'results-{}'.format(dataset.dataset.name), 177 | '{}_x{}_'.format(filename, scale) 178 | ) 179 | 180 | postfix = ('SR', 'LR', 'HR') 181 | for v, p in zip(save_list, postfix): 182 | normalized = v[0].mul(255 / self.args.rgb_range) 183 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 184 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 185 | 186 | def quantize(img, rgb_range): 187 | pixel_range = 255 / rgb_range 188 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 189 | 190 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 191 | if hr.nelement() == 1: return 0 192 | diff = (sr - hr) / rgb_range 193 | if dataset and dataset.dataset.benchmark: 194 | shave = scale 195 | if diff.size(1) > 1: 196 | gray_coeffs = [65.738, 129.057, 25.064] 197 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 198 | diff = diff.mul(convert).sum(dim=1) 199 | else: 200 | shave = scale + 6 201 | 202 | valid = diff[..., shave:-shave, shave:-shave] 203 | mse = valid.pow(2).mean() 204 | 205 | return -10 * math.log10(mse) 206 | 207 | def ssim(img1, img2): 208 | C1 = (0.01 * 255)**2 209 | C2 = (0.03 * 255)**2 210 | 211 | img1 = img1.astype(np.float64) 212 | img2 = img2.astype(np.float64) 213 | kernel = cv2.getGaussianKernel(11, 1.5) 214 | window = np.outer(kernel, kernel.transpose()) 215 | 216 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 217 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 218 | mu1_sq = mu1**2 219 | mu2_sq = mu2**2 220 | mu1_mu2 = mu1 * mu2 221 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 222 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 223 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 224 | 225 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 226 | (sigma1_sq + sigma2_sq + C2)) 227 | return ssim_map.mean() 228 | 229 | # import lpips 230 | 231 | # def calc_lpips(img1, img2, rgb_range, loss_fn): 232 | # img1 = img1.div(rgb_range).clamp_(0, 1) # LPIPS img range = [0, 1] 233 | # img2 = img2.div(rgb_range).clamp_(0, 1) 234 | 235 | # return loss_fn(img1, img2) 236 | 237 | 238 | def bgr2ycbcr(img, only_y=True): 239 | '''same as matlab rgb2ycbcr 240 | only_y: only return Y channel 241 | Input: 242 | uint8, [0, 255] 243 | float, [0, 1] 244 | ''' 245 | in_img_type = img.dtype 246 | img.astype(np.float32) 247 | if in_img_type != np.uint8: 248 | img *= 255. 249 | # convert 250 | if only_y: 251 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 252 | else: 253 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 254 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 255 | if in_img_type == np.uint8: 256 | rlt = rlt.round() 257 | else: 258 | rlt /= 255. 259 | return rlt.astype(in_img_type) 260 | 261 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 262 | ''' 263 | Converts a torch Tensor into an image Numpy array 264 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 265 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 266 | ''' 267 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 268 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 269 | n_dim = tensor.dim() 270 | if n_dim == 4: 271 | n_img = len(tensor) 272 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 273 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 274 | elif n_dim == 3: 275 | img_np = tensor.numpy() 276 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 277 | elif n_dim == 2: 278 | img_np = tensor.numpy() 279 | else: 280 | raise TypeError( 281 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 282 | if out_type == np.uint8: 283 | img_np = (img_np * 255.0).round() 284 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 285 | return img_np.astype(out_type) 286 | 287 | 288 | def calc_ssim(img1, img2, scale=2, benchmark=False): 289 | # calc_ssim of SMSR 290 | '''calculate SSIM 291 | the same outputs as MATLAB's 292 | img1, img2: [0, 255] 293 | ''' 294 | if benchmark: 295 | border = math.ceil(scale) 296 | else: 297 | border = math.ceil(scale) + 6 298 | 299 | img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy() 300 | img1 = np.transpose(img1, (1, 2, 0)) 301 | img2 = img2.data.squeeze().cpu().numpy() 302 | img2 = np.transpose(img2, (1, 2, 0)) 303 | 304 | img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0 305 | img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0 306 | if not img1.shape == img2.shape: 307 | raise ValueError('Input images must have the same dimensions.') 308 | h, w = img1.shape[:2] 309 | img1_y = img1_y[border:h - border, border:w - border] 310 | img2_y = img2_y[border:h - border, border:w - border] 311 | 312 | if img1_y.ndim == 2: 313 | return ssim(img1_y, img2_y) 314 | elif img1.ndim == 3: 315 | if img1.shape[2] == 3: 316 | ssims = [] 317 | for i in range(3): 318 | ssims.append(ssim(img1, img2)) 319 | return np.array(ssims).mean() 320 | elif img1.shape[2] == 1: 321 | return ssim(np.squeeze(img1), np.squeeze(img2)) 322 | else: 323 | raise ValueError('Wrong input image dimensions.') 324 | 325 | def make_optimizer(args, target): 326 | ''' 327 | make optimizer and scheduler together 328 | ''' 329 | # optimizer 330 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 331 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 332 | 333 | if args.optimizer == 'SGD': 334 | optimizer_class = optim.SGD 335 | kwargs_optimizer['momentum'] = args.momentum 336 | elif args.optimizer == 'ADAM': 337 | optimizer_class = optim.Adam 338 | kwargs_optimizer['betas'] = args.betas 339 | kwargs_optimizer['eps'] = args.epsilon 340 | elif args.optimizer == 'RMSprop': 341 | optimizer_class = optim.RMSprop 342 | kwargs_optimizer['eps'] = args.epsilon 343 | 344 | # scheduler 345 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 346 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 347 | scheduler_class = lrs.MultiStepLR 348 | 349 | class CustomOptimizer(optimizer_class): 350 | def __init__(self, *args, **kwargs): 351 | super(CustomOptimizer, self).__init__(*args, **kwargs) 352 | 353 | def _register_scheduler(self, scheduler_class, **kwargs): 354 | self.scheduler = scheduler_class(self, **kwargs) 355 | 356 | def save(self, save_dir): 357 | torch.save(self.state_dict(), self.get_dir(save_dir)) 358 | 359 | def load(self, load_dir, epoch=1): 360 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 361 | if epoch > 1: 362 | for _ in range(epoch): self.scheduler.step() 363 | 364 | def get_dir(self, dir_path): 365 | return os.path.join(dir_path, 'optimizer.pt') 366 | 367 | def schedule(self): 368 | self.scheduler.step() 369 | 370 | def get_lr(self): 371 | return self.scheduler.get_lr()[0] 372 | 373 | def get_last_epoch(self): 374 | return self.scheduler.last_epoch 375 | 376 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 377 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 378 | return optimizer 379 | 380 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sheldon04/CABM-pytorch/0634f7e9539fba97f094d172b65651ea14c5c4f8/utils/__init__.py -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3.6 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | from pathlib import Path 6 | import datetime 7 | import shutil 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import logging 11 | import coloredlogs 12 | import os 13 | import cv2 14 | import torch 15 | import functools 16 | import numpy as np 17 | import math 18 | from torchvision.utils import make_grid 19 | import matplotlib.pyplot as plt 20 | import random 21 | from decimal import Decimal 22 | 23 | from option import args 24 | from model.quant_ops import quant_act_pams 25 | 26 | from model.edge import BitSelector 27 | from model.cadyq import BitSelector as BitSelector_org 28 | import matplotlib 29 | matplotlib.use('Agg') 30 | import matplotlib.pyplot as plt 31 | 32 | class AverageMeter(object): 33 | def __init__(self): 34 | self.val = 0 35 | self.avg = 0 36 | self.sum = 0 37 | self.count = 0 38 | 39 | def reset(self): 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | 45 | def update(self, val, n=1): 46 | self.val = val 47 | self.sum += val * n 48 | self.count += n 49 | if self.count > 0: 50 | self.avg = self.sum / self.count 51 | 52 | def accumulate(self, val, n=1): 53 | self.sum += val 54 | self.count += n 55 | if self.count > 0: 56 | self.avg = self.sum / self.count 57 | 58 | class Logger(object): 59 | def __init__(self, fpath, title=None, resume=False): 60 | self.file = None 61 | self.resume = resume 62 | self.title = '' if title == None else title 63 | if fpath is not None: 64 | if resume: 65 | self.file = open(fpath, 'r') 66 | name = self.file.readline() 67 | self.names = name.rstrip().split('\t') 68 | self.numbers = {} 69 | for _, name in enumerate(self.names): 70 | self.numbers[name] = [] 71 | 72 | for numbers in self.file: 73 | numbers = numbers.rstrip().split('\t') 74 | for i in range(0, len(numbers)): 75 | self.numbers[self.names[i]].append(numbers[i]) 76 | self.file.close() 77 | self.file = open(fpath, 'a') 78 | else: 79 | self.file = open(fpath, 'w') 80 | 81 | def set_names(self, names): 82 | if self.resume: 83 | pass 84 | # initialize numbers as empty list 85 | self.numbers = {} 86 | self.names = names 87 | for _, name in enumerate(self.names): 88 | self.file.write(name) 89 | self.file.write('\t') 90 | self.numbers[name] = [] 91 | self.file.write('\n') 92 | self.file.flush() 93 | 94 | def append(self, numbers): 95 | assert len(self.names) == len(numbers), 'Numbers do not match names' 96 | for index, num in enumerate(numbers): 97 | self.file.write("{0:.6f}".format(num)) 98 | self.file.write('\t') 99 | self.numbers[self.names[index]].append(num) 100 | self.file.write('\n') 101 | self.file.flush() 102 | 103 | def plot(self, names=None): 104 | names = self.names if names == None else names 105 | numbers = self.numbers 106 | for _, name in enumerate(names): 107 | x = np.arange(len(numbers[name])) 108 | plt.plot(x, np.asarray(numbers[name])) 109 | plt.legend([self.title + '(' + name + ')' for name in names]) 110 | plt.grid(True) 111 | 112 | def close(self): 113 | if self.file is not None: 114 | self.file.close() 115 | 116 | 117 | def accuracy(output, target, topk=(1,)): 118 | maxk = max(topk) 119 | batch_size = target.size(0) 120 | 121 | _, pred = output.topk(maxk, 1, True, True) 122 | pred = pred.t() 123 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 124 | 125 | res = [] 126 | for k in topk: 127 | correct_k = correct[:k].view(-1).float().sum(0) 128 | res.append(correct_k.mul_(100.0 / batch_size)) 129 | return res 130 | 131 | USE_CUDA = torch.cuda.is_available() 132 | FLOAT = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 133 | from torch.autograd import Variable 134 | 135 | 136 | def to_numpy(var): 137 | # return var.cpu().data.numpy() 138 | return var.cpu().data.numpy() if USE_CUDA else var.data.numpy() 139 | 140 | 141 | def to_tensor(ndarray, volatile=False, requires_grad=False, dtype=FLOAT): 142 | return Variable( 143 | torch.from_numpy(ndarray), volatile=volatile, requires_grad=requires_grad 144 | ).type(dtype) 145 | 146 | 147 | def sample_from_truncated_normal_distribution(lower, upper, mu, sigma, size=1): 148 | from scipy import stats 149 | return stats.truncnorm.rvs((lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma, size=size) 150 | 151 | 152 | # logging 153 | def prRed(prt): print("\033[91m {}\033[00m" .format(prt)) 154 | def prGreen(prt): print("\033[92m {}\033[00m" .format(prt)) 155 | def prYellow(prt): print("\033[93m {}\033[00m" .format(prt)) 156 | def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt)) 157 | def prPurple(prt): print("\033[95m {}\033[00m" .format(prt)) 158 | def prCyan(prt): print("\033[96m {}\033[00m" .format(prt)) 159 | def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt)) 160 | def prBlack(prt): print("\033[98m {}\033[00m" .format(prt)) 161 | 162 | 163 | def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1): 164 | logsoftmax = nn.LogSoftmax() 165 | n_classes = pred.size(1) 166 | # convert to one-hot 167 | target = torch.unsqueeze(target, 1) 168 | soft_target = torch.zeros_like(pred) 169 | soft_target.scatter_(1, target, 1) 170 | # label smoothing 171 | soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes 172 | return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1)) 173 | 174 | 175 | def wrapped_partial(func, *args, **kwargs): 176 | partial_func = functools.partial(func, *args, **kwargs) 177 | functools.update_wrapper(partial_func, func) 178 | return partial_func 179 | 180 | 181 | def get_logger(file_path, name='ED'): 182 | """ Make python logger """ 183 | # [!] Since tensorboardX use default logger (e.g. logging.info()), we should use custom logger 184 | logger = logging.getLogger(name) 185 | coloredlogs.install(level='INFO', logger=logger) 186 | 187 | log_format = '%(asctime)s | %(message)s' 188 | formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p') 189 | file_handler = logging.FileHandler(file_path) 190 | file_handler.setFormatter(formatter) 191 | 192 | logger.addHandler(file_handler) 193 | 194 | return logger 195 | 196 | def print_params(config, prtf=print): 197 | prtf("") 198 | prtf("Parameters:") 199 | for attr, value in sorted(config.items()): 200 | prtf("{}={}".format(attr.upper(), value)) 201 | prtf("") 202 | 203 | 204 | def as_markdown(config): 205 | """ Return configs as markdown format """ 206 | text = "|name|value| \n|-|-| \n" 207 | for attr, value in sorted(config.items()): 208 | text += "|{}|{}| \n".format(attr, value) 209 | 210 | return text 211 | 212 | def at(x): 213 | return F.normalize(x.pow(2).mean(1).view(x.size(0), -1)) 214 | 215 | def at_loss(x, y): 216 | return (at(x) - at(y)).pow(2).mean() 217 | 218 | # def distillation(criterion,outputs, labels, teacher_outputs, params): 219 | # """ 220 | # Compute the knowledge-distillation (KD) loss given outputs, labels. 221 | # "Hyperparameters": temperature and alpha 222 | # NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher 223 | # and student expects the input tensor to be log probabilities! See Issue #2 224 | # """ 225 | # alpha = params.alpha 226 | # T = params.temperature 227 | # KD_loss = nn.KLDivLoss(reduction='mean')(torch.nn.functional.log_softmax(outputs/T, dim=1), 228 | # torch.nn.functional.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) +\ 229 | # criterion(outputs, labels) * (1. - alpha) 230 | # return KD_loss 231 | 232 | def distillation(y, teacher_scores, labels, T, alpha): 233 | p = F.log_softmax(y/T, dim=1) 234 | q = F.softmax(teacher_scores/T, dim=1) 235 | l_kl = F.kl_div(p, q, reduction='sum') * (T**2) / y.shape[0] 236 | l_ce = F.cross_entropy(y, labels) 237 | return l_kl * alpha + l_ce * (1. - alpha) 238 | 239 | 240 | def pix_loss(x,y): 241 | loss = torch.mean(torch.mean(torch.abs(x-y), dim = (1,2,3))) 242 | return loss 243 | 244 | #################### 245 | # image convert 246 | #################### 247 | 248 | def _make_dir(path): 249 | if not os.path.exists(path): os.makedirs(path) 250 | 251 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 252 | ''' 253 | Converts a torch Tensor into an image Numpy array 254 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 255 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 256 | ''' 257 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 258 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 259 | n_dim = tensor.dim() 260 | if n_dim == 4: 261 | n_img = len(tensor) 262 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 263 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 264 | elif n_dim == 3: 265 | img_np = tensor.numpy() 266 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 267 | elif n_dim == 2: 268 | img_np = tensor.numpy() 269 | else: 270 | raise TypeError( 271 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 272 | if out_type == np.uint8: 273 | img_np = (img_np * 255.0).round() 274 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 275 | return img_np.astype(out_type) 276 | 277 | 278 | def save_img(img, img_path, mode='RGB'): 279 | cv2.imwrite(img_path, img) 280 | 281 | 282 | def get_activation(name,activation): 283 | def hook(model, input, output): 284 | activation[name] = output 285 | return hook 286 | 287 | def plot_loss(args,loss,apath,epoch): 288 | axis = np.linspace(1, epoch, epoch) 289 | for i, l in enumerate(loss): 290 | label = '{} Loss'.format(l['type']) 291 | fig = plt.figure() 292 | plt.title(label) 293 | plt.plot(axis, log[:, i].numpy(), label=label) 294 | plt.legend() 295 | plt.xlabel('Epochs') 296 | plt.ylabel('Loss') 297 | plt.grid(True) 298 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) 299 | plt.close(fig) 300 | 301 | def plot_psnr(args,apath,epoch,log): 302 | 303 | axis = np.linspace(1, epoch, epoch) 304 | for idx_data, d in enumerate(args.data_test): 305 | label = 'SR on {}'.format(d) 306 | fig = plt.figure() 307 | plt.title(label) 308 | for idx_scale, scale in enumerate(args.scale): 309 | plt.plot( 310 | axis, 311 | log[:, idx_data, idx_scale].numpy(), 312 | label='Scale {}'.format(scale) 313 | ) 314 | plt.legend() 315 | plt.xlabel('Epochs') 316 | plt.ylabel('PSNR') 317 | plt.grid(True) 318 | plt.savefig(os.path.join(apath, 'test_{}_{}.png'.format(d, args.save))) 319 | plt.close(fig) 320 | 321 | def plot_bit(args,apath,epoch,log): 322 | 323 | axis = np.linspace(1, epoch, epoch) 324 | for idx_data, d in enumerate(args.data_test): 325 | label = 'SR on {}'.format(d) 326 | fig = plt.figure() 327 | plt.title(label) 328 | for idx_scale, scale in enumerate(args.scale): 329 | plt.plot( 330 | axis, 331 | log[:, idx_data, idx_scale].numpy(), 332 | label='Scale {}'.format(scale) 333 | ) 334 | plt.legend() 335 | plt.xlabel('Epochs') 336 | plt.ylabel('Avg Bit') 337 | plt.grid(True) 338 | plt.savefig(os.path.join(apath, 'test_{}_bit_{}.png'.format(d, args.save))) 339 | plt.close(fig) 340 | 341 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', lpips=False): 342 | filepath = os.path.join(checkpoint, filename) 343 | torch.save(state, filepath) 344 | if is_best: 345 | if lpips: 346 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best_lpips.pth.tar')) 347 | else: 348 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) 349 | else: 350 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_latest.pth.tar')) 351 | 352 | 353 | def laplacian(image): 354 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 355 | laplac = cv2.Laplacian(gray, cv2.CV_16S, ksize=3) 356 | mask_img = cv2.convertScaleAbs(laplac) 357 | return mask_img 358 | 359 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 360 | ''' 361 | Converts a torch Tensor into an image Numpy array 362 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 363 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 364 | ''' 365 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 366 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 367 | n_dim = tensor.dim() 368 | if n_dim == 4: 369 | n_img = len(tensor) 370 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 371 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 372 | elif n_dim == 3: 373 | img_np = tensor.numpy() 374 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 375 | elif n_dim == 2: 376 | img_np = tensor.numpy() 377 | else: 378 | raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 379 | if out_type == np.uint8: 380 | img_np = (img_np * 255.0).round() 381 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 382 | return img_np.astype(out_type) 383 | 384 | 385 | 386 | def set_bit_config(model, bit_config): 387 | for n, m in model.named_modules(): 388 | if isinstance(m, quant_act_pams): 389 | plist = n.split('.') 390 | block_index = int(plist[1]) 391 | quant_index = int(plist[2][-1]) 392 | # print(f'bindex:{block_index} qindex:{quant_index}') 393 | if quant_index != 3: 394 | setattr(m, 'k_bits', bit_config[block_index*2 + quant_index - 1]) 395 | 396 | def set_bit_flag(model, flag): 397 | # flag -> batch_size bit_width 398 | total_index = 0 399 | for n, m in model.named_modules(): 400 | if isinstance(m, BitSelector): 401 | cur_list = [] 402 | for i in range(len(flag)): 403 | cur_list.append(flag[i][total_index]) 404 | setattr(m, 'flag', torch.tensor(cur_list, dtype=torch.int32)) 405 | total_index += 1 406 | 407 | def get_bit_config(model): 408 | bit_list = [] 409 | flag=0 410 | for n, m in model.named_modules(): 411 | flag=0 412 | if isinstance(m, BitSelector_org): 413 | if int(getattr(m, 'bits_out')) == args.search_space[2]: 414 | flag = 2 415 | elif int(getattr(m, 'bits_out')) == args.search_space[1]: 416 | flag = 1 417 | else: 418 | flag = 0 419 | bit_list.append(flag) 420 | return bit_list 421 | 422 | 423 | def random_pick(some_list, probabilities): 424 | x = random.uniform(0,1) 425 | cumulative_probability = 0.0 426 | for item, item_probability in zip(some_list, probabilities): 427 | cumulative_probability += item_probability 428 | if x < cumulative_probability: 429 | break 430 | return item 431 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3.6 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | import pandas as pd 6 | import pdb 7 | import os 8 | 9 | class Logger(object): 10 | '''Save training process to log file with simple plot function.''' 11 | def __init__(self, csv_path='results.csv', log_path=None, resume=False): 12 | self.path = csv_path 13 | self.figures = [] 14 | self.results = None 15 | 16 | def add(self, logger_dict): 17 | df = pd.DataFrame([logger_dict.values()], columns=logger_dict.keys()) 18 | if self.results is None: 19 | self.results = df 20 | else: 21 | self.results = self.results.append(df, ignore_index=True) 22 | 23 | def save(self, title='Training Results'): 24 | self.results.to_csv(self.path, index=False, index_label=False) 25 | 26 | def load(self, path=None): 27 | path = path or self.path 28 | if os.path.isfile(path): 29 | self.results = pd.read_csv(path) 30 | 31 | def mask_log(self): 32 | return logging 33 | 34 | 35 | -------------------------------------------------------------------------------- /utils/utility.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3.6 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import math 6 | import time 7 | import datetime 8 | from multiprocessing import Process 9 | from multiprocessing import Queue 10 | 11 | import matplotlib 12 | matplotlib.use('Agg') 13 | import matplotlib.pyplot as plt 14 | 15 | import numpy as np 16 | import imageio 17 | 18 | import torch 19 | import torch.optim as optim 20 | import torch.optim.lr_scheduler as lrs 21 | 22 | class timer(): 23 | def __init__(self): 24 | self.acc = 0 25 | self.tic() 26 | 27 | def tic(self): 28 | self.t0 = time.time() 29 | 30 | def toc(self, restart=False): 31 | diff = time.time() - self.t0 32 | if restart: self.t0 = time.time() 33 | return diff 34 | 35 | def hold(self): 36 | self.acc += self.toc() 37 | 38 | def release(self): 39 | ret = self.acc 40 | self.acc = 0 41 | 42 | return ret 43 | 44 | def reset(self): 45 | self.acc = 0 46 | 47 | class checkpoint(): 48 | def __init__(self, args): 49 | self.args = args 50 | self.ok = True 51 | self.log = torch.Tensor() 52 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 53 | 54 | if not args.load: 55 | if not args.save: 56 | args.save = now 57 | # self.dir = args.job_dir 58 | self.dir = os.path.join( 'experiment', args.save) 59 | 60 | else: 61 | # self.dir = os.path.join(args.job_dir,args.load) 62 | self.dir = os.path.join( 'experiment', args.load) 63 | if os.path.exists(self.dir): 64 | self.log = torch.load(self.get_path('psnr_log.pt')) 65 | print('Continue from epoch {}...'.format(len(self.log))) 66 | else: 67 | args.load = '' 68 | 69 | if args.reset: 70 | os.system('rm -rf ' + self.dir) 71 | args.load = '' 72 | 73 | os.makedirs(self.dir, exist_ok=True) 74 | os.makedirs(self.get_path('model'), exist_ok=True) 75 | for d in args.data_test: 76 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) 77 | 78 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' 79 | self.log_file = open(self.get_path('log.txt'), open_type) 80 | with open(self.get_path('config.txt'), open_type) as f: 81 | f.write(now + '\n\n') 82 | for arg in vars(args): 83 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 84 | f.write('\n') 85 | 86 | self.n_processes = 8 87 | 88 | def get_path(self, *subdir): 89 | return os.path.join(self.dir, *subdir) 90 | 91 | def save(self, trainer, epoch, is_best=False): 92 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best) 93 | trainer.loss.save(self.dir) 94 | trainer.loss.plot_loss(self.dir, epoch) 95 | 96 | self.plot_psnr(epoch) 97 | trainer.optimizer.save(self.dir) 98 | torch.save(self.log, self.get_path('psnr_log.pt')) 99 | 100 | def add_log(self, log): 101 | self.log = torch.cat([self.log, log]) 102 | 103 | def write_log(self, log, refresh=False): 104 | print(log) 105 | self.log_file.write(log + '\n') 106 | if refresh: 107 | self.log_file.close() 108 | self.log_file = open(self.get_path('log.txt'), 'a') 109 | 110 | def done(self): 111 | self.log_file.close() 112 | 113 | def plot_psnr(self, epoch): 114 | axis = np.linspace(1, epoch, epoch) 115 | for idx_data, d in enumerate(self.args.data_test): 116 | label = 'SR on {}'.format(d) 117 | fig = plt.figure() 118 | plt.title(label) 119 | for idx_scale, scale in enumerate(self.args.scale): 120 | plt.plot( 121 | axis, 122 | self.log[:, idx_data, idx_scale].numpy(), 123 | label='Scale {}'.format(scale) 124 | ) 125 | plt.legend() 126 | plt.xlabel('Epochs') 127 | plt.ylabel('PSNR') 128 | plt.grid(True) 129 | plt.savefig(self.get_path('test_{}.pdf'.format(d))) 130 | plt.close(fig) 131 | 132 | def begin_background(self): 133 | self.queue = Queue() 134 | 135 | def bg_target(queue): 136 | while True: 137 | if not queue.empty(): 138 | filename, tensor = queue.get() 139 | if filename is None: break 140 | imageio.imwrite(filename, tensor.numpy()) 141 | 142 | self.process = [ 143 | Process(target=bg_target, args=(self.queue,)) \ 144 | for _ in range(self.n_processes) 145 | ] 146 | 147 | for p in self.process: p.start() 148 | 149 | def end_background(self): 150 | for _ in range(self.n_processes): self.queue.put((None, None)) 151 | while not self.queue.empty(): time.sleep(1) 152 | for p in self.process: p.join() 153 | 154 | def save_results(self, dataset, filename, save_list, scale): 155 | if self.args.save_results: 156 | filename = self.get_path( 157 | 'results-{}'.format(dataset.dataset.name), 158 | '{}_x{}_'.format(filename, scale) 159 | ) 160 | 161 | postfix = ('SR', 'LR', 'HR') 162 | for v, p in zip(save_list, postfix): 163 | normalized = v[0].mul(255 / self.args.rgb_range) 164 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() 165 | # import pdb; pdb.set_trace() 166 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) 167 | 168 | def quantize(img, rgb_range): 169 | pixel_range = 255 / rgb_range 170 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 171 | 172 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None): 173 | if hr.nelement() == 1: return 0 174 | 175 | diff = (sr - hr) / rgb_range 176 | if dataset and dataset.dataset.benchmark: 177 | shave = scale 178 | if diff.size(1) > 1: 179 | gray_coeffs = [65.738, 129.057, 25.064] 180 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 181 | diff = diff.mul(convert).sum(dim=1) 182 | else: 183 | shave = scale + 6 184 | 185 | valid = diff[..., shave:-shave, shave:-shave] 186 | mse = valid.pow(2).mean() 187 | 188 | return -10 * math.log10(mse) 189 | 190 | def make_optimizer(args, target): 191 | ''' 192 | make optimizer and scheduler together 193 | ''' 194 | # optimizer 195 | trainable = filter(lambda x: x.requires_grad, target.parameters()) 196 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} 197 | 198 | if args.optimizer == 'SGD': 199 | optimizer_class = optim.SGD 200 | kwargs_optimizer['momentum'] = args.momentum 201 | elif args.optimizer == 'ADAM': 202 | optimizer_class = optim.Adam 203 | kwargs_optimizer['betas'] = args.betas 204 | kwargs_optimizer['eps'] = args.epsilon 205 | elif args.optimizer == 'RMSprop': 206 | optimizer_class = optim.RMSprop 207 | kwargs_optimizer['eps'] = args.epsilon 208 | 209 | # scheduler 210 | milestones = list(map(lambda x: int(x), args.decay.split('-'))) 211 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} 212 | scheduler_class = lrs.MultiStepLR 213 | 214 | class CustomOptimizer(optimizer_class): 215 | def __init__(self, *args, **kwargs): 216 | super(CustomOptimizer, self).__init__(*args, **kwargs) 217 | 218 | def _register_scheduler(self, scheduler_class, **kwargs): 219 | self.scheduler = scheduler_class(self, **kwargs) 220 | 221 | def save(self, save_dir): 222 | torch.save(self.state_dict(), self.get_dir(save_dir)) 223 | 224 | def load(self, load_dir, epoch=1): 225 | self.load_state_dict(torch.load(self.get_dir(load_dir))) 226 | if epoch > 1: 227 | for _ in range(epoch): self.scheduler.step() 228 | 229 | def get_dir(self, dir_path): 230 | return os.path.join(dir_path, 'optimizer.pt') 231 | 232 | def schedule(self): 233 | self.scheduler.step() 234 | 235 | def get_lr(self): 236 | return self.scheduler.get_lr()[0] 237 | 238 | def get_last_epoch(self): 239 | return self.scheduler.last_epoch 240 | 241 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer) 242 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) 243 | return optimizer 244 | 245 | --------------------------------------------------------------------------------