├── README.md
├── code
├── data
│ ├── __init__.py
│ ├── benchmark.py
│ ├── common.py
│ ├── demo.py
│ ├── rainheavy.py
│ ├── rainheavytest.py
│ └── srdata.py
├── dataloader.py
├── dataset
│ ├── test
│ │ ├── data
│ │ │ ├── norain-1.png
│ │ │ └── norain-2.png
│ │ └── label
│ │ │ ├── norain-1.png
│ │ │ └── norain-2.png
│ └── train
│ │ ├── data
│ │ ├── norain-1.png
│ │ └── norain-2.png
│ │ └── label
│ │ ├── norain-1.png
│ │ └── norain-2.png
├── loss
│ ├── __init__.py
│ └── ssim.py
├── main.py
├── model
│ ├── __init__.py
│ ├── common.py
│ └── hct-ffn.py
├── option.py
├── template.py
├── trainer.py
├── util
│ ├── rlutrans.py
│ └── tools.py
└── utility.py
├── experiment
└── HCT-FFN
│ └── model
│ ├── model_best_Rain100H.pt
│ └── model_best_Rain100L.pt
└── figure
└── network.png
/README.md:
--------------------------------------------------------------------------------
1 | # Hybrid CNN-Transformer Feature Fusion for Single Image Deraining
2 |
3 | Xiang Chen, Jinshan Pan, Jiyang Lu, Zhentao Fan, Hao Li
4 |
5 |
6 |
7 | > **Abstract:** *Since rain streaks exhibit diverse geometric appearances and irregular overlapped phenomena, these complex characteristics challenge the design of an effective single image deraining model. To this end, rich local-global information representations are increasingly indispensable for better satisfying rain removal. In this paper, we propose a lightweight Hybrid CNN-Transformer Feature Fusion Network (dubbed as HCT-FFN) in a stage-by-stage progressive manner, which can harmonize these two architectures to help image restoration by leveraging their individual learning strengths. Specifically, we stack a sequence of the degradation-aware mixture of experts (DaMoE) modules in the CNN-based stage, where appropriate local experts adaptively enable the model to emphasize spatially-varying rain distribution features. As for the Transformer-based stage, a background-aware vision Transformer (BaViT) module is employed to complement spatially-long feature dependencies of images, so as to achieve global texture recovery while preserving the required structure. Considering the indeterminate knowledge discrepancy among CNN features and Transformer features, we introduce an interactive fusion branch at adjacent stages to further facilitate the reconstruction of high-quality deraining results. Extensive evaluations show the effectiveness and extensibility of our developed HCT-FFN.*
8 |
9 |
10 | ## Network Architecture
11 |
12 |
13 |
14 | ## Installation
15 | * PyTorch == 0.4.1
16 | * torchvision == 0.2.0
17 | * Python == 3.6.0
18 | * imageio == 2.5.0
19 | * numpy == 1.14.0
20 | * opencv-python
21 | * scikit-image == 0.13.0
22 | * tqdm == 4.32.2
23 | * scipy == 1.2.1
24 | * matplotlib == 3.1.1
25 | * ipython == 7.6.1
26 | * h5py == 2.10.0
27 |
28 | ## Training
29 | 1. Modify data path in code/data/rainheavy.py and code/data/rainheavytest.py
30 | datapath/data/\*\*\*.png
31 | datapath/label/\*\*\*.png
32 |
33 | 2. Begining Training:
34 | ```
35 | $ cd ./code/
36 | $ python main.py --save HCT-FFN --model hct-ffn --scale 2 --epochs 400 --batch_size 4 --patch_size 128 --data_train RainHeavy --n_threads 0 --data_test RainHeavyTest --data_range 1-1800/1-100 --loss 1*MSE+0.2*SSIM --save_results --lr 1e-4 --n_feats 32 --n_resblocks 3
37 | ```
38 |
39 | ## Testing
40 | ```
41 | $ cd ./code/
42 | $ python main.py --data_test RainHeavyTest --ext img --scale 2 --data_range 1-1800/1-100 --pre_train ../experiment/HCT-FFN/model/model_best.pt --model hct-ffn --test_only --save_results --save HCT-FFN_test
43 | ```
44 | The pre-trained models are available at ./experiment/HCT-FFN/model/.
45 |
46 | ## Performance Evaluation
47 |
48 | The PSNR and SSIM results are computed by using this [Matlab Code](https://github.com/hongwang01/RCDNet/tree/master/Performance_evaluation), based on Y channel of YCbCr space.
49 |
50 | ## Visual Deraining Results
51 |
52 | https://drive.google.com/drive/folders/1soXkMuQEQmJZmxZBIlo8dfCHM0RlGtxz?usp=sharing
53 |
54 | ## Citation
55 | If you are interested in this work, please consider citing:
56 |
57 | @inproceedings{chen2023hybrid,
58 | title={Hybrid CNN-Transformer Feature Fusion for Single Image Deraining},
59 | author={Chen, Xiang and Pan, Jinshan and Lu, Jiyang and Fan, Zhentao and Li, Hao},
60 | booktitle={AAAI},
61 | year={2023}
62 | }
63 |
64 | ## Acknowledgment
65 | This code is based on the [SPDNet](https://github.com/Joyies/SPDNet). Thanks for sharing !
66 |
--------------------------------------------------------------------------------
/code/data/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 |
3 | from dataloader import MSDataLoader
4 | from torch.utils.data.dataloader import default_collate
5 |
6 | class Data:
7 | def __init__(self, args):
8 | self.loader_train = None
9 | if not args.test_only:
10 | module_train = import_module('data.' + args.data_train.lower())
11 | trainset = getattr(module_train, args.data_train)(args)
12 | self.loader_train = MSDataLoader(
13 | args,
14 | trainset,
15 | batch_size=args.batch_size,
16 | shuffle=True,
17 | pin_memory=not args.cpu
18 | )
19 |
20 | if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']:
21 | module_test = import_module('data.benchmark')
22 | testset = getattr(module_test, 'Benchmark')(args, train=False)
23 | else:
24 | module_test = import_module('data.' + args.data_test.lower())
25 | testset = getattr(module_test, args.data_test)(args, train=False)
26 |
27 | self.loader_test = MSDataLoader(
28 | args,
29 | testset,
30 | batch_size=1,
31 | shuffle=False,
32 | pin_memory=not args.cpu
33 | )
34 |
35 |
--------------------------------------------------------------------------------
/code/data/benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class Benchmark(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(Benchmark, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 |
17 | def _set_filesystem(self, dir_data):
18 | self.apath = os.path.join(dir_data, 'benchmark', self.name)
19 | self.dir_hr = os.path.join(self.apath, 'HR')
20 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
21 | self.ext = ('', '.jpg')
22 |
23 |
--------------------------------------------------------------------------------
/code/data/common.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import skimage.color as sc
5 |
6 | import torch
7 | from torchvision import transforms
8 |
9 | def get_patch(*args, patch_size=96, scale=1, multi_scale=False):
10 | ih, iw = args[0].shape[:2]
11 |
12 | #p = scale if multi_scale else 1
13 | #tp = p * patch_size
14 | #ip = tp // scale
15 |
16 | tp = patch_size
17 | ip = patch_size
18 |
19 |
20 | ix = random.randrange(0, iw - ip + 1)
21 | iy = random.randrange(0, ih - ip + 1)
22 |
23 | #tx, ty = scale * ix, scale * iy
24 | tx, ty = ix, iy
25 |
26 | ret = [
27 | args[0][iy:iy + ip, ix:ix + ip, :],
28 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
29 | ]
30 |
31 | return ret
32 |
33 | def set_channel(*args, n_channels=3):
34 | def _set_channel(img):
35 | if img.ndim == 2:
36 | img = np.expand_dims(img, axis=2)
37 |
38 | c = img.shape[2]
39 | if n_channels == 1 and c == 3:
40 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
41 | elif n_channels == 3 and c == 1:
42 | img = np.concatenate([img] * n_channels, 2)
43 |
44 | return img
45 |
46 | return [_set_channel(a) for a in args]
47 |
48 | def np2Tensor(*args, rgb_range=255):
49 | def _np2Tensor(img):
50 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
51 | tensor = torch.from_numpy(np_transpose).float()
52 | tensor.mul_(rgb_range / 255)
53 |
54 | return tensor
55 |
56 | return [_np2Tensor(a) for a in args]
57 |
58 | def augment(*args, hflip=True, rot=True):
59 | hflip = hflip and random.random() < 0.5
60 | vflip = rot and random.random() < 0.5
61 | rot90 = rot and random.random() < 0.5
62 |
63 | def _augment(img):
64 | if hflip: img = img[:, ::-1, :]
65 | # if vflip: img = img[::-1, :, :]
66 | # if rot90: img = img.transpose(1, 0, 2)
67 |
68 | return img
69 |
70 | return [_augment(a) for a in args]
71 |
72 |
--------------------------------------------------------------------------------
/code/data/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 |
5 | import numpy as np
6 | import imageio
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class Demo(data.Dataset):
12 | def __init__(self, args, name='Demo', train=False, benchmark=False):
13 | self.args = args
14 | self.name = name
15 | self.scale = args.scale
16 | self.idx_scale = 0
17 | self.train = False
18 | self.do_eval = False
19 | self.benchmark = benchmark
20 |
21 | self.filelist = []
22 | for f in os.listdir(args.dir_demo):
23 | if f.find('.png') >= 0 or f.find('.jp') >= 0:
24 | self.filelist.append(os.path.join(args.dir_demo, f))
25 | self.filelist.sort()
26 |
27 | def __getitem__(self, idx):
28 | filename = os.path.split(self.filelist[idx])[-1]
29 | filename, _ = os.path.splitext(filename)
30 | lr = imageio.imread(self.filelist[idx])
31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33 |
34 | return lr_t, -1, filename
35 |
36 | def __len__(self):
37 | return len(self.filelist)
38 |
39 | def set_scale(self, idx_scale):
40 | self.idx_scale = idx_scale
41 |
42 |
--------------------------------------------------------------------------------
/code/data/rainheavy.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 |
4 | class RainHeavy(srdata.SRData):
5 | def __init__(self, args, name='RainHeavy', train=True, benchmark=False):
6 | super(RainHeavy, self).__init__(
7 | args, name=name, train=train, benchmark=benchmark
8 | )
9 |
10 | def _scan(self):
11 | names_hr, names_lr = super(RainHeavy, self)._scan()
12 | names_hr = names_hr[self.begin - 1:self.end]
13 | names_lr = [n[self.begin - 1:self.end] for n in names_lr]
14 |
15 | return names_hr, names_lr
16 |
17 | def _set_filesystem(self, dir_data):
18 | super(RainHeavy, self)._set_filesystem(dir_data)
19 | self.apath = './dataset/train/' # train data path
20 |
21 | print(self.apath)
22 | self.dir_hr = os.path.join(self.apath, 'label')
23 | self.dir_lr = os.path.join(self.apath, 'data')
24 |
25 |
--------------------------------------------------------------------------------
/code/data/rainheavytest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 |
4 | class RainHeavyTest(srdata.SRData):
5 | def __init__(self, args, name='RainHeavyTest', train=True, benchmark=False):
6 | super(RainHeavyTest, self).__init__(
7 | args, name=name, train=train, benchmark=benchmark
8 | )
9 |
10 | def _scan(self):
11 | names_hr, names_lr = super(RainHeavyTest, self)._scan()
12 | names_hr = names_hr[self.begin - 1:self.end]
13 | names_lr = [n[self.begin - 1:self.end] for n in names_lr]
14 |
15 | return names_hr, names_lr
16 |
17 | def _set_filesystem(self, dir_data):
18 | super(RainHeavyTest, self)._set_filesystem(dir_data)
19 | self.apath = './dataset/test/' # test data path
20 | print(self.apath)
21 | self.dir_hr = os.path.join(self.apath, 'label')
22 | self.dir_lr = os.path.join(self.apath, 'data')
23 |
24 |
--------------------------------------------------------------------------------
/code/data/srdata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | from data import common
5 | import pickle
6 | import numpy as np
7 | import imageio
8 |
9 | import torch
10 | import torch.utils.data as data
11 |
12 | class SRData(data.Dataset):
13 | def __init__(self, args, name='', train=True, benchmark=False):
14 | self.args = args
15 | self.name = name
16 | self.train = train
17 | self.split = 'train' if train else 'test'
18 | self.do_eval = True
19 | self.benchmark = benchmark
20 | self.scale = args.scale
21 | self.idx_scale = 0
22 |
23 | data_range = [r.split('-') for r in args.data_range.split('/')]
24 | if train:
25 | data_range = data_range[0]
26 | else:
27 | if args.test_only and len(data_range) == 1:
28 | data_range = data_range[0]
29 | else:
30 | data_range = data_range[1]
31 | self.begin, self.end = list(map(lambda x: int(x), data_range))
32 | self._set_filesystem(args.dir_data)
33 | if args.ext.find('img') < 0:
34 | path_bin = os.path.join(self.apath, 'bin')
35 | os.makedirs(path_bin, exist_ok=True)
36 |
37 | list_hr, list_lr = self._scan()
38 | if args.ext.find('bin') >= 0:
39 | # Binary files are stored in 'bin' folder
40 | # If the binary file exists, load it. If not, make it.
41 | list_hr, list_lr = self._scan()
42 | self.images_hr = self._check_and_load(
43 | args.ext, list_hr, self._name_hrbin()
44 | )
45 | self.images_lr = [
46 | self._check_and_load(args.ext, l, self._name_lrbin(s)) \
47 | for s, l in zip(self.scale, list_lr)
48 | ]
49 | else:
50 | if args.ext.find('img') >= 0 or benchmark:
51 | self.images_hr, self.images_lr = list_hr, list_lr
52 | elif args.ext.find('sep') >= 0:
53 | os.makedirs(
54 | self.dir_hr.replace(self.apath, path_bin),
55 | exist_ok=True
56 | )
57 | for s in self.scale:
58 | os.makedirs(
59 | os.path.join(
60 | self.dir_lr.replace(self.apath, path_bin),
61 | 'X{}'.format(s)
62 | ),
63 | exist_ok=True
64 | )
65 |
66 | self.images_hr, self.images_lr = [], [[] for _ in self.scale]
67 | for h in list_hr:
68 | b = h.replace(self.apath, path_bin)
69 | b = b.replace(self.ext[0], '.pt')
70 | self.images_hr.append(b)
71 | self._check_and_load(
72 | args.ext, [h], b, verbose=True, load=False
73 | )
74 |
75 | for i, ll in enumerate(list_lr):
76 | for l in ll:
77 | b = l.replace(self.apath, path_bin)
78 | b = b.replace(self.ext[1], '.pt')
79 | self.images_lr[i].append(b)
80 | self._check_and_load(
81 | args.ext, [l], b, verbose=True, load=False
82 | )
83 |
84 | if train:
85 | self.repeat \
86 | = args.test_every // (len(self.images_hr) // args.batch_size)
87 |
88 |
89 | # Below functions as used to prepare images
90 | def _scan(self):
91 | names_hr = sorted(
92 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
93 | )
94 | names_lr = [[] for _ in self.scale]
95 | for f in names_hr:
96 | #f = f.replace('.png','x2.png')
97 | f = f.replace('.png','.png')
98 | filename, _ = os.path.splitext(os.path.basename(f))
99 | for si, s in enumerate(self.scale):
100 | names_lr[si].append(os.path.join(
101 | self.dir_lr, '{}{}'.format(
102 | filename, self.ext[1]
103 | )
104 | ))
105 |
106 | return names_hr, names_lr
107 |
108 | def _set_filesystem(self, dir_data):
109 | self.apath = os.path.join(dir_data, self.name)
110 | self.dir_hr = os.path.join(self.apath, 'HR')
111 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
112 | self.ext = ('.png', '.png')
113 |
114 | def _name_hrbin(self):
115 | return os.path.join(
116 | self.apath,
117 | 'bin',
118 | '{}_bin_HR.pt'.format(self.split)
119 | )
120 |
121 | def _name_lrbin(self, scale):
122 | return os.path.join(
123 | self.apath,
124 | 'bin',
125 | '{}_bin_LR.pt'.format(self.split)
126 | )
127 |
128 | def _check_and_load(self, ext, l, f, verbose=True, load=True):
129 | if os.path.isfile(f) and ext.find('reset') < 0:
130 | if load:
131 | if verbose: print('Loading {}...'.format(f))
132 | with open(f, 'rb') as _f: ret = pickle.load(_f)
133 | return ret
134 | else:
135 | return None
136 | else:
137 | if verbose:
138 | if ext.find('reset') >= 0:
139 | print('Making a new binary: {}'.format(f))
140 | else:
141 | print('{} does not exist. Now making binary...'.format(f))
142 | b = [{
143 | 'name': os.path.splitext(os.path.basename(_l))[0],
144 | 'image': imageio.imread(_l)
145 | } for _l in l]
146 | with open(f, 'wb') as _f: pickle.dump(b, _f)
147 | return b
148 |
149 | def __getitem__(self, idx):
150 | lr, hr, filename = self._load_file(idx)
151 | lr, hr = self.get_patch(lr, hr)
152 | lr, hr = common.set_channel(lr, hr, n_channels=self.args.n_colors)
153 | lr_tensor, hr_tensor = common.np2Tensor(
154 | lr, hr, rgb_range=self.args.rgb_range
155 | )
156 |
157 | return lr_tensor, hr_tensor, filename
158 |
159 | def __len__(self):
160 | if self.train:
161 | return len(self.images_hr) * self.repeat
162 | else:
163 | return len(self.images_hr)
164 |
165 | def _get_index(self, idx):
166 | if self.train:
167 | return idx % len(self.images_hr)
168 | else:
169 | return idx
170 |
171 | def _load_file(self, idx):
172 | idx = self._get_index(idx)
173 | f_hr = self.images_hr[idx]
174 | f_lr = self.images_lr[self.idx_scale][idx]
175 |
176 | if self.args.ext.find('bin') >= 0:
177 | filename = f_hr['name']
178 | hr = f_hr['image']
179 | lr = f_lr['image']
180 | else:
181 | filename, _ = os.path.splitext(os.path.basename(f_hr))
182 | if self.args.ext == 'img' or self.benchmark:
183 | hr = imageio.imread(f_hr)
184 | lr = imageio.imread(f_lr)
185 | elif self.args.ext.find('sep') >= 0:
186 | with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image']
187 | with open(f_lr, 'rb') as _f: lr = np.load(_f)[0]['image']
188 |
189 | return lr, hr, filename
190 |
191 | def get_patch(self, lr, hr):
192 | scale = self.scale[self.idx_scale]
193 | multi_scale = len(self.scale) > 1
194 | if self.train:
195 | # print('****preparte data****')
196 | lr, hr = common.get_patch(
197 | lr,
198 | hr,
199 | patch_size=self.args.patch_size,
200 | scale=scale,
201 | multi_scale=multi_scale
202 | )
203 | if not self.args.no_augment:
204 | # print('****use augment****')
205 | lr, hr = common.augment(lr, hr)
206 | else:
207 | ih, iw = lr.shape[:2]
208 | hr = hr[0:ih, 0:iw]
209 | #hr = hr[0:ih * scale, 0:iw * scale]
210 |
211 | return lr, hr
212 |
213 | def set_scale(self, idx_scale):
214 | self.idx_scale = idx_scale
215 |
216 |
--------------------------------------------------------------------------------
/code/dataloader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import threading
3 | import queue
4 | import random
5 | import collections
6 |
7 | import torch
8 | import torch.multiprocessing as multiprocessing
9 |
10 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
11 | _remove_worker_pids, _error_if_any_worker_fails
12 | from torch.utils.data.dataloader import DataLoader
13 | from torch.utils.data.dataloader import _DataLoaderIter
14 |
15 | from torch.utils.data.dataloader import ExceptionWrapper
16 | from torch.utils.data.dataloader import _use_shared_memory
17 | from torch.utils.data.dataloader import _worker_manager_loop
18 | from torch.utils.data.dataloader import numpy_type_map
19 | from torch.utils.data.dataloader import default_collate
20 | from torch.utils.data.dataloader import pin_memory_batch
21 | from torch.utils.data.dataloader import _SIGCHLD_handler_set
22 | from torch.utils.data.dataloader import _set_SIGCHLD_handler
23 |
24 | if sys.version_info[0] == 2:
25 | import Queue as queue
26 | else:
27 | import queue
28 |
29 | def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
30 | global _use_shared_memory
31 | _use_shared_memory = True
32 | _set_worker_signal_handlers()
33 |
34 | torch.set_num_threads(1)
35 | torch.manual_seed(seed)
36 | while True:
37 | r = index_queue.get()
38 | if r is None:
39 | break
40 | idx, batch_indices = r
41 | try:
42 | idx_scale = 0
43 | if len(scale) > 1 and dataset.train:
44 | idx_scale = random.randrange(0, len(scale))
45 | dataset.set_scale(idx_scale)
46 |
47 | samples = collate_fn([dataset[i] for i in batch_indices])
48 | samples.append(idx_scale)
49 |
50 | except Exception:
51 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
52 | else:
53 | data_queue.put((idx, samples))
54 |
55 | class _MSDataLoaderIter(_DataLoaderIter):
56 | def __init__(self, loader):
57 | self.dataset = loader.dataset
58 | self.scale = loader.scale
59 | self.collate_fn = loader.collate_fn
60 | self.batch_sampler = loader.batch_sampler
61 | self.num_workers = loader.num_workers
62 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
63 | self.timeout = loader.timeout
64 | self.done_event = threading.Event()
65 |
66 | self.sample_iter = iter(self.batch_sampler)
67 |
68 | if self.num_workers > 0:
69 | self.worker_init_fn = loader.worker_init_fn
70 | self.index_queues = [
71 | multiprocessing.Queue() for _ in range(self.num_workers)
72 | ]
73 | self.worker_queue_idx = 0
74 | self.worker_result_queue = multiprocessing.SimpleQueue()
75 | self.batches_outstanding = 0
76 | self.worker_pids_set = False
77 | self.shutdown = False
78 | self.send_idx = 0
79 | self.rcvd_idx = 0
80 | self.reorder_dict = {}
81 |
82 | base_seed = torch.LongTensor(1).random_()[0]
83 | self.workers = [
84 | multiprocessing.Process(
85 | target=_ms_loop,
86 | args=(
87 | self.dataset,
88 | self.index_queues[i],
89 | self.worker_result_queue,
90 | self.collate_fn,
91 | self.scale,
92 | base_seed + i,
93 | self.worker_init_fn,
94 | i
95 | )
96 | )
97 | for i in range(self.num_workers)]
98 |
99 | if self.pin_memory or self.timeout > 0:
100 | self.data_queue = queue.Queue()
101 | if self.pin_memory:
102 | maybe_device_id = torch.cuda.current_device()
103 | else:
104 | # do not initialize cuda context if not necessary
105 | maybe_device_id = None
106 | self.worker_manager_thread = threading.Thread(
107 | target=_worker_manager_loop,
108 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
109 | maybe_device_id))
110 | self.worker_manager_thread.daemon = True
111 | self.worker_manager_thread.start()
112 | else:
113 | self.data_queue = self.worker_result_queue
114 |
115 | for w in self.workers:
116 | w.daemon = True # ensure that the worker exits on process exit
117 | w.start()
118 |
119 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers))
120 | _set_SIGCHLD_handler()
121 | self.worker_pids_set = True
122 |
123 | # prime the prefetch loop
124 | for _ in range(2 * self.num_workers):
125 | self._put_indices()
126 |
127 | class MSDataLoader(DataLoader):
128 | def __init__(
129 | self, args, dataset, batch_size=1, shuffle=False,
130 | sampler=None, batch_sampler=None,
131 | collate_fn=default_collate, pin_memory=False, drop_last=False,
132 | timeout=0, worker_init_fn=None):
133 |
134 | super(MSDataLoader, self).__init__(
135 | dataset, batch_size=batch_size, shuffle=shuffle,
136 | sampler=sampler, batch_sampler=batch_sampler,
137 | num_workers=args.n_threads, collate_fn=collate_fn,
138 | pin_memory=pin_memory, drop_last=drop_last,
139 | timeout=timeout, worker_init_fn=worker_init_fn)
140 |
141 | self.scale = args.scale
142 |
143 | def __iter__(self):
144 | return _MSDataLoaderIter(self)
145 |
--------------------------------------------------------------------------------
/code/dataset/test/data/norain-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/data/norain-1.png
--------------------------------------------------------------------------------
/code/dataset/test/data/norain-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/data/norain-2.png
--------------------------------------------------------------------------------
/code/dataset/test/label/norain-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/label/norain-1.png
--------------------------------------------------------------------------------
/code/dataset/test/label/norain-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/test/label/norain-2.png
--------------------------------------------------------------------------------
/code/dataset/train/data/norain-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/data/norain-1.png
--------------------------------------------------------------------------------
/code/dataset/train/data/norain-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/data/norain-2.png
--------------------------------------------------------------------------------
/code/dataset/train/label/norain-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/label/norain-1.png
--------------------------------------------------------------------------------
/code/dataset/train/label/norain-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/code/dataset/train/label/norain-2.png
--------------------------------------------------------------------------------
/code/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import matplotlib.pyplot as plt
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | # import SSIM
14 | class Loss(nn.modules.loss._Loss):
15 | def __init__(self, args, ckp):
16 | super(Loss, self).__init__()
17 | print('Preparing loss function:')
18 |
19 | self.n_GPUs = args.n_GPUs
20 | self.loss = []
21 | self.loss_module = nn.ModuleList()
22 | for loss in args.loss.split('+'):
23 | weight, loss_type = loss.split('*')
24 | if loss_type == 'MSE':
25 | loss_function = nn.MSELoss()
26 | elif loss_type == 'L1':
27 | loss_function = nn.L1Loss()
28 | elif loss_type.find('VGG') >= 0:
29 | module = import_module('loss.vgg')
30 | loss_function = getattr(module, 'VGG')(
31 | loss_type[3:],
32 | rgb_range=args.rgb_range
33 | )
34 | elif loss_type.find('GAN') >= 0:
35 | module = import_module('loss.adversarial')
36 | loss_function = getattr(module, 'Adversarial')(
37 | args,
38 | loss_type
39 | )
40 | elif loss_type.find('joint') >= 0:
41 | module = import_module('loss.joint')
42 | loss_function = getattr(module, 'Joint')()
43 | elif loss_type.find('SSIM') >= 0:
44 | module = import_module('loss.ssim')
45 | loss_function = getattr(module, 'SSIM')()
46 |
47 | self.loss.append({
48 | 'type': loss_type,
49 | 'weight': float(weight),
50 | 'function': loss_function}
51 | )
52 | if loss_type.find('GAN') >= 0:
53 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
54 |
55 | if len(self.loss) > 1:
56 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
57 |
58 | for l in self.loss:
59 | if l['function'] is not None:
60 | print('{:.3f} * {}'.format(l['weight'], l['type']))
61 | self.loss_module.append(l['function'])
62 |
63 | self.log = torch.Tensor()
64 |
65 | device = torch.device('cpu' if args.cpu else 'cuda')
66 | self.loss_module.to(device)
67 | if args.precision == 'half': self.loss_module.half()
68 | if not args.cpu and args.n_GPUs > 1:
69 | self.loss_module = nn.DataParallel(
70 | #self.loss_module, range(args.n_GPUs)
71 | self.loss_module, device_ids=[0]
72 | )
73 |
74 | if args.load != '.': self.load(ckp.dir, cpu=args.cpu)
75 |
76 | def forward(self, sr, hr, lr=None, detect_map=None):
77 | losses = []
78 | for i, l in enumerate(self.loss):
79 | if l['function'] is not None:
80 |
81 | if str(lr)!='None':
82 | loss = l['function'](sr, hr, lr, detect_map)
83 | effective_loss = l['weight'] * loss
84 | losses.append(effective_loss)
85 | self.log[-1, i] += effective_loss.item()
86 | else:
87 | loss = l['function'](sr, hr)
88 | effective_loss = l['weight'] * loss
89 | losses.append(effective_loss)
90 | self.log[-1, i] += effective_loss.item()
91 |
92 | elif l['type'] == 'DIS':
93 | self.log[-1, i] += self.loss[i - 1]['function'].loss
94 |
95 | loss_sum = sum(losses)
96 | if len(self.loss) > 1:
97 | self.log[-1, -1] += loss_sum.item()
98 |
99 | return loss_sum
100 |
101 | def step(self):
102 | for l in self.get_loss_module():
103 | if hasattr(l, 'scheduler'):
104 | l.scheduler.step()
105 |
106 | def start_log(self):
107 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
108 |
109 | def end_log(self, n_batches):
110 | self.log[-1].div_(n_batches)
111 |
112 | def display_loss(self, batch):
113 | n_samples = batch + 1
114 | log = []
115 | for l, c in zip(self.loss, self.log[-1]):
116 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
117 |
118 | return ''.join(log)
119 |
120 | def plot_loss(self, apath, epoch):
121 | axis = np.linspace(1, epoch, epoch)
122 | for i, l in enumerate(self.loss):
123 | # j = i
124 | # if i == len(self.loss)-1:
125 | # break
126 | label = '{} Loss'.format(l['type'])
127 | fig = plt.figure()
128 | plt.title(label)
129 | plt.plot(axis, self.log[:, i].numpy(), label=label)
130 | plt.legend()
131 | plt.xlabel('Epochs')
132 | plt.ylabel('Loss')
133 | plt.grid(True)
134 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))
135 | plt.close(fig)
136 |
137 | def get_loss_module(self):
138 | if self.n_GPUs == 1:
139 | return self.loss_module
140 | else:
141 | return self.loss_module.module
142 |
143 | def save(self, apath):
144 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
145 | torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
146 |
147 | def load(self, apath, cpu=False):
148 | if cpu:
149 | kwargs = {'map_location': lambda storage, loc: storage}
150 | else:
151 | kwargs = {}
152 |
153 | self.load_state_dict(torch.load(
154 | os.path.join(apath, 'loss.pt'),
155 | **kwargs
156 | ))
157 | self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
158 | for l in self.loss_module.module:
159 | if hasattr(l, 'scheduler'):
160 | for _ in range(len(self.log)): l.scheduler.step()
161 |
162 |
--------------------------------------------------------------------------------
/code/loss/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 1
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | (_, channel, _, _) = img1.size()
49 |
50 | if channel == self.channel and self.window.data.type() == img1.data.type():
51 | window = self.window
52 | else:
53 | window = create_window(self.window_size, channel)
54 |
55 | if img1.is_cuda:
56 | window = window.cuda(img1.get_device())
57 | window = window.type_as(img1)
58 |
59 | self.window = window
60 | self.channel = channel
61 |
62 |
63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
64 |
65 | def ssim(img1, img2, window_size = 11, size_average = True):
66 | (_, channel, _, _) = img1.size()
67 | window = create_window(window_size, channel)
68 |
69 | if img1.is_cuda:
70 | window = window.cuda(img1.get_device())
71 | window = window.type_as(img1)
72 |
73 | return _ssim(img1, img2, window, window_size, channel, size_average)
74 |
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import utility
4 | import data
5 | import model
6 | import loss
7 | from option import args
8 | from trainer import Trainer
9 | import multiprocessing
10 | import time
11 |
12 | def print_network(net):
13 | num_params = 0
14 | for param in net.parameters():
15 | num_params += param.numel()
16 | print('Total number of parameters: %d' % num_params)
17 |
18 | if __name__ == '__main__':
19 | torch.manual_seed(args.seed)
20 | checkpoint = utility.checkpoint(args)
21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
22 |
23 | seed = 1334
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed(seed)
26 |
27 | if checkpoint.ok:
28 | loader = data.Data(args)
29 | model = model.Model(args, checkpoint)
30 | print_network(model)
31 | loss = loss.Loss(args, checkpoint) if not args.test_only else None
32 | t = Trainer(args, loader, model, loss, checkpoint)
33 | # print('==================')
34 | while not t.terminate():
35 | # print('======++++++++++')
36 | t.train()
37 | t.test()
38 | checkpoint.done()
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/code/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args, ckp):
10 | super(Model, self).__init__()
11 | print('Making model...')
12 |
13 | self.scale = args.scale
14 | self.idx_scale = 0
15 | self.self_ensemble = args.self_ensemble
16 | self.chop = args.chop
17 | self.precision = args.precision
18 | self.cpu = args.cpu
19 | self.device = torch.device('cpu' if args.cpu else 'cuda')
20 | self.n_GPUs = args.n_GPUs
21 | self.save_models = args.save_models
22 |
23 | module = import_module('model.' + args.model.lower())
24 | self.model = module.make_model(args).to(self.device)
25 | if args.precision == 'half': self.model.half()
26 |
27 | if not args.cpu and args.n_GPUs > 1:
28 | #self.model = nn.DataParallel(self.model, range(args.n_GPUs))
29 | self.model = nn.DataParallel(self.model, device_ids=[0])
30 |
31 | self.load(
32 | ckp.dir,
33 | pre_train=args.pre_train,
34 | resume=args.resume,
35 | cpu=args.cpu
36 | )
37 | print(self.model, file=ckp.log_file)
38 |
39 | def forward(self, x, idx_scale):
40 | self.idx_scale = idx_scale
41 | target = self.get_model()
42 | if hasattr(target, 'set_scale'):
43 | target.set_scale(idx_scale)
44 |
45 | if self.self_ensemble and not self.training:
46 | if self.chop:
47 | forward_function = self.forward_chop
48 | else:
49 | forward_function = self.model.forward
50 |
51 | return self.forward_x8(x, forward_function)
52 | elif self.chop and not self.training:
53 | return self.forward_chop(x)
54 | else:
55 | return self.model(x)
56 |
57 | def get_model(self):
58 | if self.n_GPUs == 1:
59 | return self.model
60 | else:
61 | return self.model.module
62 |
63 | def state_dict(self, **kwargs):
64 | target = self.get_model()
65 | return target.state_dict(**kwargs)
66 |
67 | def save(self, apath, epoch, is_best=False):
68 | target = self.get_model()
69 | torch.save(
70 | target.state_dict(),
71 | os.path.join(apath, 'model', 'model_latest.pt')
72 | )
73 | if is_best:
74 | torch.save(
75 | target.state_dict(),
76 | os.path.join(apath, 'model', 'model_best.pt')
77 | )
78 |
79 | if self.save_models:
80 | torch.save(
81 | target.state_dict(),
82 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
83 | )
84 |
85 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
86 | if cpu:
87 | kwargs = {'map_location': lambda storage, loc: storage}
88 | else:
89 | kwargs = {}
90 |
91 | if resume == -1:
92 | print('-1')
93 | self.get_model().load_state_dict(
94 | torch.load(
95 | os.path.join(apath, 'model', 'model_latest.pt'),
96 | **kwargs
97 | ),
98 | strict=False
99 | )
100 | elif resume == 0:
101 | print('rest')
102 | if pre_train != '.':
103 | print('Loading model from {}'.format(pre_train))
104 | self.get_model().load_state_dict(
105 | torch.load(pre_train, **kwargs),
106 | strict=False
107 | )
108 | else:
109 | print('specific')
110 | self.get_model().load_state_dict(
111 | torch.load(
112 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
113 | **kwargs
114 | ),
115 | strict=False
116 | )
117 |
118 | def forward_chop(self, x, shave=10, min_size=160000):
119 | scale = self.scale[self.idx_scale]
120 | n_GPUs = min(self.n_GPUs, 4)
121 | b, c, h, w = x.size()
122 | h_half, w_half = h // 2, w // 2
123 | h_size, w_size = h_half + shave, w_half + shave
124 | lr_list = [
125 | x[:, :, 0:h_size, 0:w_size],
126 | x[:, :, 0:h_size, (w - w_size):w],
127 | x[:, :, (h - h_size):h, 0:w_size],
128 | x[:, :, (h - h_size):h, (w - w_size):w]]
129 |
130 | if w_size * h_size < min_size:
131 | sr_list = []
132 | for i in range(0, 4, n_GPUs):
133 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
134 | sr_batch = self.model(lr_batch)
135 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
136 | else:
137 | sr_list = [
138 | self.forward_chop(patch, shave=shave, min_size=min_size) \
139 | for patch in lr_list
140 | ]
141 |
142 | h, w = scale * h, scale * w
143 | h_half, w_half = scale * h_half, scale * w_half
144 | h_size, w_size = scale * h_size, scale * w_size
145 | shave *= scale
146 |
147 | output = x.new(b, c, h, w)
148 | output[:, :, 0:h_half, 0:w_half] \
149 | = sr_list[0][:, :, 0:h_half, 0:w_half]
150 | output[:, :, 0:h_half, w_half:w] \
151 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
152 | output[:, :, h_half:h, 0:w_half] \
153 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
154 | output[:, :, h_half:h, w_half:w] \
155 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
156 |
157 | return output
158 |
159 | def forward_x8(self, x, forward_function):
160 | def _transform(v, op):
161 | if self.precision != 'single': v = v.float()
162 |
163 | v2np = v.data.cpu().numpy()
164 | if op == 'v':
165 | tfnp = v2np[:, :, :, ::-1].copy()
166 | elif op == 'h':
167 | tfnp = v2np[:, :, ::-1, :].copy()
168 | elif op == 't':
169 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
170 |
171 | ret = torch.Tensor(tfnp).to(self.device)
172 | if self.precision == 'half': ret = ret.half()
173 |
174 | return ret
175 |
176 | lr_list = [x]
177 | for tf in 'v', 'h', 't':
178 | lr_list.extend([_transform(t, tf) for t in lr_list])
179 |
180 | sr_list = [forward_function(aug) for aug in lr_list]
181 | for i in range(len(sr_list)):
182 | if i > 3:
183 | sr_list[i] = _transform(sr_list[i], 't')
184 | if i % 4 > 1:
185 | sr_list[i] = _transform(sr_list[i], 'h')
186 | if (i % 4) % 2 == 1:
187 | sr_list[i] = _transform(sr_list[i], 'v')
188 |
189 | output_cat = torch.cat(sr_list, dim=0)
190 | output = output_cat.mean(dim=0, keepdim=True)
191 |
192 | return output
193 |
194 |
--------------------------------------------------------------------------------
/code/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from torch.autograd import Variable
8 |
9 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
10 | return nn.Conv2d(
11 | in_channels, out_channels, kernel_size,
12 | padding=(kernel_size//2), bias=bias)
13 |
14 | class MeanShift2(nn.Conv2d):
15 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
16 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
17 | std = torch.Tensor(rgb_std)
18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
19 | self.weight.data.div_(std.view(3, 1, 1, 1))
20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
21 | self.bias.data.div_(std)
22 | self.requires_grad = False
23 |
24 | class MeanShift(nn.Conv2d):
25 | def __init__(
26 | self, rgb_range,
27 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
28 |
29 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
30 | std = torch.Tensor(rgb_std)
31 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
32 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
33 | for p in self.parameters():
34 | p.requires_grad = False
35 |
36 | class BasicBlock(nn.Sequential):
37 | def __init__(
38 | self, in_channels, out_channels, kernel_size, stride=1, bias=False,
39 | bn=True, act=nn.ReLU(True)):
40 |
41 | m = [nn.Conv2d(
42 | in_channels, out_channels, kernel_size,
43 | padding=(kernel_size//2), stride=stride, bias=bias)
44 | ]
45 | if bn: m.append(nn.BatchNorm2d(out_channels))
46 | if act is not None: m.append(act)
47 | super(BasicBlock, self).__init__(*m)
48 |
49 | class ResBlock(nn.Module):
50 | def __init__(
51 | self, conv, n_feats, kernel_size,
52 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
53 |
54 | super(ResBlock, self).__init__()
55 | m = []
56 | for i in range(2):
57 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
58 | if bn: m.append(nn.BatchNorm2d(n_feats))
59 | if i == 0: m.append(act)
60 |
61 | self.body = nn.Sequential(*m)
62 | self.res_scale = res_scale
63 |
64 | def forward(self, x):
65 | res = self.body(x).mul(self.res_scale)
66 | res += x
67 |
68 | return res
69 |
70 | class Upsampler(nn.Sequential):
71 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
72 |
73 | m = []
74 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
75 | for _ in range(int(math.log(scale, 2))):
76 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
77 | m.append(nn.PixelShuffle(2))
78 | if bn: m.append(nn.BatchNorm2d(n_feats))
79 |
80 | if act == 'relu':
81 | m.append(nn.ReLU(True))
82 | elif act == 'prelu':
83 | m.append(nn.PReLU(n_feats))
84 |
85 | elif scale == 3:
86 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
87 | m.append(nn.PixelShuffle(3))
88 | if bn: m.append(nn.BatchNorm2d(n_feats))
89 |
90 | if act == 'relu':
91 | m.append(nn.ReLU(True))
92 | elif act == 'prelu':
93 | m.append(nn.PReLU(n_feats))
94 | else:
95 | raise NotImplementedError
96 |
97 | super(Upsampler, self).__init__(*m)
98 |
99 | class SELayer(nn.Module):
100 | def __init__(self, channel, reduction=16):
101 | super(SELayer, self).__init__()
102 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
103 | self.fc = nn.Sequential(
104 | nn.Linear(channel, channel // reduction, bias=False),
105 | nn.ReLU(inplace=True),
106 | nn.Linear(channel // reduction, channel, bias=False),
107 | nn.Sigmoid()
108 | )
109 |
110 | def forward(self, x):
111 | b, c, _, _ = x.size()
112 | y = self.avg_pool(x).view(b, c)
113 | y = self.fc(y).view(b, c, 1, 1)
114 | return x * y.expand_as(x)
115 |
116 | def conv3x3(in_planes, out_planes, stride=1, groups=1):
117 | """3x3 convolution with padding"""
118 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
119 | padding=1, groups=groups, bias=False)
120 |
121 |
122 | def conv1x1(in_planes, out_planes, stride=1):
123 | """1x1 convolution"""
124 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
125 |
126 |
127 | class SEModule(nn.Module):
128 | def __init__(self, channels, reduction=16):
129 | super(SEModule, self).__init__()
130 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
131 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
132 | self.relu = nn.ReLU(inplace=True)
133 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
134 | self.sigmoid = nn.Sigmoid()
135 |
136 | def forward(self, input):
137 | x = self.avg_pool(input)
138 | x = self.fc1(x)
139 | x = self.relu(x)
140 | x = self.fc2(x)
141 | x = self.sigmoid(x)
142 | return input * x
143 | #####################################################
144 | Operations = [
145 | 'sep_conv_1x1',
146 | 'sep_conv_3x3',
147 | 'sep_conv_5x5',
148 | 'sep_conv_7x7',
149 | 'dil_conv_3x3',
150 | 'dil_conv_5x5',
151 | 'dil_conv_7x7',
152 | 'avg_pool_3x3'
153 | ]
154 |
155 | OPS = {
156 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
157 | 'sep_conv_1x1' : lambda C, stride, affine: SepConv(C, C, 1, stride, 0, affine=affine),
158 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
159 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
160 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
161 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
162 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
163 | 'dil_conv_7x7' : lambda C, stride, affine: DilConv(C, C, 7, stride, 6, 2, affine=affine),
164 | }
165 |
166 | class ReLUConvBN(nn.Module):
167 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
168 | super(ReLUConvBN, self).__init__()
169 | self.op = nn.Sequential(
170 | nn.ReLU(inplace=False),
171 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
172 | nn.BatchNorm2d(C_out, affine=affine))
173 |
174 | def forward(self, x):
175 | return self.op(x)
176 |
177 | class ReLUConv(nn.Module):
178 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
179 | super(ReLUConv, self).__init__()
180 | self.op = nn.Sequential(
181 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
182 | nn.ReLU(inplace=False))
183 |
184 | def forward(self, x):
185 | return self.op(x)
186 |
187 | class DilConv(nn.Module):
188 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
189 | super(DilConv, self).__init__()
190 | self.op = nn.Sequential(
191 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
192 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),)
193 |
194 | def forward(self, x):
195 | return self.op(x)
196 |
197 | class ResBlock2(nn.Module):
198 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
199 | super(ResBlock, self).__init__()
200 | self.conv1 = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False)
201 | self.conv2 = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False)
202 | self.relu = nn.ReLU(inplace=False)
203 |
204 | def forward(self, x):
205 | residual = x
206 | out = self.relu(self.conv1(x))
207 | out = self.conv2(out)
208 | out = out + residual
209 | out = self.relu(out)
210 | return out
211 |
212 | class ResBlock(nn.Module):
213 | def __init__(
214 | self, conv, n_feats, kernel_size,
215 | bias=True, bn=False, act=nn.PReLU(), res_scale=1):
216 |
217 | super(ResBlock, self).__init__()
218 | m = []
219 | for i in range(2):
220 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
221 | if bn:
222 | m.append(nn.BatchNorm2d(n_feats))
223 | if i == 0:
224 | m.append(act)
225 |
226 | self.body = nn.Sequential(*m)
227 | self.res_scale = res_scale
228 |
229 | def forward(self, x):
230 | res = self.body(x).mul(self.res_scale)
231 | res += x
232 |
233 | return res
234 |
235 | class SepConv(nn.Module):
236 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
237 | super(SepConv, self).__init__()
238 | self.op = nn.Sequential(
239 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
240 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
241 | nn.ReLU(inplace=False),
242 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
243 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),)
244 |
245 | def forward(self, x):
246 | return self.op(x)
247 |
248 |
--------------------------------------------------------------------------------
/code/model/hct-ffn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | from model import common
5 | from util.rlutrans import Mlp, TransBlock
6 | from util.tools import extract_image_patches, reduce_mean, reduce_sum, same_padding, reverse_patches
7 |
8 | def make_model(args, parent=False):
9 | return Rainnet(args)
10 |
11 | class OperationLayer(nn.Module):
12 | def __init__(self, C, stride):
13 | super(OperationLayer, self).__init__()
14 | self._ops = nn.ModuleList()
15 | for o in common.Operations:
16 | op = common.OPS[o](C, stride, False)
17 | self._ops.append(op)
18 |
19 | self._out = nn.Sequential(nn.Conv2d(C * len(common.Operations), C, 1, padding=0, bias=False), nn.ReLU())
20 |
21 | def forward(self, x, weights):
22 | weights = weights.transpose(1, 0)
23 | states = []
24 | for w, op in zip(weights, self._ops):
25 | states.append(op(x) * w.view([-1, 1, 1, 1]))
26 | return self._out(torch.cat(states[:], dim=1))
27 |
28 | class GroupOLs(nn.Module):
29 | def __init__(self, steps, C):
30 | super(GroupOLs, self).__init__()
31 | self.preprocess = common.ReLUConv(C, C, 1, 1, 0, affine=False)
32 | self._steps = steps
33 | self._ops = nn.ModuleList()
34 | self.relu = nn.ReLU()
35 | stride = 1
36 |
37 | for _ in range(self._steps):
38 | op = OperationLayer(C, stride)
39 | self._ops.append(op)
40 |
41 | def forward(self, s0, weights):
42 | s0 = self.preprocess(s0)
43 | for i in range(self._steps):
44 | res = s0
45 | s0 = self._ops[i](s0, weights[:, i, :])
46 | s0 = self.relu(s0 + res)
47 | return s0
48 |
49 | class OALayer(nn.Module):
50 | def __init__(self, channel, k, num_ops):
51 | super(OALayer, self).__init__()
52 | self.k = k
53 | self.num_ops = num_ops
54 | self.output = k * num_ops
55 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
56 | self.ca_fc = nn.Sequential(
57 | nn.Linear(channel, self.output * 2),
58 | nn.ReLU(),
59 | nn.Linear(self.output * 2, self.k * self.num_ops))
60 |
61 | def forward(self, x):
62 | y = self.avg_pool(x)
63 | y = y.view(x.size(0), -1)
64 | y = self.ca_fc(y)
65 | y = y.view(-1, self.k, self.num_ops)
66 | return y
67 |
68 | def get_residue(tensor , r_dim = 1):
69 | """
70 | return residue_channle (RGB)
71 | """
72 | # res_channel = []
73 | max_channel = torch.max(tensor, dim=r_dim, keepdim=True) # keepdim
74 | min_channel = torch.min(tensor, dim=r_dim, keepdim=True)
75 | res_channel = max_channel[0] - min_channel[0]
76 | return res_channel
77 |
78 | class convd(nn.Module):
79 | def __init__(self, inputchannel, outchannel, kernel_size, stride):
80 | super(convd, self).__init__()
81 | self.relu = nn.ReLU()
82 | self.padding = nn.ReflectionPad2d(kernel_size//2)
83 | self.conv = nn.Conv2d(inputchannel, outchannel, kernel_size, stride)
84 | self.ins = nn.InstanceNorm2d(outchannel, affine=True)
85 |
86 | def forward(self, x):
87 | x = self.conv(self.padding(x))
88 | # x= self.ins(x)
89 | x = self.relu(x)
90 | return x
91 |
92 | class Upsample(nn.Module):
93 | def __init__(self, in_channels, out_channels, kernel_size, stride):
94 | super(Upsample, self).__init__()
95 | reflection_padding = kernel_size // 2
96 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
97 | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride)
98 | self.relu = nn.ReLU()
99 |
100 | def forward(self, x, y):
101 | out = self.reflection_pad(x)
102 | out = self.conv2d(out)
103 | out = self.relu(out)
104 | out = F.interpolate(out, y.size()[2:])
105 | return out
106 |
107 | class RB(nn.Module):
108 | def __init__(self, n_feats, nm='in'):
109 | super(RB, self).__init__()
110 | module_body = []
111 | for i in range(2):
112 | module_body.append(nn.Conv2d(n_feats, n_feats, kernel_size=3, stride=1, padding=1, bias=True))
113 | module_body.append(nn.ReLU())
114 | self.module_body = nn.Sequential(*module_body)
115 | self.relu = nn.ReLU()
116 | self.se = common.SELayer(n_feats, 1)
117 |
118 | def forward(self, x):
119 | res = self.module_body(x)
120 | res = self.se(res)
121 | res += x
122 | return res
123 |
124 | class RIR(nn.Module):
125 | def __init__(self, n_feats, n_blocks, nm='in'):
126 | super(RIR, self).__init__()
127 | module_body = [
128 | RB(n_feats) for _ in range(n_blocks)
129 | ]
130 | module_body.append(nn.Conv2d(n_feats, n_feats, kernel_size=3, stride=1, padding=1, bias=True))
131 | self.module_body = nn.Sequential(*module_body)
132 | self.relu = nn.ReLU()
133 |
134 | def forward(self, x):
135 | res = self.module_body(x)
136 | res += x
137 | return self.relu(res)
138 |
139 | class res_ch(nn.Module):
140 | def __init__(self, n_feats, blocks=2):
141 | super(res_ch,self).__init__()
142 | self.conv_init1 = convd(3, n_feats//2, 3, 1)
143 | self.conv_init2 = convd(n_feats//2, n_feats, 3, 1)
144 | self.extra = RIR(n_feats, n_blocks=blocks)
145 |
146 | def forward(self,x):
147 | x = self.conv_init2(self.conv_init1(x))
148 | x = self.extra(x)
149 | return x
150 |
151 | class Fuse(nn.Module):
152 | def __init__(self, inchannel=64, outchannel=64):
153 | super(Fuse, self).__init__()
154 | self.up = Upsample(inchannel, outchannel, 3, 2)
155 | self.conv = convd(outchannel, outchannel, 3, 1)
156 | self.rb = RB(outchannel)
157 | self.relu = nn.ReLU()
158 |
159 | def forward(self, x, y):
160 | x = self.up(x, y)
161 | # x = F.interpolate(x, y.size()[2:])
162 | # y1 = torch.cat((x, y), dim=1)
163 | y = x+y
164 | # y = self.pf(y1) + y
165 |
166 | return self.relu(self.rb(y))
167 |
168 | class Prior_Sp(nn.Module):
169 | def __init__(self, in_dim=32):
170 | super(Prior_Sp, self).__init__()
171 | self.chanel_in = in_dim
172 |
173 | self.query_conv = nn.Conv2d(in_dim, in_dim, 3, 1, 1, bias=True)
174 | self.key_conv = nn.Conv2d(in_dim, in_dim, 3, 1, 1, bias=True)
175 |
176 | self.gamma1 = nn.Conv2d(in_dim * 2, 2, 3, 1, 1, bias=True)
177 | # self.gamma1 = nn.Parameter(torch.zeros(1))
178 | self.gamma2 = nn.Conv2d(in_dim * 2, 2, 3, 1, 1, bias=True)
179 | # self.softmax = nn.Softmax(dim=-1)
180 | self.sig = nn.Sigmoid()
181 |
182 | def forward(self,x, prior):
183 |
184 | x_q = self.query_conv(x)
185 | prior_k = self.key_conv(prior)
186 | energy = x_q * prior_k
187 | attention = self.sig(energy)
188 | # print(attention.size(),x.size())
189 | attention_x = x * attention
190 | attention_p = prior * attention
191 |
192 | x_gamma = self.gamma1(torch.cat((x, attention_x),dim=1))
193 | x_out = x * x_gamma[:, [0], :, :] + attention_x * x_gamma[:, [1], :, :]
194 |
195 | p_gamma = self.gamma2(torch.cat((prior, attention_p),dim=1))
196 | prior_out = prior * p_gamma[:, [0], :, :] + attention_p * p_gamma[:, [1], :, :]
197 |
198 | return x_out, prior_out
199 |
200 | class DaMoE(nn.Module):
201 | def __init__(self, n_feats,layer_num ,steps=4):
202 | super(DaMoE,self).__init__()
203 |
204 | # fuse res
205 | self.prior = Prior_Sp()
206 | self.fuse_res = convd(n_feats*2, n_feats, 3, 1)
207 | self._C = n_feats
208 | self.num_ops = len(common.Operations)
209 | self._layer_num = layer_num
210 | self._steps = steps
211 |
212 | self.layers = nn.ModuleList()
213 | for _ in range(self._layer_num):
214 | attention = OALayer(self._C, self._steps, self.num_ops)
215 | self.layers += [attention]
216 | layer = GroupOLs(steps, self._C)
217 | self.layers += [layer]
218 |
219 | def forward(self, x, res_feats):
220 |
221 | x_p, res_feats_p = self.prior(x, res_feats)
222 | x_s = torch.cat((x_p, res_feats_p),dim=1)
223 | x1_i = self.fuse_res(x_s)
224 | for _, layer in enumerate(self.layers):
225 | if isinstance(layer, OALayer):
226 | weights = layer(x1_i)
227 | weights = F.softmax(weights, dim=-1)
228 | else:
229 | x1_i = layer(x1_i, weights)
230 |
231 | return x1_i
232 |
233 | class BaViT(nn.Module):
234 | def __init__(self, n_feats, blocks=2):
235 | super(BaViT, self).__init__()
236 | # fuse res
237 | self.prior = Prior_Sp()
238 | self.fuse_res = convd(n_feats * 2, n_feats, 3, 1)
239 |
240 | self.attention = TransBlock(n_feats, dim=n_feats * 9)
241 | self.c2 = common.default_conv(n_feats, n_feats, 3)
242 | # self.attention2 = TransBlock(n_feat=n_feat, dim=n_feat*9)
243 |
244 | def forward(self, x, res_feats):
245 | x_p, res_feats_p = self.prior(x, res_feats)
246 | x_s = torch.cat((x_p, res_feats_p), dim=1)
247 | x1_init = self.fuse_res(x_s)
248 |
249 | y8 = x1_init
250 | b, c, h, w = y8.shape
251 | y8 = extract_image_patches(y8, ksizes=[3, 3],
252 | strides=[1, 1],
253 | rates=[1, 1],
254 | padding='same') # 16*2304*576
255 | y8 = y8.permute(0, 2, 1)
256 | out_transf1 = self.attention(y8)
257 | out_transf1 = self.attention(out_transf1)
258 | out_transf1 = self.attention(out_transf1)
259 | out1 = out_transf1.permute(0, 2, 1)
260 | out1 = reverse_patches(out1, (h, w), (3, 3), 1, 1)
261 | y9 = self.c2(out1)
262 |
263 | return y9
264 |
265 | class Rainnet(nn.Module):
266 | def __init__(self,args):
267 | super(Rainnet,self).__init__()
268 | n_feats = args.n_feats
269 | blocks = args.n_resblocks
270 |
271 | self.conv_init1 = convd(3, n_feats//2, 3, 1)
272 | self.conv_init2 = convd(n_feats//2, n_feats, 3, 1)
273 | self.res_extra1 = res_ch(n_feats, blocks)
274 | self.sub1 = DaMoE(n_feats, 1)
275 | self.res_extra2 = res_ch(n_feats, blocks)
276 | self.sub2 = BaViT(n_feats, 1)
277 | self.res_extra3 = res_ch(n_feats, blocks)
278 | self.sub3 = DaMoE(n_feats, 1)
279 |
280 | self.ag1 = convd(n_feats*2,n_feats,3,1)
281 | self.ag2 = convd(n_feats*3,n_feats,3,1)
282 | self.ag2_en = convd(n_feats*2, n_feats, 3, 1)
283 | self.ag_en = convd(n_feats*3, n_feats, 3, 1)
284 |
285 | self.output1 = nn.Conv2d(n_feats, 3, 3, 1, padding=1)
286 | self.output2 = nn.Conv2d(n_feats, 3, 3, 1, padding=1)
287 | self.output3 = nn.Conv2d(n_feats, 3, 3, 1, padding=1)
288 |
289 | # self._initialize_weights()
290 |
291 | def forward(self,x):
292 |
293 | res_x = get_residue(x)
294 | x_init = self.conv_init2(self.conv_init1(x))
295 | x1 = self.sub1(x_init, self.res_extra1(torch.cat((res_x, res_x, res_x), dim=1))) #+ x # 1
296 | out1 = self.output1(x1)
297 | res_out1 = get_residue(out1)
298 | x2 = self.sub2(self.ag1(torch.cat((x1,x_init),dim=1)), self.res_extra2(torch.cat((res_out1, res_out1, res_out1), dim=1))) #+ x1 # 2
299 | x2_ = self.ag2_en(torch.cat([x2,x1], dim=1))
300 | out2 = self.output2(x2_)
301 | res_out2 = get_residue(out2)
302 | x3 = self.sub3(self.ag2(torch.cat((x2,x1,x_init),dim=1)), self.res_extra3(torch.cat((res_out2, res_out2, res_out2), dim=1))) #+ x2 # 3
303 | x3 = self.ag_en(torch.cat([x3,x2,x1],dim=1))
304 | out3 = self.output3(x3)
305 |
306 | return out3, out2, out1
307 |
308 | def _initialize_weights(self):
309 | for m in self.modules():
310 | if isinstance(m, nn.Conv2d):
311 | nn.init.normal_(m.weight, std=0.01)
312 | if m.bias is not None:
313 | nn.init.constant_(m.bias, 0)
314 | elif isinstance(m, nn.BatchNorm2d):
315 | nn.init.constant_(m.weight, 1)
316 | nn.init.constant_(m.bias, 0)
317 |
318 |
319 |
--------------------------------------------------------------------------------
/code/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import template
3 |
4 | parser = argparse.ArgumentParser(description='RCDNet')
5 |
6 | parser.add_argument('--debug', action='store_true',
7 | help='Enables debug mode')
8 | parser.add_argument('--template', default='.',
9 | help='You can set various templates in option.py')
10 |
11 | # Hardware specifications
12 | parser.add_argument('--n_threads', type=int, default=0,
13 | help='number of threads for data loading')
14 | parser.add_argument('--cpu', action='store_true',
15 | help='use cpu only')
16 | parser.add_argument('--n_GPUs', type=int, default=1,
17 | help='number of GPUs')
18 | parser.add_argument('--seed', type=int, default=1,
19 | help='random seed')
20 |
21 | # Data specifications
22 | parser.add_argument('--dir_data', type=str, default='../data',
23 | help='dataset directory')
24 | parser.add_argument('--dir_demo', type=str, default='../test',
25 | help='demo image directory')
26 | parser.add_argument('--data_train', type=str, default='RainHeavy', #'DIV2K',
27 | help='train dataset name')
28 | parser.add_argument('--data_test', type=str, default= 'RainHeavyTest', #'DIV2K',
29 | help='test dataset name')
30 | parser.add_argument('--data_range', type=str, default='1-20000/1-100',
31 | help='train/test data range')
32 | parser.add_argument('--ext', type=str, default='sep',
33 | help='dataset file extension')
34 | parser.add_argument('--scale', type=str, default='2',
35 | help='super resolution scale')
36 | parser.add_argument('--patch_size', type=int, default=64,
37 | help='output patch size')
38 | parser.add_argument('--rgb_range', type=int, default=255,
39 | help='maximum value of RGB')
40 | parser.add_argument('--n_colors', type=int, default=3,
41 | help='number of color channels to use')
42 | parser.add_argument('--chop', action='store_true',
43 | help='enable memory-efficient forward')
44 | parser.add_argument('--no_augment', action='store_true',
45 | help='do not use data augmentation')
46 |
47 | # Model specifications
48 | parser.add_argument('--model', default='OurNet',
49 | help='model name')
50 | parser.add_argument('--act', type=str, default='relu',
51 | help='activation function')
52 | parser.add_argument('--pre_train', type=str, default='.',
53 | help='pre-trained model directory')
54 | parser.add_argument('--extend', type=str, default='.',
55 | help='pre-trained model directory')
56 | parser.add_argument('--n_resblocks', type=int, default=3,
57 | help='number of residual blocks')
58 | parser.add_argument('--n_feats', type=int, default=32,
59 | help='number of feature maps')
60 | parser.add_argument('--res_scale', type=float, default=1,
61 | help='residual scaling')
62 | parser.add_argument('--shift_mean', default=True,
63 | help='subtract pixel mean from the input')
64 | parser.add_argument('--dilation', action='store_true',
65 | help='use dilated convolution')
66 | parser.add_argument('--precision', type=str, default='single',
67 | choices=('single', 'half'),
68 | help='FP precision for test (single | half)')
69 |
70 | # Training specifications
71 | parser.add_argument('--test_every', type=int, default=1500,
72 | help='do test per every N batches')
73 | parser.add_argument('--epochs', type=int, default=100,
74 | help='number of epochs to train')
75 | parser.add_argument('--batch_size', type=int, default=16,
76 | help='input batch size for training')
77 | parser.add_argument('--split_batch', type=int, default=1,
78 | help='split the batch into smaller chunks')
79 | parser.add_argument('--self_ensemble', action='store_true',
80 | help='use self-ensemble method for test')
81 | parser.add_argument('--test_only', action='store_true',
82 | help='set this option to test the model')
83 | parser.add_argument('--reset', action='store_true',
84 | help='reset the training')
85 | # Optimization specifications
86 | parser.add_argument('--lr', type=float, default=1e-3,
87 | help='learning rate')
88 | parser.add_argument('--lr_decay', type=int, default=25,
89 | help='learning rate decay per N epochs')
90 | parser.add_argument('--decay_type', type=str, default='step_100_150_200_230_260_280_300',#100_115_130_140_150_158_165_170_175_180
91 | help='learning rate decay type')
92 | parser.add_argument('--gamma', type=float, default=0.5,
93 | help='learning rate decay factor for step decay')
94 | parser.add_argument('--optimizer', default='ADAM',
95 | choices=('SGD', 'ADAM', 'RMSprop'),
96 | help='optimizer to use (SGD | ADAM | RMSprop)')
97 | parser.add_argument('--momentum', type=float, default=0.9,
98 | help='SGD momentum')
99 | parser.add_argument('--beta1', type=float, default=0.9,
100 | help='ADAM beta1')
101 | parser.add_argument('--beta2', type=float, default=0.999,
102 | help='ADAM beta2')
103 | parser.add_argument('--epsilon', type=float, default=1e-8,
104 | help='ADAM epsilon for numerical stability')
105 | parser.add_argument('--weight_decay', type=float, default=0,
106 | help='weight decay')
107 |
108 | # Loss specifications
109 | parser.add_argument('--loss', type=str, default='1*MSE',
110 | help='loss function configuration')
111 | parser.add_argument('--skip_threshold', type=float, default='1e6',
112 | help='skipping batch that has large error')
113 |
114 | # Log specifications
115 | parser.add_argument('--save', type=str, default='RCDNet_syn',
116 | help='file name to save')
117 | parser.add_argument('--load', type=str, default='.',
118 | help='file name to load')
119 | parser.add_argument('--resume', type=int, default=0,
120 | help='resume from specific checkpoint')
121 | parser.add_argument('--save_models', action='store_true',
122 | help='save all intermediate models')
123 | parser.add_argument('--print_every', type=int, default=100,
124 | help='how many batches to wait before logging training status')
125 | parser.add_argument('--save_results', action='store_true',
126 | help='save output results')
127 |
128 | args = parser.parse_args()
129 | template.set_template(args)
130 |
131 | args.scale = list(map(lambda x: int(x), args.scale.split('+')))
132 |
133 | if args.epochs == 0:
134 | args.epochs = 1e8
135 |
136 | for arg in vars(args):
137 | if vars(args)[arg] == 'True':
138 | vars(args)[arg] = True
139 | elif vars(args)[arg] == 'False':
140 | vars(args)[arg] = False
141 |
142 |
--------------------------------------------------------------------------------
/code/template.py:
--------------------------------------------------------------------------------
1 | def set_template(args):
2 | # Set the templates here
3 | if args.template.find('jpeg') >= 0:
4 | args.data_train = 'DIV2K_jpeg'
5 | args.data_test = 'DIV2K_jpeg'
6 | args.epochs = 200
7 | args.lr_decay = 100
8 |
9 | if args.template.find('EDSR_paper') >= 0:
10 | args.model = 'EDSR'
11 | args.n_resblocks = 32
12 | args.n_feats = 256
13 | args.res_scale = 0.1
14 |
15 | if args.template.find('MDSR') >= 0:
16 | args.model = 'MDSR'
17 | args.patch_size = 48
18 | args.epochs = 650
19 |
20 | if args.template.find('DDBPN') >= 0:
21 | args.model = 'DDBPN'
22 | args.patch_size = 128
23 | args.scale = '4'
24 |
25 | args.data_test = 'Set5'
26 |
27 | args.batch_size = 20
28 | args.epochs = 1000
29 | args.lr_decay = 500
30 | args.gamma = 0.1
31 | args.weight_decay = 1e-4
32 |
33 | args.loss = '1*MSE'
34 |
35 | if args.template.find('GAN') >= 0:
36 | args.epochs = 200
37 | args.lr = 5e-5
38 | args.lr_decay = 150
39 |
40 | if args.template.find('RCAN') >= 0:
41 | args.model = 'RCAN'
42 | args.n_resgroups = 10
43 | args.n_resblocks = 20
44 | args.n_feats = 64
45 | args.chop = True
46 |
47 |
--------------------------------------------------------------------------------
/code/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from decimal import Decimal
4 | import utility
5 | import IPython
6 | import torch
7 | from torch.autograd import Variable
8 | from tqdm import tqdm
9 | import scipy.io as sio
10 | import matplotlib
11 | import matplotlib.pyplot as plt
12 | import pylab
13 | import numpy as np
14 | from torchvision.transforms import ToTensor,ToPILImage
15 | class Trainer():
16 | def __init__(self, args, loader, my_model, my_loss, ckp):
17 | self.args = args
18 | self.scale = args.scale
19 | self.ckp = ckp
20 | self.loader_train = loader.loader_train
21 | self.loader_test = loader.loader_test
22 | self.model = my_model
23 | self.loss = my_loss
24 | self.optimizer = utility.make_optimizer(args, self.model)
25 | self.scheduler = utility.make_scheduler(args, self.optimizer)
26 | if self.args.load != '.':
27 | print(ckp.dir)
28 | assert os.path.exists(ckp.dir+'optimizer.pt')
29 | print('==============',ckp.dir+'optimizer.pt')
30 | self.optimizer.load_state_dict(
31 | torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
32 | )
33 | for _ in range(len(ckp.log)): self.scheduler.step()
34 |
35 | self.error_last = 1e8
36 |
37 | def train(self):
38 | # print('======>trian')
39 | self.scheduler.step()
40 | self.loss.step()
41 | epoch = self.scheduler.last_epoch + 1
42 | lr = self.scheduler.get_lr()[0]
43 | self.ckp.write_log(
44 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
45 | )
46 | self.loss.start_log()
47 | self.model.train()
48 |
49 | timer_data, timer_model = utility.timer(), utility.timer()
50 |
51 | for batch, (lr, hr, idx_scale) in enumerate(self.loader_train):
52 | lr, hr = self.prepare(lr, hr)
53 | timer_data.hold()
54 | timer_model.tic()
55 | self.model.zero_grad()
56 | self.optimizer.zero_grad()
57 | out3, out2, out1 = self.model(lr, idx_scale)
58 | loss = self.loss(out3, hr) + self.loss(out2, hr) + self.loss(out1, hr)
59 |
60 | if loss.item() < self.args.skip_threshold * self.error_last:
61 | loss.backward()
62 | ttt = 0
63 | self.optimizer.step()
64 | else:
65 | print('Skip this batch {}! (Loss: {})'.format(
66 | batch + 1, loss.item()
67 | ))
68 | timer_model.hold()
69 | if (batch + 1) % self.args.print_every == 0:
70 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
71 | (batch + 1) * self.args.batch_size,
72 | len(self.loader_train.dataset),
73 | self.loss.display_loss(batch),
74 | timer_model.release(),
75 | timer_data.release()))
76 | timer_data.tic()
77 |
78 | self.loss.end_log(len(self.loader_train))
79 | self.error_last = self.loss.log[-1, -1]
80 |
81 | def test(self):
82 | # print('=========eval')
83 | epoch = self.scheduler.last_epoch + 1
84 | self.ckp.write_log('\nEvaluation:')
85 | self.ckp.add_log(torch.zeros(1, len(self.scale)))
86 | self.model.eval()
87 |
88 | timer_test = utility.timer()
89 | with torch.no_grad():
90 | for idx_scale, scale in enumerate(self.scale):
91 | eval_acc = 0
92 | self.loader_test.dataset.set_scale(idx_scale)
93 | tqdm_test = tqdm(self.loader_test, ncols=80)
94 | for idx_img, (lr, hr, filename) in enumerate(tqdm_test):
95 | filename = filename[0]
96 | no_eval = (hr.nelement() == 1)
97 | if not no_eval:
98 | lr, hr = self.prepare(lr, hr)
99 | else:
100 | lr, = self.prepare(lr)
101 |
102 | sr,_,_ = self.model(lr, idx_scale)
103 | sr = utility.quantize(sr, self.args.rgb_range) # restored background at the last stage
104 | save_list = [sr]
105 | if not no_eval:
106 | eval_acc += utility.calc_psnr(
107 | sr, hr, scale, self.args.rgb_range,
108 | benchmark=self.loader_test.dataset.benchmark
109 | )
110 | save_list.extend([lr, hr])
111 |
112 | if self.args.save_results:
113 | self.ckp.save_results(filename, save_list, scale)
114 |
115 | self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test)
116 | best = self.ckp.log.max(0)
117 | self.ckp.write_log(
118 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
119 | self.args.data_test,
120 | scale,
121 | self.ckp.log[-1, idx_scale],
122 | best[0][idx_scale],
123 | best[1][idx_scale] + 1
124 | )
125 | )
126 |
127 | self.ckp.write_log(
128 | 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True
129 | )
130 | if not self.args.test_only:
131 | self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch))
132 |
133 | def prepare(self, *args):
134 | device = torch.device('cpu' if self.args.cpu else 'cuda:0')
135 | def _prepare(tensor):
136 | if self.args.precision == 'half': tensor = tensor.half()
137 | return tensor.to(device)
138 |
139 | return [_prepare(a) for a in args]
140 |
141 | def terminate(self):
142 | if self.args.test_only:
143 | self.test()
144 | return True
145 | else:
146 | epoch = self.scheduler.last_epoch + 1
147 | return epoch >= self.args.epochs
148 |
--------------------------------------------------------------------------------
/code/util/rlutrans.py:
--------------------------------------------------------------------------------
1 | from model import common
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | # from thop import profile
6 | from util.tools import extract_image_patches, reduce_mean, reduce_sum, same_padding, reverse_patches
7 | import pdb
8 | import math
9 |
10 | class Mlp(nn.Module):
11 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
12 | super().__init__()
13 | out_features = out_features or in_features
14 | hidden_features = hidden_features or in_features//4
15 | self.fc1 = nn.Linear(in_features, hidden_features)
16 | self.act = act_layer()
17 | self.fc2 = nn.Linear(hidden_features, out_features)
18 | self.drop = nn.Dropout(drop)
19 |
20 | def forward(self, x):
21 | x = self.fc1(x)
22 | x = self.act(x)
23 | x = self.drop(x)
24 | x = self.fc2(x)
25 | x = self.drop(x)
26 | return x
27 |
28 | class EffAttention(nn.Module):
29 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
30 | super().__init__()
31 | self.num_heads = num_heads
32 | head_dim = dim // num_heads
33 | self.scale = qk_scale or head_dim ** -0.5
34 |
35 | self.reduce = nn.Linear(dim, dim//2, bias=qkv_bias)
36 | self.qkv = nn.Linear(dim//2, dim//2 * 3, bias=qkv_bias)
37 | self.proj = nn.Linear(dim//2, dim)
38 | self.attn_drop = nn.Dropout(attn_drop)
39 |
40 | def forward(self, x):
41 | x = self.reduce(x)
42 | B, N, C = x.shape
43 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
44 | q, k, v = qkv[0], qkv[1], qkv[2]
45 |
46 | q_all = torch.split(q, math.ceil(N//16), dim=-2)
47 | k_all = torch.split(k, math.ceil(N//16), dim=-2)
48 | v_all = torch.split(v, math.ceil(N//16), dim=-2)
49 |
50 | output = []
51 | for q,k,v in zip(q_all, k_all, v_all):
52 | attn = (q @ k.transpose(-2, -1)) * self.scale #16*8*37*37
53 | attn = attn.softmax(dim=-1)
54 | attn = self.attn_drop(attn)
55 | trans_x = (attn @ v).transpose(1, 2) #.reshape(B, N, C)
56 | output.append(trans_x)
57 | x = torch.cat(output,dim=1)
58 | x = x.reshape(B,N,C)
59 | x = self.proj(x)
60 | return x
61 |
62 | class TransBlock(nn.Module):
63 | def __init__(
64 | self, n_feat = 64,dim=64, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
65 | drop_path=0., act_layer=nn.ReLU, norm_layer=nn.LayerNorm):
66 | super(TransBlock, self).__init__()
67 | self.dim = dim
68 | self.atten = EffAttention(self.dim, num_heads=8, qkv_bias=False, qk_scale=None, \
69 | attn_drop=0., proj_drop=0.)
70 | self.norm1 = nn.LayerNorm(self.dim)
71 | self.mlp = Mlp(in_features=dim, hidden_features=dim//4, act_layer=act_layer, drop=drop)
72 | self.norm2 = nn.LayerNorm(self.dim)
73 |
74 | def forward(self, x):
75 | B = x.shape[0]
76 |
77 | x = x + self.atten(self.norm1(x))
78 | x = x + self.mlp(self.norm2(x))
79 | return x
80 |
--------------------------------------------------------------------------------
/code/util/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch.nn.functional as F
7 |
8 | def normalize(x):
9 | return x.mul_(2).add_(-1)
10 |
11 | def same_padding(images, ksizes, strides, rates):
12 | assert len(images.size()) == 4
13 | batch_size, channel, rows, cols = images.size()
14 | out_rows = (rows + strides[0] - 1) // strides[0]
15 | out_cols = (cols + strides[1] - 1) // strides[1]
16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1
17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1
18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
20 | # Pad the input
21 | padding_top = int(padding_rows / 2.)
22 | padding_left = int(padding_cols / 2.)
23 | padding_bottom = padding_rows - padding_top
24 | padding_right = padding_cols - padding_left
25 | paddings = (padding_left, padding_right, padding_top, padding_bottom)
26 | images = torch.nn.ZeroPad2d(paddings)(images)
27 | return images
28 |
29 |
30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'):
31 | """
32 | Extract patches from images and put them in the C output dimension.
33 | :param padding:
34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
36 | each dimension of images
37 | :param strides: [stride_rows, stride_cols]
38 | :param rates: [dilation_rows, dilation_cols]
39 | :return: A Tensor
40 | """
41 | assert len(images.size()) == 4
42 | assert padding in ['same', 'valid']
43 | batch_size, channel, height, width = images.size()
44 |
45 | if padding == 'same':
46 | images = same_padding(images, ksizes, strides, rates)
47 | elif padding == 'valid':
48 | pass
49 | else:
50 | raise NotImplementedError('Unsupported padding type: {}.\
51 | Only "same" or "valid" are supported.'.format(padding))
52 |
53 | unfold = torch.nn.Unfold(kernel_size=ksizes,
54 | dilation=rates,
55 | padding=0,
56 | stride=strides)
57 | patches = unfold(images)
58 | return patches # [N, C*k*k, L], L is the total number of such blocks
59 |
60 | def reverse_patches(images, out_size, ksizes, strides, padding):
61 | """
62 | Extract patches from images and put them in the C output dimension.
63 | :param padding:
64 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
65 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
66 | each dimension of images
67 | :param strides: [stride_rows, stride_cols]
68 | :param rates: [dilation_rows, dilation_cols]
69 | :return: A Tensor
70 | """
71 | unfold = torch.nn.Fold(output_size = out_size,
72 | kernel_size=ksizes,
73 | dilation=1,
74 | padding=padding,
75 | stride=strides)
76 | patches = unfold(images)
77 | return patches # [N, C*k*k, L], L is the total number of such blocks
78 | def reduce_mean(x, axis=None, keepdim=False):
79 | if not axis:
80 | axis = range(len(x.shape))
81 | for i in sorted(axis, reverse=True):
82 | x = torch.mean(x, dim=i, keepdim=keepdim)
83 | return x
84 |
85 |
86 | def reduce_std(x, axis=None, keepdim=False):
87 | if not axis:
88 | axis = range(len(x.shape))
89 | for i in sorted(axis, reverse=True):
90 | x = torch.std(x, dim=i, keepdim=keepdim)
91 | return x
92 |
93 |
94 | def reduce_sum(x, axis=None, keepdim=False):
95 | if not axis:
96 | axis = range(len(x.shape))
97 | for i in sorted(axis, reverse=True):
98 | x = torch.sum(x, dim=i, keepdim=keepdim)
99 | return x
--------------------------------------------------------------------------------
/code/utility.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import datetime
5 | from functools import reduce
6 |
7 | import matplotlib
8 | matplotlib.use('Agg')
9 | import matplotlib.pyplot as plt
10 |
11 | import numpy as np
12 | import scipy.misc as misc
13 |
14 | import torch
15 | import torch.optim as optim
16 | import torch.optim.lr_scheduler as lrs
17 |
18 | class timer():
19 | def __init__(self):
20 | self.acc = 0
21 | self.tic()
22 |
23 | def tic(self):
24 | self.t0 = time.time()
25 |
26 | def toc(self):
27 | return time.time() - self.t0
28 |
29 | def hold(self):
30 | self.acc += self.toc()
31 |
32 | def release(self):
33 | ret = self.acc
34 | self.acc = 0
35 |
36 | return ret
37 |
38 | def reset(self):
39 | self.acc = 0
40 |
41 | class checkpoint():
42 | def __init__(self, args):
43 | self.args = args
44 | self.ok = True
45 | self.log = torch.Tensor()
46 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
47 |
48 | if args.load == '.':
49 | if args.save == '.': args.save = now
50 | self.dir = '../experiment/' + args.save
51 | else:
52 | self.dir = '../experiment/' + args.load
53 | if not os.path.exists(self.dir):
54 | args.load = '.'
55 | else:
56 | self.log = torch.load(self.dir + '/psnr_log.pt')
57 | print('Continue from epoch {}...'.format(len(self.log)))
58 |
59 | if args.reset:
60 | os.system('rm -rf ' + self.dir)
61 | args.load = '.'
62 |
63 | def _make_dir(path):
64 | if not os.path.exists(path): os.makedirs(path)
65 |
66 | _make_dir(self.dir)
67 | _make_dir(self.dir + '/model')
68 | _make_dir(self.dir + '/results')
69 |
70 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
71 | self.log_file = open(self.dir + '/log.txt', open_type)
72 | with open(self.dir + '/config.txt', open_type) as f:
73 | f.write(now + '\n\n')
74 | for arg in vars(args):
75 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
76 | f.write('\n')
77 |
78 | def save(self, trainer, epoch, is_best=False):
79 | trainer.model.save(self.dir, epoch, is_best=is_best)
80 | trainer.loss.save(self.dir)
81 | trainer.loss.plot_loss(self.dir, epoch)
82 |
83 | self.plot_psnr(epoch)
84 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
85 | torch.save(
86 | trainer.optimizer.state_dict(),
87 | os.path.join(self.dir, 'optimizer.pt')
88 | )
89 |
90 | def add_log(self, log):
91 | self.log = torch.cat([self.log, log])
92 |
93 | def write_log(self, log, refresh=False):
94 | print(log)
95 | self.log_file.write(log + '\n')
96 | if refresh:
97 | self.log_file.close()
98 | self.log_file = open(self.dir + '/log.txt', 'a')
99 |
100 | def done(self):
101 | self.log_file.close()
102 |
103 | def plot_psnr(self, epoch):
104 | axis = np.linspace(1, epoch, epoch)
105 | label = 'SR on {}'.format(self.args.data_test)
106 | fig = plt.figure()
107 | plt.title(label)
108 | for idx_scale, scale in enumerate(self.args.scale):
109 | plt.plot(
110 | axis,
111 | self.log[:, idx_scale].numpy(),
112 | label='Scale {}'.format(scale)
113 | )
114 | plt.legend()
115 | plt.xlabel('Epochs')
116 | plt.ylabel('PSNR')
117 | plt.grid(True)
118 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
119 | plt.close(fig)
120 |
121 | def save_results(self, filename, save_list, scale):
122 | filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
123 | postfix = ('SR','LR', 'HR')
124 | for v, p in zip(save_list, postfix):
125 | normalized = v[0].data.mul(255 / self.args.rgb_range)
126 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
127 | misc.imsave('{}{}.png'.format(filename, p), ndarr)
128 | def quantize(img, rgb_range):
129 | pixel_range = 255 / rgb_range
130 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
131 |
132 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
133 | diff = (sr - hr).data.div(rgb_range)
134 | if benchmark:
135 | shave = scale
136 | if diff.size(1) > 1:
137 | convert = diff.new(1, 3, 1, 1)
138 | convert[0, 0, 0, 0] = 65.738
139 | convert[0, 1, 0, 0] = 129.057
140 | convert[0, 2, 0, 0] = 25.064
141 | diff.mul_(convert).div_(256)
142 | diff = diff.sum(dim=1, keepdim=True)
143 | else:
144 | shave = scale + 6
145 |
146 | valid = diff[:, :, shave:-shave, shave:-shave]
147 | mse = valid.pow(2).mean()
148 |
149 | return -10 * math.log10(mse)
150 |
151 | def make_optimizer(args, my_model):
152 | trainable = filter(lambda x: x.requires_grad, my_model.parameters())
153 |
154 | if args.optimizer == 'SGD':
155 | optimizer_function = optim.SGD
156 | kwargs = {'momentum': args.momentum}
157 | elif args.optimizer == 'ADAM':
158 | optimizer_function = optim.Adam
159 | kwargs = {
160 | 'betas': (args.beta1, args.beta2),
161 | 'eps': args.epsilon
162 | }
163 | elif args.optimizer == 'RMSprop':
164 | optimizer_function = optim.RMSprop
165 | kwargs = {'eps': args.epsilon}
166 |
167 | kwargs['lr'] = args.lr
168 | kwargs['weight_decay'] = args.weight_decay
169 |
170 | return optimizer_function(trainable, **kwargs)
171 |
172 | def make_scheduler(args, my_optimizer):
173 | if args.decay_type == 'step':
174 | scheduler = lrs.StepLR(
175 | my_optimizer,
176 | step_size=args.lr_decay,
177 | gamma=args.gamma
178 | )
179 | elif args.decay_type.find('step') >= 0:
180 | milestones = args.decay_type.split('_')
181 | milestones.pop(0)
182 | milestones = list(map(lambda x: int(x), milestones))
183 | scheduler = lrs.MultiStepLR(
184 | my_optimizer,
185 | milestones=milestones,
186 | gamma=args.gamma
187 | )
188 |
189 | return scheduler
190 |
191 |
--------------------------------------------------------------------------------
/experiment/HCT-FFN/model/model_best_Rain100H.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/experiment/HCT-FFN/model/model_best_Rain100H.pt
--------------------------------------------------------------------------------
/experiment/HCT-FFN/model/model_best_Rain100L.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/experiment/HCT-FFN/model/model_best_Rain100L.pt
--------------------------------------------------------------------------------
/figure/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenxiang/HCT-FFN/0aa6526e4642dc4efa3e49ab29700fa7a312a318/figure/network.png
--------------------------------------------------------------------------------