├── .gitignore
├── LICENSE
├── README.md
└── src
├── __init__.py
├── data
├── __init__.py
├── benchmark.py
├── common.py
├── demo.py
├── df2k.py
├── div2k.py
├── div2kjpeg.py
├── sr291.py
├── srdata.py
└── video.py
├── dataloader.py
├── loss
├── __init__.py
├── __loss__.py
├── adversarial.py
├── demo.sh
├── discriminator.py
├── hash.py
└── vgg.py
├── main.py
├── model
├── LICENSE
├── README.md
├── __init__.py
├── attention.py
├── common.py
├── ddbpn.py
├── edsr.py
├── mdsr.py
├── mssr.py
├── nlsn.py
├── rcan.py
├── rdn.py
├── utils
│ ├── __init__.py
│ └── tools.py
└── vdsr.py
├── option.py
├── template.py
├── test.sh
├── train.sh
├── trainer.py
├── utility.py
├── utils
├── __init__.py
└── tools.py
└── videotester.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore Mac system files
2 | .DS_store
3 |
4 | dataset/
5 |
6 | experiment/
7 |
8 | __pycache__
9 |
10 | *.swp
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 njulj
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaDM
2 | ## [AdaDM: Enabling Normalization for Image Super-Resolution](https://arxiv.org/abs/2111.13905).
3 | You can apply BN, LN or GN in SR networks with our AdaDM. Pretrained models (EDSR\*/RDN\*/NLSN\*) can be downloaded from
4 | [Google Drive](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) or
5 | [BaiduYun](https://pan.baidu.com/s/18I3j4DJFvbNvTFHzDwsssA). The password for BaiduYun is `kymj`.
6 |
7 | :loudspeaker: If you use [BasicSR](https://github.com/xinntao/BasicSR) framework, you need to turn off the Exponential Moving Average (EMA) option when
8 | applying BN in the generator network (e.g., RRDBNet). You can disable EMA by setting `ema_decay=0` in corresponding `.yml` configuration file.
9 |
10 | | Model | Scale | File name (.pt) | Urban100 | Manga109 |
11 | | --- | --- | --- | --- | --- |
12 | |**EDSR** | 2 | | 32.93 | 39.10 |
13 | || 3 || 28.80 | 34.17 |
14 | || 4 || 26.64 | 31.02 |
15 | |**EDSR***| 2 | [EDSR_AdaDM_DIV2K_X2](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 33.12 | 39.31 |
16 | || 3 | [EDSR_AdaDM_DIV2K_X3](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 29.02 | 34.48 |
17 | || 4 | [EDSR_AdaDM_DIV2K_X4](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 26.83 | 31.24 |
18 | |**RDN** | 2 | | 32.89 | 39.18 |
19 | || 3 | | 28.80 | 34.13 |
20 | || 4 | | 26.61 | 31.00 |
21 | |**RDN***| 2 | [RDN_AdaDM_DIV2K_X2](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 33.03 | 39.18 |
22 | || 3 | [RDN_AdaDM_DIV2K_X3](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 28.95 | 34.29 |
23 | || 4 | [RDN_AdaDM_DIV2K_X4](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 26.72 | 31.18 |
24 | |**NLSN** | 2 | | 33.42 | 39.59 |
25 | || 3 | | 29.25 | 34.57 |
26 | || 4 | | 26.96 | 31.27 |
27 | |**NLSN*** | 2 | [NLSN_AdaDM_DIV2K_X2](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 33.59 | 39.67 |
28 | || 3 | [NLSN_AdaDM_DIV2K_X3](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 29.53 | 34.95 |
29 | || 4 | [NLSN_AdaDM_DIV2K_X4](https://drive.google.com/drive/folders/1xljnGUUPAXpdAzXxCUMz5Rs2yOMAMOx6?usp=sharing) | 27.24 | 31.73 |
30 |
31 | ## Preparation
32 | Please refer to [EDSR](https://github.com/thstkdgus35/EDSR-PyTorch) for instructions on dataset download and software installation, then clone our repository as follows:
33 | ```bash
34 | git clone https://github.com/njulj/AdaDM.git
35 | ```
36 |
37 | ## Training
38 | ```bash
39 | cd AdaDM/src
40 | bash train.sh
41 | ```
42 | Example training command in train.sh looks like:
43 | ```bash
44 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 2\
45 | --n_GPUs 1 --batch_size 16 --patch_size 96 --rgb_range 255 --res_scale 0.1\
46 | --save EDSR_AdaDM_Test_DIV2K_X2 --dir_data ../dataset --data_test Urban100\
47 | --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results
48 | ```
49 | Here, `$GPU_ID` specifies the GPU id used for training. `EDSR_AdaDM_Test_DIV2K_X2` is the directory where all files are saved during training.
50 | `--dir_data` specifies the root directory for all datasets, you should place the DIV2K and benchmark (e.g., Urban100) datasets under this directory.
51 |
52 | ## Testing
53 | ```bash
54 | cd AdaDM/src
55 | bash test.sh
56 | ```
57 | Example testing command in test.sh looks like:
58 | ```bash
59 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale $SCALE\
60 | --pre_train ../experiment/test/model/EDSR_AdaDM_DIV2K_X$SCALE.pt\
61 | --dir_data ../dataset --n_GPUs 1 --test_only --data_test $TEST_DATASET
62 | ```
63 | Here, `$GPU_ID` specifies the GPU id used for testing. `$SCALE` indicates the upscaling factor (e.g., 2, 3, 4). `--pre_train` specifies the path of
64 | saved checkpoints. `$TEST_DATASET` indicates the dataset to be tested.
65 |
66 | ## Acknowledgement
67 | This repository is built on [EDSR](https://github.com/thstkdgus35/EDSR-PyTorch) and [NLSN](https://github.com/HarukiYqM/Non-Local-Sparse-Attention). We thank the authors for sharing their codes.
68 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/__init__.py
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | #from dataloader import MSDataLoader
3 | from torch.utils.data import dataloader
4 | from torch.utils.data import ConcatDataset
5 |
6 | # This is a simple wrapper function for ConcatDataset
7 | class MyConcatDataset(ConcatDataset):
8 | def __init__(self, datasets):
9 | super(MyConcatDataset, self).__init__(datasets)
10 | self.train = datasets[0].train
11 |
12 | def set_scale(self, idx_scale):
13 | for d in self.datasets:
14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
15 |
16 | class Data:
17 | def __init__(self, args):
18 | self.loader_train = None
19 | if not args.test_only:
20 | datasets = []
21 | for d in args.data_train:
22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
23 | m = import_module('data.' + module_name.lower())
24 | datasets.append(getattr(m, module_name)(args, name=d))
25 |
26 | self.loader_train = dataloader.DataLoader(
27 | MyConcatDataset(datasets),
28 | batch_size=args.batch_size,
29 | shuffle=True,
30 | pin_memory=not args.cpu,
31 | num_workers=args.n_threads,
32 | )
33 |
34 | self.loader_test = []
35 | for d in args.data_test:
36 | if d in ['Set5', 'Set14', 'B100', 'Urban100', 'Manga109']:
37 | m = import_module('data.benchmark')
38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d)
39 | else:
40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
41 | m = import_module('data.' + module_name.lower())
42 | testset = getattr(m, module_name)(args, train=False, name=d)
43 |
44 | self.loader_test.append(
45 | dataloader.DataLoader(
46 | testset,
47 | batch_size=1,
48 | shuffle=False,
49 | pin_memory=not args.cpu,
50 | num_workers=args.n_threads,
51 | )
52 | )
53 |
--------------------------------------------------------------------------------
/src/data/benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 | from data import srdata
5 |
6 | import numpy as np
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class Benchmark(srdata.SRData):
12 | def __init__(self, args, name='', train=True, benchmark=True):
13 | super(Benchmark, self).__init__(
14 | args, name=name, train=train, benchmark=True
15 | )
16 |
17 | def _set_filesystem(self, dir_data):
18 | self.apath = os.path.join(dir_data, 'benchmark', self.name)
19 | self.dir_hr = os.path.join(self.apath, 'HR')
20 | if self.input_large:
21 | self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
22 | else:
23 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
24 | self.ext = ('', '.png')
25 |
26 |
--------------------------------------------------------------------------------
/src/data/common.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import skimage.color as sc
5 |
6 | import torch
7 |
8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
9 | ih, iw = args[0].shape[:2]
10 |
11 | if not input_large:
12 | p = scale if multi else 1
13 | tp = p * patch_size
14 | ip = tp // scale
15 | else:
16 | tp = patch_size
17 | ip = patch_size
18 |
19 | ix = random.randrange(0, iw - ip + 1)
20 | iy = random.randrange(0, ih - ip + 1)
21 |
22 | if not input_large:
23 | tx, ty = scale * ix, scale * iy
24 | else:
25 | tx, ty = ix, iy
26 |
27 | ret = [
28 | args[0][iy:iy + ip, ix:ix + ip, :],
29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
30 | ]
31 |
32 | return ret
33 |
34 | def set_channel(*args, n_channels=3):
35 | def _set_channel(img):
36 | if img.ndim == 2:
37 | img = np.expand_dims(img, axis=2)
38 |
39 | c = img.shape[2]
40 | if n_channels == 1 and c == 3:
41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
42 | elif n_channels == 3 and c == 1:
43 | img = np.concatenate([img] * n_channels, 2)
44 |
45 | return img
46 |
47 | return [_set_channel(a) for a in args]
48 |
49 | def np2Tensor(*args, rgb_range=255):
50 | def _np2Tensor(img):
51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
52 | tensor = torch.from_numpy(np_transpose).float()
53 | tensor.mul_(rgb_range / 255)
54 |
55 | return tensor
56 |
57 | return [_np2Tensor(a) for a in args]
58 |
59 | def augment(*args, hflip=True, rot=True):
60 | hflip = hflip and random.random() < 0.5
61 | vflip = rot and random.random() < 0.5
62 | rot90 = rot and random.random() < 0.5
63 |
64 | def _augment(img):
65 | if hflip: img = img[:, ::-1, :]
66 | if vflip: img = img[::-1, :, :]
67 | if rot90: img = img.transpose(1, 0, 2)
68 |
69 | return img
70 |
71 | return [_augment(a) for a in args]
72 |
73 |
--------------------------------------------------------------------------------
/src/data/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 |
5 | import numpy as np
6 | import imageio
7 |
8 | import torch
9 | import torch.utils.data as data
10 |
11 | class Demo(data.Dataset):
12 | def __init__(self, args, name='Demo', train=False, benchmark=False):
13 | self.args = args
14 | self.name = name
15 | self.scale = args.scale
16 | self.idx_scale = 0
17 | self.train = False
18 | self.benchmark = benchmark
19 |
20 | self.filelist = []
21 | for f in os.listdir(args.dir_demo):
22 | if f.find('.png') >= 0 or f.find('.jp') >= 0:
23 | self.filelist.append(os.path.join(args.dir_demo, f))
24 | self.filelist.sort()
25 |
26 | def __getitem__(self, idx):
27 | filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
28 | lr = imageio.imread(self.filelist[idx])
29 | lr, = common.set_channel(lr, n_channels=self.args.n_colors)
30 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
31 |
32 | return lr_t, -1, filename
33 |
34 | def __len__(self):
35 | return len(self.filelist)
36 |
37 | def set_scale(self, idx_scale):
38 | self.idx_scale = idx_scale
39 |
40 |
--------------------------------------------------------------------------------
/src/data/df2k.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 |
4 | class DF2K(srdata.SRData):
5 | def __init__(self, args, name='DF2K', train=True, benchmark=False):
6 | data_range = [r.split('-') for r in args.data_range.split('/')]
7 | if train:
8 | data_range = data_range[0]
9 | else:
10 | if args.test_only and len(data_range) == 1:
11 | data_range = data_range[0]
12 | else:
13 | data_range = data_range[1]
14 |
15 | self.begin, self.end = list(map(lambda x: int(x), data_range))
16 | super(DF2K, self).__init__(
17 | args, name=name, train=train, benchmark=benchmark
18 | )
19 |
20 | def _scan(self):
21 | names_hr, names_lr = super(DF2K, self)._scan()
22 | names_hr = names_hr[self.begin - 1:self.end]
23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24 |
25 | return names_hr, names_lr
26 |
27 | def _set_filesystem(self, dir_data):
28 | super(DF2K, self)._set_filesystem(dir_data)
29 | self.dir_hr = os.path.join(self.apath, 'DF2K_train_HR')
30 | self.dir_lr = os.path.join(self.apath, 'DF2K_train_LR_bicubic')
31 | if self.input_large: self.dir_lr += 'L'
32 |
33 |
--------------------------------------------------------------------------------
/src/data/div2k.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 |
4 | class DIV2K(srdata.SRData):
5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False):
6 | data_range = [r.split('-') for r in args.data_range.split('/')]
7 | if train:
8 | data_range = data_range[0]
9 | else:
10 | if args.test_only and len(data_range) == 1:
11 | data_range = data_range[0]
12 | else:
13 | data_range = data_range[1]
14 |
15 | self.begin, self.end = list(map(lambda x: int(x), data_range))
16 | super(DIV2K, self).__init__(
17 | args, name=name, train=train, benchmark=benchmark
18 | )
19 |
20 | def _scan(self):
21 | names_hr, names_lr = super(DIV2K, self)._scan()
22 | names_hr = names_hr[self.begin - 1:self.end]
23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr]
24 |
25 | return names_hr, names_lr
26 |
27 | def _set_filesystem(self, dir_data):
28 | super(DIV2K, self)._set_filesystem(dir_data)
29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
31 | if self.input_large: self.dir_lr += 'L'
32 |
33 |
--------------------------------------------------------------------------------
/src/data/div2kjpeg.py:
--------------------------------------------------------------------------------
1 | import os
2 | from data import srdata
3 | from data import div2k
4 |
5 | class DIV2KJPEG(div2k.DIV2K):
6 | def __init__(self, args, name='', train=True, benchmark=False):
7 | self.q_factor = int(name.replace('DIV2K-Q', ''))
8 | super(DIV2KJPEG, self).__init__(
9 | args, name=name, train=train, benchmark=benchmark
10 | )
11 |
12 | def _set_filesystem(self, dir_data):
13 | self.apath = os.path.join(dir_data, 'DIV2K')
14 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
15 | self.dir_lr = os.path.join(
16 | self.apath, 'DIV2K_Q{}'.format(self.q_factor)
17 | )
18 | if self.input_large: self.dir_lr += 'L'
19 | self.ext = ('.png', '.jpg')
20 |
21 |
--------------------------------------------------------------------------------
/src/data/sr291.py:
--------------------------------------------------------------------------------
1 | from data import srdata
2 |
3 | class SR291(srdata.SRData):
4 | def __init__(self, args, name='SR291', train=True, benchmark=False):
5 | super(SR291, self).__init__(args, name=name)
6 |
7 |
--------------------------------------------------------------------------------
/src/data/srdata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 | import pickle
5 |
6 | from data import common
7 |
8 | import numpy as np
9 | import imageio
10 | import torch
11 | import torch.utils.data as data
12 |
13 | class SRData(data.Dataset):
14 | def __init__(self, args, name='', train=True, benchmark=False):
15 | self.args = args
16 | self.name = name
17 | self.train = train
18 | self.split = 'train' if train else 'test'
19 | self.do_eval = True
20 | self.benchmark = benchmark
21 | self.input_large = (args.model == 'VDSR')
22 | self.scale = args.scale
23 | self.idx_scale = 0
24 |
25 | self._set_filesystem(args.dir_data)
26 | if args.ext.find('img') < 0:
27 | path_bin = os.path.join(self.apath, 'bin')
28 | os.makedirs(path_bin, exist_ok=True)
29 |
30 | list_hr, list_lr = self._scan()
31 | if args.ext.find('img') >= 0 or benchmark:
32 | self.images_hr, self.images_lr = list_hr, list_lr
33 | elif args.ext.find('sep') >= 0:
34 | os.makedirs(
35 | self.dir_hr.replace(self.apath, path_bin),
36 | exist_ok=True
37 | )
38 | for s in self.scale:
39 | os.makedirs(
40 | os.path.join(
41 | self.dir_lr.replace(self.apath, path_bin),
42 | 'X{}'.format(s)
43 | ),
44 | exist_ok=True
45 | )
46 |
47 | self.images_hr, self.images_lr = [], [[] for _ in self.scale]
48 | for h in list_hr:
49 | b = h.replace(self.apath, path_bin)
50 | b = b.replace(self.ext[0], '.pt')
51 | self.images_hr.append(b)
52 | self._check_and_load(args.ext, h, b, verbose=True)
53 | for i, ll in enumerate(list_lr):
54 | for l in ll:
55 | b = l.replace(self.apath, path_bin)
56 | b = b.replace(self.ext[1], '.pt')
57 | self.images_lr[i].append(b)
58 | self._check_and_load(args.ext, l, b, verbose=True)
59 | if train:
60 | n_patches = args.batch_size * args.test_every
61 | n_images = len(args.data_train) * len(self.images_hr)
62 | if n_images == 0:
63 | self.repeat = 0
64 | else:
65 | self.repeat = max(n_patches // n_images, 1)
66 |
67 | # Below functions as used to prepare images
68 | def _scan(self):
69 | names_hr = sorted(
70 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
71 | )
72 | names_lr = [[] for _ in self.scale]
73 | for f in names_hr:
74 | filename, _ = os.path.splitext(os.path.basename(f))
75 | if self.name == 'Manga109':
76 | filename = filename[:-6]
77 | for si, s in enumerate(self.scale):
78 | names_lr[si].append(os.path.join(
79 | self.dir_lr, 'X{}/{}x{}{}'.format(
80 | s, filename, s, self.ext[1]
81 | )
82 | ))
83 |
84 | return names_hr, names_lr
85 |
86 | def _set_filesystem(self, dir_data):
87 | self.apath = os.path.join(dir_data, self.name)
88 | self.dir_hr = os.path.join(self.apath, 'HR')
89 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
90 | if self.input_large: self.dir_lr += 'L'
91 | self.ext = ('.png', '.png')
92 |
93 | def _check_and_load(self, ext, img, f, verbose=True):
94 | if not os.path.isfile(f) or ext.find('reset') >= 0:
95 | if verbose:
96 | print('Making a binary: {}'.format(f))
97 | with open(f, 'wb') as _f:
98 | pickle.dump(imageio.imread(img), _f)
99 |
100 | def __getitem__(self, idx):
101 | lr, hr, filename = self._load_file(idx)
102 | pair = self.get_patch(lr, hr)
103 | pair = common.set_channel(*pair, n_channels=self.args.n_colors)
104 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
105 |
106 | return pair_t[0], pair_t[1], filename
107 |
108 | def __len__(self):
109 | if self.train:
110 | return len(self.images_hr) * self.repeat
111 | else:
112 | return len(self.images_hr)
113 |
114 | def _get_index(self, idx):
115 | if self.train:
116 | return idx % len(self.images_hr)
117 | else:
118 | return idx
119 |
120 | def _load_file(self, idx):
121 | idx = self._get_index(idx)
122 | f_hr = self.images_hr[idx]
123 | f_lr = self.images_lr[self.idx_scale][idx]
124 |
125 | filename, _ = os.path.splitext(os.path.basename(f_hr))
126 | if self.args.ext == 'img' or self.benchmark:
127 | hr = imageio.imread(f_hr)
128 | lr = imageio.imread(f_lr)
129 | elif self.args.ext.find('sep') >= 0:
130 | with open(f_hr, 'rb') as _f:
131 | hr = pickle.load(_f)
132 | with open(f_lr, 'rb') as _f:
133 | lr = pickle.load(_f)
134 |
135 | return lr, hr, filename
136 |
137 | def get_patch(self, lr, hr):
138 | scale = self.scale[self.idx_scale]
139 | if self.train:
140 | lr, hr = common.get_patch(
141 | lr, hr,
142 | patch_size=self.args.patch_size,
143 | scale=scale,
144 | multi=(len(self.scale) > 1),
145 | input_large=self.input_large
146 | )
147 | if not self.args.no_augment: lr, hr = common.augment(lr, hr)
148 | else:
149 | ih, iw = lr.shape[:2]
150 | hr = hr[0:ih * scale, 0:iw * scale]
151 |
152 | return lr, hr
153 |
154 | def set_scale(self, idx_scale):
155 | if not self.input_large:
156 | self.idx_scale = idx_scale
157 | else:
158 | self.idx_scale = random.randint(0, len(self.scale) - 1)
159 |
160 |
--------------------------------------------------------------------------------
/src/data/video.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data import common
4 |
5 | import cv2
6 | import numpy as np
7 | import imageio
8 |
9 | import torch
10 | import torch.utils.data as data
11 |
12 | class Video(data.Dataset):
13 | def __init__(self, args, name='Video', train=False, benchmark=False):
14 | self.args = args
15 | self.name = name
16 | self.scale = args.scale
17 | self.idx_scale = 0
18 | self.train = False
19 | self.do_eval = False
20 | self.benchmark = benchmark
21 |
22 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
23 | self.vidcap = cv2.VideoCapture(args.dir_demo)
24 | self.n_frames = 0
25 | self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26 |
27 | def __getitem__(self, idx):
28 | success, lr = self.vidcap.read()
29 | if success:
30 | self.n_frames += 1
31 | lr, = common.set_channel(lr, n_channels=self.args.n_colors)
32 | lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
33 |
34 | return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
35 | else:
36 | vidcap.release()
37 | return None
38 |
39 | def __len__(self):
40 | return self.total_frames
41 |
42 | def set_scale(self, idx_scale):
43 | self.idx_scale = idx_scale
44 |
45 |
--------------------------------------------------------------------------------
/src/dataloader.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import random
3 |
4 | import torch
5 | import torch.multiprocessing as multiprocessing
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data import SequentialSampler
8 | from torch.utils.data import RandomSampler
9 | from torch.utils.data import BatchSampler
10 | from torch.utils.data import _utils
11 | from torch.utils.data.dataloader import _DataLoaderIter
12 |
13 | from torch.utils.data._utils import collate
14 | from torch.utils.data._utils import signal_handling
15 | from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
16 | from torch.utils.data._utils import ExceptionWrapper
17 | from torch.utils.data._utils import IS_WINDOWS
18 | from torch.utils.data._utils.worker import ManagerWatchdog
19 |
20 | from torch._six import queue
21 |
22 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
23 | try:
24 | collate._use_shared_memory = True
25 | signal_handling._set_worker_signal_handlers()
26 |
27 | torch.set_num_threads(1)
28 | random.seed(seed)
29 | torch.manual_seed(seed)
30 |
31 | data_queue.cancel_join_thread()
32 |
33 | if init_fn is not None:
34 | init_fn(worker_id)
35 |
36 | watchdog = ManagerWatchdog()
37 |
38 | while watchdog.is_alive():
39 | try:
40 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
41 | except queue.Empty:
42 | continue
43 |
44 | if r is None:
45 | assert done_event.is_set()
46 | return
47 | elif done_event.is_set():
48 | continue
49 |
50 | idx, batch_indices = r
51 | try:
52 | idx_scale = 0
53 | if len(scale) > 1 and dataset.train:
54 | idx_scale = random.randrange(0, len(scale))
55 | dataset.set_scale(idx_scale)
56 |
57 | samples = collate_fn([dataset[i] for i in batch_indices])
58 | samples.append(idx_scale)
59 | except Exception:
60 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
61 | else:
62 | data_queue.put((idx, samples))
63 | del samples
64 |
65 | except KeyboardInterrupt:
66 | pass
67 |
68 | class _MSDataLoaderIter(_DataLoaderIter):
69 |
70 | def __init__(self, loader):
71 | self.dataset = loader.dataset
72 | self.scale = loader.scale
73 | self.collate_fn = loader.collate_fn
74 | self.batch_sampler = loader.batch_sampler
75 | self.num_workers = loader.num_workers
76 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
77 | self.timeout = loader.timeout
78 |
79 | self.sample_iter = iter(self.batch_sampler)
80 |
81 | base_seed = torch.LongTensor(1).random_().item()
82 |
83 | if self.num_workers > 0:
84 | self.worker_init_fn = loader.worker_init_fn
85 | self.worker_queue_idx = 0
86 | self.worker_result_queue = multiprocessing.Queue()
87 | self.batches_outstanding = 0
88 | self.worker_pids_set = False
89 | self.shutdown = False
90 | self.send_idx = 0
91 | self.rcvd_idx = 0
92 | self.reorder_dict = {}
93 | self.done_event = multiprocessing.Event()
94 |
95 | base_seed = torch.LongTensor(1).random_()[0]
96 |
97 | self.index_queues = []
98 | self.workers = []
99 | for i in range(self.num_workers):
100 | index_queue = multiprocessing.Queue()
101 | index_queue.cancel_join_thread()
102 | w = multiprocessing.Process(
103 | target=_ms_loop,
104 | args=(
105 | self.dataset,
106 | index_queue,
107 | self.worker_result_queue,
108 | self.done_event,
109 | self.collate_fn,
110 | self.scale,
111 | base_seed + i,
112 | self.worker_init_fn,
113 | i
114 | )
115 | )
116 | w.daemon = True
117 | w.start()
118 | self.index_queues.append(index_queue)
119 | self.workers.append(w)
120 |
121 | if self.pin_memory:
122 | self.data_queue = queue.Queue()
123 | pin_memory_thread = threading.Thread(
124 | target=_utils.pin_memory._pin_memory_loop,
125 | args=(
126 | self.worker_result_queue,
127 | self.data_queue,
128 | torch.cuda.current_device(),
129 | self.done_event
130 | )
131 | )
132 | pin_memory_thread.daemon = True
133 | pin_memory_thread.start()
134 | self.pin_memory_thread = pin_memory_thread
135 | else:
136 | self.data_queue = self.worker_result_queue
137 |
138 | _utils.signal_handling._set_worker_pids(
139 | id(self), tuple(w.pid for w in self.workers)
140 | )
141 | _utils.signal_handling._set_SIGCHLD_handler()
142 | self.worker_pids_set = True
143 |
144 | for _ in range(2 * self.num_workers):
145 | self._put_indices()
146 |
147 |
148 | class MSDataLoader(DataLoader):
149 |
150 | def __init__(self, cfg, *args, **kwargs):
151 | super(MSDataLoader, self).__init__(
152 | *args, **kwargs, num_workers=cfg.n_threads
153 | )
154 | self.scale = cfg.scale
155 |
156 | def __iter__(self):
157 | return _MSDataLoaderIter(self)
158 |
159 |
--------------------------------------------------------------------------------
/src/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import matplotlib.pyplot as plt
7 |
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | class Loss(nn.modules.loss._Loss):
15 | def __init__(self, args, ckp):
16 | super(Loss, self).__init__()
17 | print('Preparing loss function:')
18 |
19 | self.n_GPUs = args.n_GPUs
20 | self.loss = []
21 | self.loss_module = nn.ModuleList()
22 | for loss in args.loss.split('+'):
23 | weight, loss_type = loss.split('*')
24 | if loss_type == 'MSE':
25 | loss_function = nn.MSELoss()
26 | elif loss_type == 'L1':
27 | loss_function = nn.L1Loss()
28 | elif loss_type.find('VGG') >= 0:
29 | module = import_module('loss.vgg')
30 | loss_function = getattr(module, 'VGG')(
31 | loss_type[3:],
32 | rgb_range=args.rgb_range
33 | )
34 | elif loss_type.find('GAN') >= 0:
35 | module = import_module('loss.adversarial')
36 | loss_function = getattr(module, 'Adversarial')(
37 | args,
38 | loss_type
39 | )
40 |
41 | self.loss.append({
42 | 'type': loss_type,
43 | 'weight': float(weight),
44 | 'function': loss_function}
45 | )
46 | if loss_type.find('GAN') >= 0:
47 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
48 |
49 | if len(self.loss) > 1:
50 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
51 |
52 | for l in self.loss:
53 | if l['function'] is not None:
54 | print('{:.3f} * {}'.format(l['weight'], l['type']))
55 | self.loss_module.append(l['function'])
56 |
57 | self.log = torch.Tensor()
58 |
59 | device = torch.device('cpu' if args.cpu else 'cuda')
60 | self.loss_module.to(device)
61 | if args.precision == 'half': self.loss_module.half()
62 | if not args.cpu and args.n_GPUs > 1:
63 | self.loss_module = nn.DataParallel(self.loss_module,range(args.n_GPUs))
64 |
65 | if args.load != '': self.load(ckp.dir, cpu=args.cpu)
66 |
67 | def forward(self, sr, hr):
68 | losses = []
69 | for i, l in enumerate(self.loss):
70 | if l['function'] is not None:
71 | loss = l['function'](sr, hr)
72 | effective_loss = l['weight'] * loss
73 | losses.append(effective_loss)
74 | self.log[-1, i] += effective_loss.item()
75 | elif l['type'] == 'DIS':
76 | self.log[-1, i] += self.loss[i - 1]['function'].loss
77 |
78 | loss_sum = sum(losses)
79 | if len(self.loss) > 1:
80 | self.log[-1, -1] += loss_sum.item()
81 |
82 | return loss_sum
83 |
84 | def step(self):
85 | for l in self.get_loss_module():
86 | if hasattr(l, 'scheduler'):
87 | l.scheduler.step()
88 |
89 | def start_log(self):
90 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
91 |
92 | def end_log(self, n_batches):
93 | self.log[-1].div_(n_batches)
94 |
95 | def display_loss(self, batch):
96 | n_samples = batch + 1
97 | log = []
98 | for l, c in zip(self.loss, self.log[-1]):
99 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
100 |
101 | return ''.join(log)
102 |
103 | def plot_loss(self, apath, epoch):
104 | axis = np.linspace(1, epoch, epoch)
105 | for i, l in enumerate(self.loss):
106 | label = '{} Loss'.format(l['type'])
107 | fig = plt.figure()
108 | plt.title(label)
109 | plt.plot(axis, self.log[:, i].numpy(), label=label)
110 | plt.legend()
111 | plt.xlabel('Epochs')
112 | plt.ylabel('Loss')
113 | plt.grid(True)
114 | plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
115 | plt.close(fig)
116 |
117 | def get_loss_module(self):
118 | if self.n_GPUs == 1:
119 | return self.loss_module
120 | else:
121 | return self.loss_module.module
122 |
123 | def save(self, apath):
124 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
125 | torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
126 |
127 | def load(self, apath, cpu=False):
128 | if cpu:
129 | kwargs = {'map_location': lambda storage, loc: storage}
130 | else:
131 | kwargs = {}
132 |
133 | self.load_state_dict(torch.load(
134 | os.path.join(apath, 'loss.pt'),
135 | **kwargs
136 | ))
137 | self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
138 | for l in self.get_loss_module():
139 | if hasattr(l, 'scheduler'):
140 | for _ in range(len(self.log)): l.scheduler.step()
141 |
142 |
--------------------------------------------------------------------------------
/src/loss/__loss__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/loss/__loss__.py
--------------------------------------------------------------------------------
/src/loss/adversarial.py:
--------------------------------------------------------------------------------
1 | import utility
2 | from types import SimpleNamespace
3 |
4 | from model import common
5 | from loss import discriminator
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 |
12 | class Adversarial(nn.Module):
13 | def __init__(self, args, gan_type):
14 | super(Adversarial, self).__init__()
15 | self.gan_type = gan_type
16 | self.gan_k = args.gan_k
17 | self.dis = discriminator.Discriminator(args)
18 | if gan_type == 'WGAN_GP':
19 | # see https://arxiv.org/pdf/1704.00028.pdf pp.4
20 | optim_dict = {
21 | 'optimizer': 'ADAM',
22 | 'betas': (0, 0.9),
23 | 'epsilon': 1e-8,
24 | 'lr': 1e-5,
25 | 'weight_decay': args.weight_decay,
26 | 'decay': args.decay,
27 | 'gamma': args.gamma
28 | }
29 | optim_args = SimpleNamespace(**optim_dict)
30 | else:
31 | optim_args = args
32 |
33 | self.optimizer = utility.make_optimizer(optim_args, self.dis)
34 |
35 | def forward(self, fake, real):
36 | # updating discriminator...
37 | self.loss = 0
38 | fake_detach = fake.detach() # do not backpropagate through G
39 | for _ in range(self.gan_k):
40 | self.optimizer.zero_grad()
41 | # d: B x 1 tensor
42 | d_fake = self.dis(fake_detach)
43 | d_real = self.dis(real)
44 | retain_graph = False
45 | if self.gan_type == 'GAN':
46 | loss_d = self.bce(d_real, d_fake)
47 | elif self.gan_type.find('WGAN') >= 0:
48 | loss_d = (d_fake - d_real).mean()
49 | if self.gan_type.find('GP') >= 0:
50 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
51 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
52 | hat.requires_grad = True
53 | d_hat = self.dis(hat)
54 | gradients = torch.autograd.grad(
55 | outputs=d_hat.sum(), inputs=hat,
56 | retain_graph=True, create_graph=True, only_inputs=True
57 | )[0]
58 | gradients = gradients.view(gradients.size(0), -1)
59 | gradient_norm = gradients.norm(2, dim=1)
60 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
61 | loss_d += gradient_penalty
62 | # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
63 | elif self.gan_type == 'RGAN':
64 | better_real = d_real - d_fake.mean(dim=0, keepdim=True)
65 | better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
66 | loss_d = self.bce(better_real, better_fake)
67 | retain_graph = True
68 |
69 | # Discriminator update
70 | self.loss += loss_d.item()
71 | loss_d.backward(retain_graph=retain_graph)
72 | self.optimizer.step()
73 |
74 | if self.gan_type == 'WGAN':
75 | for p in self.dis.parameters():
76 | p.data.clamp_(-1, 1)
77 |
78 | self.loss /= self.gan_k
79 |
80 | # updating generator...
81 | d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
82 | if self.gan_type == 'GAN':
83 | label_real = torch.ones_like(d_fake_bp)
84 | loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
85 | elif self.gan_type.find('WGAN') >= 0:
86 | loss_g = -d_fake_bp.mean()
87 | elif self.gan_type == 'RGAN':
88 | better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
89 | better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
90 | loss_g = self.bce(better_fake, better_real)
91 |
92 | # Generator loss
93 | return loss_g
94 |
95 | def state_dict(self, *args, **kwargs):
96 | state_discriminator = self.dis.state_dict(*args, **kwargs)
97 | state_optimizer = self.optimizer.state_dict()
98 |
99 | return dict(**state_discriminator, **state_optimizer)
100 |
101 | def bce(self, real, fake):
102 | label_real = torch.ones_like(real)
103 | label_fake = torch.zeros_like(fake)
104 | bce_real = F.binary_cross_entropy_with_logits(real, label_real)
105 | bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
106 | bce_loss = bce_real + bce_fake
107 | return bce_loss
108 |
109 | # Some references
110 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
111 | # OR
112 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
113 |
--------------------------------------------------------------------------------
/src/loss/demo.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/loss/demo.sh
--------------------------------------------------------------------------------
/src/loss/discriminator.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | class Discriminator(nn.Module):
6 | '''
7 | output is not normalized
8 | '''
9 | def __init__(self, args):
10 | super(Discriminator, self).__init__()
11 |
12 | in_channels = args.n_colors
13 | out_channels = 64
14 | depth = 7
15 |
16 | def _block(_in_channels, _out_channels, stride=1):
17 | return nn.Sequential(
18 | nn.Conv2d(
19 | _in_channels,
20 | _out_channels,
21 | 3,
22 | padding=1,
23 | stride=stride,
24 | bias=False
25 | ),
26 | nn.BatchNorm2d(_out_channels),
27 | nn.LeakyReLU(negative_slope=0.2, inplace=True)
28 | )
29 |
30 | m_features = [_block(in_channels, out_channels)]
31 | for i in range(depth):
32 | in_channels = out_channels
33 | if i % 2 == 1:
34 | stride = 1
35 | out_channels *= 2
36 | else:
37 | stride = 2
38 | m_features.append(_block(in_channels, out_channels, stride=stride))
39 |
40 | patch_size = args.patch_size // (2**((depth + 1) // 2))
41 | m_classifier = [
42 | nn.Linear(out_channels * patch_size**2, 1024),
43 | nn.LeakyReLU(negative_slope=0.2, inplace=True),
44 | nn.Linear(1024, 1)
45 | ]
46 |
47 | self.features = nn.Sequential(*m_features)
48 | self.classifier = nn.Sequential(*m_classifier)
49 |
50 | def forward(self, x):
51 | features = self.features(x)
52 | output = self.classifier(features.view(features.size(0), -1))
53 |
54 | return output
55 |
56 |
--------------------------------------------------------------------------------
/src/loss/hash.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | class HASH(nn.Module):
9 | def __init__(self):
10 | super(HASH, self).__init__()
11 | self.l1 = nn.L1Loss()
12 | def forward(self, sr, qk, orders, hr, m=3):
13 | #hash loss
14 | qk = F.normalize(qk, p=2, dim=1, eps=5e-5)
15 | N,C,H,W = qk.shape
16 | qk = qk.view(N,C,H*W)
17 | qk_t = qk.permute(0,2,1).contiguous()
18 | similarity_map = F.relu(torch.matmul(qk_t, qk),inplace=True) #[N,H*W,H*W]
19 |
20 | orders = orders.unsqueeze(2).expand_as(similarity_map)#[N,H*W,H*W]
21 | orders_t = torch.transpose(orders,1,2)
22 | dist = torch.pow(orders-orders_t,2)
23 |
24 | ls = torch.mean(similarity_map*torch.log(torch.exp(dist+m)+1))
25 | ld = torch.mean((1-similarity_map)*torch.log(torch.exp(-dist+m)+1))
26 | loss = 0.005*(ls+ld)+self.l1(sr,hr)
27 |
28 | return loss
29 |
--------------------------------------------------------------------------------
/src/loss/vgg.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models as models
7 |
8 | class VGG(nn.Module):
9 | def __init__(self, conv_index, rgb_range=1):
10 | super(VGG, self).__init__()
11 | vgg_features = models.vgg19(pretrained=True).features
12 | modules = [m for m in vgg_features]
13 | if conv_index.find('22') >= 0:
14 | self.vgg = nn.Sequential(*modules[:8])
15 | elif conv_index.find('54') >= 0:
16 | self.vgg = nn.Sequential(*modules[:35])
17 |
18 | vgg_mean = (0.485, 0.456, 0.406)
19 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
20 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
21 | for p in self.parameters():
22 | p.requires_grad = False
23 |
24 | def forward(self, sr, hr):
25 | def _forward(x):
26 | x = self.sub_mean(x)
27 | x = self.vgg(x)
28 | return x
29 |
30 | vgg_sr = _forward(sr)
31 | with torch.no_grad():
32 | vgg_hr = _forward(hr.detach())
33 |
34 | loss = F.mse_loss(vgg_sr, vgg_hr)
35 |
36 | return loss
37 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import utility
4 | import data
5 | import model
6 | import loss
7 | from option import args
8 | from trainer import Trainer
9 |
10 | torch.manual_seed(args.seed)
11 | checkpoint = utility.checkpoint(args)
12 |
13 | def main():
14 | global model
15 | if args.data_test == ['video']:
16 | from videotester import VideoTester
17 | model = model.Model(args, checkpoint)
18 | print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
19 | t = VideoTester(args, model, checkpoint)
20 | t.test()
21 | else:
22 | if checkpoint.ok:
23 | loader = data.Data(args)
24 | _model = model.Model(args, checkpoint)
25 | print('Total params: %.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0))
26 | _loss = loss.Loss(args, checkpoint) if not args.test_only else None
27 | t = Trainer(args, loader, _model, _loss, checkpoint)
28 | while not t.terminate():
29 | t.train()
30 | t.test()
31 |
32 | checkpoint.done()
33 |
34 | if __name__ == '__main__':
35 | main()
36 |
--------------------------------------------------------------------------------
/src/model/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Sanghyun Son
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/model/README.md:
--------------------------------------------------------------------------------
1 | # EDSR-PyTorch
2 | 
3 |
4 | This repository is an official PyTorch implementation of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017, 2nd NTIRE**.
5 | You can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017).
6 |
7 | If you find our work useful in your research or publication, please cite our work:
8 |
9 | [1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** 2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)]
10 | ```
11 | @InProceedings{Lim_2017_CVPR_Workshops,
12 | author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
13 | title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
14 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
15 | month = {July},
16 | year = {2017}
17 | }
18 | ```
19 | We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images.
20 |
21 | **Differences between Torch version**
22 | * Codes are much more compact. (Removed all unnecessary parts.)
23 | * Models are smaller. (About half.)
24 | * Slightly better performances.
25 | * Training and evaluation requires less memory.
26 | * Python-based.
27 |
28 | ## Dependencies
29 | * Python 3.6
30 | * PyTorch >= 0.4.0
31 | * numpy
32 | * skimage
33 | * **imageio**
34 | * matplotlib
35 | * tqdm
36 |
37 | **Recent updates**
38 |
39 | * July 22, 2018
40 | * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models.
41 | * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!).
42 |
43 |
44 | ## Code
45 | Clone this repository into any place you want.
46 | ```bash
47 | git clone https://github.com/thstkdgus35/EDSR-PyTorch
48 | cd EDSR-PyTorch
49 | ```
50 |
51 | ## Quick start (Demo)
52 | You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/``) We support **png** and **jpeg** files.
53 |
54 | Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
55 | ```bash
56 | cd src # You are now in */EDSR-PyTorch/src
57 | sh demo.sh
58 | ```
59 |
60 | You can find the result images from ```experiment/test/results``` folder.
61 |
62 | | Model | Scale | File name (.pt) | Parameters | ****PSNR** |
63 | | --- | --- | --- | --- | --- |
64 | | **EDSR** | 2 | EDSR_baseline_x2 | 1.37 M | 34.61 dB |
65 | | | | *EDSR_x2 | 40.7 M | 35.03 dB |
66 | | | 3 | EDSR_baseline_x3 | 1.55 M | 30.92 dB |
67 | | | | *EDSR_x3 | 43.7 M | 31.26 dB |
68 | | | 4 | EDSR_baseline_x4 | 1.52 M | 28.95 dB |
69 | | | | *EDSR_x4 | 43.1 M | 29.25 dB |
70 | | **MDSR** | 2 | MDSR_baseline | 3.23 M | 34.63 dB |
71 | | | | *MDSR | 7.95 M| 34.92 dB |
72 | | | 3 | MDSR_baseline | | 30.94 dB |
73 | | | | *MDSR | | 31.22 dB |
74 | | | 4 | MDSR_baseline | | 28.97 dB |
75 | | | | *MDSR | | 29.24 dB |
76 |
77 | *Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (542MB)
78 | **We measured PSNR using DIV2K 0801 ~ 0900, RGB channels, without self-ensemble. (scale + 2) pixels from the image boundary are ignored.
79 |
80 | You can evaluate your models with widely-used benchmark datasets:
81 |
82 | [Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html),
83 |
84 | [Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests),
85 |
86 | [B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),
87 |
88 | [Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
89 |
90 | For these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. You can download [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB). Set ``--dir_data `` to evaluate the EDSR and MDSR with the benchmarks.
91 |
92 | ## How to train EDSR and MDSR
93 | We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB).
94 |
95 | Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```src/option.py``` to the place where DIV2K images are located.
96 |
97 | We recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument.
98 |
99 | If you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file.
100 |
101 | You can train EDSR and MDSR by yourself. All scripts are provided in the ``src/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train ``` argument.
102 |
103 | ```bash
104 | cd src # You are now in */EDSR-PyTorch/src
105 | sh demo.sh
106 | ```
107 |
108 | **Update log**
109 | * Jan 04, 2018
110 | * Many parts are re-written. You cannot use previous scripts and models directly.
111 | * Pre-trained MDSR is temporarily disabled.
112 | * Training details are included.
113 |
114 | * Jan 09, 2018
115 | * Missing files are included (```src/data/MyImage.py```).
116 | * Some links are fixed.
117 |
118 | * Jan 16, 2018
119 | * Memory efficient forward function is implemented.
120 | * Add --chop_forward argument to your script to enable it.
121 | * Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.)
122 |
123 | * Feb 21, 2018
124 | * Fixed the problem when loading pre-trained multi-gpu model.
125 | * Added pre-trained scale 2 baseline model.
126 | * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models.
127 | * PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch.
128 |
129 | * Feb 23, 2018
130 | * Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version.
131 |
132 | * With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution.
133 | * New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.)
134 | * With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.)
135 | * If you cannot make the binary pack, just use the default setting (``--ext img``).
136 |
137 | * Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match.
138 | * Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.)
139 | * Added performance comparison between Torch7 model and PyTorch models.
140 |
141 | * Mar 5, 2018
142 | * All baseline models are uploaded.
143 | * Now supports half-precision at test time. Use ``--precision half`` to enable it. This does not degrade the output images.
144 |
145 | * Mar 11, 2018
146 | * Fixed some typos in the code and script.
147 | * Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only.
148 | * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected.
149 |
150 | * Mar 20, 2018
151 | * Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.
152 | * Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.
153 | * Changed the behavior of skip_batch.
154 |
155 | * Mar 29, 2018
156 | * We now provide all models from our paper.
157 | * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble.
158 | * ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before.
159 | * Some codes and script are re-written.
160 |
161 | * Apr 9, 2018
162 | * VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet.
163 | * Many codes are refactored. If there exists a bug, please report it.
164 | * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L.
165 |
166 | * Apr 26, 2018
167 | * Compatible with PyTorch 0.4.0
168 | * Please use the legacy/0.3.1 branch if you are using the old version of PyTorch.
169 | * Minor bug fixes
170 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | from importlib import import_module
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 |
8 | class Model(nn.Module):
9 | def __init__(self, args, ckp):
10 | super(Model, self).__init__()
11 | print('Making model...')
12 |
13 | self.scale = args.scale
14 | self.idx_scale = 0
15 | self.self_ensemble = args.self_ensemble
16 | self.chop = args.chop
17 | self.precision = args.precision
18 | self.cpu = args.cpu
19 | self.device = torch.device('cpu' if args.cpu else 'cuda')
20 | self.n_GPUs = args.n_GPUs
21 | self.save_models = args.save_models
22 |
23 | module = import_module('model.' + args.model.lower())
24 | self.model = module.make_model(args).to(self.device)
25 | if args.precision == 'half': self.model.half()
26 |
27 | if not args.cpu and args.n_GPUs > 1:
28 | self.model = nn.DataParallel(self.model, range(args.n_GPUs))
29 |
30 | self.load(
31 | ckp.dir,
32 | pre_train=args.pre_train,
33 | resume=args.resume,
34 | cpu=args.cpu
35 | )
36 | print(self.model, file=ckp.log_file)
37 |
38 | def forward(self, x, idx_scale):
39 | self.idx_scale = idx_scale
40 | target = self.get_model()
41 | if hasattr(target, 'set_scale'):
42 | target.set_scale(idx_scale)
43 |
44 | if self.self_ensemble and not self.training:
45 | if self.chop:
46 | forward_function = self.forward_chop
47 | else:
48 | forward_function = self.model.forward
49 |
50 | return self.forward_x8(x, forward_function)
51 | elif self.chop and not self.training:
52 | return self.forward_chop(x)
53 | else:
54 | return self.model(x)
55 |
56 | def get_model(self):
57 | if self.n_GPUs == 1:
58 | return self.model
59 | else:
60 | return self.model.module
61 |
62 | def state_dict(self, **kwargs):
63 | target = self.get_model()
64 | return target.state_dict(**kwargs)
65 |
66 | def save(self, apath, epoch, is_best=False):
67 | target = self.get_model()
68 | torch.save(
69 | target.state_dict(),
70 | os.path.join(apath, 'model_latest.pt')
71 | )
72 | if is_best:
73 | torch.save(
74 | target.state_dict(),
75 | os.path.join(apath, 'model_best.pt')
76 | )
77 |
78 | #if self.save_models:
79 | if self.save_models and epoch % 50 == 0:
80 | torch.save(
81 | target.state_dict(),
82 | os.path.join(apath, 'model_{}.pt'.format(epoch))
83 | )
84 |
85 | def load(self, apath, pre_train='.', resume=-1, cpu=False):
86 | if cpu:
87 | kwargs = {'map_location': lambda storage, loc: storage}
88 | else:
89 | kwargs = {}
90 |
91 | if resume == -1:
92 | self.get_model().load_state_dict(
93 | torch.load(
94 | os.path.join(apath, 'model_latest.pt'),
95 | **kwargs
96 | ),
97 | strict=False
98 | )
99 | elif resume == 0:
100 | if pre_train != '.':
101 | print('Loading model from {}'.format(pre_train))
102 | self.get_model().load_state_dict(
103 | torch.load(pre_train, **kwargs),
104 | strict=False
105 | )
106 | else:
107 | self.get_model().load_state_dict(
108 | torch.load(
109 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
110 | **kwargs
111 | ),
112 | strict=False
113 | )
114 |
115 | def forward_chop(self, x, shave=10, min_size=120000):
116 | scale = self.scale[self.idx_scale]
117 | n_GPUs = min(self.n_GPUs, 4)
118 | b, c, h, w = x.size()
119 | h_half, w_half = h // 2, w // 2
120 | h_size, w_size = h_half + shave, w_half + shave
121 | h_size +=4-h_size%4
122 | w_size +=8-w_size%8
123 |
124 | lr_list = [
125 | x[:, :, 0:h_size, 0:w_size],
126 | x[:, :, 0:h_size, (w - w_size):w],
127 | x[:, :, (h - h_size):h, 0:w_size],
128 | x[:, :, (h - h_size):h, (w - w_size):w]]
129 |
130 | if w_size * h_size < min_size:
131 | sr_list = []
132 | for i in range(0, 4, n_GPUs):
133 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
134 | sr_batch = self.model(lr_batch)
135 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
136 | else:
137 | sr_list = [
138 | self.forward_chop(patch, shave=shave, min_size=min_size) \
139 | for patch in lr_list
140 | ]
141 |
142 | h, w = scale * h, scale * w
143 | h_half, w_half = scale * h_half, scale * w_half
144 | h_size, w_size = scale * h_size, scale * w_size
145 | shave *= scale
146 |
147 | output = x.new(b, c, h, w)
148 | output[:, :, 0:h_half, 0:w_half] \
149 | = sr_list[0][:, :, 0:h_half, 0:w_half]
150 | output[:, :, 0:h_half, w_half:w] \
151 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
152 | output[:, :, h_half:h, 0:w_half] \
153 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
154 | output[:, :, h_half:h, w_half:w] \
155 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
156 |
157 | return output
158 |
159 | def forward_x8(self, x, forward_function):
160 | def _transform(v, op):
161 | if self.precision != 'single': v = v.float()
162 |
163 | v2np = v.data.cpu().numpy()
164 | if op == 'v':
165 | tfnp = v2np[:, :, :, ::-1].copy()
166 | elif op == 'h':
167 | tfnp = v2np[:, :, ::-1, :].copy()
168 | elif op == 't':
169 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
170 |
171 | ret = torch.Tensor(tfnp).to(self.device)
172 | if self.precision == 'half': ret = ret.half()
173 |
174 | return ret
175 |
176 | lr_list = [x]
177 | for tf in 'v', 'h', 't':
178 | lr_list.extend([_transform(t, tf) for t in lr_list])
179 |
180 | sr_list = [forward_function(aug) for aug in lr_list]
181 | for i in range(len(sr_list)):
182 | if i > 3:
183 | sr_list[i] = _transform(sr_list[i], 't')
184 | if i % 4 > 1:
185 | sr_list[i] = _transform(sr_list[i], 'h')
186 | if (i % 4) % 2 == 1:
187 | sr_list[i] = _transform(sr_list[i], 'v')
188 |
189 | output_cat = torch.cat(sr_list, dim=0)
190 | output = output_cat.mean(dim=0, keepdim=True)
191 |
192 | return output
193 |
194 |
--------------------------------------------------------------------------------
/src/model/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from model import common
5 |
6 | class NonLocalSparseAttention(nn.Module):
7 | def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1):
8 | super(NonLocalSparseAttention,self).__init__()
9 | self.chunk_size = chunk_size
10 | self.n_hashes = n_hashes
11 | self.reduction = reduction
12 | self.res_scale = res_scale
13 | self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None)
14 | self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None)
15 |
16 | def LSH(self, hash_buckets, x):
17 | #x: [N,H*W,C]
18 | N = x.shape[0]
19 | device = x.device
20 |
21 | #generate random rotation matrix
22 | rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2]
23 | random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2]
24 |
25 | #locality sensitive hashing
26 | rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2]
27 | rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets]
28 |
29 | #get hash codes
30 | hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N,n_hashes,H*W]
31 |
32 | #add offsets to avoid hash codes overlapping between hash rounds
33 | offsets = torch.arange(self.n_hashes, device=device)
34 | offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1))
35 | hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes*H*W]
36 |
37 | return hash_codes
38 |
39 | def add_adjacent_buckets(self, x):
40 | x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2)
41 | x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2)
42 | return torch.cat([x, x_extra_back,x_extra_forward], dim=3)
43 |
44 | def forward(self, input):
45 |
46 | N,_,H,W = input.shape
47 | x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1)
48 | y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1)
49 | L,C = x_embed.shape[-2:]
50 |
51 | #number of hash buckets/hash bits
52 | hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128)
53 |
54 | #get assigned hash codes/bucket number
55 | hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W]
56 | hash_codes = hash_codes.detach()
57 |
58 | #group elements with same hash code by sorting
59 | _, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W]
60 | _, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order
61 | mod_indices = (indices % L) #now range from (0->H*W)
62 | x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C]
63 | y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C]
64 |
65 | #pad the embedding if it cannot be divided by chunk_size
66 | padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0
67 | x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C]
68 | y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction))
69 | if padding:
70 | pad_x = x_att_buckets[:,:,-padding:,:].clone()
71 | pad_y = y_att_buckets[:,:,-padding:,:].clone()
72 | x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2)
73 | y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2)
74 |
75 | x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C]
76 | y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction))
77 |
78 | x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5)
79 |
80 | #allow attend to adjacent buckets
81 | x_match = self.add_adjacent_buckets(x_match)
82 | y_att_buckets = self.add_adjacent_buckets(y_att_buckets)
83 |
84 | #unormalized attention score
85 | raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3]
86 |
87 | #softmax
88 | bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)
89 | score = torch.exp(raw_score - bucket_score) #(after softmax)
90 | bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1])
91 |
92 | #attention
93 | ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C]
94 | ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction))
95 |
96 | #if padded, then remove extra elements
97 | if padding:
98 | ret = ret[:,:,:-padding,:].clone()
99 | bucket_score = bucket_score[:,:,:-padding].clone()
100 |
101 | #recover the original order
102 | ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C]
103 | bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W]
104 | ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C]
105 | bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W]
106 |
107 | #weighted sum multi-round attention
108 | ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C]
109 | bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))
110 | probs = nn.functional.softmax(bucket_score,dim=1)
111 | ret = torch.sum(ret * probs, dim=1)
112 |
113 | ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input
114 | return ret
115 |
116 |
117 | class NonLocalAttention(nn.Module):
118 | def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=1, softmax_scale=10, average=True, res_scale=1,conv=common.default_conv):
119 | super(NonLocalAttention, self).__init__()
120 | self.res_scale = res_scale
121 | self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
122 | self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU())
123 | self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())
124 |
125 | def forward(self, input):
126 | x_embed_1 = self.conv_match1(input)
127 | x_embed_2 = self.conv_match2(input)
128 | x_assembly = self.conv_assembly(input)
129 |
130 | N,C,H,W = x_embed_1.shape
131 | x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C))
132 | x_embed_2 = x_embed_2.view(N,C,H*W)
133 | score = torch.matmul(x_embed_1, x_embed_2)
134 | score = F.softmax(score, dim=2)
135 | x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1)
136 | x_final = torch.matmul(score, x_assembly)
137 | return x_final.permute(0,2,1).view(N,-1,H,W)+self.res_scale*input
138 |
--------------------------------------------------------------------------------
/src/model/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def batched_index_select(values, indices):
9 | last_dim = values.shape[-1]
10 | return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
11 |
12 | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
13 | return nn.Conv2d(
14 | in_channels, out_channels, kernel_size,
15 | padding=(kernel_size//2),stride=stride, bias=bias)
16 |
17 | class MeanShift(nn.Conv2d):
18 | def __init__(
19 | self, rgb_range,
20 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
21 |
22 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
23 | std = torch.Tensor(rgb_std)
24 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
25 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
26 | for p in self.parameters():
27 | p.requires_grad = False
28 |
29 | class BasicBlock(nn.Sequential):
30 | def __init__(
31 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
32 | bn=False, act=nn.PReLU()):
33 |
34 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
35 | if bn:
36 | m.append(nn.BatchNorm2d(out_channels))
37 | if act is not None:
38 | m.append(act)
39 |
40 | super(BasicBlock, self).__init__(*m)
41 |
42 | class ResBlock(nn.Module):
43 | def __init__(
44 | self, conv, n_feats, kernel_size,
45 | bias=True, bn=False, act=nn.PReLU(), res_scale=1):
46 |
47 | super(ResBlock, self).__init__()
48 | m = []
49 | for i in range(2):
50 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
51 | if bn:
52 | m.append(nn.BatchNorm2d(n_feats))
53 | if i == 0:
54 | m.append(act)
55 |
56 | self.body = nn.Sequential(*m)
57 | self.res_scale = res_scale
58 |
59 | def forward(self, x):
60 | res = self.body(x).mul(self.res_scale)
61 | res += x
62 |
63 | return res
64 |
65 | class ResBlock_AdaDM(nn.Module):
66 | def __init__(
67 | self, conv, n_feats, kernel_size,
68 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
69 |
70 | super(ResBlock_AdaDM, self).__init__()
71 | self.conv1 = conv(n_feats, n_feats, kernel_size, bias=bias)
72 | self.conv2 = conv(n_feats, n_feats, kernel_size, bias=bias)
73 | self.act = act
74 | self.res_scale = res_scale
75 | self.phi = nn.Conv2d(1, 1, 1, 1, 0, bias=True)
76 | self.phi.weight.data.fill_(1)
77 | self.phi.bias.data.fill_(0)
78 | self.norm1 = nn.BatchNorm2d(n_feats)
79 | self.norm2 = nn.BatchNorm2d(n_feats)
80 |
81 | def forward(self, x):
82 | s = torch.std(x, dim=[1,2,3], keepdim=True)
83 | x_n = self.norm1(x)
84 | res = self.conv1(x_n)
85 | res = self.act(res)
86 | res = self.norm2(res)
87 | res = self.conv2(res).mul(self.res_scale)
88 | res = res * torch.exp(self.phi(torch.log(s)))
89 |
90 | res += x
91 |
92 | return res
93 |
94 | class Upsampler(nn.Sequential):
95 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
96 |
97 | m = []
98 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
99 | for _ in range(int(math.log(scale, 2))):
100 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
101 | m.append(nn.PixelShuffle(2))
102 | if bn:
103 | m.append(nn.BatchNorm2d(n_feats))
104 | if act == 'relu':
105 | m.append(nn.ReLU(True))
106 | elif act == 'prelu':
107 | m.append(nn.PReLU(n_feats))
108 |
109 | elif scale == 3:
110 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
111 | m.append(nn.PixelShuffle(3))
112 | if bn:
113 | m.append(nn.BatchNorm2d(n_feats))
114 | if act == 'relu':
115 | m.append(nn.ReLU(True))
116 | elif act == 'prelu':
117 | m.append(nn.PReLU(n_feats))
118 | else:
119 | raise NotImplementedError
120 |
121 | super(Upsampler, self).__init__(*m)
122 |
123 |
--------------------------------------------------------------------------------
/src/model/ddbpn.py:
--------------------------------------------------------------------------------
1 | # Deep Back-Projection Networks For Super-Resolution
2 | # https://arxiv.org/abs/1803.02735
3 |
4 | from model import common
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def make_model(args, parent=False):
11 | return DDBPN(args)
12 |
13 | def projection_conv(in_channels, out_channels, scale, up=True):
14 | kernel_size, stride, padding = {
15 | 2: (6, 2, 2),
16 | 4: (8, 4, 2),
17 | 8: (12, 8, 2)
18 | }[scale]
19 | if up:
20 | conv_f = nn.ConvTranspose2d
21 | else:
22 | conv_f = nn.Conv2d
23 |
24 | return conv_f(
25 | in_channels, out_channels, kernel_size,
26 | stride=stride, padding=padding
27 | )
28 |
29 | class DenseProjection(nn.Module):
30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
31 | super(DenseProjection, self).__init__()
32 | if bottleneck:
33 | self.bottleneck = nn.Sequential(*[
34 | nn.Conv2d(in_channels, nr, 1),
35 | nn.PReLU(nr)
36 | ])
37 | inter_channels = nr
38 | else:
39 | self.bottleneck = None
40 | inter_channels = in_channels
41 |
42 | self.conv_1 = nn.Sequential(*[
43 | projection_conv(inter_channels, nr, scale, up),
44 | nn.PReLU(nr)
45 | ])
46 | self.conv_2 = nn.Sequential(*[
47 | projection_conv(nr, inter_channels, scale, not up),
48 | nn.PReLU(inter_channels)
49 | ])
50 | self.conv_3 = nn.Sequential(*[
51 | projection_conv(inter_channels, nr, scale, up),
52 | nn.PReLU(nr)
53 | ])
54 |
55 | def forward(self, x):
56 | if self.bottleneck is not None:
57 | x = self.bottleneck(x)
58 |
59 | a_0 = self.conv_1(x)
60 | b_0 = self.conv_2(a_0)
61 | e = b_0.sub(x)
62 | a_1 = self.conv_3(e)
63 |
64 | out = a_0.add(a_1)
65 |
66 | return out
67 |
68 | class DDBPN(nn.Module):
69 | def __init__(self, args):
70 | super(DDBPN, self).__init__()
71 | scale = args.scale[0]
72 |
73 | n0 = 128
74 | nr = 32
75 | self.depth = 6
76 |
77 | rgb_mean = (0.4488, 0.4371, 0.4040)
78 | rgb_std = (1.0, 1.0, 1.0)
79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
80 | initial = [
81 | nn.Conv2d(args.n_colors, n0, 3, padding=1),
82 | nn.PReLU(n0),
83 | nn.Conv2d(n0, nr, 1),
84 | nn.PReLU(nr)
85 | ]
86 | self.initial = nn.Sequential(*initial)
87 |
88 | self.upmodules = nn.ModuleList()
89 | self.downmodules = nn.ModuleList()
90 | channels = nr
91 | for i in range(self.depth):
92 | self.upmodules.append(
93 | DenseProjection(channels, nr, scale, True, i > 1)
94 | )
95 | if i != 0:
96 | channels += nr
97 |
98 | channels = nr
99 | for i in range(self.depth - 1):
100 | self.downmodules.append(
101 | DenseProjection(channels, nr, scale, False, i != 0)
102 | )
103 | channels += nr
104 |
105 | reconstruction = [
106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1)
107 | ]
108 | self.reconstruction = nn.Sequential(*reconstruction)
109 |
110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
111 |
112 | def forward(self, x):
113 | x = self.sub_mean(x)
114 | x = self.initial(x)
115 |
116 | h_list = []
117 | l_list = []
118 | for i in range(self.depth - 1):
119 | if i == 0:
120 | l = x
121 | else:
122 | l = torch.cat(l_list, dim=1)
123 | h_list.append(self.upmodules[i](l))
124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
125 |
126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
127 | out = self.reconstruction(torch.cat(h_list, dim=1))
128 | out = self.add_mean(out)
129 |
130 | return out
131 |
132 |
--------------------------------------------------------------------------------
/src/model/edsr.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | def make_model(args, parent=False):
6 | return EDSR(args)
7 |
8 | class EDSR(nn.Module):
9 | def __init__(self, args, conv=common.default_conv):
10 | super(EDSR, self).__init__()
11 |
12 | n_resblocks = args.n_resblocks
13 | n_feats = args.n_feats
14 | kernel_size = 3
15 | scale = args.scale[0]
16 | act = nn.ReLU(True)
17 |
18 | # define head module
19 | m_head = [conv(args.n_colors, n_feats, kernel_size)]
20 |
21 | # define body module
22 | m_body = [
23 | common.ResBlock_AdaDM(
24 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
25 | ) for _ in range(n_resblocks)
26 | ]
27 | m_body.append(conv(n_feats, n_feats, kernel_size))
28 |
29 | # define tail module
30 | m_tail = [
31 | common.Upsampler(conv, scale, n_feats, act=False),
32 | conv(n_feats, args.n_colors, kernel_size)
33 | ]
34 |
35 | self.head = nn.Sequential(*m_head)
36 | self.body = nn.Sequential(*m_body)
37 | self.tail = nn.Sequential(*m_tail)
38 |
39 | def forward(self, x):
40 | x = self.head(x)
41 |
42 | res = self.body(x)
43 | res += x
44 |
45 | x = self.tail(res)
46 |
47 | return x
48 |
49 | def load_state_dict(self, state_dict, strict=True):
50 | own_state = self.state_dict()
51 | for name, param in state_dict.items():
52 | if name in own_state:
53 | if isinstance(param, nn.Parameter):
54 | param = param.data
55 | try:
56 | own_state[name].copy_(param)
57 | except Exception:
58 | if name.find('tail') == -1:
59 | raise RuntimeError('While copying the parameter named {}, '
60 | 'whose dimensions in the model are {} and '
61 | 'whose dimensions in the checkpoint are {}.'
62 | .format(name, own_state[name].size(), param.size()))
63 | elif strict:
64 | if name.find('tail') == -1:
65 | raise KeyError('unexpected key "{}" in state_dict'
66 | .format(name))
67 |
--------------------------------------------------------------------------------
/src/model/mdsr.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 |
5 | def make_model(args, parent=False):
6 | return MDSR(args)
7 |
8 | class MDSR(nn.Module):
9 | def __init__(self, args, conv=common.default_conv):
10 | super(MDSR, self).__init__()
11 | n_resblocks = args.n_resblocks
12 | n_feats = args.n_feats
13 | kernel_size = 3
14 | self.scale_idx = 0
15 |
16 | act = nn.ReLU(True)
17 |
18 | rgb_mean = (0.4488, 0.4371, 0.4040)
19 | rgb_std = (1.0, 1.0, 1.0)
20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
21 |
22 | m_head = [conv(args.n_colors, n_feats, kernel_size)]
23 |
24 | self.pre_process = nn.ModuleList([
25 | nn.Sequential(
26 | common.ResBlock(conv, n_feats, 5, act=act),
27 | common.ResBlock(conv, n_feats, 5, act=act)
28 | ) for _ in args.scale
29 | ])
30 |
31 | m_body = [
32 | common.ResBlock(
33 | conv, n_feats, kernel_size, act=act
34 | ) for _ in range(n_resblocks)
35 | ]
36 | m_body.append(conv(n_feats, n_feats, kernel_size))
37 |
38 | self.upsample = nn.ModuleList([
39 | common.Upsampler(
40 | conv, s, n_feats, act=False
41 | ) for s in args.scale
42 | ])
43 |
44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)]
45 |
46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
47 |
48 | self.head = nn.Sequential(*m_head)
49 | self.body = nn.Sequential(*m_body)
50 | self.tail = nn.Sequential(*m_tail)
51 |
52 | def forward(self, x):
53 | x = self.sub_mean(x)
54 | x = self.head(x)
55 | x = self.pre_process[self.scale_idx](x)
56 |
57 | res = self.body(x)
58 | res += x
59 |
60 | x = self.upsample[self.scale_idx](res)
61 | x = self.tail(x)
62 | x = self.add_mean(x)
63 |
64 | return x
65 |
66 | def set_scale(self, scale_idx):
67 | self.scale_idx = scale_idx
68 |
69 |
--------------------------------------------------------------------------------
/src/model/mssr.py:
--------------------------------------------------------------------------------
1 | from model import common
2 | import torch.nn as nn
3 | import torch
4 | from model.attention import ContextualAttention,NonLocalAttention
5 | def make_model(args, parent=False):
6 | return MSSR(args)
7 |
8 | class MultisourceProjection(nn.Module):
9 | def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
10 | super(MultisourceProjection, self).__init__()
11 | self.up_attention = ContextualAttention(scale=2)
12 | self.down_attention = NonLocalAttention()
13 | self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
14 | self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1)
15 |
16 | def forward(self,x):
17 | down_map = self.upsample(self.down_attention(x))
18 | up_map = self.up_attention(x)
19 |
20 | err = self.encoder(up_map-down_map)
21 | final_map = down_map + err
22 |
23 | return final_map
24 |
25 | class RecurrentProjection(nn.Module):
26 | def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
27 | super(RecurrentProjection, self).__init__()
28 | self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
29 | self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
30 | self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
31 | #self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
32 | self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
33 | self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
34 | self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
35 | self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
36 | self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU())
37 |
38 |
39 | def forward(self, x):
40 | x_up = self.multi_source_projection_1(x)
41 |
42 | x_down = self.down_sample_1(x_up)
43 | error_up = self.error_encode_1(x-x_down)
44 | h_estimate_1 = x_up + error_up
45 |
46 | x_up_2 = self.multi_source_projection_2(h_estimate_1)
47 | x_down_2 = self.down_sample_3(x_up_2)
48 | error_up_2 = self.error_encode_2(x-x_down_2)
49 | h_estimate_2 = x_up_2 + error_up_2
50 | x_final = self.post_conv(self.down_sample_4(h_estimate_2))
51 |
52 | return x_final, h_estimate_2
53 |
54 |
55 |
56 |
57 |
58 | class MSSR(nn.Module):
59 | def __init__(self, args, conv=common.default_conv):
60 | super(MSSR, self).__init__()
61 |
62 | #n_convblock = args.n_convblocks
63 | n_feats = args.n_feats
64 | self.depth = args.depth
65 | kernel_size = 3
66 | scale = args.scale[0]
67 |
68 |
69 | rgb_mean = (0.4488, 0.4371, 0.4040)
70 | rgb_std = (1.0, 1.0, 1.0)
71 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
72 |
73 | # define head module
74 | m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()),
75 | common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())]
76 |
77 | # define multiple reconstruction module
78 |
79 | self.body = RecurrentProjection(n_feats)
80 |
81 |
82 | # define tail module
83 | m_tail = [
84 | nn.Conv2d(
85 | n_feats*self.depth, args.n_colors, kernel_size,
86 | padding=(kernel_size//2)
87 | )
88 | ]
89 |
90 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
91 |
92 | self.head = nn.Sequential(*m_head)
93 | self.tail = nn.Sequential(*m_tail)
94 | def forward(self,input):
95 | x = self.sub_mean(input)
96 | x = self.head(x)
97 | bag = []
98 | for i in range(self.depth):
99 | x, h_estimate = self.body(x)
100 | bag.append(h_estimate)
101 | h_feature = torch.cat(bag,dim=1)
102 | h_final = self.tail(h_feature)
103 |
104 | return self.add_mean(h_final)
105 |
--------------------------------------------------------------------------------
/src/model/nlsn.py:
--------------------------------------------------------------------------------
1 | from model import common
2 | from model import attention
3 | import torch.nn as nn
4 |
5 | def make_model(args, parent=False):
6 | if args.dilation:
7 | from model import dilated
8 | return NLSN(args, dilated.dilated_conv)
9 | else:
10 | return NLSN(args)
11 |
12 |
13 | class NLSN(nn.Module):
14 | def __init__(self, args, conv=common.default_conv):
15 | super(NLSN, self).__init__()
16 |
17 | n_resblock = args.n_resblocks
18 | n_feats = args.n_feats
19 | kernel_size = 3
20 | scale = args.scale[0]
21 | act = nn.ReLU(True)
22 |
23 | rgb_mean = (0.4488, 0.4371, 0.4040)
24 | rgb_std = (1.0, 1.0, 1.0)
25 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
26 | m_head = [conv(args.n_colors, n_feats, kernel_size)]
27 |
28 | # define body module
29 | m_body = [attention.NonLocalSparseAttention(
30 | channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)]
31 |
32 | for i in range(n_resblock):
33 | m_body.append( common.ResBlock_AdaDM(
34 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
35 | ))
36 | if (i+1)%8==0:
37 | m_body.append(attention.NonLocalSparseAttention(
38 | channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale))
39 | m_body.append(conv(n_feats, n_feats, kernel_size))
40 |
41 | # define tail module
42 | m_tail = [
43 | common.Upsampler(conv, scale, n_feats, act=False),
44 | nn.Conv2d(
45 | n_feats, args.n_colors, kernel_size,
46 | padding=(kernel_size//2)
47 | )
48 | ]
49 |
50 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
51 |
52 | self.head = nn.Sequential(*m_head)
53 | self.body = nn.Sequential(*m_body)
54 | self.tail = nn.Sequential(*m_tail)
55 |
56 | def forward(self, x):
57 | x = self.sub_mean(x)
58 | x = self.head(x)
59 |
60 | res = self.body(x)
61 | res += x
62 |
63 | x = self.tail(res)
64 | x = self.add_mean(x)
65 |
66 | return x
67 |
68 | def load_state_dict(self, state_dict, strict=True):
69 | own_state = self.state_dict()
70 | for name, param in state_dict.items():
71 | if name in own_state:
72 | if isinstance(param, nn.Parameter):
73 | param = param.data
74 | try:
75 | own_state[name].copy_(param)
76 | except Exception:
77 | if name.find('tail') == -1:
78 | raise RuntimeError('While copying the parameter named {}, '
79 | 'whose dimensions in the model are {} and '
80 | 'whose dimensions in the checkpoint are {}.'
81 | .format(name, own_state[name].size(), param.size()))
82 | elif strict:
83 | if name.find('tail') == -1:
84 | raise KeyError('unexpected key "{}" in state_dict'
85 | .format(name))
86 |
87 |
--------------------------------------------------------------------------------
/src/model/rcan.py:
--------------------------------------------------------------------------------
1 | ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks
2 | ## https://arxiv.org/abs/1807.02758
3 | from model import common
4 |
5 | import torch.nn as nn
6 | import torch
7 | def make_model(args, parent=False):
8 | return RCAN(args)
9 |
10 | ## Channel Attention (CA) Layer
11 | class CALayer(nn.Module):
12 | def __init__(self, channel, reduction=16):
13 | super(CALayer, self).__init__()
14 | # global average pooling: feature --> point
15 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
16 | # feature channel downscale and upscale --> channel weight
17 | #self.a = torch.nn.Parameter(torch.Tensor([0]))
18 | #self.a.requires_grad=True
19 |
20 | self.conv_du = nn.Sequential(
21 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
22 | nn.ReLU(inplace=True),
23 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
24 | nn.Sigmoid()
25 | )
26 |
27 | def forward(self, x):
28 | y = self.avg_pool(x)
29 | y = self.conv_du(y)
30 | return x * y
31 |
32 | ## Residual Channel Attention Block (RCAB)
33 | class RCAB(nn.Module):
34 | def __init__(
35 | self, conv, n_feat, kernel_size, reduction,
36 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
37 |
38 | super(RCAB, self).__init__()
39 | modules_body = []
40 | for i in range(2):
41 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
42 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
43 | if i == 0: modules_body.append(act)
44 | modules_body.append(CALayer(n_feat, reduction))
45 | self.body = nn.Sequential(*modules_body)
46 | self.res_scale = res_scale
47 |
48 | def forward(self, x):
49 | res = self.body(x)
50 | #res = self.body(x).mul(self.res_scale)
51 | res += x
52 | return res
53 |
54 | ## Residual Group (RG)
55 | class ResidualGroup(nn.Module):
56 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
57 | super(ResidualGroup, self).__init__()
58 | modules_body = []
59 | modules_body = [
60 | RCAB(
61 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
62 | for _ in range(n_resblocks)]
63 | modules_body.append(conv(n_feat, n_feat, kernel_size))
64 | self.body = nn.Sequential(*modules_body)
65 |
66 | def forward(self, x):
67 | res = self.body(x)
68 | res += x
69 | return res
70 |
71 | ## Residual Channel Attention Network (RCAN)
72 | class RCAN(nn.Module):
73 | def __init__(self, args, conv=common.default_conv):
74 | super(RCAN, self).__init__()
75 | self.a = nn.Parameter(torch.Tensor([0]))
76 | self.a.requires_grad=True
77 | n_resgroups = args.n_resgroups
78 | n_resblocks = args.n_resblocks
79 | n_feats = args.n_feats
80 | kernel_size = 3
81 | reduction = args.reduction
82 | scale = args.scale[0]
83 | act = nn.ReLU(True)
84 |
85 | # RGB mean for DIV2K
86 | rgb_mean = (0.4488, 0.4371, 0.4040)
87 | rgb_std = (1.0, 1.0, 1.0)
88 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
89 |
90 | # define head module
91 | modules_head = [conv(args.n_colors, n_feats, kernel_size)]
92 |
93 | # define body module
94 | modules_body = [
95 | ResidualGroup(
96 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
97 | for _ in range(n_resgroups)]
98 | modules_body.append(conv(n_feats, n_feats, kernel_size))
99 |
100 | # define tail module
101 | modules_tail = [
102 | common.Upsampler(conv, scale, n_feats, act=False),
103 | conv(n_feats, args.n_colors, kernel_size)]
104 |
105 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
106 |
107 | self.head = nn.Sequential(*modules_head)
108 | self.body = nn.Sequential(*modules_body)
109 | self.tail = nn.Sequential(*modules_tail)
110 |
111 | def forward(self, x):
112 | x = self.sub_mean(x)
113 | x = self.head(x)
114 | res = self.body(x)
115 | res += x
116 |
117 | x = self.tail(res)
118 | x = self.add_mean(x)
119 |
120 | return x
121 |
122 | def load_state_dict(self, state_dict, strict=False):
123 | own_state = self.state_dict()
124 | for name, param in state_dict.items():
125 | if name in own_state:
126 | if isinstance(param, nn.Parameter):
127 | param = param.data
128 | try:
129 | own_state[name].copy_(param)
130 | except Exception:
131 | if name.find('msa') or name.find('a') >= 0:
132 | print('Replace pre-trained upsampler to new one...')
133 | else:
134 | raise RuntimeError('While copying the parameter named {}, '
135 | 'whose dimensions in the model are {} and '
136 | 'whose dimensions in the checkpoint are {}.'
137 | .format(name, own_state[name].size(), param.size()))
138 | elif strict:
139 | if name.find('msa') == -1:
140 | raise KeyError('unexpected key "{}" in state_dict'
141 | .format(name))
142 |
143 | if strict:
144 | missing = set(own_state.keys()) - set(state_dict.keys())
145 | if len(missing) > 0:
146 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
147 |
--------------------------------------------------------------------------------
/src/model/rdn.py:
--------------------------------------------------------------------------------
1 | # Residual Dense Network for Image Super-Resolution
2 | # https://arxiv.org/abs/1802.08797
3 |
4 | from model import common
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | def make_model(args, parent=False):
11 | return RDN(args)
12 |
13 | class RDB_Conv(nn.Module):
14 | def __init__(self, inChannels, growRate, kSize=3):
15 | super(RDB_Conv, self).__init__()
16 | Cin = inChannels
17 | G = growRate
18 | self.conv = nn.Sequential(*[
19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
20 | nn.ReLU()
21 | ])
22 |
23 | def forward(self, x):
24 | out = self.conv(x)
25 | return torch.cat((x, out), 1)
26 |
27 | class RDB(nn.Module):
28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
29 | super(RDB, self).__init__()
30 | G0 = growRate0
31 | G = growRate
32 | C = nConvLayers
33 |
34 | convs = []
35 | for c in range(C):
36 | convs.append(RDB_Conv(G0 + c*G, G))
37 | self.convs = nn.Sequential(*convs)
38 |
39 | # Local Feature Fusion
40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
41 |
42 | def forward(self, x):
43 | return self.LFF(self.convs(x)) + x
44 |
45 | class RDB_AdaDM(nn.Module):
46 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
47 | super(RDB_AdaDM, self).__init__()
48 | G0 = growRate0
49 | G = growRate
50 | C = nConvLayers
51 |
52 | convs = []
53 | for c in range(C):
54 | convs.append(RDB_Conv(G0 + c*G, G))
55 | self.convs = nn.Sequential(*convs)
56 |
57 | # Local Feature Fusion
58 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
59 | self.phi = nn.Conv2d(1, 1, 1, 1, 0, bias=True)
60 | self.phi.weight.data.fill_(1)
61 | self.phi.bias.data.fill_(0)
62 | self.norm = nn.BatchNorm2d(G0)
63 |
64 | def forward(self, x):
65 | s = torch.std(x, dim=[1,2,3], keepdim=True)
66 | x_n = self.norm(x)
67 | F = self.LFF(self.convs(x_n))
68 | F = F * (torch.exp(self.phi(torch.log(s))))
69 |
70 | return F + x
71 |
72 | class RDN(nn.Module):
73 | def __init__(self, args):
74 | super(RDN, self).__init__()
75 | r = args.scale[0]
76 | G0 = args.G0
77 | kSize = args.RDNkSize
78 |
79 | # number of RDB blocks, conv layers, out channels
80 | self.D, C, G = {
81 | 'A': (20, 6, 32),
82 | 'B': (16, 8, 64),
83 | }[args.RDNconfig]
84 |
85 | # Shallow feature extraction net
86 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
87 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
88 |
89 | # Redidual dense blocks and dense feature fusion
90 | self.RDBs = nn.ModuleList()
91 | for i in range(self.D):
92 | self.RDBs.append(
93 | RDB_AdaDM(growRate0 = G0, growRate = G, nConvLayers = C)
94 | )
95 |
96 | # Global Feature Fusion
97 | self.GFF = nn.Sequential(*[
98 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
99 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
100 | ])
101 |
102 | # Up-sampling net
103 | if r == 2 or r == 3:
104 | self.UPNet = nn.Sequential(*[
105 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
106 | nn.PixelShuffle(r),
107 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
108 | ])
109 | elif r == 4:
110 | self.UPNet = nn.Sequential(*[
111 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
112 | nn.PixelShuffle(2),
113 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
114 | nn.PixelShuffle(2),
115 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
116 | ])
117 | else:
118 | raise ValueError("scale must be 2 or 3 or 4.")
119 |
120 | def forward(self, x):
121 | f__1 = self.SFENet1(x)
122 | x = self.SFENet2(f__1)
123 |
124 | RDBs_out = []
125 | for i in range(self.D):
126 | x = self.RDBs[i](x)
127 | RDBs_out.append(x)
128 |
129 | x = self.GFF(torch.cat(RDBs_out,1))
130 | x += f__1
131 |
132 | return self.UPNet(x)
133 |
--------------------------------------------------------------------------------
/src/model/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/model/utils/__init__.py
--------------------------------------------------------------------------------
/src/model/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch.nn.functional as F
7 |
8 | def normalize(x):
9 | return x.mul_(2).add_(-1)
10 |
11 | def same_padding(images, ksizes, strides, rates):
12 | assert len(images.size()) == 4
13 | batch_size, channel, rows, cols = images.size()
14 | out_rows = (rows + strides[0] - 1) // strides[0]
15 | out_cols = (cols + strides[1] - 1) // strides[1]
16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1
17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1
18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
20 | # Pad the input
21 | padding_top = int(padding_rows / 2.)
22 | padding_left = int(padding_cols / 2.)
23 | padding_bottom = padding_rows - padding_top
24 | padding_right = padding_cols - padding_left
25 | paddings = (padding_left, padding_right, padding_top, padding_bottom)
26 | images = torch.nn.ZeroPad2d(paddings)(images)
27 | return images
28 |
29 |
30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'):
31 | """
32 | Extract patches from images and put them in the C output dimension.
33 | :param padding:
34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
36 | each dimension of images
37 | :param strides: [stride_rows, stride_cols]
38 | :param rates: [dilation_rows, dilation_cols]
39 | :return: A Tensor
40 | """
41 | assert len(images.size()) == 4
42 | assert padding in ['same', 'valid']
43 | batch_size, channel, height, width = images.size()
44 |
45 | if padding == 'same':
46 | images = same_padding(images, ksizes, strides, rates)
47 | elif padding == 'valid':
48 | pass
49 | else:
50 | raise NotImplementedError('Unsupported padding type: {}.\
51 | Only "same" or "valid" are supported.'.format(padding))
52 |
53 | unfold = torch.nn.Unfold(kernel_size=ksizes,
54 | dilation=rates,
55 | padding=0,
56 | stride=strides)
57 | patches = unfold(images)
58 | return patches # [N, C*k*k, L], L is the total number of such blocks
59 | def reduce_mean(x, axis=None, keepdim=False):
60 | if not axis:
61 | axis = range(len(x.shape))
62 | for i in sorted(axis, reverse=True):
63 | x = torch.mean(x, dim=i, keepdim=keepdim)
64 | return x
65 |
66 |
67 | def reduce_std(x, axis=None, keepdim=False):
68 | if not axis:
69 | axis = range(len(x.shape))
70 | for i in sorted(axis, reverse=True):
71 | x = torch.std(x, dim=i, keepdim=keepdim)
72 | return x
73 |
74 |
75 | def reduce_sum(x, axis=None, keepdim=False):
76 | if not axis:
77 | axis = range(len(x.shape))
78 | for i in sorted(axis, reverse=True):
79 | x = torch.sum(x, dim=i, keepdim=keepdim)
80 | return x
81 |
82 |
--------------------------------------------------------------------------------
/src/model/vdsr.py:
--------------------------------------------------------------------------------
1 | from model import common
2 |
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 |
6 | url = {
7 | 'r20f64': ''
8 | }
9 |
10 | def make_model(args, parent=False):
11 | return VDSR(args)
12 |
13 | class VDSR(nn.Module):
14 | def __init__(self, args, conv=common.default_conv):
15 | super(VDSR, self).__init__()
16 |
17 | n_resblocks = args.n_resblocks
18 | n_feats = args.n_feats
19 | kernel_size = 3
20 | self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
21 | self.sub_mean = common.MeanShift(args.rgb_range)
22 | self.add_mean = common.MeanShift(args.rgb_range, sign=1)
23 |
24 | def basic_block(in_channels, out_channels, act):
25 | return common.BasicBlock(
26 | conv, in_channels, out_channels, kernel_size,
27 | bias=True, bn=False, act=act
28 | )
29 |
30 | # define body module
31 | m_body = []
32 | m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))
33 | for _ in range(n_resblocks - 2):
34 | m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))
35 | m_body.append(basic_block(n_feats, args.n_colors, None))
36 |
37 | self.body = nn.Sequential(*m_body)
38 |
39 | def forward(self, x):
40 | x = self.sub_mean(x)
41 | res = self.body(x)
42 | res += x
43 | x = self.add_mean(res)
44 |
45 | return x
46 |
47 |
--------------------------------------------------------------------------------
/src/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import template
3 |
4 | parser = argparse.ArgumentParser(description='EDSR and MDSR')
5 |
6 | parser.add_argument('--debug', action='store_true',
7 | help='Enables debug mode')
8 | parser.add_argument('--template', default='.',
9 | help='You can set various templates in option.py')
10 |
11 | # Hardware specifications
12 | parser.add_argument('--n_threads', type=int, default=18,
13 | help='number of threads for data loading')
14 | parser.add_argument('--cpu', action='store_true',
15 | help='use cpu only')
16 | parser.add_argument('--n_GPUs', type=int, default=1,
17 | help='number of GPUs')
18 | parser.add_argument('--seed', type=int, default=1,
19 | help='random seed')
20 | parser.add_argument('--local_rank',type=int, default=0)
21 | # Data specifications
22 | parser.add_argument('--dir_data', type=str, default='../../../',
23 | help='dataset directory')
24 | parser.add_argument('--dir_demo', type=str, default='../Demo',
25 | help='demo image directory')
26 | parser.add_argument('--data_train', type=str, default='DIV2K',
27 | help='train dataset name')
28 | parser.add_argument('--data_test', type=str, default='DIV2K',
29 | help='test dataset name')
30 | parser.add_argument('--data_range', type=str, default='1-800/801-810',
31 | help='train/test data range')
32 | parser.add_argument('--ext', type=str, default='sep',
33 | help='dataset file extension')
34 | parser.add_argument('--scale', type=str, default='4',
35 | help='super resolution scale')
36 | parser.add_argument('--patch_size', type=int, default=192,
37 | help='output patch size')
38 | parser.add_argument('--rgb_range', type=int, default=255,
39 | help='maximum value of RGB')
40 | parser.add_argument('--n_colors', type=int, default=3,
41 | help='number of color channels to use')
42 | parser.add_argument('--chunk_size',type=int,default=144,
43 | help='attention bucket size')
44 | parser.add_argument('--n_hashes',type=int,default=4,
45 | help='number of hash rounds')
46 | parser.add_argument('--chop', action='store_true',
47 | help='enable memory-efficient forward')
48 | parser.add_argument('--no_augment', action='store_true',
49 | help='do not use data augmentation')
50 |
51 | # Model specifications
52 | parser.add_argument('--model', default='EDSR',
53 | help='model name')
54 |
55 | parser.add_argument('--act', type=str, default='relu',
56 | help='activation function')
57 | parser.add_argument('--pre_train', type=str, default='.',
58 | help='pre-trained model directory')
59 | parser.add_argument('--extend', type=str, default='.',
60 | help='pre-trained model directory')
61 | parser.add_argument('--n_resblocks', type=int, default=20,
62 | help='number of residual blocks')
63 | parser.add_argument('--n_feats', type=int, default=64,
64 | help='number of feature maps')
65 | parser.add_argument('--res_scale', type=float, default=1,
66 | help='residual scaling')
67 | parser.add_argument('--shift_mean', default=True,
68 | help='subtract pixel mean from the input')
69 | parser.add_argument('--dilation', action='store_true',
70 | help='use dilated convolution')
71 | parser.add_argument('--precision', type=str, default='single',
72 | choices=('single', 'half'),
73 | help='FP precision for test (single | half)')
74 |
75 | # Option for Residual dense network (RDN)
76 | parser.add_argument('--G0', type=int, default=64,
77 | help='default number of filters. (Use in RDN)')
78 | parser.add_argument('--RDNkSize', type=int, default=3,
79 | help='default kernel size. (Use in RDN)')
80 | parser.add_argument('--RDNconfig', type=str, default='B',
81 | help='parameters config of RDN. (Use in RDN)')
82 |
83 | parser.add_argument('--depth', type=int, default=12,
84 | help='number of residual groups')
85 | # Option for Residual channel attention network (RCAN)
86 | parser.add_argument('--n_resgroups', type=int, default=10,
87 | help='number of residual groups')
88 | parser.add_argument('--reduction', type=int, default=16,
89 | help='number of feature maps reduction')
90 |
91 | # Training specifications
92 | parser.add_argument('--reset', action='store_true',
93 | help='reset the training')
94 | parser.add_argument('--test_every', type=int, default=1000,
95 | help='do test per every N batches')
96 | parser.add_argument('--epochs', type=int, default=1000,
97 | help='number of epochs to train')
98 | parser.add_argument('--batch_size', type=int, default=16,
99 | help='input batch size for training')
100 | parser.add_argument('--split_batch', type=int, default=1,
101 | help='split the batch into smaller chunks')
102 | parser.add_argument('--self_ensemble', action='store_true',
103 | help='use self-ensemble method for test')
104 | parser.add_argument('--test_only', action='store_true',
105 | help='set this option to test the model')
106 | parser.add_argument('--gan_k', type=int, default=1,
107 | help='k value for adversarial loss')
108 |
109 | # Optimization specifications
110 | parser.add_argument('--lr', type=float, default=1e-4,
111 | help='learning rate')
112 | parser.add_argument('--decay', type=str, default='200',
113 | help='learning rate decay type')
114 | parser.add_argument('--gamma', type=float, default=0.5,
115 | help='learning rate decay factor for step decay')
116 | parser.add_argument('--optimizer', default='ADAM',
117 | choices=('SGD', 'ADAM', 'RMSprop'),
118 | help='optimizer to use (SGD | ADAM | RMSprop)')
119 | parser.add_argument('--momentum', type=float, default=0.9,
120 | help='SGD momentum')
121 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
122 | help='ADAM beta')
123 | parser.add_argument('--epsilon', type=float, default=1e-8,
124 | help='ADAM epsilon for numerical stability')
125 | parser.add_argument('--weight_decay', type=float, default=0,
126 | help='weight decay')
127 | parser.add_argument('--gclip', type=float, default=0,
128 | help='gradient clipping threshold (0 = no clipping)')
129 |
130 | # Loss specifications
131 | parser.add_argument('--loss', type=str, default='1*L1',
132 | help='loss function configuration')
133 | parser.add_argument('--skip_threshold', type=float, default='1e8',
134 | help='skipping batch that has large error')
135 |
136 | # Log specifications
137 | parser.add_argument('--save', type=str, default='test',
138 | help='file name to save')
139 | parser.add_argument('--load', type=str, default='',
140 | help='file name to load')
141 | parser.add_argument('--resume', type=int, default=0,
142 | help='resume from specific checkpoint')
143 | parser.add_argument('--save_models', action='store_true',
144 | help='save all intermediate models')
145 | parser.add_argument('--print_every', type=int, default=100,
146 | help='how many batches to wait before logging training status')
147 | parser.add_argument('--save_results', action='store_true',
148 | help='save output results')
149 | parser.add_argument('--save_gt', action='store_true',
150 | help='save low-resolution and high-resolution images together')
151 |
152 | args = parser.parse_args()
153 | template.set_template(args)
154 |
155 | args.scale = list(map(lambda x: int(x), args.scale.split('+')))
156 | args.data_train = args.data_train.split('+')
157 | args.data_test = args.data_test.split('+')
158 |
159 | if args.epochs == 0:
160 | args.epochs = 1e8
161 |
162 | for arg in vars(args):
163 | if vars(args)[arg] == 'True':
164 | vars(args)[arg] = True
165 | elif vars(args)[arg] == 'False':
166 | vars(args)[arg] = False
167 |
168 |
--------------------------------------------------------------------------------
/src/template.py:
--------------------------------------------------------------------------------
1 | def set_template(args):
2 | # Set the templates here
3 | if args.template.find('jpeg') >= 0:
4 | args.data_train = 'DIV2K_jpeg'
5 | args.data_test = 'DIV2K_jpeg'
6 | args.epochs = 200
7 | args.decay = '100'
8 |
9 | if args.template.find('EDSR_paper') >= 0:
10 | args.model = 'EDSR'
11 | args.n_resblocks = 32
12 | args.n_feats = 256
13 | args.res_scale = 0.1
14 |
15 | if args.template.find('MDSR') >= 0:
16 | args.model = 'MDSR'
17 | args.patch_size = 48
18 | args.epochs = 650
19 |
20 | if args.template.find('DDBPN') >= 0:
21 | args.model = 'DDBPN'
22 | args.patch_size = 128
23 | args.scale = '4'
24 |
25 | args.data_test = 'Set5'
26 |
27 | args.batch_size = 20
28 | args.epochs = 1000
29 | args.decay = '500'
30 | args.gamma = 0.1
31 | args.weight_decay = 1e-4
32 |
33 | args.loss = '1*MSE'
34 |
35 | if args.template.find('GAN') >= 0:
36 | args.epochs = 200
37 | args.lr = 5e-5
38 | args.decay = '150'
39 |
40 | if args.template.find('RCAN') >= 0:
41 | args.model = 'RCAN'
42 | args.n_resgroups = 10
43 | args.n_resblocks = 20
44 | args.n_feats = 64
45 | args.chop = True
46 |
47 | if args.template.find('VDSR') >= 0:
48 | args.model = 'VDSR'
49 | args.n_resblocks = 20
50 | args.n_feats = 64
51 | args.patch_size = 41
52 | args.lr = 1e-1
53 |
54 |
--------------------------------------------------------------------------------
/src/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | GPU_ID=3
4 | SCALE=2
5 | TEST_DATASET="Urban100"
6 | TEST_MODEL="EDSR"
7 |
8 | ######################################################################################################
9 | # EDSR Test
10 | if [[ $TEST_MODEL == "EDSR" ]]; then
11 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale $SCALE\
12 | --res_scale 0.1 --pre_train ../experiment/test/model/EDSR_AdaDM_DIV2K_X$SCALE.pt\
13 | --dir_data ../dataset --n_GPUs 1 --test_only --data_test $TEST_DATASET
14 | fi
15 |
16 | ######################################################################################################
17 | # RDN Test
18 | if [[ $TEST_MODEL == "RDN" ]]; then
19 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale $SCALE\
20 | --pre_train ../experiment/test/model/RDN_AdaDM_DIV2K_X$SCALE.pt\
21 | --dir_data ../dataset --n_GPUs 1 --test_only --data_test $TEST_DATASET
22 | fi
23 |
24 | ######################################################################################################
25 | # NLSN Test
26 | if [[ $TEST_MODEL == "NLSN" ]]; then
27 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --dir_data ../dataset --model NLSN --chunk_size 144\
28 | --n_hashes 4 --chop --rgb_range 1 --scale $SCALE --n_feats 256 --n_resblocks 32 --res_scale 0.1\
29 | --pre_train ../experiment/test/model/NLSN_AdaDM_DIV2K_X$SCALE.pt --test_only --data_test $TEST_DATASET
30 | fi
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/src/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | GPU_ID=0
4 |
5 | ######################################################################################################
6 | # EDSR Train X2
7 | CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 2\
8 | --n_GPUs 1 --batch_size 16 --patch_size 96 --rgb_range 255 --res_scale 0.1\
9 | --save EDSR_AdaDM_Test_DIV2K_X2 --dir_data ../dataset --data_test Urban100\
10 | --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results
11 |
12 | # EDSR Train X3
13 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 3\
14 | # ---n_GPUs 1 -batch_size 16 --patch_size 144 --rgb_range 255 --res_scale 0.1\
15 | # --save EDSR_AdaDM_Test_DIV2K_X3 --dir_data ../dataset --data_test Urban100\
16 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\
17 | # --pre_train ../experiment/EDSR_AdaDM_Test_DIV2K_X2/model/model_best.pt
18 |
19 | # EDSR Train X4
20 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --template EDSR_paper --scale 4\
21 | # ---n_GPUs 1 -batch_size 16 --patch_size 192 --rgb_range 255 --res_scale 0.1\
22 | # --save EDSR_AdaDM_Test_DIV2K_X4 --dir_data ../dataset --data_test Urban100\
23 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\
24 | # --pre_train ../experiment/EDSR_AdaDM_Test_DIV2K_X2/model/model_best.pt
25 |
26 | ######################################################################################################
27 | # RDN Train X2
28 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale 2\
29 | # --batch_size 16 --patch_size 96 --rgb_range 255 --n_GPUs 1\
30 | # --save RDN_AdaDM_Test_DIV2K_X2 --dir_data ../dataset --data_test Urban100\
31 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results
32 |
33 | # RDN Train X3
34 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale 3\
35 | # --batch_size 16 --patch_size 144 --rgb_range 255 --n_GPUs 1\
36 | # --save RDN_AdaDM_Test_DIV2K_X3 --dir_data ../dataset --data_test Urban100\
37 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\
38 | # --pre_train ../experiment/RDN_AdaDM_Test_DIV2K_X2/model/model_best.pt
39 |
40 | # RDN Train X4
41 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model RDN --scale 4\
42 | # --batch_size 16 --patch_size 192 --rgb_range 255 --n_GPUs 1\
43 | # --save RDN_AdaDM_Test_DIV2K_X4 --dir_data ../dataset --data_test Urban100\
44 | # --epochs 1000 --decay 200-400-600-800 --lr 1e-4 --save_models --save_results\
45 | # --pre_train ../experiment/RDN_AdaDM_Test_DIV2K_X2/model/model_best.pt
46 |
47 |
48 |
49 | ######################################################################################################
50 | # NLSN Train X2
51 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model NLSN --dir_data ../dataset --n_GPUs 1\
52 | # --chunk_size 144 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop\
53 | # --n_resblocks 32 --n_feats 256 --rgb_range 1 --res_scale 0.1 --batch_size 16 --scale 2\
54 | # --patch_size 96 --save NLSN_AdaDM_Test_DIV2K_X2 --data_test Urban100 --save_models --save_results
55 |
56 | # NLSN Train X3
57 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model NLSN --dir_data ../dataset --n_GPUs 1\
58 | # --chunk_size 144 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop\
59 | # --n_resblocks 32 --n_feats 256 --rgb_range 1 --res_scale 0.1 --batch_size 16 --scale 3\
60 | # --patch_size 144 --save NLSN_AdaDM_Test_DIV2K_X3 --data_test Urban100 --save_models --save_results\
61 | # --pre_train ../experiment/NLSN_AdaDM_Test_DIV2K_X2/model/model_best.pt
62 |
63 | # NLSN Train X4
64 | # CUDA_VISIBLE_DEVICES=$GPU_ID python3 main.py --model NLSN --dir_data ../dataset --n_GPUs 1\
65 | # --chunk_size 144 --n_hashes 4 --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop\
66 | # --n_resblocks 32 --n_feats 256 --rgb_range 1 --res_scale 0.1 --batch_size 16 --scale 4\
67 | # --patch_size 192 --save NLSN_AdaDM_Test_DIV2K_X4 --data_test Urban100 --save_models --save_results\
68 | # --pre_train ../experiment/NLSN_AdaDM_Test_DIV2K_X2/model/model_best.pt
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | from decimal import Decimal
4 |
5 | import utility
6 |
7 | import torch
8 | import torch.nn.utils as utils
9 | from tqdm import tqdm
10 |
11 | class Trainer():
12 | def __init__(self, args, loader, my_model, my_loss, ckp):
13 | self.args = args
14 | self.scale = args.scale
15 |
16 | self.ckp = ckp
17 | self.loader_train = loader.loader_train
18 | self.loader_test = loader.loader_test
19 | self.model = my_model
20 | self.loss = my_loss
21 | self.optimizer = utility.make_optimizer(args, self.model)
22 |
23 | if self.args.load != '':
24 | self.optimizer.load(ckp.dir, epoch=len(ckp.log))
25 |
26 | self.error_last = 1e8
27 |
28 | def train(self):
29 | self.loss.step()
30 | epoch = self.optimizer.get_last_epoch() + 1
31 | lr = self.optimizer.get_lr()
32 |
33 | self.ckp.write_log(
34 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
35 | )
36 | self.loss.start_log()
37 | self.model.train()
38 |
39 | timer_data, timer_model = utility.timer(), utility.timer()
40 | # TEMP
41 | self.loader_train.dataset.set_scale(0)
42 | for batch, (lr, hr, _,) in enumerate(self.loader_train):
43 | lr, hr = self.prepare(lr, hr)
44 | timer_data.hold()
45 | timer_model.tic()
46 |
47 | self.optimizer.zero_grad()
48 | sr = self.model(lr, 0)
49 | loss = self.loss(sr, hr)
50 | loss.backward()
51 | if self.args.gclip > 0:
52 | utils.clip_grad_value_(
53 | self.model.parameters(),
54 | self.args.gclip
55 | )
56 | self.optimizer.step()
57 |
58 | timer_model.hold()
59 |
60 | if (batch + 1) % self.args.print_every == 0:
61 | self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
62 | (batch + 1) * self.args.batch_size,
63 | len(self.loader_train.dataset),
64 | self.loss.display_loss(batch),
65 | timer_model.release(),
66 | timer_data.release()))
67 |
68 | timer_data.tic()
69 |
70 | self.loss.end_log(len(self.loader_train))
71 | self.error_last = self.loss.log[-1, -1]
72 | self.optimizer.schedule()
73 |
74 | def test(self):
75 | torch.set_grad_enabled(False)
76 |
77 | epoch = self.optimizer.get_last_epoch()
78 | self.ckp.write_log('\nEvaluation:')
79 | self.ckp.add_log(
80 | torch.zeros(1, len(self.loader_test), len(self.scale))
81 | )
82 | self.model.eval()
83 |
84 | timer_test = utility.timer()
85 | if self.args.save_results: self.ckp.begin_background()
86 | for idx_data, d in enumerate(self.loader_test):
87 | for idx_scale, scale in enumerate(self.scale):
88 | d.dataset.set_scale(idx_scale)
89 | for lr, hr, filename in tqdm(d, ncols=80):
90 | lr, hr = self.prepare(lr, hr)
91 | sr = self.model(lr, idx_scale)
92 | sr = utility.quantize(sr, self.args.rgb_range)
93 |
94 | save_list = [sr]
95 | self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
96 | sr, hr, scale, self.args.rgb_range, dataset=d
97 | )
98 | if self.args.save_gt:
99 | save_list.extend([lr, hr])
100 |
101 | if self.args.save_results:
102 | self.ckp.save_results(d, filename[0], save_list, scale)
103 |
104 | self.ckp.log[-1, idx_data, idx_scale] /= len(d)
105 | best = self.ckp.log.max(0)
106 | self.ckp.write_log(
107 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
108 | d.dataset.name,
109 | scale,
110 | self.ckp.log[-1, idx_data, idx_scale],
111 | best[0][idx_data, idx_scale],
112 | best[1][idx_data, idx_scale] + 1
113 | )
114 | )
115 |
116 | self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
117 | self.ckp.write_log('Saving...')
118 |
119 | if self.args.save_results:
120 | self.ckp.end_background()
121 |
122 | if not self.args.test_only:
123 | self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
124 |
125 | self.ckp.write_log(
126 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
127 | )
128 |
129 | torch.set_grad_enabled(True)
130 |
131 | def prepare(self, *args):
132 | device = torch.device('cpu' if self.args.cpu else 'cuda')
133 | def _prepare(tensor):
134 | if self.args.precision == 'half': tensor = tensor.half()
135 | return tensor.to(device)
136 |
137 | return [_prepare(a) for a in args]
138 |
139 | def terminate(self):
140 | if self.args.test_only:
141 | self.test()
142 | return True
143 | else:
144 | epoch = self.optimizer.get_last_epoch() + 1
145 | return epoch >= self.args.epochs
146 |
147 |
--------------------------------------------------------------------------------
/src/utility.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import datetime
5 | from multiprocessing import Process
6 | from multiprocessing import Queue
7 |
8 | import matplotlib
9 | matplotlib.use('Agg')
10 | import matplotlib.pyplot as plt
11 |
12 | import numpy as np
13 | import imageio
14 |
15 | import torch
16 | import torch.optim as optim
17 | import torch.optim.lr_scheduler as lrs
18 |
19 | class timer():
20 | def __init__(self):
21 | self.acc = 0
22 | self.tic()
23 |
24 | def tic(self):
25 | self.t0 = time.time()
26 |
27 | def toc(self, restart=False):
28 | diff = time.time() - self.t0
29 | if restart: self.t0 = time.time()
30 | return diff
31 |
32 | def hold(self):
33 | self.acc += self.toc()
34 |
35 | def release(self):
36 | ret = self.acc
37 | self.acc = 0
38 |
39 | return ret
40 |
41 | def reset(self):
42 | self.acc = 0
43 |
44 | class checkpoint():
45 | def __init__(self, args):
46 | self.args = args
47 | self.ok = True
48 | self.log = torch.Tensor()
49 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
50 |
51 | if not args.load:
52 | if not args.save:
53 | args.save = now
54 | self.dir = os.path.join('..', 'experiment', args.save)
55 | else:
56 | self.dir = os.path.join('..', 'experiment', args.load)
57 | if os.path.exists(self.dir):
58 | self.log = torch.load(self.get_path('psnr_log.pt'))
59 | print('Continue from epoch {}...'.format(len(self.log)))
60 | else:
61 | args.load = ''
62 |
63 | if args.reset:
64 | os.system('rm -rf ' + self.dir)
65 | args.load = ''
66 |
67 | os.makedirs(self.dir, exist_ok=True)
68 | os.makedirs(self.get_path('model'), exist_ok=True)
69 | for d in args.data_test:
70 | os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
71 |
72 | open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
73 | self.log_file = open(self.get_path('log.txt'), open_type)
74 | with open(self.get_path('config.txt'), open_type) as f:
75 | f.write(now + '\n\n')
76 | for arg in vars(args):
77 | f.write('{}: {}\n'.format(arg, getattr(args, arg)))
78 | f.write('\n')
79 |
80 | self.n_processes = 8
81 |
82 | def get_path(self, *subdir):
83 | return os.path.join(self.dir, *subdir)
84 |
85 | def save(self, trainer, epoch, is_best=False):
86 | trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
87 | trainer.loss.save(self.dir)
88 | trainer.loss.plot_loss(self.dir, epoch)
89 |
90 | self.plot_psnr(epoch)
91 | trainer.optimizer.save(self.dir)
92 | torch.save(self.log, self.get_path('psnr_log.pt'))
93 |
94 | def add_log(self, log):
95 | self.log = torch.cat([self.log, log])
96 |
97 | def write_log(self, log, refresh=False):
98 | print(log)
99 | self.log_file.write(log + '\n')
100 | if refresh:
101 | self.log_file.close()
102 | self.log_file = open(self.get_path('log.txt'), 'a')
103 |
104 | def done(self):
105 | self.log_file.close()
106 |
107 | def plot_psnr(self, epoch):
108 | axis = np.linspace(1, epoch, epoch)
109 | for idx_data, d in enumerate(self.args.data_test):
110 | label = 'SR on {}'.format(d)
111 | fig = plt.figure()
112 | plt.title(label)
113 | for idx_scale, scale in enumerate(self.args.scale):
114 | plt.plot(
115 | axis,
116 | self.log[:, idx_data, idx_scale].numpy(),
117 | label='Scale {}'.format(scale)
118 | )
119 | plt.legend()
120 | plt.xlabel('Epochs')
121 | plt.ylabel('PSNR')
122 | plt.grid(True)
123 | plt.savefig(self.get_path('test_{}.pdf'.format(d)))
124 | plt.close(fig)
125 |
126 | def begin_background(self):
127 | self.queue = Queue()
128 |
129 | def bg_target(queue):
130 | while True:
131 | if not queue.empty():
132 | filename, tensor = queue.get()
133 | if filename is None: break
134 | imageio.imwrite(filename, tensor.numpy())
135 |
136 | self.process = [
137 | Process(target=bg_target, args=(self.queue,)) \
138 | for _ in range(self.n_processes)
139 | ]
140 |
141 | for p in self.process: p.start()
142 |
143 | def end_background(self):
144 | for _ in range(self.n_processes): self.queue.put((None, None))
145 | while not self.queue.empty(): time.sleep(1)
146 | for p in self.process: p.join()
147 |
148 | def save_results(self, dataset, filename, save_list, scale):
149 | if self.args.save_results:
150 | filename = self.get_path(
151 | 'results-{}'.format(dataset.dataset.name),
152 | '{}_x{}_'.format(filename, scale)
153 | )
154 |
155 | postfix = ('SR', 'LR', 'HR')
156 | for v, p in zip(save_list, postfix):
157 | normalized = v[0].mul(255 / self.args.rgb_range)
158 | tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
159 | self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
160 |
161 | def quantize(img, rgb_range):
162 | pixel_range = 255 / rgb_range
163 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
164 |
165 | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
166 | if hr.nelement() == 1: return 0
167 |
168 | diff = (sr - hr) / rgb_range
169 | if dataset and dataset.dataset.benchmark:
170 | shave = scale
171 | if diff.size(1) > 1:
172 | gray_coeffs = [65.738, 129.057, 25.064]
173 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
174 | diff = diff.mul(convert).sum(dim=1)
175 | else:
176 | shave = scale + 6
177 |
178 | valid = diff[..., shave:-shave, shave:-shave]
179 | mse = valid.pow(2).mean()
180 |
181 | return -10 * math.log10(mse)
182 |
183 | def make_optimizer(args, target):
184 | '''
185 | make optimizer and scheduler together
186 | '''
187 | # optimizer
188 | trainable = filter(lambda x: x.requires_grad, target.parameters())
189 | kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
190 |
191 | if args.optimizer == 'SGD':
192 | optimizer_class = optim.SGD
193 | kwargs_optimizer['momentum'] = args.momentum
194 | elif args.optimizer == 'ADAM':
195 | optimizer_class = optim.Adam
196 | kwargs_optimizer['betas'] = args.betas
197 | kwargs_optimizer['eps'] = args.epsilon
198 | elif args.optimizer == 'RMSprop':
199 | optimizer_class = optim.RMSprop
200 | kwargs_optimizer['eps'] = args.epsilon
201 |
202 | # scheduler
203 | milestones = list(map(lambda x: int(x), args.decay.split('-')))
204 | kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
205 | scheduler_class = lrs.MultiStepLR
206 |
207 | class CustomOptimizer(optimizer_class):
208 | def __init__(self, *args, **kwargs):
209 | super(CustomOptimizer, self).__init__(*args, **kwargs)
210 |
211 | def _register_scheduler(self, scheduler_class, **kwargs):
212 | self.scheduler = scheduler_class(self, **kwargs)
213 |
214 | def save(self, save_dir):
215 | torch.save(self.state_dict(), self.get_dir(save_dir))
216 |
217 | def load(self, load_dir, epoch=1):
218 | self.load_state_dict(torch.load(self.get_dir(load_dir)))
219 | if epoch > 1:
220 | for _ in range(epoch): self.scheduler.step()
221 |
222 | def get_dir(self, dir_path):
223 | return os.path.join(dir_path, 'optimizer.pt')
224 |
225 | def schedule(self):
226 | self.scheduler.step()
227 |
228 | def get_lr(self):
229 | return self.scheduler.get_lr()[0]
230 |
231 | def get_last_epoch(self):
232 | return self.scheduler.last_epoch
233 |
234 | optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
235 | optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
236 | return optimizer
237 |
238 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/njulj/AdaDM/7777bf000fb341720c8896acf087e5837858edc6/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch.nn.functional as F
7 |
8 | def normalize(x):
9 | return x.mul_(2).add_(-1)
10 |
11 | def same_padding(images, ksizes, strides, rates):
12 | assert len(images.size()) == 4
13 | batch_size, channel, rows, cols = images.size()
14 | out_rows = (rows + strides[0] - 1) // strides[0]
15 | out_cols = (cols + strides[1] - 1) // strides[1]
16 | effective_k_row = (ksizes[0] - 1) * rates[0] + 1
17 | effective_k_col = (ksizes[1] - 1) * rates[1] + 1
18 | padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
19 | padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
20 | # Pad the input
21 | padding_top = int(padding_rows / 2.)
22 | padding_left = int(padding_cols / 2.)
23 | padding_bottom = padding_rows - padding_top
24 | padding_right = padding_cols - padding_left
25 | paddings = (padding_left, padding_right, padding_top, padding_bottom)
26 | images = torch.nn.ZeroPad2d(paddings)(images)
27 | return images
28 |
29 |
30 | def extract_image_patches(images, ksizes, strides, rates, padding='same'):
31 | """
32 | Extract patches from images and put them in the C output dimension.
33 | :param padding:
34 | :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
35 | :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
36 | each dimension of images
37 | :param strides: [stride_rows, stride_cols]
38 | :param rates: [dilation_rows, dilation_cols]
39 | :return: A Tensor
40 | """
41 | assert len(images.size()) == 4
42 | assert padding in ['same', 'valid']
43 | batch_size, channel, height, width = images.size()
44 |
45 | if padding == 'same':
46 | images = same_padding(images, ksizes, strides, rates)
47 | elif padding == 'valid':
48 | pass
49 | else:
50 | raise NotImplementedError('Unsupported padding type: {}.\
51 | Only "same" or "valid" are supported.'.format(padding))
52 |
53 | unfold = torch.nn.Unfold(kernel_size=ksizes,
54 | dilation=rates,
55 | padding=0,
56 | stride=strides)
57 | patches = unfold(images)
58 | return patches # [N, C*k*k, L], L is the total number of such blocks
59 | def reduce_mean(x, axis=None, keepdim=False):
60 | if not axis:
61 | axis = range(len(x.shape))
62 | for i in sorted(axis, reverse=True):
63 | x = torch.mean(x, dim=i, keepdim=keepdim)
64 | return x
65 |
66 |
67 | def reduce_std(x, axis=None, keepdim=False):
68 | if not axis:
69 | axis = range(len(x.shape))
70 | for i in sorted(axis, reverse=True):
71 | x = torch.std(x, dim=i, keepdim=keepdim)
72 | return x
73 |
74 |
75 | def reduce_sum(x, axis=None, keepdim=False):
76 | if not axis:
77 | axis = range(len(x.shape))
78 | for i in sorted(axis, reverse=True):
79 | x = torch.sum(x, dim=i, keepdim=keepdim)
80 | return x
81 |
82 |
--------------------------------------------------------------------------------
/src/videotester.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 |
4 | import utility
5 | from data import common
6 |
7 | import torch
8 | import cv2
9 |
10 | from tqdm import tqdm
11 |
12 | class VideoTester():
13 | def __init__(self, args, my_model, ckp):
14 | self.args = args
15 | self.scale = args.scale
16 |
17 | self.ckp = ckp
18 | self.model = my_model
19 |
20 | self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
21 |
22 | def test(self):
23 | torch.set_grad_enabled(False)
24 |
25 | self.ckp.write_log('\nEvaluation on video:')
26 | self.model.eval()
27 |
28 | timer_test = utility.timer()
29 | for idx_scale, scale in enumerate(self.scale):
30 | vidcap = cv2.VideoCapture(self.args.dir_demo)
31 | total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
32 | vidwri = cv2.VideoWriter(
33 | self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
34 | cv2.VideoWriter_fourcc(*'XVID'),
35 | vidcap.get(cv2.CAP_PROP_FPS),
36 | (
37 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
38 | int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39 | )
40 | )
41 |
42 | tqdm_test = tqdm(range(total_frames), ncols=80)
43 | for _ in tqdm_test:
44 | success, lr = vidcap.read()
45 | if not success: break
46 |
47 | lr, = common.set_channel(lr, n_channels=self.args.n_colors)
48 | lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
49 | lr, = self.prepare(lr.unsqueeze(0))
50 | sr = self.model(lr, idx_scale)
51 | sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
52 |
53 | normalized = sr * 255 / self.args.rgb_range
54 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
55 | vidwri.write(ndarr)
56 |
57 | vidcap.release()
58 | vidwri.release()
59 |
60 | self.ckp.write_log(
61 | 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
62 | )
63 | torch.set_grad_enabled(True)
64 |
65 | def prepare(self, *args):
66 | device = torch.device('cpu' if self.args.cpu else 'cuda')
67 | def _prepare(tensor):
68 | if self.args.precision == 'half': tensor = tensor.half()
69 | return tensor.to(device)
70 |
71 | return [_prepare(a) for a in args]
72 |
73 |
--------------------------------------------------------------------------------