├── README.md ├── code ├── data │ ├── __init__.py │ ├── benchmark.py │ ├── common.py │ ├── demo.py │ ├── rainheavy.py │ ├── rainheavytest.py │ └── srdata.py ├── dataloader.py ├── dataset │ ├── test │ │ ├── data │ │ │ ├── norain-1.png │ │ │ └── norain-2.png │ │ └── label │ │ │ ├── norain-1.png │ │ │ └── norain-2.png │ └── train │ │ ├── data │ │ ├── norain-1.png │ │ └── norain-2.png │ │ └── label │ │ ├── norain-1.png │ │ └── norain-2.png ├── loss │ ├── __init__.py │ └── ssim.py ├── main.py ├── model │ ├── __init__.py │ ├── common.py │ └── hct-ffn.py ├── option.py ├── template.py ├── trainer.py ├── util │ ├── rlutrans.py │ └── tools.py └── utility.py ├── experiment └── HCT-FFN │ └── model │ ├── model_best_Rain100H.pt │ └── model_best_Rain100L.pt └── figure └── network.png /README.md: -------------------------------------------------------------------------------- 1 | # Hybrid CNN-Transformer Feature Fusion for Single Image Deraining 2 | 3 | Xiang Chen, Jinshan Pan, Jiyang Lu, Zhentao Fan, Hao Li 4 | 5 |
6 | 7 | > **Abstract:** *Since rain streaks exhibit diverse geometric appearances and irregular overlapped phenomena, these complex characteristics challenge the design of an effective single image deraining model. To this end, rich local-global information representations are increasingly indispensable for better satisfying rain removal. In this paper, we propose a lightweight Hybrid CNN-Transformer Feature Fusion Network (dubbed as HCT-FFN) in a stage-by-stage progressive manner, which can harmonize these two architectures to help image restoration by leveraging their individual learning strengths. Specifically, we stack a sequence of the degradation-aware mixture of experts (DaMoE) modules in the CNN-based stage, where appropriate local experts adaptively enable the model to emphasize spatially-varying rain distribution features. As for the Transformer-based stage, a background-aware vision Transformer (BaViT) module is employed to complement spatially-long feature dependencies of images, so as to achieve global texture recovery while preserving the required structure. Considering the indeterminate knowledge discrepancy among CNN features and Transformer features, we introduce an interactive fusion branch at adjacent stages to further facilitate the reconstruction of high-quality deraining results. Extensive evaluations show the effectiveness and extensibility of our developed HCT-FFN.* 8 |
9 | 10 | ## Network Architecture 11 | 12 | 13 | 14 | ## Installation 15 | * PyTorch == 0.4.1 16 | * torchvision == 0.2.0 17 | * Python == 3.6.0 18 | * imageio == 2.5.0 19 | * numpy == 1.14.0 20 | * opencv-python 21 | * scikit-image == 0.13.0 22 | * tqdm == 4.32.2 23 | * scipy == 1.2.1 24 | * matplotlib == 3.1.1 25 | * ipython == 7.6.1 26 | * h5py == 2.10.0 27 | 28 | ## Training 29 | 1. Modify data path in code/data/rainheavy.py and code/data/rainheavytest.py
30 | datapath/data/\*\*\*.png
31 | datapath/label/\*\*\*.png 32 | 33 | 2. Begining Training: 34 | ``` 35 | $ cd ./code/ 36 | $ python main.py --save HCT-FFN --model hct-ffn --scale 2 --epochs 400 --batch_size 4 --patch_size 128 --data_train RainHeavy --n_threads 0 --data_test RainHeavyTest --data_range 1-1800/1-100 --loss 1*MSE+0.2*SSIM --save_results --lr 1e-4 --n_feats 32 --n_resblocks 3 37 | ``` 38 | 39 | ## Testing 40 | ``` 41 | $ cd ./code/ 42 | $ python main.py --data_test RainHeavyTest --ext img --scale 2 --data_range 1-1800/1-100 --pre_train ../experiment/HCT-FFN/model/model_best.pt --model hct-ffn --test_only --save_results --save HCT-FFN_test 43 | ``` 44 | The pre-trained models are available at ./experiment/HCT-FFN/model/. 45 | 46 | ## Performance Evaluation 47 | 48 | The PSNR and SSIM results are computed by using this [Matlab Code](https://github.com/hongwang01/RCDNet/tree/master/Performance_evaluation), based on Y channel of YCbCr space. 49 | 50 | ## Visual Deraining Results 51 | 52 | https://drive.google.com/drive/folders/1soXkMuQEQmJZmxZBIlo8dfCHM0RlGtxz?usp=sharing 53 | 54 | ## Citation 55 | If you are interested in this work, please consider citing: 56 | 57 | @inproceedings{chen2023hybrid, 58 | title={Hybrid CNN-Transformer Feature Fusion for Single Image Deraining}, 59 | author={Chen, Xiang and Pan, Jinshan and Lu, Jiyang and Fan, Zhentao and Li, Hao}, 60 | booktitle={AAAI}, 61 | year={2023} 62 | } 63 | 64 | ## Acknowledgment 65 | This code is based on the [SPDNet](https://github.com/Joyies/SPDNet). Thanks for sharing ! 66 | -------------------------------------------------------------------------------- /code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | from dataloader import MSDataLoader 4 | from torch.utils.data.dataloader import default_collate 5 | 6 | class Data: 7 | def __init__(self, args): 8 | self.loader_train = None 9 | if not args.test_only: 10 | module_train = import_module('data.' + args.data_train.lower()) 11 | trainset = getattr(module_train, args.data_train)(args) 12 | self.loader_train = MSDataLoader( 13 | args, 14 | trainset, 15 | batch_size=args.batch_size, 16 | shuffle=True, 17 | pin_memory=not args.cpu 18 | ) 19 | 20 | if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']: 21 | module_test = import_module('data.benchmark') 22 | testset = getattr(module_test, 'Benchmark')(args, train=False) 23 | else: 24 | module_test = import_module('data.' + args.data_test.lower()) 25 | testset = getattr(module_test, args.data_test)(args, train=False) 26 | 27 | self.loader_test = MSDataLoader( 28 | args, 29 | testset, 30 | batch_size=1, 31 | shuffle=False, 32 | pin_memory=not args.cpu 33 | ) 34 | 35 | -------------------------------------------------------------------------------- /code/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 21 | self.ext = ('', '.jpg') 22 | 23 | -------------------------------------------------------------------------------- /code/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | def get_patch(*args, patch_size=96, scale=1, multi_scale=False): 10 | ih, iw = args[0].shape[:2] 11 | 12 | #p = scale if multi_scale else 1 13 | #tp = p * patch_size 14 | #ip = tp // scale 15 | 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | 20 | ix = random.randrange(0, iw - ip + 1) 21 | iy = random.randrange(0, ih - ip + 1) 22 | 23 | #tx, ty = scale * ix, scale * iy 24 | tx, ty = ix, iy 25 | 26 | ret = [ 27 | args[0][iy:iy + ip, ix:ix + ip, :], 28 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 29 | ] 30 | 31 | return ret 32 | 33 | def set_channel(*args, n_channels=3): 34 | def _set_channel(img): 35 | if img.ndim == 2: 36 | img = np.expand_dims(img, axis=2) 37 | 38 | c = img.shape[2] 39 | if n_channels == 1 and c == 3: 40 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 41 | elif n_channels == 3 and c == 1: 42 | img = np.concatenate([img] * n_channels, 2) 43 | 44 | return img 45 | 46 | return [_set_channel(a) for a in args] 47 | 48 | def np2Tensor(*args, rgb_range=255): 49 | def _np2Tensor(img): 50 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 51 | tensor = torch.from_numpy(np_transpose).float() 52 | tensor.mul_(rgb_range / 255) 53 | 54 | return tensor 55 | 56 | return [_np2Tensor(a) for a in args] 57 | 58 | def augment(*args, hflip=True, rot=True): 59 | hflip = hflip and random.random() < 0.5 60 | vflip = rot and random.random() < 0.5 61 | rot90 = rot and random.random() < 0.5 62 | 63 | def _augment(img): 64 | if hflip: img = img[:, ::-1, :] 65 | # if vflip: img = img[::-1, :, :] 66 | # if rot90: img = img.transpose(1, 0, 2) 67 | 68 | return img 69 | 70 | return [_augment(a) for a in args] 71 | 72 | -------------------------------------------------------------------------------- /code/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import imageio 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, name='Demo', train=False, benchmark=False): 13 | self.args = args 14 | self.name = name 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.do_eval = False 19 | self.benchmark = benchmark 20 | 21 | self.filelist = [] 22 | for f in os.listdir(args.dir_demo): 23 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 24 | self.filelist.append(os.path.join(args.dir_demo, f)) 25 | self.filelist.sort() 26 | 27 | def __getitem__(self, idx): 28 | filename = os.path.split(self.filelist[idx])[-1] 29 | filename, _ = os.path.splitext(filename) 30 | lr = imageio.imread(self.filelist[idx]) 31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors) 32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) 33 | 34 | return lr_t, -1, filename 35 | 36 | def __len__(self): 37 | return len(self.filelist) 38 | 39 | def set_scale(self, idx_scale): 40 | self.idx_scale = idx_scale 41 | 42 | -------------------------------------------------------------------------------- /code/data/rainheavy.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class RainHeavy(srdata.SRData): 5 | def __init__(self, args, name='RainHeavy', train=True, benchmark=False): 6 | super(RainHeavy, self).__init__( 7 | args, name=name, train=train, benchmark=benchmark 8 | ) 9 | 10 | def _scan(self): 11 | names_hr, names_lr = super(RainHeavy, self)._scan() 12 | names_hr = names_hr[self.begin - 1:self.end] 13 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 14 | 15 | return names_hr, names_lr 16 | 17 | def _set_filesystem(self, dir_data): 18 | super(RainHeavy, self)._set_filesystem(dir_data) 19 | self.apath = './dataset/train/' # train data path 20 | 21 | print(self.apath) 22 | self.dir_hr = os.path.join(self.apath, 'label') 23 | self.dir_lr = os.path.join(self.apath, 'data') 24 | 25 | -------------------------------------------------------------------------------- /code/data/rainheavytest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class RainHeavyTest(srdata.SRData): 5 | def __init__(self, args, name='RainHeavyTest', train=True, benchmark=False): 6 | super(RainHeavyTest, self).__init__( 7 | args, name=name, train=train, benchmark=benchmark 8 | ) 9 | 10 | def _scan(self): 11 | names_hr, names_lr = super(RainHeavyTest, self)._scan() 12 | names_hr = names_hr[self.begin - 1:self.end] 13 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 14 | 15 | return names_hr, names_lr 16 | 17 | def _set_filesystem(self, dir_data): 18 | super(RainHeavyTest, self)._set_filesystem(dir_data) 19 | self.apath = './dataset/test/' # test data path 20 | print(self.apath) 21 | self.dir_hr = os.path.join(self.apath, 'label') 22 | self.dir_lr = os.path.join(self.apath, 'data') 23 | 24 | -------------------------------------------------------------------------------- /code/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | from data import common 5 | import pickle 6 | import numpy as np 7 | import imageio 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class SRData(data.Dataset): 13 | def __init__(self, args, name='', train=True, benchmark=False): 14 | self.args = args 15 | self.name = name 16 | self.train = train 17 | self.split = 'train' if train else 'test' 18 | self.do_eval = True 19 | self.benchmark = benchmark 20 | self.scale = args.scale 21 | self.idx_scale = 0 22 | 23 | data_range = [r.split('-') for r in args.data_range.split('/')] 24 | if train: 25 | data_range = data_range[0] 26 | else: 27 | if args.test_only and len(data_range) == 1: 28 | data_range = data_range[0] 29 | else: 30 | data_range = data_range[1] 31 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 32 | self._set_filesystem(args.dir_data) 33 | if args.ext.find('img') < 0: 34 | path_bin = os.path.join(self.apath, 'bin') 35 | os.makedirs(path_bin, exist_ok=True) 36 | 37 | list_hr, list_lr = self._scan() 38 | if args.ext.find('bin') >= 0: 39 | # Binary files are stored in 'bin' folder 40 | # If the binary file exists, load it. If not, make it. 41 | list_hr, list_lr = self._scan() 42 | self.images_hr = self._check_and_load( 43 | args.ext, list_hr, self._name_hrbin() 44 | ) 45 | self.images_lr = [ 46 | self._check_and_load(args.ext, l, self._name_lrbin(s)) \ 47 | for s, l in zip(self.scale, list_lr) 48 | ] 49 | else: 50 | if args.ext.find('img') >= 0 or benchmark: 51 | self.images_hr, self.images_lr = list_hr, list_lr 52 | elif args.ext.find('sep') >= 0: 53 | os.makedirs( 54 | self.dir_hr.replace(self.apath, path_bin), 55 | exist_ok=True 56 | ) 57 | for s in self.scale: 58 | os.makedirs( 59 | os.path.join( 60 | self.dir_lr.replace(self.apath, path_bin), 61 | 'X{}'.format(s) 62 | ), 63 | exist_ok=True 64 | ) 65 | 66 | self.images_hr, self.images_lr = [], [[] for _ in self.scale] 67 | for h in list_hr: 68 | b = h.replace(self.apath, path_bin) 69 | b = b.replace(self.ext[0], '.pt') 70 | self.images_hr.append(b) 71 | self._check_and_load( 72 | args.ext, [h], b, verbose=True, load=False 73 | ) 74 | 75 | for i, ll in enumerate(list_lr): 76 | for l in ll: 77 | b = l.replace(self.apath, path_bin) 78 | b = b.replace(self.ext[1], '.pt') 79 | self.images_lr[i].append(b) 80 | self._check_and_load( 81 | args.ext, [l], b, verbose=True, load=False 82 | ) 83 | 84 | if train: 85 | self.repeat \ 86 | = args.test_every // (len(self.images_hr) // args.batch_size) 87 | 88 | 89 | # Below functions as used to prepare images 90 | def _scan(self): 91 | names_hr = sorted( 92 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 93 | ) 94 | names_lr = [[] for _ in self.scale] 95 | for f in names_hr: 96 | #f = f.replace('.png','x2.png') 97 | f = f.replace('.png','.png') 98 | filename, _ = os.path.splitext(os.path.basename(f)) 99 | for si, s in enumerate(self.scale): 100 | names_lr[si].append(os.path.join( 101 | self.dir_lr, '{}{}'.format( 102 | filename, self.ext[1] 103 | ) 104 | )) 105 | 106 | return names_hr, names_lr 107 | 108 | def _set_filesystem(self, dir_data): 109 | self.apath = os.path.join(dir_data, self.name) 110 | self.dir_hr = os.path.join(self.apath, 'HR') 111 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 112 | self.ext = ('.png', '.png') 113 | 114 | def _name_hrbin(self): 115 | return os.path.join( 116 | self.apath, 117 | 'bin', 118 | '{}_bin_HR.pt'.format(self.split) 119 | ) 120 | 121 | def _name_lrbin(self, scale): 122 | return os.path.join( 123 | self.apath, 124 | 'bin', 125 | '{}_bin_LR.pt'.format(self.split) 126 | ) 127 | 128 | def _check_and_load(self, ext, l, f, verbose=True, load=True): 129 | if os.path.isfile(f) and ext.find('reset') < 0: 130 | if load: 131 | if verbose: print('Loading {}...'.format(f)) 132 | with open(f, 'rb') as _f: ret = pickle.load(_f) 133 | return ret 134 | else: 135 | return None 136 | else: 137 | if verbose: 138 | if ext.find('reset') >= 0: 139 | print('Making a new binary: {}'.format(f)) 140 | else: 141 | print('{} does not exist. Now making binary...'.format(f)) 142 | b = [{ 143 | 'name': os.path.splitext(os.path.basename(_l))[0], 144 | 'image': imageio.imread(_l) 145 | } for _l in l] 146 | with open(f, 'wb') as _f: pickle.dump(b, _f) 147 | return b 148 | 149 | def __getitem__(self, idx): 150 | lr, hr, filename = self._load_file(idx) 151 | lr, hr = self.get_patch(lr, hr) 152 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors) 153 | lr_tensor, hr_tensor = common.np2Tensor( 154 | lr, hr, rgb_range=self.args.rgb_range 155 | ) 156 | 157 | return lr_tensor, hr_tensor, filename 158 | 159 | def __len__(self): 160 | if self.train: 161 | return len(self.images_hr) * self.repeat 162 | else: 163 | return len(self.images_hr) 164 | 165 | def _get_index(self, idx): 166 | if self.train: 167 | return idx % len(self.images_hr) 168 | else: 169 | return idx 170 | 171 | def _load_file(self, idx): 172 | idx = self._get_index(idx) 173 | f_hr = self.images_hr[idx] 174 | f_lr = self.images_lr[self.idx_scale][idx] 175 | 176 | if self.args.ext.find('bin') >= 0: 177 | filename = f_hr['name'] 178 | hr = f_hr['image'] 179 | lr = f_lr['image'] 180 | else: 181 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 182 | if self.args.ext == 'img' or self.benchmark: 183 | hr = imageio.imread(f_hr) 184 | lr = imageio.imread(f_lr) 185 | elif self.args.ext.find('sep') >= 0: 186 | with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image'] 187 | with open(f_lr, 'rb') as _f: lr = np.load(_f)[0]['image'] 188 | 189 | return lr, hr, filename 190 | 191 | def get_patch(self, lr, hr): 192 | scale = self.scale[self.idx_scale] 193 | multi_scale = len(self.scale) > 1 194 | if self.train: 195 | # print('****preparte data****') 196 | lr, hr = common.get_patch( 197 | lr, 198 | hr, 199 | patch_size=self.args.patch_size, 200 | scale=scale, 201 | multi_scale=multi_scale 202 | ) 203 | if not self.args.no_augment: 204 | # print('****use augment****') 205 | lr, hr = common.augment(lr, hr) 206 | else: 207 | ih, iw = lr.shape[:2] 208 | hr = hr[0:ih, 0:iw] 209 | #hr = hr[0:ih * scale, 0:iw * scale] 210 | 211 | return lr, hr 212 | 213 | def set_scale(self, idx_scale): 214 | self.idx_scale = idx_scale 215 | 216 | -------------------------------------------------------------------------------- /code/dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import threading 3 | import queue 4 | import random 5 | import collections 6 | 7 | import torch 8 | import torch.multiprocessing as multiprocessing 9 | 10 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 11 | _remove_worker_pids, _error_if_any_worker_fails 12 | from torch.utils.data.dataloader import DataLoader 13 | from torch.utils.data.dataloader import _DataLoaderIter 14 | 15 | from torch.utils.data.dataloader import ExceptionWrapper 16 | from torch.utils.data.dataloader import _use_shared_memory 17 | from torch.utils.data.dataloader import _worker_manager_loop 18 | from torch.utils.data.dataloader import numpy_type_map 19 | from torch.utils.data.dataloader import default_collate 20 | from torch.utils.data.dataloader import pin_memory_batch 21 | from torch.utils.data.dataloader import _SIGCHLD_handler_set 22 | from torch.utils.data.dataloader import _set_SIGCHLD_handler 23 | 24 | if sys.version_info[0] == 2: 25 | import Queue as queue 26 | else: 27 | import queue 28 | 29 | def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id): 30 | global _use_shared_memory 31 | _use_shared_memory = True 32 | _set_worker_signal_handlers() 33 | 34 | torch.set_num_threads(1) 35 | torch.manual_seed(seed) 36 | while True: 37 | r = index_queue.get() 38 | if r is None: 39 | break 40 | idx, batch_indices = r 41 | try: 42 | idx_scale = 0 43 | if len(scale) > 1 and dataset.train: 44 | idx_scale = random.randrange(0, len(scale)) 45 | dataset.set_scale(idx_scale) 46 | 47 | samples = collate_fn([dataset[i] for i in batch_indices]) 48 | samples.append(idx_scale) 49 | 50 | except Exception: 51 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 52 | else: 53 | data_queue.put((idx, samples)) 54 | 55 | class _MSDataLoaderIter(_DataLoaderIter): 56 | def __init__(self, loader): 57 | self.dataset = loader.dataset 58 | self.scale = loader.scale 59 | self.collate_fn = loader.collate_fn 60 | self.batch_sampler = loader.batch_sampler 61 | self.num_workers = loader.num_workers 62 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 63 | self.timeout = loader.timeout 64 | self.done_event = threading.Event() 65 | 66 | self.sample_iter = iter(self.batch_sampler) 67 | 68 | if self.num_workers > 0: 69 | self.worker_init_fn = loader.worker_init_fn 70 | self.index_queues = [ 71 | multiprocessing.Queue() for _ in range(self.num_workers) 72 | ] 73 | self.worker_queue_idx = 0 74 | self.worker_result_queue = multiprocessing.SimpleQueue() 75 | self.batches_outstanding = 0 76 | self.worker_pids_set = False 77 | self.shutdown = False 78 | self.send_idx = 0 79 | self.rcvd_idx = 0 80 | self.reorder_dict = {} 81 | 82 | base_seed = torch.LongTensor(1).random_()[0] 83 | self.workers = [ 84 | multiprocessing.Process( 85 | target=_ms_loop, 86 | args=( 87 | self.dataset, 88 | self.index_queues[i], 89 | self.worker_result_queue, 90 | self.collate_fn, 91 | self.scale, 92 | base_seed + i, 93 | self.worker_init_fn, 94 | i 95 | ) 96 | ) 97 | for i in range(self.num_workers)] 98 | 99 | if self.pin_memory or self.timeout > 0: 100 | self.data_queue = queue.Queue() 101 | if self.pin_memory: 102 | maybe_device_id = torch.cuda.current_device() 103 | else: 104 | # do not initialize cuda context if not necessary 105 | maybe_device_id = None 106 | self.worker_manager_thread = threading.Thread( 107 | target=_worker_manager_loop, 108 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 109 | maybe_device_id)) 110 | self.worker_manager_thread.daemon = True 111 | self.worker_manager_thread.start() 112 | else: 113 | self.data_queue = self.worker_result_queue 114 | 115 | for w in self.workers: 116 | w.daemon = True # ensure that the worker exits on process exit 117 | w.start() 118 | 119 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 120 | _set_SIGCHLD_handler() 121 | self.worker_pids_set = True 122 | 123 | # prime the prefetch loop 124 | for _ in range(2 * self.num_workers): 125 | self._put_indices() 126 | 127 | class MSDataLoader(DataLoader): 128 | def __init__( 129 | self, args, dataset, batch_size=1, shuffle=False, 130 | sampler=None, batch_sampler=None, 131 | collate_fn=default_collate, pin_memory=False, drop_last=False, 132 | timeout=0, worker_init_fn=None): 133 | 134 | super(MSDataLoader, self).__init__( 135 | dataset, batch_size=batch_size, shuffle=shuffle, 136 | sampler=sampler, batch_sampler=batch_sampler, 137 | num_workers=args.n_threads, collate_fn=collate_fn, 138 | pin_memory=pin_memory, drop_last=drop_last, 139 | timeout=timeout, worker_init_fn=worker_init_fn) 140 | 141 | self.scale = args.scale 142 | 143 | def __iter__(self): 144 | return _MSDataLoaderIter(self) 145 | -------------------------------------------------------------------------------- /code/dataset/test/data/norain-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/data/norain-1.png -------------------------------------------------------------------------------- /code/dataset/test/data/norain-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/data/norain-2.png -------------------------------------------------------------------------------- /code/dataset/test/label/norain-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/label/norain-1.png -------------------------------------------------------------------------------- /code/dataset/test/label/norain-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/label/norain-2.png -------------------------------------------------------------------------------- /code/dataset/train/data/norain-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/data/norain-1.png -------------------------------------------------------------------------------- /code/dataset/train/data/norain-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/data/norain-2.png -------------------------------------------------------------------------------- /code/dataset/train/label/norain-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/label/norain-1.png -------------------------------------------------------------------------------- /code/dataset/train/label/norain-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/label/norain-2.png -------------------------------------------------------------------------------- /code/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | # import SSIM 14 | class Loss(nn.modules.loss._Loss): 15 | def __init__(self, args, ckp): 16 | super(Loss, self).__init__() 17 | print('Preparing loss function:') 18 | 19 | self.n_GPUs = args.n_GPUs 20 | self.loss = [] 21 | self.loss_module = nn.ModuleList() 22 | for loss in args.loss.split('+'): 23 | weight, loss_type = loss.split('*') 24 | if loss_type == 'MSE': 25 | loss_function = nn.MSELoss() 26 | elif loss_type == 'L1': 27 | loss_function = nn.L1Loss() 28 | elif loss_type.find('VGG') >= 0: 29 | module = import_module('loss.vgg') 30 | loss_function = getattr(module, 'VGG')( 31 | loss_type[3:], 32 | rgb_range=args.rgb_range 33 | ) 34 | elif loss_type.find('GAN') >= 0: 35 | module = import_module('loss.adversarial') 36 | loss_function = getattr(module, 'Adversarial')( 37 | args, 38 | loss_type 39 | ) 40 | elif loss_type.find('joint') >= 0: 41 | module = import_module('loss.joint') 42 | loss_function = getattr(module, 'Joint')() 43 | elif loss_type.find('SSIM') >= 0: 44 | module = import_module('loss.ssim') 45 | loss_function = getattr(module, 'SSIM')() 46 | 47 | self.loss.append({ 48 | 'type': loss_type, 49 | 'weight': float(weight), 50 | 'function': loss_function} 51 | ) 52 | if loss_type.find('GAN') >= 0: 53 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 54 | 55 | if len(self.loss) > 1: 56 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 57 | 58 | for l in self.loss: 59 | if l['function'] is not None: 60 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 61 | self.loss_module.append(l['function']) 62 | 63 | self.log = torch.Tensor() 64 | 65 | device = torch.device('cpu' if args.cpu else 'cuda') 66 | self.loss_module.to(device) 67 | if args.precision == 'half': self.loss_module.half() 68 | if not args.cpu and args.n_GPUs > 1: 69 | self.loss_module = nn.DataParallel( 70 | #self.loss_module, range(args.n_GPUs) 71 | self.loss_module, device_ids=[0] 72 | ) 73 | 74 | if args.load != '.': self.load(ckp.dir, cpu=args.cpu) 75 | 76 | def forward(self, sr, hr, lr=None, detect_map=None): 77 | losses = [] 78 | for i, l in enumerate(self.loss): 79 | if l['function'] is not None: 80 | 81 | if str(lr)!='None': 82 | loss = l['function'](sr, hr, lr, detect_map) 83 | effective_loss = l['weight'] * loss 84 | losses.append(effective_loss) 85 | self.log[-1, i] += effective_loss.item() 86 | else: 87 | loss = l['function'](sr, hr) 88 | effective_loss = l['weight'] * loss 89 | losses.append(effective_loss) 90 | self.log[-1, i] += effective_loss.item() 91 | 92 | elif l['type'] == 'DIS': 93 | self.log[-1, i] += self.loss[i - 1]['function'].loss 94 | 95 | loss_sum = sum(losses) 96 | if len(self.loss) > 1: 97 | self.log[-1, -1] += loss_sum.item() 98 | 99 | return loss_sum 100 | 101 | def step(self): 102 | for l in self.get_loss_module(): 103 | if hasattr(l, 'scheduler'): 104 | l.scheduler.step() 105 | 106 | def start_log(self): 107 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 108 | 109 | def end_log(self, n_batches): 110 | self.log[-1].div_(n_batches) 111 | 112 | def display_loss(self, batch): 113 | n_samples = batch + 1 114 | log = [] 115 | for l, c in zip(self.loss, self.log[-1]): 116 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 117 | 118 | return ''.join(log) 119 | 120 | def plot_loss(self, apath, epoch): 121 | axis = np.linspace(1, epoch, epoch) 122 | for i, l in enumerate(self.loss): 123 | # j = i 124 | # if i == len(self.loss)-1: 125 | # break 126 | label = '{} Loss'.format(l['type']) 127 | fig = plt.figure() 128 | plt.title(label) 129 | plt.plot(axis, self.log[:, i].numpy(), label=label) 130 | plt.legend() 131 | plt.xlabel('Epochs') 132 | plt.ylabel('Loss') 133 | plt.grid(True) 134 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) 135 | plt.close(fig) 136 | 137 | def get_loss_module(self): 138 | if self.n_GPUs == 1: 139 | return self.loss_module 140 | else: 141 | return self.loss_module.module 142 | 143 | def save(self, apath): 144 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 145 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 146 | 147 | def load(self, apath, cpu=False): 148 | if cpu: 149 | kwargs = {'map_location': lambda storage, loc: storage} 150 | else: 151 | kwargs = {} 152 | 153 | self.load_state_dict(torch.load( 154 | os.path.join(apath, 'loss.pt'), 155 | **kwargs 156 | )) 157 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 158 | for l in self.loss_module.module: 159 | if hasattr(l, 'scheduler'): 160 | for _ in range(len(self.log)): l.scheduler.step() 161 | 162 | -------------------------------------------------------------------------------- /code/loss/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) 74 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer import Trainer 9 | import multiprocessing 10 | import time 11 | 12 | def print_network(net): 13 | num_params = 0 14 | for param in net.parameters(): 15 | num_params += param.numel() 16 | print('Total number of parameters: %d' % num_params) 17 | 18 | if __name__ == '__main__': 19 | torch.manual_seed(args.seed) 20 | checkpoint = utility.checkpoint(args) 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 22 | 23 | seed = 1334 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | 27 | if checkpoint.ok: 28 | loader = data.Data(args) 29 | model = model.Model(args, checkpoint) 30 | print_network(model) 31 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 32 | t = Trainer(args, loader, model, loss, checkpoint) 33 | # print('==================') 34 | while not t.terminate(): 35 | # print('======++++++++++') 36 | t.train() 37 | t.test() 38 | checkpoint.done() 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /code/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | 23 | module = import_module('model.' + args.model.lower()) 24 | self.model = module.make_model(args).to(self.device) 25 | if args.precision == 'half': self.model.half() 26 | 27 | if not args.cpu and args.n_GPUs > 1: 28 | #self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 29 | self.model = nn.DataParallel(self.model, device_ids=[0]) 30 | 31 | self.load( 32 | ckp.dir, 33 | pre_train=args.pre_train, 34 | resume=args.resume, 35 | cpu=args.cpu 36 | ) 37 | print(self.model, file=ckp.log_file) 38 | 39 | def forward(self, x, idx_scale): 40 | self.idx_scale = idx_scale 41 | target = self.get_model() 42 | if hasattr(target, 'set_scale'): 43 | target.set_scale(idx_scale) 44 | 45 | if self.self_ensemble and not self.training: 46 | if self.chop: 47 | forward_function = self.forward_chop 48 | else: 49 | forward_function = self.model.forward 50 | 51 | return self.forward_x8(x, forward_function) 52 | elif self.chop and not self.training: 53 | return self.forward_chop(x) 54 | else: 55 | return self.model(x) 56 | 57 | def get_model(self): 58 | if self.n_GPUs == 1: 59 | return self.model 60 | else: 61 | return self.model.module 62 | 63 | def state_dict(self, **kwargs): 64 | target = self.get_model() 65 | return target.state_dict(**kwargs) 66 | 67 | def save(self, apath, epoch, is_best=False): 68 | target = self.get_model() 69 | torch.save( 70 | target.state_dict(), 71 | os.path.join(apath, 'model', 'model_latest.pt') 72 | ) 73 | if is_best: 74 | torch.save( 75 | target.state_dict(), 76 | os.path.join(apath, 'model', 'model_best.pt') 77 | ) 78 | 79 | if self.save_models: 80 | torch.save( 81 | target.state_dict(), 82 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 83 | ) 84 | 85 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 86 | if cpu: 87 | kwargs = {'map_location': lambda storage, loc: storage} 88 | else: 89 | kwargs = {} 90 | 91 | if resume == -1: 92 | print('-1') 93 | self.get_model().load_state_dict( 94 | torch.load( 95 | os.path.join(apath, 'model', 'model_latest.pt'), 96 | **kwargs 97 | ), 98 | strict=False 99 | ) 100 | elif resume == 0: 101 | print('rest') 102 | if pre_train != '.': 103 | print('Loading model from {}'.format(pre_train)) 104 | self.get_model().load_state_dict( 105 | torch.load(pre_train, **kwargs), 106 | strict=False 107 | ) 108 | else: 109 | print('specific') 110 | self.get_model().load_state_dict( 111 | torch.load( 112 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 113 | **kwargs 114 | ), 115 | strict=False 116 | ) 117 | 118 | def forward_chop(self, x, shave=10, min_size=160000): 119 | scale = self.scale[self.idx_scale] 120 | n_GPUs = min(self.n_GPUs, 4) 121 | b, c, h, w = x.size() 122 | h_half, w_half = h // 2, w // 2 123 | h_size, w_size = h_half + shave, w_half + shave 124 | lr_list = [ 125 | x[:, :, 0:h_size, 0:w_size], 126 | x[:, :, 0:h_size, (w - w_size):w], 127 | x[:, :, (h - h_size):h, 0:w_size], 128 | x[:, :, (h - h_size):h, (w - w_size):w]] 129 | 130 | if w_size * h_size < min_size: 131 | sr_list = [] 132 | for i in range(0, 4, n_GPUs): 133 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 134 | sr_batch = self.model(lr_batch) 135 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 136 | else: 137 | sr_list = [ 138 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 139 | for patch in lr_list 140 | ] 141 | 142 | h, w = scale * h, scale * w 143 | h_half, w_half = scale * h_half, scale * w_half 144 | h_size, w_size = scale * h_size, scale * w_size 145 | shave *= scale 146 | 147 | output = x.new(b, c, h, w) 148 | output[:, :, 0:h_half, 0:w_half] \ 149 | = sr_list[0][:, :, 0:h_half, 0:w_half] 150 | output[:, :, 0:h_half, w_half:w] \ 151 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 152 | output[:, :, h_half:h, 0:w_half] \ 153 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 154 | output[:, :, h_half:h, w_half:w] \ 155 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 156 | 157 | return output 158 | 159 | def forward_x8(self, x, forward_function): 160 | def _transform(v, op): 161 | if self.precision != 'single': v = v.float() 162 | 163 | v2np = v.data.cpu().numpy() 164 | if op == 'v': 165 | tfnp = v2np[:, :, :, ::-1].copy() 166 | elif op == 'h': 167 | tfnp = v2np[:, :, ::-1, :].copy() 168 | elif op == 't': 169 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 170 | 171 | ret = torch.Tensor(tfnp).to(self.device) 172 | if self.precision == 'half': ret = ret.half() 173 | 174 | return ret 175 | 176 | lr_list = [x] 177 | for tf in 'v', 'h', 't': 178 | lr_list.extend([_transform(t, tf) for t in lr_list]) 179 | 180 | sr_list = [forward_function(aug) for aug in lr_list] 181 | for i in range(len(sr_list)): 182 | if i > 3: 183 | sr_list[i] = _transform(sr_list[i], 't') 184 | if i % 4 > 1: 185 | sr_list[i] = _transform(sr_list[i], 'h') 186 | if (i % 4) % 2 == 1: 187 | sr_list[i] = _transform(sr_list[i], 'v') 188 | 189 | output_cat = torch.cat(sr_list, dim=0) 190 | output = output_cat.mean(dim=0, keepdim=True) 191 | 192 | return output 193 | 194 | -------------------------------------------------------------------------------- /code/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.autograd import Variable 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size//2), bias=bias) 13 | 14 | class MeanShift2(nn.Conv2d): 15 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 16 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 17 | std = torch.Tensor(rgb_std) 18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 19 | self.weight.data.div_(std.view(3, 1, 1, 1)) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 21 | self.bias.data.div_(std) 22 | self.requires_grad = False 23 | 24 | class MeanShift(nn.Conv2d): 25 | def __init__( 26 | self, rgb_range, 27 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 28 | 29 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 30 | std = torch.Tensor(rgb_std) 31 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 32 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 33 | for p in self.parameters(): 34 | p.requires_grad = False 35 | 36 | class BasicBlock(nn.Sequential): 37 | def __init__( 38 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 39 | bn=True, act=nn.ReLU(True)): 40 | 41 | m = [nn.Conv2d( 42 | in_channels, out_channels, kernel_size, 43 | padding=(kernel_size//2), stride=stride, bias=bias) 44 | ] 45 | if bn: m.append(nn.BatchNorm2d(out_channels)) 46 | if act is not None: m.append(act) 47 | super(BasicBlock, self).__init__(*m) 48 | 49 | class ResBlock(nn.Module): 50 | def __init__( 51 | self, conv, n_feats, kernel_size, 52 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 53 | 54 | super(ResBlock, self).__init__() 55 | m = [] 56 | for i in range(2): 57 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 58 | if bn: m.append(nn.BatchNorm2d(n_feats)) 59 | if i == 0: m.append(act) 60 | 61 | self.body = nn.Sequential(*m) 62 | self.res_scale = res_scale 63 | 64 | def forward(self, x): 65 | res = self.body(x).mul(self.res_scale) 66 | res += x 67 | 68 | return res 69 | 70 | class Upsampler(nn.Sequential): 71 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 72 | 73 | m = [] 74 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 75 | for _ in range(int(math.log(scale, 2))): 76 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(2)) 78 | if bn: m.append(nn.BatchNorm2d(n_feats)) 79 | 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | 85 | elif scale == 3: 86 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 87 | m.append(nn.PixelShuffle(3)) 88 | if bn: m.append(nn.BatchNorm2d(n_feats)) 89 | 90 | if act == 'relu': 91 | m.append(nn.ReLU(True)) 92 | elif act == 'prelu': 93 | m.append(nn.PReLU(n_feats)) 94 | else: 95 | raise NotImplementedError 96 | 97 | super(Upsampler, self).__init__(*m) 98 | 99 | class SELayer(nn.Module): 100 | def __init__(self, channel, reduction=16): 101 | super(SELayer, self).__init__() 102 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 103 | self.fc = nn.Sequential( 104 | nn.Linear(channel, channel // reduction, bias=False), 105 | nn.ReLU(inplace=True), 106 | nn.Linear(channel // reduction, channel, bias=False), 107 | nn.Sigmoid() 108 | ) 109 | 110 | def forward(self, x): 111 | b, c, _, _ = x.size() 112 | y = self.avg_pool(x).view(b, c) 113 | y = self.fc(y).view(b, c, 1, 1) 114 | return x * y.expand_as(x) 115 | 116 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 117 | """3x3 convolution with padding""" 118 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 119 | padding=1, groups=groups, bias=False) 120 | 121 | 122 | def conv1x1(in_planes, out_planes, stride=1): 123 | """1x1 convolution""" 124 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 125 | 126 | 127 | class SEModule(nn.Module): 128 | def __init__(self, channels, reduction=16): 129 | super(SEModule, self).__init__() 130 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 131 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) 134 | self.sigmoid = nn.Sigmoid() 135 | 136 | def forward(self, input): 137 | x = self.avg_pool(input) 138 | x = self.fc1(x) 139 | x = self.relu(x) 140 | x = self.fc2(x) 141 | x = self.sigmoid(x) 142 | return input * x 143 | ##################################################### 144 | Operations = [ 145 | 'sep_conv_1x1', 146 | 'sep_conv_3x3', 147 | 'sep_conv_5x5', 148 | 'sep_conv_7x7', 149 | 'dil_conv_3x3', 150 | 'dil_conv_5x5', 151 | 'dil_conv_7x7', 152 | 'avg_pool_3x3' 153 | ] 154 | 155 | OPS = { 156 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 157 | 'sep_conv_1x1' : lambda C, stride, affine: SepConv(C, C, 1, stride, 0, affine=affine), 158 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 159 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 160 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 161 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 162 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 163 | 'dil_conv_7x7' : lambda C, stride, affine: DilConv(C, C, 7, stride, 6, 2, affine=affine), 164 | } 165 | 166 | class ReLUConvBN(nn.Module): 167 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 168 | super(ReLUConvBN, self).__init__() 169 | self.op = nn.Sequential( 170 | nn.ReLU(inplace=False), 171 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 172 | nn.BatchNorm2d(C_out, affine=affine)) 173 | 174 | def forward(self, x): 175 | return self.op(x) 176 | 177 | class ReLUConv(nn.Module): 178 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 179 | super(ReLUConv, self).__init__() 180 | self.op = nn.Sequential( 181 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 182 | nn.ReLU(inplace=False)) 183 | 184 | def forward(self, x): 185 | return self.op(x) 186 | 187 | class DilConv(nn.Module): 188 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 189 | super(DilConv, self).__init__() 190 | self.op = nn.Sequential( 191 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 192 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),) 193 | 194 | def forward(self, x): 195 | return self.op(x) 196 | 197 | class ResBlock2(nn.Module): 198 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 199 | super(ResBlock, self).__init__() 200 | self.conv1 = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False) 201 | self.conv2 = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False) 202 | self.relu = nn.ReLU(inplace=False) 203 | 204 | def forward(self, x): 205 | residual = x 206 | out = self.relu(self.conv1(x)) 207 | out = self.conv2(out) 208 | out = out + residual 209 | out = self.relu(out) 210 | return out 211 | 212 | class ResBlock(nn.Module): 213 | def __init__( 214 | self, conv, n_feats, kernel_size, 215 | bias=True, bn=False, act=nn.PReLU(), res_scale=1): 216 | 217 | super(ResBlock, self).__init__() 218 | m = [] 219 | for i in range(2): 220 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 221 | if bn: 222 | m.append(nn.BatchNorm2d(n_feats)) 223 | if i == 0: 224 | m.append(act) 225 | 226 | self.body = nn.Sequential(*m) 227 | self.res_scale = res_scale 228 | 229 | def forward(self, x): 230 | res = self.body(x).mul(self.res_scale) 231 | res += x 232 | 233 | return res 234 | 235 | class SepConv(nn.Module): 236 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 237 | super(SepConv, self).__init__() 238 | self.op = nn.Sequential( 239 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 240 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 241 | nn.ReLU(inplace=False), 242 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 243 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),) 244 | 245 | def forward(self, x): 246 | return self.op(x) 247 | 248 | -------------------------------------------------------------------------------- /code/model/hct-ffn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from model import common 5 | from util.rlutrans import Mlp, TransBlock 6 | from util.tools import extract_image_patches, reduce_mean, reduce_sum, same_padding, reverse_patches 7 | 8 | def make_model(args, parent=False): 9 | return Rainnet(args) 10 | 11 | class OperationLayer(nn.Module): 12 | def __init__(self, C, stride): 13 | super(OperationLayer, self).__init__() 14 | self._ops = nn.ModuleList() 15 | for o in common.Operations: 16 | op = common.OPS[o](C, stride, False) 17 | self._ops.append(op) 18 | 19 | self._out = nn.Sequential(nn.Conv2d(C * len(common.Operations), C, 1, padding=0, bias=False), nn.ReLU()) 20 | 21 | def forward(self, x, weights): 22 | weights = weights.transpose(1, 0) 23 | states = [] 24 | for w, op in zip(weights, self._ops): 25 | states.append(op(x) * w.view([-1, 1, 1, 1])) 26 | return self._out(torch.cat(states[:], dim=1)) 27 | 28 | class GroupOLs(nn.Module): 29 | def __init__(self, steps, C): 30 | super(GroupOLs, self).__init__() 31 | self.preprocess = common.ReLUConv(C, C, 1, 1, 0, affine=False) 32 | self._steps = steps 33 | self._ops = nn.ModuleList() 34 | self.relu = nn.ReLU() 35 | stride = 1 36 | 37 | for _ in range(self._steps): 38 | op = OperationLayer(C, stride) 39 | self._ops.append(op) 40 | 41 | def forward(self, s0, weights): 42 | s0 = self.preprocess(s0) 43 | for i in range(self._steps): 44 | res = s0 45 | s0 = self._ops[i](s0, weights[:, i, :]) 46 | s0 = self.relu(s0 + res) 47 | return s0 48 | 49 | class OALayer(nn.Module): 50 | def __init__(self, channel, k, num_ops): 51 | super(OALayer, self).__init__() 52 | self.k = k 53 | self.num_ops = num_ops 54 | self.output = k * num_ops 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.ca_fc = nn.Sequential( 57 | nn.Linear(channel, self.output * 2), 58 | nn.ReLU(), 59 | nn.Linear(self.output * 2, self.k * self.num_ops)) 60 | 61 | def forward(self, x): 62 | y = self.avg_pool(x) 63 | y = y.view(x.size(0), -1) 64 | y = self.ca_fc(y) 65 | y = y.view(-1, self.k, self.num_ops) 66 | return y 67 | 68 | def get_residue(tensor , r_dim = 1): 69 | """ 70 | return residue_channle (RGB) 71 | """ 72 | # res_channel = [] 73 | max_channel = torch.max(tensor, dim=r_dim, keepdim=True) # keepdim 74 | min_channel = torch.min(tensor, dim=r_dim, keepdim=True) 75 | res_channel = max_channel[0] - min_channel[0] 76 | return res_channel 77 | 78 | class convd(nn.Module): 79 | def __init__(self, inputchannel, outchannel, kernel_size, stride): 80 | super(convd, self).__init__() 81 | self.relu = nn.ReLU() 82 | self.padding = nn.ReflectionPad2d(kernel_size//2) 83 | self.conv = nn.Conv2d(inputchannel, outchannel, kernel_size, stride) 84 | self.ins = nn.InstanceNorm2d(outchannel, affine=True) 85 | 86 | def forward(self, x): 87 | x = self.conv(self.padding(x)) 88 | # x= self.ins(x) 89 | x = self.relu(x) 90 | return x 91 | 92 | class Upsample(nn.Module): 93 | def __init__(self, in_channels, out_channels, kernel_size, stride): 94 | super(Upsample, self).__init__() 95 | reflection_padding = kernel_size // 2 96 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 97 | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride) 98 | self.relu = nn.ReLU() 99 | 100 | def forward(self, x, y): 101 | out = self.reflection_pad(x) 102 | out = self.conv2d(out) 103 | out = self.relu(out) 104 | out = F.interpolate(out, y.size()[2:]) 105 | return out 106 | 107 | class RB(nn.Module): 108 | def __init__(self, n_feats, nm='in'): 109 | super(RB, self).__init__() 110 | module_body = [] 111 | for i in range(2): 112 | module_body.append(nn.Conv2d(n_feats, n_feats, kernel_size=3, stride=1, padding=1, bias=True)) 113 | module_body.append(nn.ReLU()) 114 | self.module_body = nn.Sequential(*module_body) 115 | self.relu = nn.ReLU() 116 | self.se = common.SELayer(n_feats, 1) 117 | 118 | def forward(self, x): 119 | res = self.module_body(x) 120 | res = self.se(res) 121 | res += x 122 | return res 123 | 124 | class RIR(nn.Module): 125 | def __init__(self, n_feats, n_blocks, nm='in'): 126 | super(RIR, self).__init__() 127 | module_body = [ 128 | RB(n_feats) for _ in range(n_blocks) 129 | ] 130 | module_body.append(nn.Conv2d(n_feats, n_feats, kernel_size=3, stride=1, padding=1, bias=True)) 131 | self.module_body = nn.Sequential(*module_body) 132 | self.relu = nn.ReLU() 133 | 134 | def forward(self, x): 135 | res = self.module_body(x) 136 | res += x 137 | return self.relu(res) 138 | 139 | class res_ch(nn.Module): 140 | def __init__(self, n_feats, blocks=2): 141 | super(res_ch,self).__init__() 142 | self.conv_init1 = convd(3, n_feats//2, 3, 1) 143 | self.conv_init2 = convd(n_feats//2, n_feats, 3, 1) 144 | self.extra = RIR(n_feats, n_blocks=blocks) 145 | 146 | def forward(self,x): 147 | x = self.conv_init2(self.conv_init1(x)) 148 | x = self.extra(x) 149 | return x 150 | 151 | class Fuse(nn.Module): 152 | def __init__(self, inchannel=64, outchannel=64): 153 | super(Fuse, self).__init__() 154 | self.up = Upsample(inchannel, outchannel, 3, 2) 155 | self.conv = convd(outchannel, outchannel, 3, 1) 156 | self.rb = RB(outchannel) 157 | self.relu = nn.ReLU() 158 | 159 | def forward(self, x, y): 160 | x = self.up(x, y) 161 | # x = F.interpolate(x, y.size()[2:]) 162 | # y1 = torch.cat((x, y), dim=1) 163 | y = x+y 164 | # y = self.pf(y1) + y 165 | 166 | return self.relu(self.rb(y)) 167 | 168 | class Prior_Sp(nn.Module): 169 | def __init__(self, in_dim=32): 170 | super(Prior_Sp, self).__init__() 171 | self.chanel_in = in_dim 172 | 173 | self.query_conv = nn.Conv2d(in_dim, in_dim, 3, 1, 1, bias=True) 174 | self.key_conv = nn.Conv2d(in_dim, in_dim, 3, 1, 1, bias=True) 175 | 176 | self.gamma1 = nn.Conv2d(in_dim * 2, 2, 3, 1, 1, bias=True) 177 | # self.gamma1 = nn.Parameter(torch.zeros(1)) 178 | self.gamma2 = nn.Conv2d(in_dim * 2, 2, 3, 1, 1, bias=True) 179 | # self.softmax = nn.Softmax(dim=-1) 180 | self.sig = nn.Sigmoid() 181 | 182 | def forward(self,x, prior): 183 | 184 | x_q = self.query_conv(x) 185 | prior_k = self.key_conv(prior) 186 | energy = x_q * prior_k 187 | attention = self.sig(energy) 188 | # print(attention.size(),x.size()) 189 | attention_x = x * attention 190 | attention_p = prior * attention 191 | 192 | x_gamma = self.gamma1(torch.cat((x, attention_x),dim=1)) 193 | x_out = x * x_gamma[:, [0], :, :] + attention_x * x_gamma[:, [1], :, :] 194 | 195 | p_gamma = self.gamma2(torch.cat((prior, attention_p),dim=1)) 196 | prior_out = prior * p_gamma[:, [0], :, :] + attention_p * p_gamma[:, [1], :, :] 197 | 198 | return x_out, prior_out 199 | 200 | class DaMoE(nn.Module): 201 | def __init__(self, n_feats,layer_num ,steps=4): 202 | super(DaMoE,self).__init__() 203 | 204 | # fuse res 205 | self.prior = Prior_Sp() 206 | self.fuse_res = convd(n_feats*2, n_feats, 3, 1) 207 | self._C = n_feats 208 | self.num_ops = len(common.Operations) 209 | self._layer_num = layer_num 210 | self._steps = steps 211 | 212 | self.layers = nn.ModuleList() 213 | for _ in range(self._layer_num): 214 | attention = OALayer(self._C, self._steps, self.num_ops) 215 | self.layers += [attention] 216 | layer = GroupOLs(steps, self._C) 217 | self.layers += [layer] 218 | 219 | def forward(self, x, res_feats): 220 | 221 | x_p, res_feats_p = self.prior(x, res_feats) 222 | x_s = torch.cat((x_p, res_feats_p),dim=1) 223 | x1_i = self.fuse_res(x_s) 224 | for _, layer in enumerate(self.layers): 225 | if isinstance(layer, OALayer): 226 | weights = layer(x1_i) 227 | weights = F.softmax(weights, dim=-1) 228 | else: 229 | x1_i = layer(x1_i, weights) 230 | 231 | return x1_i 232 | 233 | class BaViT(nn.Module): 234 | def __init__(self, n_feats, blocks=2): 235 | super(BaViT, self).__init__() 236 | # fuse res 237 | self.prior = Prior_Sp() 238 | self.fuse_res = convd(n_feats * 2, n_feats, 3, 1) 239 | 240 | self.attention = TransBlock(n_feats, dim=n_feats * 9) 241 | self.c2 = common.default_conv(n_feats, n_feats, 3) 242 | # self.attention2 = TransBlock(n_feat=n_feat, dim=n_feat*9) 243 | 244 | def forward(self, x, res_feats): 245 | x_p, res_feats_p = self.prior(x, res_feats) 246 | x_s = torch.cat((x_p, res_feats_p), dim=1) 247 | x1_init = self.fuse_res(x_s) 248 | 249 | y8 = x1_init 250 | b, c, h, w = y8.shape 251 | y8 = extract_image_patches(y8, ksizes=[3, 3], 252 | strides=[1, 1], 253 | rates=[1, 1], 254 | padding='same') # 16*2304*576 255 | y8 = y8.permute(0, 2, 1) 256 | out_transf1 = self.attention(y8) 257 | out_transf1 = self.attention(out_transf1) 258 | out_transf1 = self.attention(out_transf1) 259 | out1 = out_transf1.permute(0, 2, 1) 260 | out1 = reverse_patches(out1, (h, w), (3, 3), 1, 1) 261 | y9 = self.c2(out1) 262 | 263 | return y9 264 | 265 | class Rainnet(nn.Module): 266 | def __init__(self,args): 267 | super(Rainnet,self).__init__() 268 | n_feats = args.n_feats 269 | blocks = args.n_resblocks 270 | 271 | self.conv_init1 = convd(3, n_feats//2, 3, 1) 272 | self.conv_init2 = convd(n_feats//2, n_feats, 3, 1) 273 | self.res_extra1 = res_ch(n_feats, blocks) 274 | self.sub1 = DaMoE(n_feats, 1) 275 | self.res_extra2 = res_ch(n_feats, blocks) 276 | self.sub2 = BaViT(n_feats, 1) 277 | self.res_extra3 = res_ch(n_feats, blocks) 278 | self.sub3 = DaMoE(n_feats, 1) 279 | 280 | self.ag1 = convd(n_feats*2,n_feats,3,1) 281 | self.ag2 = convd(n_feats*3,n_feats,3,1) 282 | self.ag2_en = convd(n_feats*2, n_feats, 3, 1) 283 | self.ag_en = convd(n_feats*3, n_feats, 3, 1) 284 | 285 | self.output1 = nn.Conv2d(n_feats, 3, 3, 1, padding=1) 286 | self.output2 = nn.Conv2d(n_feats, 3, 3, 1, padding=1) 287 | self.output3 = nn.Conv2d(n_feats, 3, 3, 1, padding=1) 288 | 289 | # self._initialize_weights() 290 | 291 | def forward(self,x): 292 | 293 | res_x = get_residue(x) 294 | x_init = self.conv_init2(self.conv_init1(x)) 295 | x1 = self.sub1(x_init, self.res_extra1(torch.cat((res_x, res_x, res_x), dim=1))) #+ x # 1 296 | out1 = self.output1(x1) 297 | res_out1 = get_residue(out1) 298 | x2 = self.sub2(self.ag1(torch.cat((x1,x_init),dim=1)), self.res_extra2(torch.cat((res_out1, res_out1, res_out1), dim=1))) #+ x1 # 2 299 | x2_ = self.ag2_en(torch.cat([x2,x1], dim=1)) 300 | out2 = self.output2(x2_) 301 | res_out2 = get_residue(out2) 302 | x3 = self.sub3(self.ag2(torch.cat((x2,x1,x_init),dim=1)), self.res_extra3(torch.cat((res_out2, res_out2, res_out2), dim=1))) #+ x2 # 3 303 | x3 = self.ag_en(torch.cat([x3,x2,x1],dim=1)) 304 | out3 = self.output3(x3) 305 | 306 | return out3, out2, out1 307 | 308 | def _initialize_weights(self): 309 | for m in self.modules(): 310 | if isinstance(m, nn.Conv2d): 311 | nn.init.normal_(m.weight, std=0.01) 312 | if m.bias is not None: 313 | nn.init.constant_(m.bias, 0) 314 | elif isinstance(m, nn.BatchNorm2d): 315 | nn.init.constant_(m.weight, 1) 316 | nn.init.constant_(m.bias, 0) 317 | 318 | 319 | -------------------------------------------------------------------------------- /code/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='RCDNet') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=0, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', action='store_true', 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=1, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | 21 | # Data specifications 22 | parser.add_argument('--dir_data', type=str, default='../data', 23 | help='dataset directory') 24 | parser.add_argument('--dir_demo', type=str, default='../test', 25 | help='demo image directory') 26 | parser.add_argument('--data_train', type=str, default='RainHeavy', #'DIV2K', 27 | help='train dataset name') 28 | parser.add_argument('--data_test', type=str, default= 'RainHeavyTest', #'DIV2K', 29 | help='test dataset name') 30 | parser.add_argument('--data_range', type=str, default='1-20000/1-100', 31 | help='train/test data range') 32 | parser.add_argument('--ext', type=str, default='sep', 33 | help='dataset file extension') 34 | parser.add_argument('--scale', type=str, default='2', 35 | help='super resolution scale') 36 | parser.add_argument('--patch_size', type=int, default=64, 37 | help='output patch size') 38 | parser.add_argument('--rgb_range', type=int, default=255, 39 | help='maximum value of RGB') 40 | parser.add_argument('--n_colors', type=int, default=3, 41 | help='number of color channels to use') 42 | parser.add_argument('--chop', action='store_true', 43 | help='enable memory-efficient forward') 44 | parser.add_argument('--no_augment', action='store_true', 45 | help='do not use data augmentation') 46 | 47 | # Model specifications 48 | parser.add_argument('--model', default='OurNet', 49 | help='model name') 50 | parser.add_argument('--act', type=str, default='relu', 51 | help='activation function') 52 | parser.add_argument('--pre_train', type=str, default='.', 53 | help='pre-trained model directory') 54 | parser.add_argument('--extend', type=str, default='.', 55 | help='pre-trained model directory') 56 | parser.add_argument('--n_resblocks', type=int, default=3, 57 | help='number of residual blocks') 58 | parser.add_argument('--n_feats', type=int, default=32, 59 | help='number of feature maps') 60 | parser.add_argument('--res_scale', type=float, default=1, 61 | help='residual scaling') 62 | parser.add_argument('--shift_mean', default=True, 63 | help='subtract pixel mean from the input') 64 | parser.add_argument('--dilation', action='store_true', 65 | help='use dilated convolution') 66 | parser.add_argument('--precision', type=str, default='single', 67 | choices=('single', 'half'), 68 | help='FP precision for test (single | half)') 69 | 70 | # Training specifications 71 | parser.add_argument('--test_every', type=int, default=1500, 72 | help='do test per every N batches') 73 | parser.add_argument('--epochs', type=int, default=100, 74 | help='number of epochs to train') 75 | parser.add_argument('--batch_size', type=int, default=16, 76 | help='input batch size for training') 77 | parser.add_argument('--split_batch', type=int, default=1, 78 | help='split the batch into smaller chunks') 79 | parser.add_argument('--self_ensemble', action='store_true', 80 | help='use self-ensemble method for test') 81 | parser.add_argument('--test_only', action='store_true', 82 | help='set this option to test the model') 83 | parser.add_argument('--reset', action='store_true', 84 | help='reset the training') 85 | # Optimization specifications 86 | parser.add_argument('--lr', type=float, default=1e-3, 87 | help='learning rate') 88 | parser.add_argument('--lr_decay', type=int, default=25, 89 | help='learning rate decay per N epochs') 90 | parser.add_argument('--decay_type', type=str, default='step_100_150_200_230_260_280_300',#100_115_130_140_150_158_165_170_175_180 91 | help='learning rate decay type') 92 | parser.add_argument('--gamma', type=float, default=0.5, 93 | help='learning rate decay factor for step decay') 94 | parser.add_argument('--optimizer', default='ADAM', 95 | choices=('SGD', 'ADAM', 'RMSprop'), 96 | help='optimizer to use (SGD | ADAM | RMSprop)') 97 | parser.add_argument('--momentum', type=float, default=0.9, 98 | help='SGD momentum') 99 | parser.add_argument('--beta1', type=float, default=0.9, 100 | help='ADAM beta1') 101 | parser.add_argument('--beta2', type=float, default=0.999, 102 | help='ADAM beta2') 103 | parser.add_argument('--epsilon', type=float, default=1e-8, 104 | help='ADAM epsilon for numerical stability') 105 | parser.add_argument('--weight_decay', type=float, default=0, 106 | help='weight decay') 107 | 108 | # Loss specifications 109 | parser.add_argument('--loss', type=str, default='1*MSE', 110 | help='loss function configuration') 111 | parser.add_argument('--skip_threshold', type=float, default='1e6', 112 | help='skipping batch that has large error') 113 | 114 | # Log specifications 115 | parser.add_argument('--save', type=str, default='RCDNet_syn', 116 | help='file name to save') 117 | parser.add_argument('--load', type=str, default='.', 118 | help='file name to load') 119 | parser.add_argument('--resume', type=int, default=0, 120 | help='resume from specific checkpoint') 121 | parser.add_argument('--save_models', action='store_true', 122 | help='save all intermediate models') 123 | parser.add_argument('--print_every', type=int, default=100, 124 | help='how many batches to wait before logging training status') 125 | parser.add_argument('--save_results', action='store_true', 126 | help='save output results') 127 | 128 | args = parser.parse_args() 129 | template.set_template(args) 130 | 131 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 132 | 133 | if args.epochs == 0: 134 | args.epochs = 1e8 135 | 136 | for arg in vars(args): 137 | if vars(args)[arg] == 'True': 138 | vars(args)[arg] = True 139 | elif vars(args)[arg] == 'False': 140 | vars(args)[arg] = False 141 | 142 | -------------------------------------------------------------------------------- /code/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.lr_decay = 100 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.lr_decay = 500 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.lr_decay = 150 39 | 40 | if args.template.find('RCAN') >= 0: 41 | args.model = 'RCAN' 42 | args.n_resgroups = 10 43 | args.n_resblocks = 20 44 | args.n_feats = 64 45 | args.chop = True 46 | 47 | -------------------------------------------------------------------------------- /code/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | import utility 5 | import IPython 6 | import torch 7 | from torch.autograd import Variable 8 | from tqdm import tqdm 9 | import scipy.io as sio 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import pylab 13 | import numpy as np 14 | from torchvision.transforms import ToTensor,ToPILImage 15 | class Trainer(): 16 | def __init__(self, args, loader, my_model, my_loss, ckp): 17 | self.args = args 18 | self.scale = args.scale 19 | self.ckp = ckp 20 | self.loader_train = loader.loader_train 21 | self.loader_test = loader.loader_test 22 | self.model = my_model 23 | self.loss = my_loss 24 | self.optimizer = utility.make_optimizer(args, self.model) 25 | self.scheduler = utility.make_scheduler(args, self.optimizer) 26 | if self.args.load != '.': 27 | print(ckp.dir) 28 | assert os.path.exists(ckp.dir+'optimizer.pt') 29 | print('==============',ckp.dir+'optimizer.pt') 30 | self.optimizer.load_state_dict( 31 | torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 32 | ) 33 | for _ in range(len(ckp.log)): self.scheduler.step() 34 | 35 | self.error_last = 1e8 36 | 37 | def train(self): 38 | # print('======>trian') 39 | self.scheduler.step() 40 | self.loss.step() 41 | epoch = self.scheduler.last_epoch + 1 42 | lr = self.scheduler.get_lr()[0] 43 | self.ckp.write_log( 44 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 45 | ) 46 | self.loss.start_log() 47 | self.model.train() 48 | 49 | timer_data, timer_model = utility.timer(), utility.timer() 50 | 51 | for batch, (lr, hr, idx_scale) in enumerate(self.loader_train): 52 | lr, hr = self.prepare(lr, hr) 53 | timer_data.hold() 54 | timer_model.tic() 55 | self.model.zero_grad() 56 | self.optimizer.zero_grad() 57 | out3, out2, out1 = self.model(lr, idx_scale) 58 | loss = self.loss(out3, hr) + self.loss(out2, hr) + self.loss(out1, hr) 59 | 60 | if loss.item() < self.args.skip_threshold * self.error_last: 61 | loss.backward() 62 | ttt = 0 63 | self.optimizer.step() 64 | else: 65 | print('Skip this batch {}! (Loss: {})'.format( 66 | batch + 1, loss.item() 67 | )) 68 | timer_model.hold() 69 | if (batch + 1) % self.args.print_every == 0: 70 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( 71 | (batch + 1) * self.args.batch_size, 72 | len(self.loader_train.dataset), 73 | self.loss.display_loss(batch), 74 | timer_model.release(), 75 | timer_data.release())) 76 | timer_data.tic() 77 | 78 | self.loss.end_log(len(self.loader_train)) 79 | self.error_last = self.loss.log[-1, -1] 80 | 81 | def test(self): 82 | # print('=========eval') 83 | epoch = self.scheduler.last_epoch + 1 84 | self.ckp.write_log('\nEvaluation:') 85 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 86 | self.model.eval() 87 | 88 | timer_test = utility.timer() 89 | with torch.no_grad(): 90 | for idx_scale, scale in enumerate(self.scale): 91 | eval_acc = 0 92 | self.loader_test.dataset.set_scale(idx_scale) 93 | tqdm_test = tqdm(self.loader_test, ncols=80) 94 | for idx_img, (lr, hr, filename) in enumerate(tqdm_test): 95 | filename = filename[0] 96 | no_eval = (hr.nelement() == 1) 97 | if not no_eval: 98 | lr, hr = self.prepare(lr, hr) 99 | else: 100 | lr, = self.prepare(lr) 101 | 102 | sr,_,_ = self.model(lr, idx_scale) 103 | sr = utility.quantize(sr, self.args.rgb_range) # restored background at the last stage 104 | save_list = [sr] 105 | if not no_eval: 106 | eval_acc += utility.calc_psnr( 107 | sr, hr, scale, self.args.rgb_range, 108 | benchmark=self.loader_test.dataset.benchmark 109 | ) 110 | save_list.extend([lr, hr]) 111 | 112 | if self.args.save_results: 113 | self.ckp.save_results(filename, save_list, scale) 114 | 115 | self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test) 116 | best = self.ckp.log.max(0) 117 | self.ckp.write_log( 118 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 119 | self.args.data_test, 120 | scale, 121 | self.ckp.log[-1, idx_scale], 122 | best[0][idx_scale], 123 | best[1][idx_scale] + 1 124 | ) 125 | ) 126 | 127 | self.ckp.write_log( 128 | 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 129 | ) 130 | if not self.args.test_only: 131 | self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) 132 | 133 | def prepare(self, *args): 134 | device = torch.device('cpu' if self.args.cpu else 'cuda:0') 135 | def _prepare(tensor): 136 | if self.args.precision == 'half': tensor = tensor.half() 137 | return tensor.to(device) 138 | 139 | return [_prepare(a) for a in args] 140 | 141 | def terminate(self): 142 | if self.args.test_only: 143 | self.test() 144 | return True 145 | else: 146 | epoch = self.scheduler.last_epoch + 1 147 | return epoch >= self.args.epochs 148 | -------------------------------------------------------------------------------- /code/util/rlutrans.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | # from thop import profile 6 | from util.tools import extract_image_patches, reduce_mean, reduce_sum, same_padding, reverse_patches 7 | import pdb 8 | import math 9 | 10 | class Mlp(nn.Module): 11 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.): 12 | super().__init__() 13 | out_features = out_features or in_features 14 | hidden_features = hidden_features or in_features//4 15 | self.fc1 = nn.Linear(in_features, hidden_features) 16 | self.act = act_layer() 17 | self.fc2 = nn.Linear(hidden_features, out_features) 18 | self.drop = nn.Dropout(drop) 19 | 20 | def forward(self, x): 21 | x = self.fc1(x) 22 | x = self.act(x) 23 | x = self.drop(x) 24 | x = self.fc2(x) 25 | x = self.drop(x) 26 | return x 27 | 28 | class EffAttention(nn.Module): 29 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 30 | super().__init__() 31 | self.num_heads = num_heads 32 | head_dim = dim // num_heads 33 | self.scale = qk_scale or head_dim ** -0.5 34 | 35 | self.reduce = nn.Linear(dim, dim//2, bias=qkv_bias) 36 | self.qkv = nn.Linear(dim//2, dim//2 * 3, bias=qkv_bias) 37 | self.proj = nn.Linear(dim//2, dim) 38 | self.attn_drop = nn.Dropout(attn_drop) 39 | 40 | def forward(self, x): 41 | x = self.reduce(x) 42 | B, N, C = x.shape 43 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 44 | q, k, v = qkv[0], qkv[1], qkv[2] 45 | 46 | q_all = torch.split(q, math.ceil(N//16), dim=-2) 47 | k_all = torch.split(k, math.ceil(N//16), dim=-2) 48 | v_all = torch.split(v, math.ceil(N//16), dim=-2) 49 | 50 | output = [] 51 | for q,k,v in zip(q_all, k_all, v_all): 52 | attn = (q @ k.transpose(-2, -1)) * self.scale #16*8*37*37 53 | attn = attn.softmax(dim=-1) 54 | attn = self.attn_drop(attn) 55 | trans_x = (attn @ v).transpose(1, 2) #.reshape(B, N, C) 56 | output.append(trans_x) 57 | x = torch.cat(output,dim=1) 58 | x = x.reshape(B,N,C) 59 | x = self.proj(x) 60 | return x 61 | 62 | class TransBlock(nn.Module): 63 | def __init__( 64 | self, n_feat = 64,dim=64, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 65 | drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm): 66 | super(TransBlock, self).__init__() 67 | self.dim = dim 68 | self.atten = EffAttention(self.dim, num_heads=8, qkv_bias=False, qk_scale=None, \ 69 | attn_drop=0., proj_drop=0.) 70 | self.norm1 = nn.LayerNorm(self.dim) 71 | self.mlp = Mlp(in_features=dim, hidden_features=dim//4, act_layer=act_layer, drop=drop) 72 | self.norm2 = nn.LayerNorm(self.dim) 73 | 74 | def forward(self, x): 75 | B = x.shape[0] 76 | 77 | x = x + self.atten(self.norm1(x)) 78 | x = x + self.mlp(self.norm2(x)) 79 | return x 80 | -------------------------------------------------------------------------------- /code/util/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch.nn.functional as F 7 | 8 | def normalize(x): 9 | return x.mul_(2).add_(-1) 10 | 11 | def same_padding(images, ksizes, strides, rates): 12 | assert len(images.size()) == 4 13 | batch_size, channel, rows, cols = images.size() 14 | out_rows = (rows + strides[0] - 1) // strides[0] 15 | out_cols = (cols + strides[1] - 1) // strides[1] 16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1 17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1 18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) 19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) 20 | # Pad the input 21 | padding_top = int(padding_rows / 2.) 22 | padding_left = int(padding_cols / 2.) 23 | padding_bottom = padding_rows - padding_top 24 | padding_right = padding_cols - padding_left 25 | paddings = (padding_left, padding_right, padding_top, padding_bottom) 26 | images = torch.nn.ZeroPad2d(paddings)(images) 27 | return images 28 | 29 | 30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'): 31 | """ 32 | Extract patches from images and put them in the C output dimension. 33 | :param padding: 34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 36 | each dimension of images 37 | :param strides: [stride_rows, stride_cols] 38 | :param rates: [dilation_rows, dilation_cols] 39 | :return: A Tensor 40 | """ 41 | assert len(images.size()) == 4 42 | assert padding in ['same', 'valid'] 43 | batch_size, channel, height, width = images.size() 44 | 45 | if padding == 'same': 46 | images = same_padding(images, ksizes, strides, rates) 47 | elif padding == 'valid': 48 | pass 49 | else: 50 | raise NotImplementedError('Unsupported padding type: {}.\ 51 | Only "same" or "valid" are supported.'.format(padding)) 52 | 53 | unfold = torch.nn.Unfold(kernel_size=ksizes, 54 | dilation=rates, 55 | padding=0, 56 | stride=strides) 57 | patches = unfold(images) 58 | return patches # [N, C*k*k, L], L is the total number of such blocks 59 | 60 | def reverse_patches(images, out_size, ksizes, strides, padding): 61 | """ 62 | Extract patches from images and put them in the C output dimension. 63 | :param padding: 64 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape 65 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for 66 | each dimension of images 67 | :param strides: [stride_rows, stride_cols] 68 | :param rates: [dilation_rows, dilation_cols] 69 | :return: A Tensor 70 | """ 71 | unfold = torch.nn.Fold(output_size = out_size, 72 | kernel_size=ksizes, 73 | dilation=1, 74 | padding=padding, 75 | stride=strides) 76 | patches = unfold(images) 77 | return patches # [N, C*k*k, L], L is the total number of such blocks 78 | def reduce_mean(x, axis=None, keepdim=False): 79 | if not axis: 80 | axis = range(len(x.shape)) 81 | for i in sorted(axis, reverse=True): 82 | x = torch.mean(x, dim=i, keepdim=keepdim) 83 | return x 84 | 85 | 86 | def reduce_std(x, axis=None, keepdim=False): 87 | if not axis: 88 | axis = range(len(x.shape)) 89 | for i in sorted(axis, reverse=True): 90 | x = torch.std(x, dim=i, keepdim=keepdim) 91 | return x 92 | 93 | 94 | def reduce_sum(x, axis=None, keepdim=False): 95 | if not axis: 96 | axis = range(len(x.shape)) 97 | for i in sorted(axis, reverse=True): 98 | x = torch.sum(x, dim=i, keepdim=keepdim) 99 | return x -------------------------------------------------------------------------------- /code/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from functools import reduce 6 | 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | 11 | import numpy as np 12 | import scipy.misc as misc 13 | 14 | import torch 15 | import torch.optim as optim 16 | import torch.optim.lr_scheduler as lrs 17 | 18 | class timer(): 19 | def __init__(self): 20 | self.acc = 0 21 | self.tic() 22 | 23 | def tic(self): 24 | self.t0 = time.time() 25 | 26 | def toc(self): 27 | return time.time() - self.t0 28 | 29 | def hold(self): 30 | self.acc += self.toc() 31 | 32 | def release(self): 33 | ret = self.acc 34 | self.acc = 0 35 | 36 | return ret 37 | 38 | def reset(self): 39 | self.acc = 0 40 | 41 | class checkpoint(): 42 | def __init__(self, args): 43 | self.args = args 44 | self.ok = True 45 | self.log = torch.Tensor() 46 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 47 | 48 | if args.load == '.': 49 | if args.save == '.': args.save = now 50 | self.dir = '../experiment/' + args.save 51 | else: 52 | self.dir = '../experiment/' + args.load 53 | if not os.path.exists(self.dir): 54 | args.load = '.' 55 | else: 56 | self.log = torch.load(self.dir + '/psnr_log.pt') 57 | print('Continue from epoch {}...'.format(len(self.log))) 58 | 59 | if args.reset: 60 | os.system('rm -rf ' + self.dir) 61 | args.load = '.' 62 | 63 | def _make_dir(path): 64 | if not os.path.exists(path): os.makedirs(path) 65 | 66 | _make_dir(self.dir) 67 | _make_dir(self.dir + '/model') 68 | _make_dir(self.dir + '/results') 69 | 70 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' 71 | self.log_file = open(self.dir + '/log.txt', open_type) 72 | with open(self.dir + '/config.txt', open_type) as f: 73 | f.write(now + '\n\n') 74 | for arg in vars(args): 75 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 76 | f.write('\n') 77 | 78 | def save(self, trainer, epoch, is_best=False): 79 | trainer.model.save(self.dir, epoch, is_best=is_best) 80 | trainer.loss.save(self.dir) 81 | trainer.loss.plot_loss(self.dir, epoch) 82 | 83 | self.plot_psnr(epoch) 84 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) 85 | torch.save( 86 | trainer.optimizer.state_dict(), 87 | os.path.join(self.dir, 'optimizer.pt') 88 | ) 89 | 90 | def add_log(self, log): 91 | self.log = torch.cat([self.log, log]) 92 | 93 | def write_log(self, log, refresh=False): 94 | print(log) 95 | self.log_file.write(log + '\n') 96 | if refresh: 97 | self.log_file.close() 98 | self.log_file = open(self.dir + '/log.txt', 'a') 99 | 100 | def done(self): 101 | self.log_file.close() 102 | 103 | def plot_psnr(self, epoch): 104 | axis = np.linspace(1, epoch, epoch) 105 | label = 'SR on {}'.format(self.args.data_test) 106 | fig = plt.figure() 107 | plt.title(label) 108 | for idx_scale, scale in enumerate(self.args.scale): 109 | plt.plot( 110 | axis, 111 | self.log[:, idx_scale].numpy(), 112 | label='Scale {}'.format(scale) 113 | ) 114 | plt.legend() 115 | plt.xlabel('Epochs') 116 | plt.ylabel('PSNR') 117 | plt.grid(True) 118 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) 119 | plt.close(fig) 120 | 121 | def save_results(self, filename, save_list, scale): 122 | filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale) 123 | postfix = ('SR','LR', 'HR') 124 | for v, p in zip(save_list, postfix): 125 | normalized = v[0].data.mul(255 / self.args.rgb_range) 126 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 127 | misc.imsave('{}{}.png'.format(filename, p), ndarr) 128 | def quantize(img, rgb_range): 129 | pixel_range = 255 / rgb_range 130 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 131 | 132 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): 133 | diff = (sr - hr).data.div(rgb_range) 134 | if benchmark: 135 | shave = scale 136 | if diff.size(1) > 1: 137 | convert = diff.new(1, 3, 1, 1) 138 | convert[0, 0, 0, 0] = 65.738 139 | convert[0, 1, 0, 0] = 129.057 140 | convert[0, 2, 0, 0] = 25.064 141 | diff.mul_(convert).div_(256) 142 | diff = diff.sum(dim=1, keepdim=True) 143 | else: 144 | shave = scale + 6 145 | 146 | valid = diff[:, :, shave:-shave, shave:-shave] 147 | mse = valid.pow(2).mean() 148 | 149 | return -10 * math.log10(mse) 150 | 151 | def make_optimizer(args, my_model): 152 | trainable = filter(lambda x: x.requires_grad, my_model.parameters()) 153 | 154 | if args.optimizer == 'SGD': 155 | optimizer_function = optim.SGD 156 | kwargs = {'momentum': args.momentum} 157 | elif args.optimizer == 'ADAM': 158 | optimizer_function = optim.Adam 159 | kwargs = { 160 | 'betas': (args.beta1, args.beta2), 161 | 'eps': args.epsilon 162 | } 163 | elif args.optimizer == 'RMSprop': 164 | optimizer_function = optim.RMSprop 165 | kwargs = {'eps': args.epsilon} 166 | 167 | kwargs['lr'] = args.lr 168 | kwargs['weight_decay'] = args.weight_decay 169 | 170 | return optimizer_function(trainable, **kwargs) 171 | 172 | def make_scheduler(args, my_optimizer): 173 | if args.decay_type == 'step': 174 | scheduler = lrs.StepLR( 175 | my_optimizer, 176 | step_size=args.lr_decay, 177 | gamma=args.gamma 178 | ) 179 | elif args.decay_type.find('step') >= 0: 180 | milestones = args.decay_type.split('_') 181 | milestones.pop(0) 182 | milestones = list(map(lambda x: int(x), milestones)) 183 | scheduler = lrs.MultiStepLR( 184 | my_optimizer, 185 | milestones=milestones, 186 | gamma=args.gamma 187 | ) 188 | 189 | return scheduler 190 | 191 | -------------------------------------------------------------------------------- /experiment/HCT-FFN/model/model_best_Rain100H.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/experiment/HCT-FFN/model/model_best_Rain100H.pt -------------------------------------------------------------------------------- /experiment/HCT-FFN/model/model_best_Rain100L.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/experiment/HCT-FFN/model/model_best_Rain100L.pt -------------------------------------------------------------------------------- /figure/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/figure/network.png --------------------------------------------------------------------------------