├── .gitignore
├── README.md
├── codes
├── .ipynb_checkpoints
│ └── Test_demo-checkpoint.ipynb
├── Test_demo.ipynb
├── data
│ ├── LQGT_dataset.py
│ ├── LQ_dataset.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── Color_dataset.cpython-37.pyc
│ │ ├── ContinueLQGT_dataset.cpython-37.pyc
│ │ ├── LQGT_dataset.cpython-36.pyc
│ │ ├── LQGT_dataset.cpython-37.pyc
│ │ ├── LQ_dataset.cpython-37.pyc
│ │ ├── Vimeo90K_dataset.cpython-37.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── data_sampler.cpython-36.pyc
│ │ ├── data_sampler.cpython-37.pyc
│ │ ├── util.cpython-36.pyc
│ │ ├── util.cpython-37.pyc
│ │ └── video_test_dataset.cpython-37.pyc
│ ├── data_sampler.py
│ └── util.py
├── data_scripts
│ ├── __pycache__
│ │ └── generate_mod_LR_bic.cpython-36.pyc
│ ├── create_lmdb.py
│ ├── extract_subimages.py
│ ├── generate_mod_LR_bic.m
│ ├── generate_mod_LR_bic.py
│ ├── prepare_DIV2K_x4_dataset.sh
│ ├── rename.py
│ └── test_dataloader.py
├── metrics
│ ├── calculate_PSNR_SSIM.m
│ └── calculate_PSNR_SSIM.py
├── models
│ ├── SR_model.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── SRGAN_model.cpython-37.pyc
│ │ ├── SR_model.cpython-36.pyc
│ │ ├── SR_model.cpython-37.pyc
│ │ ├── Video_base_model.cpython-37.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── base_model.cpython-36.pyc
│ │ ├── base_model.cpython-37.pyc
│ │ ├── loss.cpython-36.pyc
│ │ ├── loss.cpython-37.pyc
│ │ ├── lr_scheduler.cpython-36.pyc
│ │ ├── lr_scheduler.cpython-37.pyc
│ │ ├── networks.cpython-36.pyc
│ │ └── networks.cpython-37.pyc
│ ├── archs
│ │ ├── PAN_arch.py
│ │ ├── RCAN_arch.py
│ │ ├── SRResNet_arch.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── AWSRN_arch.cpython-37.pyc
│ │ │ ├── DUF_arch.cpython-37.pyc
│ │ │ ├── Dual_arch.cpython-37.pyc
│ │ │ ├── EDVR_arch.cpython-36.pyc
│ │ │ ├── EfficientSR_arch.cpython-37.pyc
│ │ │ ├── EfficientSR_clean.cpython-37.pyc
│ │ │ ├── FSRCNN_arch.cpython-37.pyc
│ │ │ ├── KernelMD_arch.cpython-37.pyc
│ │ │ ├── MSSResNet_deblur_arch.cpython-37.pyc
│ │ │ ├── Octave_arch.cpython-37.pyc
│ │ │ ├── PAN_arch.cpython-36.pyc
│ │ │ ├── PAN_arch.cpython-37.pyc
│ │ │ ├── PAN_arch_update.cpython-37.pyc
│ │ │ ├── PANet_arch.cpython-37.pyc
│ │ │ ├── PANv2_arch.cpython-37.pyc
│ │ │ ├── RCAN_arch.cpython-36.pyc
│ │ │ ├── RCAN_arch.cpython-37.pyc
│ │ │ ├── RRDBNet_arch.cpython-36.pyc
│ │ │ ├── RRDBNet_arch.cpython-37.pyc
│ │ │ ├── SRResNet_arch.cpython-36.pyc
│ │ │ ├── SRResNet_arch.cpython-37.pyc
│ │ │ ├── UNet_arch.cpython-37.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── __init__.cpython-37.pyc
│ │ │ ├── arch_util.cpython-36.pyc
│ │ │ ├── arch_util.cpython-37.pyc
│ │ │ ├── discriminator_vgg_arch.cpython-36.pyc
│ │ │ ├── discriminator_vgg_arch.cpython-37.pyc
│ │ │ ├── unet_arch.cpython-37.pyc
│ │ │ └── unet_parts.cpython-37.pyc
│ │ ├── arch_util.py
│ │ └── dcn
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── deform_conv.cpython-36.pyc
│ │ │ ├── deform_conv.egg-info
│ │ │ ├── PKG-INFO
│ │ │ ├── SOURCES.txt
│ │ │ ├── dependency_links.txt
│ │ │ ├── not-zip-safe
│ │ │ └── top_level.txt
│ │ │ ├── deform_conv.py
│ │ │ ├── setup.py
│ │ │ └── src
│ │ │ ├── deform_conv_cuda.cpp
│ │ │ └── deform_conv_cuda_kernel.cu
│ ├── base_model.py
│ ├── loss.py
│ ├── lr_scheduler.py
│ └── networks.py
├── options
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── options.cpython-36.pyc
│ │ └── options.cpython-37.pyc
│ ├── options.py
│ ├── test
│ │ ├── test_PANx2.yml
│ │ ├── test_PANx3.yml
│ │ ├── test_PANx4.yml
│ │ ├── test_RCAN.yml
│ │ └── test_SRResNet.yml
│ └── train
│ │ ├── train_PANx2.yml
│ │ ├── train_PANx3.yml
│ │ ├── train_PANx4.yml
│ │ ├── train_RCAN.yml
│ │ └── train_SRResNet.yml
├── requirements.txt
├── run_scripts.sh
├── scripts
│ ├── back_projection
│ │ ├── backprojection.m
│ │ ├── main_bp.m
│ │ └── main_reverse_filter.m
│ └── transfer_params_MSRResNet.py
├── test.py
├── test_running_time.py
├── test_summary.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── util.cpython-36.pyc
│ └── util.cpython-37.pyc
│ └── util.py
├── datasets
├── Set14
│ ├── .DS_Store
│ ├── HR
│ │ ├── baboon.png
│ │ ├── barbara.png
│ │ ├── bridge.png
│ │ ├── coastguard.png
│ │ ├── comic.png
│ │ ├── face.png
│ │ ├── flowers.png
│ │ ├── foreman.png
│ │ ├── lenna.png
│ │ ├── man.png
│ │ ├── monarch.png
│ │ ├── pepper.png
│ │ ├── ppt3.png
│ │ └── zebra.png
│ └── LR_bicubic
│ │ ├── X2
│ │ ├── baboonx2.png
│ │ ├── barbarax2.png
│ │ ├── bridgex2.png
│ │ ├── coastguardx2.png
│ │ ├── comicx2.png
│ │ ├── facex2.png
│ │ ├── flowersx2.png
│ │ ├── foremanx2.png
│ │ ├── lennax2.png
│ │ ├── manx2.png
│ │ ├── monarchx2.png
│ │ ├── pepperx2.png
│ │ ├── ppt3x2.png
│ │ └── zebrax2.png
│ │ ├── X3
│ │ ├── baboonx3.png
│ │ ├── barbarax3.png
│ │ ├── bridgex3.png
│ │ ├── coastguardx3.png
│ │ ├── comicx3.png
│ │ ├── facex3.png
│ │ ├── flowersx3.png
│ │ ├── foremanx3.png
│ │ ├── lennax3.png
│ │ ├── manx3.png
│ │ ├── monarchx3.png
│ │ ├── pepperx3.png
│ │ ├── ppt3x3.png
│ │ └── zebrax3.png
│ │ ├── X4
│ │ ├── baboonx4.png
│ │ ├── barbarax4.png
│ │ ├── bridgex4.png
│ │ ├── coastguardx4.png
│ │ ├── comicx4.png
│ │ ├── facex4.png
│ │ ├── flowersx4.png
│ │ ├── foremanx4.png
│ │ ├── lennax4.png
│ │ ├── manx4.png
│ │ ├── monarchx4.png
│ │ ├── pepperx4.png
│ │ ├── ppt3x4.png
│ │ └── zebrax4.png
│ │ └── X8
│ │ ├── baboonx8.png
│ │ ├── barbarax8.png
│ │ ├── bridgex8.png
│ │ ├── coastguardx8.png
│ │ ├── comicx8.png
│ │ ├── facex8.png
│ │ ├── flowersx8.png
│ │ ├── foremanx8.png
│ │ ├── lennax8.png
│ │ ├── manx8.png
│ │ ├── monarchx8.png
│ │ ├── pepperx8.png
│ │ ├── ppt3x8.png
│ │ └── zebrax8.png
└── Set5
│ ├── HR
│ ├── baby.png
│ ├── bird.png
│ ├── butterfly.png
│ ├── head.png
│ └── woman.png
│ └── LR_bicubic
│ ├── X2
│ ├── babyx2.png
│ ├── birdx2.png
│ ├── butterflyx2.png
│ ├── headx2.png
│ └── womanx2.png
│ ├── X3
│ ├── babyx3.png
│ ├── birdx3.png
│ ├── butterflyx3.png
│ ├── headx3.png
│ └── womanx3.png
│ ├── X4
│ ├── babyx4.png
│ ├── birdx4.png
│ ├── butterflyx4.png
│ ├── headx4.png
│ └── womanx4.png
│ └── X8
│ ├── babyx8.png
│ ├── birdx8.png
│ ├── butterflyx8.png
│ ├── headx8.png
│ └── womanx8.png
├── experiments
└── pretrained_models
│ ├── PANx2_DF2K.pth
│ ├── PANx3_DF2K.pth
│ └── PANx4_DF2K.pth
├── results
├── PANx2_DF2K
│ └── test_PANx2_DF2Ktrain_200912-232704.log
├── PANx3_DF2K
│ └── test_PANx3_DF2Ktrain_200912-232635.log
└── PANx4_DF2K
│ └── test_PANx4_DF2Ktrain_200912-232209.log
└── show_figs
├── main.jpg
└── main.pdf
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PAN [🔥 272K parameters]
2 | ### Lowest parameters in AIM2020 Efficient Super Resolution.
3 |
4 | ## [Paper](https://arxiv.org/abs/2010.01073) | [Video](https://www.bilibili.com/video/BV1Qh411R7vZ/)
5 | ## Efficient Image Super-Resolution Using Pixel Attention
6 | Authors: Hengyuan Zhao, [Xiangtao Kong](https://github.com/Xiangtaokong), [Jingwen He](https://github.com/hejingwenhejingwen), [Yu Qiao](https://scholar.google.com/citations?user=gFtI-8QAAAAJ&hl=zh-CN), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=zh-CN)
7 |
8 |
9 |
10 |
11 |
12 |
13 | ## Dependencies
14 |
15 | - Python >= 3.6 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
16 | - [PyTorch >= 1.5.0](https://pytorch.org/)
17 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
18 | - Python packages: `pip install numpy opencv-python lmdb`
19 | - [option] Python packages: [`pip install tensorboardX`](https://github.com/lanpa/tensorboardX), for visualizing curves.
20 |
21 | # Codes
22 | - Our codes version based on [mmsr](https://github.com/open-mmlab/mmsr).
23 | - This codes provide the testing and training code.
24 |
25 |
26 |
27 | ## How to Test
28 | 1. Clone this github repo.
29 | ```
30 | git clone https://github.com/zhaohengyuan1/PAN.git
31 | cd PAN
32 | ```
33 | 2. Download the five test datasets (Set5, Set14, B100, Urban100, Manga109) from [Google Drive](https://drive.google.com/drive/folders/1lsoyAjsUEyp7gm1t6vZI9j7jr9YzKzcF?usp=sharing)
34 |
35 | 3. Pretrained models have be placed in `./experiments/pretrained_models/` folder. More models can be download from [Google Drive](https://drive.google.com/drive/folders/1_zZqTvvAb_ad4T4-uiIGF9CkNiPrBXGr?usp=sharing).
36 |
37 | 4. Run test. We provide `x2,x3,x4` pretrained models.
38 | ```
39 | cd codes
40 | python test.py -opt option/test/test_PANx4.yml
41 | ```
42 | More testing commonds can be found in `./codes/run_scripts.sh` file.
43 | 5. The output results will be sorted in `./results`. (We have been put our testing log file in `./results`) We also provide our testing results on five benchmark datasets on [Google Drive](https://drive.google.com/drive/folders/1F6unBkp6L1oJb_gOgSHYM5ZZbyLImDPH?usp=sharing).
44 |
45 | ## How to Train
46 |
47 | 1. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](https://github.com/LimBee/NTIRE2017) from [Google Drive](https://drive.google.com/drive/folders/1B-uaxvV9qeuQ-t7MFiN1oEdA6dKnj2vW?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1CFIML6KfQVYGZSNFrhMXmA)
48 |
49 | 2. Generate Training patches. Modified the path of your training datasets in `./codes/data_scripts/extract_subimages.py` file.
50 |
51 | 3. Run Training.
52 |
53 | ```
54 | python train.py -opt options/train/train_PANx4.yml
55 | ```
56 | 4. More training commond can be found in `./codes/run_scripts.sh` file.
57 |
58 | ## Testing the Parameters, Mult-Adds and Running Time
59 |
60 | 1. Testing the parameters and Mult-Adds.
61 | ```
62 | python test_summary.py
63 | ```
64 |
65 | 2. Testing the Running Time.
66 |
67 | ```
68 | python test_running_time.py
69 | ```
70 |
71 | ## Related Work on AIM2020
72 | Enhanced Quadratic Video Interpolation (winning solution of AIM2020 VTSR Challenge)
73 | [paper](https://arxiv.org/pdf/2009.04642.pdf) | [code](https://github.com/lyh-18/EQVI)
74 |
75 | ## Contact
76 | Email: hubylidayuan@gmail.com
77 |
78 | If you find our work is useful, please kindly cite it.
79 | ```
80 | @inproceedings{zhao2020efficient,
81 | title={Efficient image super-resolution using pixel attention},
82 | author={Zhao, Hengyuan and Kong, Xiangtao and He, Jingwen and Qiao, Yu and Dong, Chao},
83 | booktitle={European Conference on Computer Vision},
84 | pages={56--72},
85 | year={2020},
86 | organization={Springer}
87 | }
88 | ```
89 |
90 |
--------------------------------------------------------------------------------
/codes/data/LQGT_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 | import lmdb
5 | import torch
6 | import torch.utils.data as data
7 | import data.util as util
8 |
9 |
10 | class LQGTDataset(data.Dataset):
11 | """
12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
13 | If only GT images are provided, generate LQ images on-the-fly.
14 | """
15 |
16 | def __init__(self, opt):
17 | super(LQGTDataset, self).__init__()
18 | self.opt = opt
19 | self.data_type = self.opt['data_type']
20 | self.paths_LQ, self.paths_GT = None, None
21 | self.sizes_LQ, self.sizes_GT = None, None
22 | self.LQ_env, self.GT_env = None, None # environments for lmdb
23 |
24 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
25 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
26 |
27 | assert self.paths_GT, 'Error: GT path is empty.'
28 | if self.paths_LQ and self.paths_GT:
29 | assert len(self.paths_LQ) == len(
30 | self.paths_GT
31 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
32 | len(self.paths_LQ), len(self.paths_GT))
33 | self.random_scale_list = [1]
34 |
35 | def _init_lmdb(self):
36 | # https://github.com/chainer/chainermn/issues/129
37 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
38 | meminit=False)
39 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
40 | meminit=False)
41 |
42 | def __getitem__(self, index):
43 | if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
44 | self._init_lmdb()
45 | GT_path, LQ_path = None, None
46 | scale = self.opt['scale']
47 | GT_size = self.opt['GT_size']
48 |
49 | # get GT image
50 | GT_path = self.paths_GT[index]
51 | resolution = [int(s) for s in self.sizes_GT[index].split('_')
52 | ] if self.data_type == 'lmdb' else None
53 | img_GT = util.read_img(self.GT_env, GT_path, resolution)
54 | if img_GT.shape[2] == 1:
55 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
56 | if self.opt['phase'] != 'train': # modcrop in the validation / test phase
57 | img_GT = util.modcrop(img_GT, scale)
58 | if self.opt['color']: # change color space if necessary
59 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
60 |
61 | # get LQ image
62 | if self.paths_LQ:
63 | LQ_path = self.paths_LQ[index]
64 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')
65 | ] if self.data_type == 'lmdb' else None
66 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
67 | if img_LQ.shape[2] == 1:
68 | # for test gray to color
69 | img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_GRAY2BGR)
70 | else: # down-sampling on-the-fly
71 | # randomly scale during training
72 | if self.opt['phase'] == 'train':
73 | random_scale = random.choice(self.random_scale_list)
74 | H_s, W_s, _ = img_GT.shape
75 |
76 | def _mod(n, random_scale, scale, thres):
77 | rlt = int(n * random_scale)
78 | rlt = (rlt // scale) * scale
79 | return thres if rlt < thres else rlt
80 |
81 | H_s = _mod(H_s, random_scale, scale, GT_size)
82 | W_s = _mod(W_s, random_scale, scale, GT_size)
83 | img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
84 | if img_GT.ndim == 2:
85 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
86 |
87 | H, W, _ = img_GT.shape
88 | # using matlab imresize
89 | img_LQ = util.imresize_np(img_GT, 1 / scale, True)
90 | if img_LQ.ndim == 2:
91 | img_LQ = np.expand_dims(img_LQ, axis=2)
92 |
93 | if self.opt['phase'] == 'train':
94 | # if the image size is too small
95 | H, W, _ = img_GT.shape
96 | if H < GT_size or W < GT_size:
97 | img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
98 | # using matlab imresize
99 | img_LQ = util.imresize_np(img_GT, 1 / scale, True)
100 | if img_LQ.ndim == 2:
101 | img_LQ = np.expand_dims(img_LQ, axis=2)
102 |
103 | H, W, C = img_LQ.shape
104 | LQ_size = GT_size // scale
105 | # randomly crop
106 | rnd_h = random.randint(0, max(0, H - LQ_size))
107 | rnd_w = random.randint(0, max(0, W - LQ_size))
108 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
109 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
110 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
111 |
112 | # print(img_GT.shape, img_LQ.shape)
113 | # augmentation - flip, rotate
114 | img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
115 | self.opt['use_rot'])
116 |
117 | if self.opt['color']: # change color space if necessary
118 | img_LQ = util.channel_convert(C, self.opt['color'],
119 | [img_LQ])[0] # TODO during val no definition
120 |
121 | # BGR to RGB, HWC to CHW, numpy to tensor
122 | if img_GT.shape[2] == 3:
123 | img_GT = img_GT[:, :, [2, 1, 0]]
124 | img_LQ = img_LQ[:, :, [2, 1, 0]]
125 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
126 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
127 |
128 | if LQ_path is None:
129 | LQ_path = GT_path
130 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
131 |
132 | def __len__(self):
133 | return len(self.paths_GT)
134 |
--------------------------------------------------------------------------------
/codes/data/LQ_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import lmdb
3 | import torch
4 | import torch.utils.data as data
5 | import data.util as util
6 |
7 |
8 | class LQDataset(data.Dataset):
9 | '''Read LQ images only in the test phase.'''
10 |
11 | def __init__(self, opt):
12 | super(LQDataset, self).__init__()
13 | self.opt = opt
14 | self.data_type = self.opt['data_type']
15 | self.paths_LQ, self.paths_GT = None, None
16 | self.LQ_env = None # environment for lmdb
17 |
18 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
19 | assert self.paths_LQ, 'Error: LQ paths are empty.'
20 |
21 | def _init_lmdb(self):
22 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
23 | meminit=False)
24 |
25 | def __getitem__(self, index):
26 | if self.data_type == 'lmdb' and self.LQ_env is None:
27 | self._init_lmdb()
28 | LQ_path = None
29 |
30 | # get LQ image
31 | LQ_path = self.paths_LQ[index]
32 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')
33 | ] if self.data_type == 'lmdb' else None
34 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
35 | H, W, C = img_LQ.shape
36 |
37 | if self.opt['color']: # change color space if necessary
38 | img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0]
39 |
40 | # BGR to RGB, HWC to CHW, numpy to tensor
41 | if img_LQ.shape[2] == 3:
42 | img_LQ = img_LQ[:, :, [2, 1, 0]]
43 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
44 |
45 | return {'LQ': img_LQ, 'LQ_path': LQ_path}
46 |
47 | def __len__(self):
48 | return len(self.paths_LQ)
49 |
--------------------------------------------------------------------------------
/codes/data/__init__.py:
--------------------------------------------------------------------------------
1 | """create dataset and dataloader"""
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 |
7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
8 | phase = dataset_opt['phase']
9 | if phase == 'train':
10 | if opt['dist']:
11 | world_size = torch.distributed.get_world_size()
12 | num_workers = dataset_opt['n_workers']
13 | assert dataset_opt['batch_size'] % world_size == 0
14 | batch_size = dataset_opt['batch_size'] // world_size
15 | shuffle = False
16 | else:
17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
18 | batch_size = dataset_opt['batch_size']
19 | shuffle = True
20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
21 | num_workers=num_workers, sampler=sampler, drop_last=True,
22 | pin_memory=False)
23 | else:
24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
25 | pin_memory=False)
26 |
27 |
28 | def create_dataset(dataset_opt):
29 | mode = dataset_opt['mode']
30 | # datasets for image restoration
31 | if mode == 'LQ':
32 | from data.LQ_dataset import LQDataset as D
33 | elif mode == 'LQGT':
34 | from data.LQGT_dataset import LQGTDataset as D
35 | elif mode == 'Color':
36 | from data.Color_dataset import ColorDataset as D
37 | elif mode == 'ContinueLQGT':
38 | from data.ContinueLQGT_dataset import ContinueLQGTDataset as D
39 | # datasets for video restoration
40 | elif mode == 'REDS':
41 | from data.REDS_dataset import REDSDataset as D
42 | elif mode == 'Vimeo90K':
43 | from data.Vimeo90K_dataset import Vimeo90KDataset as D
44 | elif mode == 'video_test':
45 | from data.video_test_dataset import VideoTestDataset as D
46 | else:
47 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
48 | dataset = D(dataset_opt)
49 |
50 | logger = logging.getLogger('base')
51 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
52 | dataset_opt['name']))
53 | return dataset
54 |
--------------------------------------------------------------------------------
/codes/data/__pycache__/Color_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/Color_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/ContinueLQGT_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/ContinueLQGT_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/LQGT_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/LQGT_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/LQGT_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/LQGT_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/LQ_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/LQ_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/Vimeo90K_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/Vimeo90K_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/data_sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/data_sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/data_sampler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/data_sampler.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/video_test_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/video_test_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from torch.utils.data.distributed.DistributedSampler
3 | Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the
4 | dataloader after each epoch
5 | """
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | import torch.distributed as dist
10 |
11 |
12 | class DistIterSampler(Sampler):
13 | """Sampler that restricts data loading to a subset of the dataset.
14 |
15 | It is especially useful in conjunction with
16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17 | process can pass a DistributedSampler instance as a DataLoader sampler,
18 | and load a subset of the original dataset that is exclusive to it.
19 |
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 |
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (optional): Number of processes participating in
26 | distributed training.
27 | rank (optional): Rank of the current process within num_replicas.
28 | """
29 |
30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31 | if num_replicas is None:
32 | if not dist.is_available():
33 | raise RuntimeError("Requires distributed package to be available")
34 | num_replicas = dist.get_world_size()
35 | if rank is None:
36 | if not dist.is_available():
37 | raise RuntimeError("Requires distributed package to be available")
38 | rank = dist.get_rank()
39 | self.dataset = dataset
40 | self.num_replicas = num_replicas
41 | self.rank = rank
42 | self.epoch = 0
43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
44 | self.total_size = self.num_samples * self.num_replicas
45 |
46 | def __iter__(self):
47 | # deterministically shuffle based on epoch
48 | g = torch.Generator()
49 | g.manual_seed(self.epoch)
50 | indices = torch.randperm(self.total_size, generator=g).tolist()
51 |
52 | dsize = len(self.dataset)
53 | indices = [v % dsize for v in indices]
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | assert len(indices) == self.num_samples
58 |
59 | return iter(indices)
60 |
61 | def __len__(self):
62 | return self.num_samples
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
--------------------------------------------------------------------------------
/codes/data_scripts/__pycache__/generate_mod_LR_bic.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data_scripts/__pycache__/generate_mod_LR_bic.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data_scripts/extract_subimages.py:
--------------------------------------------------------------------------------
1 | """A multi-thread tool to crop large images to sub-images for faster IO."""
2 | import os
3 | import os.path as osp
4 | import sys
5 | from multiprocessing import Pool
6 | import numpy as np
7 | import cv2
8 | from PIL import Image
9 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
10 | from utils.util import ProgressBar # noqa: E402
11 | import data.util as data_util # noqa: E402
12 |
13 |
14 | def main():
15 | mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
16 | opt = {}
17 | opt['n_thread'] = 20
18 | opt['compression_level'] = 3 # 3 is the default value in cv2
19 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
20 | # compression time. If read raw images during training, use 0 for faster IO speed.
21 | if mode == 'single':
22 | # opt['input_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_LR_bicubic/X4_blur'
23 | # opt['save_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_LR_bicubic/X4_blur_sub'
24 | opt['input_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_Bic'
25 | opt['save_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_Bic/Bic_sub480'
26 | opt['crop_sz'] = 480 # the size of each sub-image
27 | opt['step'] = 120 # step of the sliding crop window
28 | opt['thres_sz'] = 48 # size threshold
29 | extract_signle(opt)
30 | elif mode == 'pair':
31 | # GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
32 | # LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
33 | # save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
34 | # save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
35 |
36 | GT_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/HR'
37 | LR_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/LR/X3'
38 | save_GT_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/HRx3_sub360'
39 | save_LR_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/LRx3_sub120'
40 |
41 | scale_ratio = 3
42 | crop_sz = 360 # the size of each sub-image (GT)
43 | step = 180 # step of the sliding crop window (GT)
44 |
45 | thres_sz = 48 # size threshold
46 | ########################################################################
47 | # check that all the GT and LR images have correct scale ratio
48 | img_GT_list = data_util._get_paths_from_images(GT_folder)
49 | img_LR_list = data_util._get_paths_from_images(LR_folder)
50 | assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
51 | for path_GT, path_LR in zip(img_GT_list, img_LR_list):
52 | img_GT = Image.open(path_GT)
53 | img_LR = Image.open(path_LR)
54 | w_GT, h_GT = img_GT.size
55 | w_LR, h_LR = img_LR.size
56 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
57 | w_GT, scale_ratio, w_LR, path_GT)
58 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
59 | w_GT, scale_ratio, w_LR, path_GT)
60 | # check crop size, step and threshold size
61 | assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
62 | scale_ratio)
63 | assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
64 | assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
65 | scale_ratio)
66 | print('process GT...')
67 | opt['input_folder'] = GT_folder
68 | opt['save_folder'] = save_GT_folder
69 | opt['crop_sz'] = crop_sz
70 | opt['step'] = step
71 | opt['thres_sz'] = thres_sz
72 | extract_signle(opt)
73 | print('process LR...')
74 | opt['input_folder'] = LR_folder
75 | opt['save_folder'] = save_LR_folder
76 | opt['crop_sz'] = crop_sz // scale_ratio
77 | opt['step'] = step // scale_ratio
78 | opt['thres_sz'] = thres_sz // scale_ratio
79 | extract_signle(opt)
80 | assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
81 | data_util._get_paths_from_images(
82 | save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
83 | else:
84 | raise ValueError('Wrong mode.')
85 |
86 |
87 | def extract_signle(opt):
88 | input_folder = opt['input_folder']
89 | save_folder = opt['save_folder']
90 | if not osp.exists(save_folder):
91 | os.makedirs(save_folder)
92 | print('mkdir [{:s}] ...'.format(save_folder))
93 | else:
94 | print('Folder [{:s}] already exists. Exit...'.format(save_folder))
95 | sys.exit(1)
96 | img_list = data_util._get_paths_from_images(input_folder)
97 |
98 | def update(arg):
99 | pbar.update(arg)
100 |
101 | pbar = ProgressBar(len(img_list))
102 |
103 | pool = Pool(opt['n_thread'])
104 | for path in img_list:
105 | pool.apply_async(worker, args=(path, opt), callback=update)
106 | pool.close()
107 | pool.join()
108 | print('All subprocesses done.')
109 |
110 |
111 | def worker(path, opt):
112 | crop_sz = opt['crop_sz']
113 | step = opt['step']
114 | thres_sz = opt['thres_sz']
115 | img_name = osp.basename(path)
116 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
117 |
118 | n_channels = len(img.shape)
119 | if n_channels == 2:
120 | h, w = img.shape
121 | elif n_channels == 3:
122 | h, w, c = img.shape
123 | else:
124 | raise ValueError('Wrong image shape - {}'.format(n_channels))
125 |
126 | h_space = np.arange(0, h - crop_sz + 1, step)
127 | if h - (h_space[-1] + crop_sz) > thres_sz:
128 | h_space = np.append(h_space, h - crop_sz)
129 | w_space = np.arange(0, w - crop_sz + 1, step)
130 | if w - (w_space[-1] + crop_sz) > thres_sz:
131 | w_space = np.append(w_space, w - crop_sz)
132 |
133 | index = 0
134 | for x in h_space:
135 | for y in w_space:
136 | index += 1
137 | if n_channels == 2:
138 | crop_img = img[x:x + crop_sz, y:y + crop_sz]
139 | else:
140 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
141 | crop_img = np.ascontiguousarray(crop_img)
142 | cv2.imwrite(
143 | osp.join(opt['save_folder'],
144 | img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
145 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
146 | return 'Processing {:s} ...'.format(img_name)
147 |
148 |
149 | if __name__ == '__main__':
150 | main()
151 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_mod_LR_bic.m:
--------------------------------------------------------------------------------
1 | function generate_mod_LR_bic()
2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
3 |
4 | %% set parameters
5 | % comment the unnecessary line
6 | input_folder = '../../datasets/DIV2K/DIV2K800';
7 | % save_mod_folder = '';
8 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
9 | % save_bic_folder = '';
10 |
11 | up_scale = 4;
12 | mod_scale = 4;
13 |
14 | if exist('save_mod_folder', 'var')
15 | if exist(save_mod_folder, 'dir')
16 | disp(['It will cover ', save_mod_folder]);
17 | else
18 | mkdir(save_mod_folder);
19 | end
20 | end
21 | if exist('save_LR_folder', 'var')
22 | if exist(save_LR_folder, 'dir')
23 | disp(['It will cover ', save_LR_folder]);
24 | else
25 | mkdir(save_LR_folder);
26 | end
27 | end
28 | if exist('save_bic_folder', 'var')
29 | if exist(save_bic_folder, 'dir')
30 | disp(['It will cover ', save_bic_folder]);
31 | else
32 | mkdir(save_bic_folder);
33 | end
34 | end
35 |
36 | idx = 0;
37 | filepaths = dir(fullfile(input_folder,'*.*'));
38 | for i = 1 : length(filepaths)
39 | [paths,imname,ext] = fileparts(filepaths(i).name);
40 | if isempty(imname)
41 | disp('Ignore . folder.');
42 | elseif strcmp(imname, '.')
43 | disp('Ignore .. folder.');
44 | else
45 | idx = idx + 1;
46 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
47 | fprintf(str_rlt);
48 | % read image
49 | img = imread(fullfile(input_folder, [imname, ext]));
50 | img = im2double(img);
51 | % modcrop
52 | img = modcrop(img, mod_scale);
53 | if exist('save_mod_folder', 'var')
54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
55 | end
56 | % LR
57 | im_LR = imresize(img, 1/up_scale, 'bicubic');
58 | if exist('save_LR_folder', 'var')
59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
60 | end
61 | % Bicubic
62 | if exist('save_bic_folder', 'var')
63 | im_B = imresize(im_LR, up_scale, 'bicubic');
64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
65 | end
66 | end
67 | end
68 | end
69 |
70 | %% modcrop
71 | function img = modcrop(img, modulo)
72 | if size(img,3) == 1
73 | sz = size(img);
74 | sz = sz - mod(sz, modulo);
75 | img = img(1:sz(1), 1:sz(2));
76 | else
77 | tmpsz = size(img);
78 | sz = tmpsz(1:2);
79 | sz = sz - mod(sz, modulo);
80 | img = img(1:sz(1), 1:sz(2),:);
81 | end
82 | end
83 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_mod_LR_bic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import numpy as np
5 |
6 | try:
7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8 | from data.util import imresize_np
9 | except ImportError:
10 | pass
11 |
12 |
13 | def generate_mod_LR_bic():
14 | # set parameters
15 | up_scale = 4
16 | mod_scale = 4
17 | # set data dir
18 | sourcedir = '/mnt/hyzhao/Documents/datasets/DF2K_train/HR'
19 | savedir = '/mnt/hyzhao/Documents/datasets/DF2K_train/'
20 |
21 | saveHRpath = os.path.join(savedir, 'HR', 'X' + str(mod_scale))
22 | saveLRpath = os.path.join(savedir, 'LR', 'X' + str(up_scale))
23 | saveBicpath = os.path.join(savedir, 'Bic', 'X' + str(up_scale))
24 |
25 | if not os.path.isdir(sourcedir):
26 | print('Error: No source data found')
27 | exit(0)
28 | if not os.path.isdir(savedir):
29 | os.mkdir(savedir)
30 |
31 | if not os.path.isdir(os.path.join(savedir, 'HR')):
32 | os.mkdir(os.path.join(savedir, 'HR'))
33 | if not os.path.isdir(os.path.join(savedir, 'LR')):
34 | os.mkdir(os.path.join(savedir, 'LR'))
35 | if not os.path.isdir(os.path.join(savedir, 'Bic')):
36 | os.mkdir(os.path.join(savedir, 'Bic'))
37 |
38 | if not os.path.isdir(saveHRpath):
39 | os.mkdir(saveHRpath)
40 | else:
41 | print('It will cover ' + str(saveHRpath))
42 |
43 | if not os.path.isdir(saveLRpath):
44 | os.mkdir(saveLRpath)
45 | else:
46 | print('It will cover ' + str(saveLRpath))
47 |
48 | if not os.path.isdir(saveBicpath):
49 | os.mkdir(saveBicpath)
50 | else:
51 | print('It will cover ' + str(saveBicpath))
52 |
53 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
54 | num_files = len(filepaths)
55 |
56 | # prepare data with augementation
57 | for i in range(num_files):
58 | filename = filepaths[i]
59 | print('No.{} -- Processing {}'.format(i, filename))
60 | # read image
61 | image = cv2.imread(os.path.join(sourcedir, filename))
62 |
63 | width = int(np.floor(image.shape[1] / mod_scale))
64 | height = int(np.floor(image.shape[0] / mod_scale))
65 | # modcrop
66 | if len(image.shape) == 3:
67 | image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
68 | else:
69 | image_HR = image[0:mod_scale * height, 0:mod_scale * width]
70 | # LR
71 | image_LR = imresize_np(image_HR, 1 / up_scale, True)
72 | # bic
73 | image_Bic = imresize_np(image_LR, up_scale, True)
74 |
75 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
76 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
77 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
78 |
79 |
80 | if __name__ == "__main__":
81 | generate_mod_LR_bic()
--------------------------------------------------------------------------------
/codes/data_scripts/prepare_DIV2K_x4_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | echo "Prepare DIV2K X4 datasets..."
4 | cd ../../datasets
5 | mkdir DIV2K
6 | cd DIV2K
7 |
8 | #### Step 1
9 | echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
10 | # GT
11 | FOLDER=DIV2K_train_HR
12 | FILE=DIV2K_train_HR.zip
13 | if [ ! -d "$FOLDER" ]; then
14 | if [ ! -f "$FILE" ]; then
15 | echo "Downloading $FILE..."
16 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
17 | fi
18 | unzip $FILE
19 | fi
20 | # LR
21 | FOLDER=DIV2K_train_LR_bicubic
22 | FILE=DIV2K_train_LR_bicubic_X4.zip
23 | if [ ! -d "$FOLDER" ]; then
24 | if [ ! -f "$FILE" ]; then
25 | echo "Downloading $FILE..."
26 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
27 | fi
28 | unzip $FILE
29 | fi
30 |
31 | #### Step 2
32 | echo "Step 2: Rename the LR images..."
33 | cd ../../codes/data_scripts
34 | python rename.py
35 |
36 | #### Step 4
37 | echo "Step 4: Crop to sub-images..."
38 | python extract_subimages.py
39 |
40 | #### Step 5
41 | echo "Step5: Create LMDB files..."
42 | python create_lmdb.py
43 |
--------------------------------------------------------------------------------
/codes/data_scripts/rename.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 |
5 | def main():
6 | folder = '../../results/006_RRDBNet_ILRx4_Flickr2K_100w+/DIV2K100'
7 | DIV2K(folder)
8 | print('Finished.')
9 |
10 |
11 | def DIV2K(path):
12 | img_path_l = glob.glob(os.path.join(path, '*'))
13 | for img_path in img_path_l:
14 | # new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
15 | new_path = img_path.replace('.png', 'x4.png')
16 | os.rename(img_path, new_path)
17 |
18 |
19 | if __name__ == "__main__":
20 | main()
--------------------------------------------------------------------------------
/codes/data_scripts/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os.path as osp
3 | import math
4 | import torchvision.utils
5 |
6 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
7 | from data import create_dataloader, create_dataset # noqa: E402
8 | from utils import util # noqa: E402
9 |
10 |
11 | def main():
12 | dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
13 | opt = {}
14 | opt['dist'] = False
15 | opt['gpu_ids'] = [0]
16 | if dataset == 'REDS':
17 | opt['name'] = 'test_REDS'
18 | opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
19 | opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
20 | opt['mode'] = 'REDS'
21 | opt['N_frames'] = 5
22 | opt['phase'] = 'train'
23 | opt['use_shuffle'] = True
24 | opt['n_workers'] = 8
25 | opt['batch_size'] = 16
26 | opt['GT_size'] = 256
27 | opt['LQ_size'] = 64
28 | opt['scale'] = 4
29 | opt['use_flip'] = True
30 | opt['use_rot'] = True
31 | opt['interval_list'] = [1]
32 | opt['random_reverse'] = False
33 | opt['border_mode'] = False
34 | opt['cache_keys'] = None
35 | opt['data_type'] = 'lmdb' # img | lmdb | mc
36 | elif dataset == 'Vimeo90K':
37 | opt['name'] = 'test_Vimeo90K'
38 | opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
39 | opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
40 | opt['mode'] = 'Vimeo90K'
41 | opt['N_frames'] = 7
42 | opt['phase'] = 'train'
43 | opt['use_shuffle'] = True
44 | opt['n_workers'] = 8
45 | opt['batch_size'] = 16
46 | opt['GT_size'] = 256
47 | opt['LQ_size'] = 64
48 | opt['scale'] = 4
49 | opt['use_flip'] = True
50 | opt['use_rot'] = True
51 | opt['interval_list'] = [1]
52 | opt['random_reverse'] = False
53 | opt['border_mode'] = False
54 | opt['cache_keys'] = None
55 | opt['data_type'] = 'lmdb' # img | lmdb | mc
56 | elif dataset == 'DIV2K800_sub':
57 | opt['name'] = 'DIV2K800'
58 | opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
59 | opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
60 | opt['mode'] = 'LQGT'
61 | opt['phase'] = 'train'
62 | opt['use_shuffle'] = True
63 | opt['n_workers'] = 8
64 | opt['batch_size'] = 16
65 | opt['GT_size'] = 128
66 | opt['scale'] = 4
67 | opt['use_flip'] = True
68 | opt['use_rot'] = True
69 | opt['color'] = 'RGB'
70 | opt['data_type'] = 'lmdb' # img | lmdb
71 | else:
72 | raise ValueError('Please implement by yourself.')
73 |
74 | util.mkdir('tmp')
75 | train_set = create_dataset(opt)
76 | train_loader = create_dataloader(train_set, opt, opt, None)
77 | nrow = int(math.sqrt(opt['batch_size']))
78 | padding = 2 if opt['phase'] == 'train' else 0
79 |
80 | print('start...')
81 | for i, data in enumerate(train_loader):
82 | if i > 5:
83 | break
84 | print(i)
85 | if dataset == 'REDS' or dataset == 'Vimeo90K':
86 | LQs = data['LQs']
87 | else:
88 | LQ = data['LQ']
89 | GT = data['GT']
90 |
91 | if dataset == 'REDS' or dataset == 'Vimeo90K':
92 | for j in range(LQs.size(1)):
93 | torchvision.utils.save_image(LQs[:, j, :, :, :],
94 | 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
95 | padding=padding, normalize=False)
96 | else:
97 | torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
98 | padding=padding, normalize=False)
99 | torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
100 | normalize=False)
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/codes/metrics/calculate_PSNR_SSIM.m:
--------------------------------------------------------------------------------
1 | function calculate_PSNR_SSIM()
2 |
3 | % GT and SR folder
4 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
5 | folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
6 | scale = 4;
7 | suffix = ''; % suffix for SR images
8 | test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
9 | if test_Y
10 | fprintf('Tesing Y channel.\n');
11 | else
12 | fprintf('Tesing RGB channels.\n');
13 | end
14 | filepaths = dir(fullfile(folder_GT, '*.png'));
15 | PSNR_all = zeros(1, length(filepaths));
16 | SSIM_all = zeros(1, length(filepaths));
17 |
18 | for idx_im = 1:length(filepaths)
19 | im_name = filepaths(idx_im).name;
20 | im_GT = imread(fullfile(folder_GT, im_name));
21 | im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
22 |
23 | if test_Y % evaluate on Y channel in YCbCr color space
24 | if size(im_GT, 3) == 3
25 | im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
26 | im_GT_in = im_GT_YCbCr(:,:,1);
27 | im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
28 | im_SR_in = im_SR_YCbCr(:,:,1);
29 | else
30 | im_GT_in = im2double(im_GT);
31 | im_SR_in = im2double(im_SR);
32 | end
33 | else % evaluate on RGB channels
34 | im_GT_in = im2double(im_GT);
35 | im_SR_in = im2double(im_SR);
36 | end
37 |
38 | % calculate PSNR and SSIM
39 | PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
40 | SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
41 | fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
42 | end
43 |
44 | fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
45 | end
46 |
47 | function res = calculate_PSNR(GT, SR, border)
48 | % remove border
49 | GT = GT(border+1:end-border, border+1:end-border, :);
50 | SR = SR(border+1:end-border, border+1:end-border, :);
51 | % calculate PNSR (assume in [0,255])
52 | error = GT(:) - SR(:);
53 | mse = mean(error.^2);
54 | res = 10 * log10(255^2/mse);
55 | end
56 |
57 | function res = calculate_SSIM(GT, SR, border)
58 | GT = GT(border+1:end-border, border+1:end-border, :);
59 | SR = SR(border+1:end-border, border+1:end-border, :);
60 | % calculate SSIM
61 | mssim = zeros(1, size(SR, 3));
62 | for i = 1:size(SR,3)
63 | [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
64 | end
65 | res = mean(mssim);
66 | end
67 |
68 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
69 |
70 | %========================================================================
71 | %SSIM Index, Version 1.0
72 | %Copyright(c) 2003 Zhou Wang
73 | %All Rights Reserved.
74 | %
75 | %The author is with Howard Hughes Medical Institute, and Laboratory
76 | %for Computational Vision at Center for Neural Science and Courant
77 | %Institute of Mathematical Sciences, New York University.
78 | %
79 | %----------------------------------------------------------------------
80 | %Permission to use, copy, or modify this software and its documentation
81 | %for educational and research purposes only and without fee is hereby
82 | %granted, provided that this copyright notice and the original authors'
83 | %names appear on all copies and supporting documentation. This program
84 | %shall not be used, rewritten, or adapted as the basis of a commercial
85 | %software or hardware product without first obtaining permission of the
86 | %authors. The authors make no representations about the suitability of
87 | %this software for any purpose. It is provided "as is" without express
88 | %or implied warranty.
89 | %----------------------------------------------------------------------
90 | %
91 | %This is an implementation of the algorithm for calculating the
92 | %Structural SIMilarity (SSIM) index between two images. Please refer
93 | %to the following paper:
94 | %
95 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
96 | %quality assessment: From error measurement to structural similarity"
97 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
98 | %
99 | %Kindly report any suggestions or corrections to zhouwang@ieee.org
100 | %
101 | %----------------------------------------------------------------------
102 | %
103 | %Input : (1) img1: the first image being compared
104 | % (2) img2: the second image being compared
105 | % (3) K: constants in the SSIM index formula (see the above
106 | % reference). defualt value: K = [0.01 0.03]
107 | % (4) window: local window for statistics (see the above
108 | % reference). default widnow is Gaussian given by
109 | % window = fspecial('gaussian', 11, 1.5);
110 | % (5) L: dynamic range of the images. default: L = 255
111 | %
112 | %Output: (1) mssim: the mean SSIM index value between 2 images.
113 | % If one of the images being compared is regarded as
114 | % perfect quality, then mssim can be considered as the
115 | % quality measure of the other image.
116 | % If img1 = img2, then mssim = 1.
117 | % (2) ssim_map: the SSIM index map of the test image. The map
118 | % has a smaller size than the input images. The actual size:
119 | % size(img1) - size(window) + 1.
120 | %
121 | %Default Usage:
122 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
123 | %
124 | % [mssim ssim_map] = ssim_index(img1, img2);
125 | %
126 | %Advanced Usage:
127 | % User defined parameters. For example
128 | %
129 | % K = [0.05 0.05];
130 | % window = ones(8);
131 | % L = 100;
132 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
133 | %
134 | %See the results:
135 | %
136 | % mssim %Gives the mssim value
137 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
138 | %
139 | %========================================================================
140 |
141 |
142 | if (nargin < 2 || nargin > 5)
143 | ssim_index = -Inf;
144 | ssim_map = -Inf;
145 | return;
146 | end
147 |
148 | if (size(img1) ~= size(img2))
149 | ssim_index = -Inf;
150 | ssim_map = -Inf;
151 | return;
152 | end
153 |
154 | [M, N] = size(img1);
155 |
156 | if (nargin == 2)
157 | if ((M < 11) || (N < 11))
158 | ssim_index = -Inf;
159 | ssim_map = -Inf;
160 | return
161 | end
162 | window = fspecial('gaussian', 11, 1.5); %
163 | K(1) = 0.01; % default settings
164 | K(2) = 0.03; %
165 | L = 255; %
166 | end
167 |
168 | if (nargin == 3)
169 | if ((M < 11) || (N < 11))
170 | ssim_index = -Inf;
171 | ssim_map = -Inf;
172 | return
173 | end
174 | window = fspecial('gaussian', 11, 1.5);
175 | L = 255;
176 | if (length(K) == 2)
177 | if (K(1) < 0 || K(2) < 0)
178 | ssim_index = -Inf;
179 | ssim_map = -Inf;
180 | return;
181 | end
182 | else
183 | ssim_index = -Inf;
184 | ssim_map = -Inf;
185 | return;
186 | end
187 | end
188 |
189 | if (nargin == 4)
190 | [H, W] = size(window);
191 | if ((H*W) < 4 || (H > M) || (W > N))
192 | ssim_index = -Inf;
193 | ssim_map = -Inf;
194 | return
195 | end
196 | L = 255;
197 | if (length(K) == 2)
198 | if (K(1) < 0 || K(2) < 0)
199 | ssim_index = -Inf;
200 | ssim_map = -Inf;
201 | return;
202 | end
203 | else
204 | ssim_index = -Inf;
205 | ssim_map = -Inf;
206 | return;
207 | end
208 | end
209 |
210 | if (nargin == 5)
211 | [H, W] = size(window);
212 | if ((H*W) < 4 || (H > M) || (W > N))
213 | ssim_index = -Inf;
214 | ssim_map = -Inf;
215 | return
216 | end
217 | if (length(K) == 2)
218 | if (K(1) < 0 || K(2) < 0)
219 | ssim_index = -Inf;
220 | ssim_map = -Inf;
221 | return;
222 | end
223 | else
224 | ssim_index = -Inf;
225 | ssim_map = -Inf;
226 | return;
227 | end
228 | end
229 |
230 | C1 = (K(1)*L)^2;
231 | C2 = (K(2)*L)^2;
232 | window = window/sum(sum(window));
233 | img1 = double(img1);
234 | img2 = double(img2);
235 |
236 | mu1 = filter2(window, img1, 'valid');
237 | mu2 = filter2(window, img2, 'valid');
238 | mu1_sq = mu1.*mu1;
239 | mu2_sq = mu2.*mu2;
240 | mu1_mu2 = mu1.*mu2;
241 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
242 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
243 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
244 |
245 | if (C1 > 0 && C2 > 0)
246 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
247 | else
248 | numerator1 = 2*mu1_mu2 + C1;
249 | numerator2 = 2*sigma12 + C2;
250 | denominator1 = mu1_sq + mu2_sq + C1;
251 | denominator2 = sigma1_sq + sigma2_sq + C2;
252 | ssim_map = ones(size(mu1));
253 | index = (denominator1.*denominator2 > 0);
254 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
255 | index = (denominator1 ~= 0) & (denominator2 == 0);
256 | ssim_map(index) = numerator1(index)./denominator1(index);
257 | end
258 |
259 | mssim = mean2(ssim_map);
260 |
261 | end
262 |
--------------------------------------------------------------------------------
/codes/metrics/calculate_PSNR_SSIM.py:
--------------------------------------------------------------------------------
1 | '''
2 | calculate the PSNR and SSIM.
3 | same as MATLAB's results
4 | '''
5 | import os
6 | import math
7 | import numpy as np
8 | import cv2
9 | import glob
10 |
11 |
12 | def main():
13 | # Configurations
14 |
15 | # GT - Ground-truth;
16 | # Gen: Generated / Restored / Recovered images
17 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
18 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
19 |
20 | crop_border = 4
21 | suffix = '' # suffix for Gen images
22 | test_Y = False # True: test Y channel only; False: test RGB channels
23 |
24 | PSNR_all = []
25 | SSIM_all = []
26 | img_list = sorted(glob.glob(folder_GT + '/*'))
27 |
28 | if test_Y:
29 | print('Testing Y channel.')
30 | else:
31 | print('Testing RGB channels.')
32 |
33 | for i, img_path in enumerate(img_list):
34 | base_name = os.path.splitext(os.path.basename(img_path))[0]
35 | im_GT = cv2.imread(img_path) / 255.
36 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
37 |
38 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
39 | im_GT_in = bgr2ycbcr(im_GT)
40 | im_Gen_in = bgr2ycbcr(im_Gen)
41 | else:
42 | im_GT_in = im_GT
43 | im_Gen_in = im_Gen
44 |
45 | # crop borders
46 | if im_GT_in.ndim == 3:
47 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
48 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
49 | elif im_GT_in.ndim == 2:
50 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
51 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
52 | else:
53 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
54 |
55 | # calculate PSNR and SSIM
56 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
57 |
58 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
59 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
60 | i + 1, base_name, PSNR, SSIM))
61 | PSNR_all.append(PSNR)
62 | SSIM_all.append(SSIM)
63 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
64 | sum(PSNR_all) / len(PSNR_all),
65 | sum(SSIM_all) / len(SSIM_all)))
66 |
67 |
68 | def calculate_psnr(img1, img2):
69 | # img1 and img2 have range [0, 255]
70 | img1 = img1.astype(np.float64)
71 | img2 = img2.astype(np.float64)
72 | mse = np.mean((img1 - img2)**2)
73 | if mse == 0:
74 | return float('inf')
75 | return 20 * math.log10(255.0 / math.sqrt(mse))
76 |
77 |
78 | def ssim(img1, img2):
79 | C1 = (0.01 * 255)**2
80 | C2 = (0.03 * 255)**2
81 |
82 | img1 = img1.astype(np.float64)
83 | img2 = img2.astype(np.float64)
84 | kernel = cv2.getGaussianKernel(11, 1.5)
85 | window = np.outer(kernel, kernel.transpose())
86 |
87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
89 | mu1_sq = mu1**2
90 | mu2_sq = mu2**2
91 | mu1_mu2 = mu1 * mu2
92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
95 |
96 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
97 | (sigma1_sq + sigma2_sq + C2))
98 | return ssim_map.mean()
99 |
100 |
101 | def calculate_ssim(img1, img2):
102 | '''calculate SSIM
103 | the same outputs as MATLAB's
104 | img1, img2: [0, 255]
105 | '''
106 | if not img1.shape == img2.shape:
107 | raise ValueError('Input images must have the same dimensions.')
108 | if img1.ndim == 2:
109 | return ssim(img1, img2)
110 | elif img1.ndim == 3:
111 | if img1.shape[2] == 3:
112 | ssims = []
113 | for i in range(3):
114 | ssims.append(ssim(img1, img2))
115 | return np.array(ssims).mean()
116 | elif img1.shape[2] == 1:
117 | return ssim(np.squeeze(img1), np.squeeze(img2))
118 | else:
119 | raise ValueError('Wrong input image dimensions.')
120 |
121 |
122 | def bgr2ycbcr(img, only_y=True):
123 | '''same as matlab rgb2ycbcr
124 | only_y: only return Y channel
125 | Input:
126 | uint8, [0, 255]
127 | float, [0, 1]
128 | '''
129 | in_img_type = img.dtype
130 | img.astype(np.float32)
131 | if in_img_type != np.uint8:
132 | img *= 255.
133 | # convert
134 | if only_y:
135 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
136 | else:
137 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
138 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
139 | if in_img_type == np.uint8:
140 | rlt = rlt.round()
141 | else:
142 | rlt /= 255.
143 | return rlt.astype(in_img_type)
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
148 |
--------------------------------------------------------------------------------
/codes/models/SR_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 | import numpy as np
8 | import models.networks as networks
9 | import models.lr_scheduler as lr_scheduler
10 | from .base_model import BaseModel
11 | from models.loss import CharbonnierLoss, FSLoss, GradientLoss
12 |
13 | logger = logging.getLogger('base')
14 |
15 |
16 | class SRModel(BaseModel):
17 | def __init__(self, opt):
18 | super(SRModel, self).__init__(opt)
19 |
20 | if opt['dist']:
21 | self.rank = torch.distributed.get_rank()
22 | else:
23 | self.rank = -1 # non dist training
24 | train_opt = opt['train']
25 |
26 | # define network and load pretrained models
27 | self.netG = networks.define_G(opt).to(self.device)
28 | if opt['dist']:
29 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
30 | else:
31 | self.netG = DataParallel(self.netG)
32 | # print network
33 | self.print_network()
34 | self.load()
35 |
36 | if self.is_train:
37 | self.netG.train()
38 |
39 | # loss
40 | loss_type = train_opt['pixel_criterion']
41 | self.loss_type = loss_type
42 | if loss_type == 'l1':
43 | self.cri_pix = nn.L1Loss().to(self.device)
44 | elif loss_type == 'l2':
45 | self.cri_pix = nn.MSELoss().to(self.device)
46 | elif loss_type == 'cb':
47 | self.cri_pix = CharbonnierLoss().to(self.device)
48 | else:
49 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
50 | self.l_pix_w = train_opt['pixel_weight']
51 |
52 | # optimizers
53 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
54 | optim_params = []
55 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
56 | if v.requires_grad:
57 | optim_params.append(v)
58 | else:
59 | if self.rank <= 0:
60 | logger.warning('Params [{:s}] will not optimize.'.format(k))
61 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
62 | weight_decay=wd_G,
63 | betas=(train_opt['beta1'], train_opt['beta2']))
64 | self.optimizers.append(self.optimizer_G)
65 |
66 | # schedulers
67 | if train_opt['lr_scheme'] == 'MultiStepLR':
68 | for optimizer in self.optimizers:
69 | self.schedulers.append(
70 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
71 | restarts=train_opt['restarts'],
72 | weights=train_opt['restart_weights'],
73 | gamma=train_opt['lr_gamma'],
74 | clear_state=train_opt['clear_state']))
75 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
76 | for optimizer in self.optimizers:
77 | self.schedulers.append(
78 | lr_scheduler.CosineAnnealingLR_Restart(
79 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
80 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
81 | else:
82 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
83 |
84 | self.log_dict = OrderedDict()
85 |
86 | def feed_data(self, data, need_GT=True):
87 | self.var_L = data['LQ'].to(self.device) # LQ
88 | if need_GT:
89 | self.real_H = data['GT'].to(self.device) # GT
90 |
91 | def mixup_data(self, x, y, alpha=1.0, use_cuda=True):
92 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
93 | batch_size = x.size()[0]
94 | lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
95 | index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size)
96 | mixed_x = lam * x + (1 - lam) * x[index,:]
97 | mixed_y = lam * y + (1 - lam) * y[index,:]
98 | return mixed_x, mixed_y
99 |
100 | def optimize_parameters(self, step):
101 | self.optimizer_G.zero_grad()
102 |
103 | '''add mixup operation'''
104 | # self.var_L, self.real_H = self.mixup_data(self.var_L, self.real_H)
105 |
106 | self.fake_H = self.netG(self.var_L)
107 | if self.loss_type == 'fs':
108 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) + self.l_fs_w * self.cri_fs(self.fake_H, self.real_H)
109 | elif self.loss_type == 'grad':
110 | l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
111 | lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H)
112 | l_pix = l1 + lg
113 | elif self.loss_type == 'grad_fs':
114 | l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
115 | lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H)
116 | lfs = self.l_fs_w * self.cri_fs(self.fake_H, self.real_H)
117 | l_pix = l1 + lg + lfs
118 | else:
119 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
120 | l_pix.backward()
121 | self.optimizer_G.step()
122 |
123 | # set log
124 | self.log_dict['l_pix'] = l_pix.item()
125 | if self.loss_type == 'grad':
126 | self.log_dict['l_1'] = l1.item()
127 | self.log_dict['l_grad'] = lg.item()
128 | if self.loss_type == 'grad_fs':
129 | self.log_dict['l_1'] = l1.item()
130 | self.log_dict['l_grad'] = lg.item()
131 | self.log_dict['l_fs'] = lfs.item()
132 |
133 | def test(self):
134 | self.netG.eval()
135 |
136 | with torch.no_grad():
137 | self.fake_H = self.netG(self.var_L)
138 | self.netG.train()
139 |
140 | def test_x8(self):
141 | # from https://github.com/thstkdgus35/EDSR-PyTorch
142 | self.netG.eval()
143 |
144 | def _transform(v, op):
145 | # if self.precision != 'single': v = v.float()
146 | v2np = v.data.cpu().numpy()
147 | if op == 'v':
148 | tfnp = v2np[:, :, :, ::-1].copy()
149 | elif op == 'h':
150 | tfnp = v2np[:, :, ::-1, :].copy()
151 | elif op == 't':
152 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
153 |
154 | ret = torch.Tensor(tfnp).to(self.device)
155 | # if self.precision == 'half': ret = ret.half()
156 |
157 | return ret
158 |
159 | lr_list = [self.var_L]
160 | for tf in 'v', 'h', 't':
161 | lr_list.extend([_transform(t, tf) for t in lr_list])
162 | with torch.no_grad():
163 | sr_list = [self.netG(aug) for aug in lr_list]
164 | for i in range(len(sr_list)):
165 | if i > 3:
166 | sr_list[i] = _transform(sr_list[i], 't')
167 | if i % 4 > 1:
168 | sr_list[i] = _transform(sr_list[i], 'h')
169 | if (i % 4) % 2 == 1:
170 | sr_list[i] = _transform(sr_list[i], 'v')
171 |
172 | output_cat = torch.cat(sr_list, dim=0)
173 | self.fake_H = output_cat.mean(dim=0, keepdim=True)
174 | self.netG.train()
175 |
176 | def get_current_log(self):
177 | return self.log_dict
178 |
179 | def get_current_visuals(self, need_GT=True):
180 | out_dict = OrderedDict()
181 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
182 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
183 | if need_GT:
184 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
185 | return out_dict
186 |
187 | def print_network(self):
188 | s, n = self.get_network_description(self.netG)
189 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
190 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
191 | self.netG.module.__class__.__name__)
192 | else:
193 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
194 | if self.rank <= 0:
195 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
196 | logger.info(s)
197 |
198 | def load(self):
199 | load_path_G = self.opt['path']['pretrain_model_G']
200 | if load_path_G is not None:
201 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
202 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
203 |
204 | # def load(self):
205 | # load_path_G_1 = self.opt['path']['pretrain_model_G_1']
206 | # load_path_G_2 = self.opt['path']['pretrain_model_G_2']
207 | # load_path_Gs=[load_path_G_1, load_path_G_2]
208 |
209 | # load_path_G = self.opt['path']['pretrain_model_G']
210 | # if load_path_G is not None:
211 | # logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
212 | # self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
213 | # if load_path_G_1 is not None:
214 | # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_1))
215 | # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_2))
216 | # self.load_network_part(load_path_Gs, self.netG, self.opt['path']['strict_load'])
217 |
218 | def save(self, iter_label):
219 | self.save_network(self.netG, 'G', iter_label)
220 |
--------------------------------------------------------------------------------
/codes/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | model = opt['model']
7 | # image restoration
8 | if model == 'sr': # PSNR-oriented super resolution
9 | from .SR_model import SRModel as M
10 | elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
11 | from .SRGAN_model import SRGANModel as M
12 | # video restoration
13 | elif model == 'video_base':
14 | from .Video_base_model import VideoBaseModel as M
15 | else:
16 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
17 | m = M(opt)
18 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
19 | return m
20 |
--------------------------------------------------------------------------------
/codes/models/__pycache__/SRGAN_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/SRGAN_model.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/SR_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/SR_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/SR_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/SR_model.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/Video_base_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/Video_base_model.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/base_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/base_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/base_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/base_model.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/lr_scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/lr_scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/lr_scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/lr_scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/networks.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/PAN_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import models.archs.arch_util as arch_util
6 |
7 | class PA(nn.Module):
8 | '''PA is pixel attention'''
9 | def __init__(self, nf):
10 |
11 | super(PA, self).__init__()
12 | self.conv = nn.Conv2d(nf, nf, 1)
13 | self.sigmoid = nn.Sigmoid()
14 |
15 | def forward(self, x):
16 |
17 | y = self.conv(x)
18 | y = self.sigmoid(y)
19 | out = torch.mul(x, y)
20 |
21 | return out
22 |
23 | class PAConv(nn.Module):
24 |
25 | def __init__(self, nf, k_size=3):
26 |
27 | super(PAConv, self).__init__()
28 | self.k2 = nn.Conv2d(nf, nf, 1) # 1x1 convolution nf->nf
29 | self.sigmoid = nn.Sigmoid()
30 | self.k3 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
31 | self.k4 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
32 |
33 | def forward(self, x):
34 |
35 | y = self.k2(x)
36 | y = self.sigmoid(y)
37 |
38 | out = torch.mul(self.k3(x), y)
39 | out = self.k4(out)
40 |
41 | return out
42 |
43 | class SCPA(nn.Module):
44 |
45 | """SCPA is modified from SCNet (Jiang-Jiang Liu et al. Improving Convolutional Networks with Self-Calibrated Convolutions. In CVPR, 2020)
46 | Github: https://github.com/MCG-NKU/SCNet
47 | """
48 |
49 | def __init__(self, nf, reduction=2, stride=1, dilation=1):
50 | super(SCPA, self).__init__()
51 | group_width = nf // reduction
52 |
53 | self.conv1_a = nn.Conv2d(nf, group_width, kernel_size=1, bias=False)
54 | self.conv1_b = nn.Conv2d(nf, group_width, kernel_size=1, bias=False)
55 |
56 | self.k1 = nn.Sequential(
57 | nn.Conv2d(
58 | group_width, group_width, kernel_size=3, stride=stride,
59 | padding=dilation, dilation=dilation,
60 | bias=False)
61 | )
62 |
63 | self.PAConv = PAConv(group_width)
64 |
65 | self.conv3 = nn.Conv2d(
66 | group_width * reduction, nf, kernel_size=1, bias=False)
67 |
68 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out_a= self.conv1_a(x)
74 | out_b = self.conv1_b(x)
75 | out_a = self.lrelu(out_a)
76 | out_b = self.lrelu(out_b)
77 |
78 | out_a = self.k1(out_a)
79 | out_b = self.PAConv(out_b)
80 | out_a = self.lrelu(out_a)
81 | out_b = self.lrelu(out_b)
82 |
83 | out = self.conv3(torch.cat([out_a, out_b], dim=1))
84 | out += residual
85 |
86 | return out
87 |
88 | class PAN(nn.Module):
89 |
90 | def __init__(self, in_nc, out_nc, nf, unf, nb, scale=4):
91 | super(PAN, self).__init__()
92 | # SCPA
93 | SCPA_block_f = functools.partial(SCPA, nf=nf, reduction=2)
94 | self.scale = scale
95 |
96 | ### first convolution
97 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
98 |
99 | ### main blocks
100 | self.SCPA_trunk = arch_util.make_layer(SCPA_block_f, nb)
101 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
102 |
103 | #### upsampling
104 | self.upconv1 = nn.Conv2d(nf, unf, 3, 1, 1, bias=True)
105 | self.att1 = PA(unf)
106 | self.HRconv1 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
107 |
108 | if self.scale == 4:
109 | self.upconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
110 | self.att2 = PA(unf)
111 | self.HRconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
112 |
113 | self.conv_last = nn.Conv2d(unf, out_nc, 3, 1, 1, bias=True)
114 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
115 |
116 | def forward(self, x):
117 |
118 | fea = self.conv_first(x)
119 | trunk = self.trunk_conv(self.SCPA_trunk(fea))
120 | fea = fea + trunk
121 |
122 | if self.scale == 2 or self.scale == 3:
123 | fea = self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))
124 | fea = self.lrelu(self.att1(fea))
125 | fea = self.lrelu(self.HRconv1(fea))
126 | elif self.scale == 4:
127 | fea = self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))
128 | fea = self.lrelu(self.att1(fea))
129 | fea = self.lrelu(self.HRconv1(fea))
130 | fea = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
131 | fea = self.lrelu(self.att2(fea))
132 | fea = self.lrelu(self.HRconv2(fea))
133 |
134 | out = self.conv_last(fea)
135 |
136 | ILR = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
137 | out = out + ILR
138 | return out
139 |
--------------------------------------------------------------------------------
/codes/models/archs/RCAN_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import models.archs.arch_util as arch_util
5 |
6 | ## Channel Attention (CA) Layer
7 | class CALayer(nn.Module):
8 | def __init__(self, channel, reduction=16):
9 | super(CALayer, self).__init__()
10 | # global average pooling: feature --> point
11 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
12 | # feature channel downscale and upscale --> channel weight
13 | self.conv_du = nn.Sequential(
14 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
17 | nn.Sigmoid()
18 | )
19 |
20 | def forward(self, x):
21 | y = self.avg_pool(x)
22 | y = self.conv_du(y)
23 | return x * y
24 |
25 | class PA(nn.Module):
26 | '''PA is pixel attention'''
27 | def __init__(self, nf):
28 |
29 | super(PA, self).__init__()
30 | self.conv = nn.Conv2d(nf, nf, 1)
31 | self.sigmoid = nn.Sigmoid()
32 |
33 | def forward(self, x):
34 |
35 | y = self.conv(x)
36 | y = self.sigmoid(y)
37 | out = torch.mul(x, y)
38 |
39 | return out
40 |
41 |
42 | ## Residual Channel Attention Block (RCAB)
43 | class RCAB(nn.Module):
44 | def __init__(
45 | self, conv, n_feat, kernel_size, reduction,
46 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
47 |
48 | super(RCAB, self).__init__()
49 | modules_body = []
50 | for i in range(2):
51 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
52 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
53 | if i == 0: modules_body.append(act)
54 | modules_body.append(CALayer(n_feat, reduction))
55 | self.body = nn.Sequential(*modules_body)
56 | self.res_scale = res_scale
57 |
58 | def forward(self, x):
59 | res = self.body(x)
60 | res += x
61 | return res
62 |
63 | class RPAB(nn.Module):
64 | '''Residual Block with PA'''
65 | def __init__(
66 | self, conv, n_feat, kernel_size, reduction,
67 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
68 |
69 | super(RPAB, self).__init__()
70 | modules_body = []
71 | for i in range(2):
72 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
73 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
74 | if i == 0: modules_body.append(act)
75 | modules_body.append(PA(n_feat))
76 | self.body = nn.Sequential(*modules_body)
77 | self.res_scale = res_scale
78 |
79 | def forward(self, x):
80 | res = self.body(x)
81 | res += x
82 | return res
83 |
84 | # Residual Group (RG)
85 | class ResidualGroup(nn.Module):
86 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
87 | super(ResidualGroup, self).__init__()
88 | modules_body = []
89 | modules_body = [
90 | RPAB(
91 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
92 | for _ in range(n_resblocks)]
93 | modules_body.append(conv(n_feat, n_feat, kernel_size))
94 | self.body = nn.Sequential(*modules_body)
95 |
96 | def forward(self, x):
97 | res = self.body(x)
98 | res += x
99 | return res
100 |
101 | # class ResidualGroup(nn.Module):
102 | # def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
103 | # super(ResidualGroup, self).__init__()
104 | # modules_body = []
105 | # modules_body = [
106 | # RCAB(
107 | # conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
108 | # for _ in range(n_resblocks)]
109 | # modules_body.append(conv(n_feat, n_feat, kernel_size))
110 | # self.body = nn.Sequential(*modules_body)
111 |
112 | # def forward(self, x):
113 | # res = self.body(x)
114 | # res += x
115 | # return res
116 |
117 | class Upsampler(nn.Sequential):
118 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
119 |
120 | m = []
121 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
122 | for _ in range(int(math.log(scale, 2))):
123 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
124 | m.append(nn.PixelShuffle(2))
125 | if bn: m.append(nn.BatchNorm2d(n_feat))
126 | if act: m.append(act())
127 | elif scale == 3:
128 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
129 | m.append(nn.PixelShuffle(3))
130 | if bn: m.append(nn.BatchNorm2d(n_feat))
131 | if act: m.append(act())
132 | else:
133 | raise NotImplementedError
134 |
135 | super(Upsampler, self).__init__(*m)
136 |
137 | ## Residual Channel Attention Network (RCAN)
138 | class MRCAN(nn.Module):
139 | ''' modified RCAN '''
140 |
141 | def __init__(self, n_resgroups, n_resblocks, n_feats, res_scale, n_colors, rgb_range, scale, reduction, conv=arch_util.default_conv):
142 | super(MRCAN, self).__init__()
143 |
144 | n_resgroups = n_resgroups
145 | n_resblocks = n_resblocks
146 | n_feats = n_feats
147 | kernel_size = 3
148 | reduction = reduction
149 | scale = scale
150 | act = nn.ReLU(True)
151 |
152 | # RGB mean for DIV2K
153 | rgb_mean = (0.4488, 0.4371, 0.4040)
154 | rgb_std = (1.0, 1.0, 1.0)
155 | self.sub_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std)
156 |
157 | # define head module
158 | modules_head = [conv(n_colors, n_feats, kernel_size)]
159 |
160 | # define body module
161 | modules_body = [
162 | ResidualGroup(
163 | conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \
164 | for _ in range(n_resgroups)]
165 |
166 | modules_body.append(conv(n_feats, n_feats, kernel_size))
167 |
168 | # define tail module
169 | modules_tail = [
170 | Upsampler(conv, scale, n_feats, act=False),
171 | conv(n_feats, n_colors, kernel_size)]
172 |
173 | self.add_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std, 1)
174 |
175 | self.head = nn.Sequential(*modules_head)
176 | self.body = nn.Sequential(*modules_body)
177 | self.tail = nn.Sequential(*modules_tail)
178 |
179 | arch_util.initialize_weights([self.head, self.body, self.tail], 0.1)
180 | def forward(self, x):
181 | x = self.sub_mean(x)
182 | x = self.head(x)
183 | res = self.body(x)
184 | res += x
185 |
186 | x = self.tail(res)
187 | x = self.add_mean(x)
188 |
189 | return x
190 |
191 | class RCAN_PA(nn.Module):
192 | ''' RCAN + PA '''
193 |
194 | def __init__(self, n_resgroups=10, n_resblocks=20, n_feats=64, res_scale=1, n_colors=3, rgb_range=1, scale=4, reduction=16, conv=arch_util.default_conv):
195 | super(RCAN_PA, self).__init__()
196 |
197 | n_resgroups = n_resgroups
198 | n_resblocks = n_resblocks
199 | n_feats = n_feats
200 | kernel_size = 3
201 | reduction = reduction
202 | scale = scale
203 | act = nn.ReLU(True)
204 |
205 | # RGB mean for DIV2K
206 | rgb_mean = (0.4488, 0.4371, 0.4040)
207 | rgb_std = (1.0, 1.0, 1.0)
208 | self.sub_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std)
209 |
210 | # define head module
211 | modules_head = [conv(n_colors, n_feats, kernel_size)]
212 |
213 | # define body module
214 | modules_body = [
215 | ResidualGroup(
216 | conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \
217 | for _ in range(n_resgroups)]
218 |
219 | modules_body.append(conv(n_feats, n_feats, kernel_size))
220 |
221 | # define tail module
222 | modules_tail = [
223 | Upsampler(conv, scale, n_feats, act=False),
224 | conv(n_feats, n_colors, kernel_size)]
225 |
226 | self.add_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std, 1)
227 |
228 | self.head = nn.Sequential(*modules_head)
229 | self.body = nn.Sequential(*modules_body)
230 | self.tail = nn.Sequential(*modules_tail)
231 |
232 | arch_util.initialize_weights([self.head, self.body, self.tail], 0.1)
233 | def forward(self, x):
234 | x = self.sub_mean(x)
235 | x = self.head(x)
236 |
237 | res = self.body(x)
238 | res += x
239 |
240 | x = self.tail(res)
241 | x = self.add_mean(x)
242 |
243 | return x
--------------------------------------------------------------------------------
/codes/models/archs/SRResNet_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import models.archs.arch_util as arch_util
6 |
7 | class PA(nn.Module):
8 | '''PA is pixel attention'''
9 | def __init__(self, nf):
10 |
11 | super(PA, self).__init__()
12 | self.conv = nn.Conv2d(nf, nf, 1)
13 | self.sigmoid = nn.Sigmoid()
14 |
15 | def forward(self, x):
16 |
17 | y = self.conv(x)
18 | y = self.sigmoid(y)
19 | out = torch.mul(x, y)
20 |
21 | return out
22 |
23 | class ResidualBlock_noBN_PA(nn.Module):
24 | '''Residual block w/o BN
25 | ---Conv-ReLU-Conv-PA-+-
26 | |___________________|
27 | '''
28 |
29 | def __init__(self, nf=64):
30 | super(ResidualBlock_noBN_PA, self).__init__()
31 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
32 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
33 | self.pa = PA(nf)
34 |
35 | # initialization
36 | arch_util.initialize_weights([self.conv1, self.conv2], 0.1)
37 |
38 | def forward(self, x):
39 | identity = x
40 | out = F.relu(self.conv1(x), inplace=True)
41 | out = self.conv2(out)
42 | out = self.pa(out)
43 | return identity + out
44 |
45 | class MSRResNet(nn.Module):
46 | ''' modified SRResNet'''
47 |
48 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
49 | super(MSRResNet, self).__init__()
50 | self.upscale = upscale
51 |
52 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
53 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
54 | self.recon_trunk = arch_util.make_layer(basic_block, nb)
55 |
56 | # upsampling
57 | if self.upscale == 2:
58 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
59 | self.pixel_shuffle = nn.PixelShuffle(2)
60 | elif self.upscale == 3:
61 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
62 | self.pixel_shuffle = nn.PixelShuffle(3)
63 | elif self.upscale == 4:
64 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
65 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
66 | self.pixel_shuffle = nn.PixelShuffle(2)
67 |
68 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
69 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
70 |
71 | # activation function
72 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
73 |
74 | # initialization
75 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1)
76 | if self.upscale == 4:
77 | arch_util.initialize_weights(self.upconv2, 0.1)
78 |
79 | def forward(self, x):
80 | fea = self.lrelu(self.conv_first(x))
81 | out = self.recon_trunk(fea)
82 |
83 | if self.upscale == 4:
84 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
85 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
86 | elif self.upscale == 3 or self.upscale == 2:
87 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
88 |
89 | out = self.conv_last(self.lrelu(self.HRconv(out)))
90 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
91 | out += base
92 | return out
93 |
94 | class MSRResNet_PA(nn.Module):
95 | ''' modified SRResNet + PA'''
96 |
97 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
98 | super(MSRResNet_PA, self).__init__()
99 | self.upscale = upscale
100 |
101 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
102 | basic_block = functools.partial(ResidualBlock_noBN_PA, nf=nf)
103 | self.recon_trunk = arch_util.make_layer(basic_block, nb)
104 |
105 | # upsampling
106 | if self.upscale == 2:
107 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
108 | self.pixel_shuffle = nn.PixelShuffle(2)
109 | elif self.upscale == 3:
110 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
111 | self.pixel_shuffle = nn.PixelShuffle(3)
112 | elif self.upscale == 4:
113 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
114 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
115 | self.pixel_shuffle = nn.PixelShuffle(2)
116 |
117 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
118 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
119 |
120 | # activation function
121 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
122 |
123 | # initialization
124 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
125 | 0.1)
126 | if self.upscale == 4:
127 | arch_util.initialize_weights(self.upconv2, 0.1)
128 |
129 | def forward(self, x):
130 | fea = self.lrelu(self.conv_first(x))
131 | out = self.recon_trunk(fea)
132 |
133 | if self.upscale == 4:
134 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
135 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
136 | elif self.upscale == 3 or self.upscale == 2:
137 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
138 |
139 | out = self.conv_last(self.lrelu(self.HRconv(out)))
140 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
141 | out += base
142 | return out
--------------------------------------------------------------------------------
/codes/models/archs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__init__.py
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/AWSRN_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/AWSRN_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/DUF_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/DUF_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/Dual_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/Dual_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/EDVR_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/EDVR_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/EfficientSR_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/EfficientSR_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/EfficientSR_clean.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/EfficientSR_clean.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/FSRCNN_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/FSRCNN_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/KernelMD_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/KernelMD_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/MSSResNet_deblur_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/MSSResNet_deblur_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/Octave_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/Octave_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/PAN_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PAN_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/PAN_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PAN_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/PAN_arch_update.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PAN_arch_update.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/PANet_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PANet_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/PANv2_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PANv2_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/RCAN_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RCAN_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/RCAN_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RCAN_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/RRDBNet_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RRDBNet_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/SRResNet_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/SRResNet_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/UNet_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/UNet_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/arch_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/arch_util.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/arch_util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/arch_util.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/unet_arch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/unet_arch.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/unet_parts.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/unet_parts.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/models/archs/arch_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 | import numpy as np
7 |
8 | # for RCAN
9 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
10 | return nn.Conv2d(
11 | in_channels, out_channels, kernel_size,
12 | padding=(kernel_size//2), bias=bias)
13 |
14 | class MeanShift(nn.Conv2d):
15 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
16 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
17 | std = torch.Tensor(rgb_std)
18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
19 | self.weight.data.div_(std.view(3, 1, 1, 1))
20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
21 | self.bias.data.div_(std)
22 | self.requires_grad = False
23 |
24 | # for other networks
25 | def initialize_weights(net_l, scale=1):
26 | if not isinstance(net_l, list):
27 | net_l = [net_l]
28 | for net in net_l:
29 | for m in net.modules():
30 | if isinstance(m, nn.Conv2d):
31 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
32 | m.weight.data *= scale # for residual block
33 | if m.bias is not None:
34 | m.bias.data.zero_()
35 | elif isinstance(m, nn.Linear):
36 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
37 | m.weight.data *= scale
38 | if m.bias is not None:
39 | m.bias.data.zero_()
40 | elif isinstance(m, nn.BatchNorm2d):
41 | init.constant_(m.weight, 1)
42 | init.constant_(m.bias.data, 0.0)
43 |
44 |
45 | def make_layer(block, n_layers):
46 | layers = []
47 | for _ in range(n_layers):
48 | layers.append(block())
49 | return nn.Sequential(*layers)
50 |
51 |
52 | class ResidualBlock_noBN(nn.Module):
53 | '''Residual block w/o BN
54 | ---Conv-ReLU-Conv-+-
55 | |________________|
56 | '''
57 |
58 | def __init__(self, nf=64):
59 | super(ResidualBlock_noBN, self).__init__()
60 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
61 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
62 |
63 | # initialization
64 | initialize_weights([self.conv1, self.conv2], 0.1)
65 |
66 | def forward(self, x):
67 | identity = x
68 | out = F.relu(self.conv1(x), inplace=True)
69 | out = self.conv2(out)
70 | return identity + out
71 |
72 |
73 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
74 | """Warp an image or feature map with optical flow
75 | Args:
76 | x (Tensor): size (N, C, H, W)
77 | flow (Tensor): size (N, H, W, 2), normal value
78 | interp_mode (str): 'nearest' or 'bilinear'
79 | padding_mode (str): 'zeros' or 'border' or 'reflection'
80 |
81 | Returns:
82 | Tensor: warped image or feature map
83 | """
84 | assert x.size()[-2:] == flow.size()[1:3]
85 | B, C, H, W = x.size()
86 | # mesh grid
87 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
88 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
89 | grid.requires_grad = False
90 | grid = grid.type_as(x)
91 | vgrid = grid + flow
92 | # scale grid to [-1,1]
93 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
94 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
95 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
96 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
97 | return output
98 |
99 | def scalex4(im):
100 | '''Nearest Upsampling by myself'''
101 | im1 = im[:, :1, ...].repeat(1, 16, 1, 1)
102 | im2 = im[:, 1:2, ...].repeat(1, 16, 1, 1)
103 | im3 = im[:, 2:, ...].repeat(1, 16, 1, 1)
104 |
105 | # b, c, h, w = im.shape
106 | # w = torch.randn(b,16,h,w).cuda() * (5e-2)
107 |
108 | # img1 = im1 + im1 * w
109 | # img2 = im2 + im2 * w
110 | # img3 = im3 + im3 * w
111 |
112 | imhr = torch.cat((im1, im2, im3), 1)
113 | imhr = F.pixel_shuffle(imhr, 4)
114 | return imhr
--------------------------------------------------------------------------------
/codes/models/archs/dcn/__init__.py:
--------------------------------------------------------------------------------
1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
2 | deform_conv, modulated_deform_conv)
3 |
4 | __all__ = [
5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
6 | 'modulated_deform_conv'
7 | ]
8 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/dcn/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/dcn/__pycache__/deform_conv.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/dcn/__pycache__/deform_conv.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/dcn/deform_conv.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: deform-conv
3 | Version: 0.0.0
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | Author: UNKNOWN
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/deform_conv.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | setup.py
2 | deform_conv.egg-info/PKG-INFO
3 | deform_conv.egg-info/SOURCES.txt
4 | deform_conv.egg-info/dependency_links.txt
5 | deform_conv.egg-info/not-zip-safe
6 | deform_conv.egg-info/top_level.txt
7 | src/deform_conv_cuda.cpp
8 | src/deform_conv_cuda_kernel.cu
--------------------------------------------------------------------------------
/codes/models/archs/dcn/deform_conv.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/deform_conv.egg-info/not-zip-safe:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/deform_conv.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | deform_conv_cuda
2 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/deform_conv.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Function
7 | from torch.autograd.function import once_differentiable
8 | from torch.nn.modules.utils import _pair
9 |
10 | from . import deform_conv_cuda
11 |
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class DeformConvFunction(Function):
16 | @staticmethod
17 | def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
18 | deformable_groups=1, im2col_step=64):
19 | if input is not None and input.dim() != 4:
20 | raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
21 | input.dim()))
22 | ctx.stride = _pair(stride)
23 | ctx.padding = _pair(padding)
24 | ctx.dilation = _pair(dilation)
25 | ctx.groups = groups
26 | ctx.deformable_groups = deformable_groups
27 | ctx.im2col_step = im2col_step
28 |
29 | ctx.save_for_backward(input, offset, weight)
30 |
31 | output = input.new_empty(
32 | DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
33 |
34 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
35 |
36 | if not input.is_cuda:
37 | raise NotImplementedError
38 | else:
39 | cur_im2col_step = min(ctx.im2col_step, input.shape[0])
40 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
41 | deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
42 | ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
43 | weight.size(2), ctx.stride[1], ctx.stride[0],
44 | ctx.padding[1], ctx.padding[0],
45 | ctx.dilation[1], ctx.dilation[0], ctx.groups,
46 | ctx.deformable_groups, cur_im2col_step)
47 | return output
48 |
49 | @staticmethod
50 | @once_differentiable
51 | def backward(ctx, grad_output):
52 | input, offset, weight = ctx.saved_tensors
53 |
54 | grad_input = grad_offset = grad_weight = None
55 |
56 | if not grad_output.is_cuda:
57 | raise NotImplementedError
58 | else:
59 | cur_im2col_step = min(ctx.im2col_step, input.shape[0])
60 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
61 |
62 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
63 | grad_input = torch.zeros_like(input)
64 | grad_offset = torch.zeros_like(offset)
65 | deform_conv_cuda.deform_conv_backward_input_cuda(
66 | input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
67 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
68 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
69 | ctx.deformable_groups, cur_im2col_step)
70 |
71 | if ctx.needs_input_grad[2]:
72 | grad_weight = torch.zeros_like(weight)
73 | deform_conv_cuda.deform_conv_backward_parameters_cuda(
74 | input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
75 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
76 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
77 | ctx.deformable_groups, 1, cur_im2col_step)
78 |
79 | return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
80 |
81 | @staticmethod
82 | def _output_size(input, weight, padding, dilation, stride):
83 | channels = weight.size(0)
84 | output_size = (input.size(0), channels)
85 | for d in range(input.dim() - 2):
86 | in_size = input.size(d + 2)
87 | pad = padding[d]
88 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
89 | stride_ = stride[d]
90 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
91 | if not all(map(lambda s: s > 0, output_size)):
92 | raise ValueError("convolution input is too small (output would be {})".format('x'.join(
93 | map(str, output_size))))
94 | return output_size
95 |
96 |
97 | class ModulatedDeformConvFunction(Function):
98 | @staticmethod
99 | def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
100 | groups=1, deformable_groups=1):
101 | ctx.stride = stride
102 | ctx.padding = padding
103 | ctx.dilation = dilation
104 | ctx.groups = groups
105 | ctx.deformable_groups = deformable_groups
106 | ctx.with_bias = bias is not None
107 | if not ctx.with_bias:
108 | bias = input.new_empty(1) # fake tensor
109 | if not input.is_cuda:
110 | raise NotImplementedError
111 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \
112 | or input.requires_grad:
113 | ctx.save_for_backward(input, offset, mask, weight, bias)
114 | output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
115 | ctx._bufs = [input.new_empty(0), input.new_empty(0)]
116 | deform_conv_cuda.modulated_deform_conv_cuda_forward(
117 | input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
118 | weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
119 | ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
120 | return output
121 |
122 | @staticmethod
123 | @once_differentiable
124 | def backward(ctx, grad_output):
125 | if not grad_output.is_cuda:
126 | raise NotImplementedError
127 | input, offset, mask, weight, bias = ctx.saved_tensors
128 | grad_input = torch.zeros_like(input)
129 | grad_offset = torch.zeros_like(offset)
130 | grad_mask = torch.zeros_like(mask)
131 | grad_weight = torch.zeros_like(weight)
132 | grad_bias = torch.zeros_like(bias)
133 | deform_conv_cuda.modulated_deform_conv_cuda_backward(
134 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
135 | grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
136 | ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
137 | ctx.groups, ctx.deformable_groups, ctx.with_bias)
138 | if not ctx.with_bias:
139 | grad_bias = None
140 |
141 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
142 | None)
143 |
144 | @staticmethod
145 | def _infer_shape(ctx, input, weight):
146 | n = input.size(0)
147 | channels_out = weight.size(0)
148 | height, width = input.shape[2:4]
149 | kernel_h, kernel_w = weight.shape[2:4]
150 | height_out = (height + 2 * ctx.padding - (ctx.dilation *
151 | (kernel_h - 1) + 1)) // ctx.stride + 1
152 | width_out = (width + 2 * ctx.padding - (ctx.dilation *
153 | (kernel_w - 1) + 1)) // ctx.stride + 1
154 | return n, channels_out, height_out, width_out
155 |
156 |
157 | deform_conv = DeformConvFunction.apply
158 | modulated_deform_conv = ModulatedDeformConvFunction.apply
159 |
160 |
161 | class DeformConv(nn.Module):
162 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
163 | groups=1, deformable_groups=1, bias=False):
164 | super(DeformConv, self).__init__()
165 |
166 | assert not bias
167 | assert in_channels % groups == 0, \
168 | 'in_channels {} cannot be divisible by groups {}'.format(
169 | in_channels, groups)
170 | assert out_channels % groups == 0, \
171 | 'out_channels {} cannot be divisible by groups {}'.format(
172 | out_channels, groups)
173 |
174 | self.in_channels = in_channels
175 | self.out_channels = out_channels
176 | self.kernel_size = _pair(kernel_size)
177 | self.stride = _pair(stride)
178 | self.padding = _pair(padding)
179 | self.dilation = _pair(dilation)
180 | self.groups = groups
181 | self.deformable_groups = deformable_groups
182 |
183 | self.weight = nn.Parameter(
184 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
185 |
186 | self.reset_parameters()
187 |
188 | def reset_parameters(self):
189 | n = self.in_channels
190 | for k in self.kernel_size:
191 | n *= k
192 | stdv = 1. / math.sqrt(n)
193 | self.weight.data.uniform_(-stdv, stdv)
194 |
195 | def forward(self, x, offset):
196 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
197 | self.groups, self.deformable_groups)
198 |
199 |
200 | class DeformConvPack(DeformConv):
201 | def __init__(self, *args, **kwargs):
202 | super(DeformConvPack, self).__init__(*args, **kwargs)
203 |
204 | self.conv_offset = nn.Conv2d(
205 | self.in_channels,
206 | self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
207 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
208 | bias=True)
209 | self.init_offset()
210 |
211 | def init_offset(self):
212 | self.conv_offset.weight.data.zero_()
213 | self.conv_offset.bias.data.zero_()
214 |
215 | def forward(self, x):
216 | offset = self.conv_offset(x)
217 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
218 | self.groups, self.deformable_groups)
219 |
220 |
221 | class ModulatedDeformConv(nn.Module):
222 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
223 | groups=1, deformable_groups=1, bias=True):
224 | super(ModulatedDeformConv, self).__init__()
225 | self.in_channels = in_channels
226 | self.out_channels = out_channels
227 | self.kernel_size = _pair(kernel_size)
228 | self.stride = stride
229 | self.padding = padding
230 | self.dilation = dilation
231 | self.groups = groups
232 | self.deformable_groups = deformable_groups
233 | self.with_bias = bias
234 |
235 | self.weight = nn.Parameter(
236 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
237 | if bias:
238 | self.bias = nn.Parameter(torch.Tensor(out_channels))
239 | else:
240 | self.register_parameter('bias', None)
241 | self.reset_parameters()
242 |
243 | def reset_parameters(self):
244 | n = self.in_channels
245 | for k in self.kernel_size:
246 | n *= k
247 | stdv = 1. / math.sqrt(n)
248 | self.weight.data.uniform_(-stdv, stdv)
249 | if self.bias is not None:
250 | self.bias.data.zero_()
251 |
252 | def forward(self, x, offset, mask):
253 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
254 | self.padding, self.dilation, self.groups,
255 | self.deformable_groups)
256 |
257 |
258 | class ModulatedDeformConvPack(ModulatedDeformConv):
259 | def __init__(self, *args, extra_offset_mask=False, **kwargs):
260 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
261 |
262 | self.extra_offset_mask = extra_offset_mask
263 | self.conv_offset_mask = nn.Conv2d(
264 | self.in_channels,
265 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
266 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
267 | bias=True)
268 | self.init_offset()
269 |
270 | def init_offset(self):
271 | self.conv_offset_mask.weight.data.zero_()
272 | self.conv_offset_mask.bias.data.zero_()
273 |
274 | def forward(self, x):
275 | if self.extra_offset_mask:
276 | # x = [input, features]
277 | out = self.conv_offset_mask(x[1])
278 | x = x[0]
279 | else:
280 | out = self.conv_offset_mask(x)
281 | o1, o2, mask = torch.chunk(out, 3, dim=1)
282 | offset = torch.cat((o1, o2), dim=1)
283 | mask = torch.sigmoid(mask)
284 |
285 | offset_mean = torch.mean(torch.abs(offset))
286 | if offset_mean > 100:
287 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
288 |
289 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
290 | self.padding, self.dilation, self.groups,
291 | self.deformable_groups)
292 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | def make_cuda_ext(name, sources):
6 |
7 | return CUDAExtension(
8 | name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
9 | 'cxx': [],
10 | 'nvcc': [
11 | '-D__CUDA_NO_HALF_OPERATORS__',
12 | '-D__CUDA_NO_HALF_CONVERSIONS__',
13 | '-D__CUDA_NO_HALF2_OPERATORS__',
14 | ]
15 | })
16 |
17 |
18 | setup(
19 | name='deform_conv', ext_modules=[
20 | make_cuda_ext(name='deform_conv_cuda',
21 | sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
22 | ], cmdclass={'build_ext': BuildExtension}, zip_safe=False)
23 |
--------------------------------------------------------------------------------
/codes/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
12 | self.is_train = opt['is_train']
13 | self.schedulers = []
14 | self.optimizers = []
15 |
16 | def feed_data(self, data):
17 | pass
18 |
19 | def optimize_parameters(self):
20 | pass
21 |
22 | def get_current_visuals(self):
23 | pass
24 |
25 | def get_current_losses(self):
26 | pass
27 |
28 | def print_network(self):
29 | pass
30 |
31 | def save(self, label):
32 | pass
33 |
34 | def load(self):
35 | pass
36 |
37 | def _set_lr(self, lr_groups_l):
38 | """Set learning rate for warmup
39 | lr_groups_l: list for lr_groups. each for a optimizer"""
40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
41 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
42 | param_group['lr'] = lr
43 |
44 | def _get_init_lr(self):
45 | """Get the initial lr, which is set by the scheduler"""
46 | init_lr_groups_l = []
47 | for optimizer in self.optimizers:
48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
49 | return init_lr_groups_l
50 |
51 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
52 | for scheduler in self.schedulers:
53 | scheduler.step()
54 | # set up warm-up learning rate
55 | if cur_iter < warmup_iter:
56 | # get initial lr for each group
57 | init_lr_g_l = self._get_init_lr()
58 | # modify warming-up learning rates
59 | warm_up_lr_l = []
60 | for init_lr_g in init_lr_g_l:
61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
62 | # set learning rate
63 | self._set_lr(warm_up_lr_l)
64 |
65 | def get_current_learning_rate(self):
66 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
67 |
68 | def get_network_description(self, network):
69 | """Get the string and total parameters of the network"""
70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
71 | network = network.module
72 | return str(network), sum(map(lambda x: x.numel(), network.parameters()))
73 |
74 | def save_network(self, network, network_label, iter_label):
75 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
76 | save_path = os.path.join(self.opt['path']['models'], save_filename)
77 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
78 | network = network.module
79 | state_dict = network.state_dict()
80 | for key, param in state_dict.items():
81 | state_dict[key] = param.cpu()
82 | torch.save(state_dict, save_path)
83 |
84 | def load_network(self, load_path, network, strict=True):
85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
86 | network = network.module
87 | load_net = torch.load(load_path)
88 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
89 | for k, v in load_net.items():
90 | if k.startswith('module.'):
91 | load_net_clean[k[7:]] = v
92 | else:
93 | load_net_clean[k] = v
94 | network.load_state_dict(load_net_clean, strict=strict)
95 |
96 | def load_network_part(self, load_path, network, strict=True):
97 |
98 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
99 | network1 = network.module.net1
100 | network2 = network.module.net2
101 | load_net = torch.load(load_path[0])
102 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
103 | for k, v in load_net.items():
104 | if k.startswith('module.'):
105 | load_net_clean[k[7:]] = v
106 | else:
107 | load_net_clean[k] = v
108 | network1.load_state_dict(load_net_clean, strict=strict)
109 |
110 | load_net = torch.load(load_path[1])
111 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
112 | for k, v in load_net.items():
113 | if k.startswith('module.'):
114 | load_net_clean[k[7:]] = v
115 | else:
116 | load_net_clean[k] = v
117 | network2.load_state_dict(load_net_clean, strict=strict)
118 |
119 | def save_training_state(self, epoch, iter_step):
120 | """Save training state during training, which will be used for resuming"""
121 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
122 | for s in self.schedulers:
123 | state['schedulers'].append(s.state_dict())
124 | for o in self.optimizers:
125 | state['optimizers'].append(o.state_dict())
126 | save_filename = '{}.state'.format(iter_step)
127 | save_path = os.path.join(self.opt['path']['training_state'], save_filename)
128 | torch.save(state, save_path)
129 |
130 | def resume_training(self, resume_state):
131 | """Resume the optimizers and schedulers for training"""
132 | resume_optimizers = resume_state['optimizers']
133 | resume_schedulers = resume_state['schedulers']
134 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
135 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
136 | for i, o in enumerate(resume_optimizers):
137 | self.optimizers[i].load_state_dict(o)
138 | for i, s in enumerate(resume_schedulers):
139 | self.schedulers[i].load_state_dict(s)
140 |
--------------------------------------------------------------------------------
/codes/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class GradientLoss(nn.Module):
6 | def __init__(self):
7 | super(GradientLoss, self).__init__()
8 | filterx = torch.tensor([[-3., 0., 3.], [-10., 0., 10.], [-3., 0. , 3.]])
9 | self.fx = filterx.expand(1,3,3,3).cuda()
10 |
11 | filtery = torch.tensor([[-3., -10, -3.], [0., 0., 0.], [3., 10. , 3.]])
12 | self.fy = filtery.expand(1,3,3,3).cuda()
13 |
14 | def forward(self, x, y):
15 | schxx = F.conv2d(x, self.fx, stride=1, padding=1)
16 | schxy = F.conv2d(x, self.fy, stride=1, padding=1)
17 | gradx = torch.sqrt(torch.pow(schxx, 2) + torch.pow(schxy, 2) + 1e-6)
18 |
19 | schyx = F.conv2d(y, self.fx, stride=1, padding=1)
20 | schyy = F.conv2d(y, self.fy, stride=1, padding=1)
21 | grady = torch.sqrt(torch.pow(schyx, 2) + torch.pow(schyy, 2) + 1e-6)
22 |
23 | loss = F.l1_loss(gradx, grady)
24 | return loss
25 |
26 | class GaussianFilter(nn.Module):
27 | def __init__(self, kernel_size=5, stride=1, padding=4):
28 | super(GaussianFilter, self).__init__()
29 | # initialize guassian kernel
30 | mean = (kernel_size - 1) / 2.0
31 | variance = (kernel_size / 6.0) ** 2.0
32 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
33 | x_coord = torch.arange(kernel_size)
34 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
35 | y_grid = x_grid.t()
36 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
37 |
38 | # Calculate the 2-dimensional gaussian kernel
39 | gaussian_kernel = torch.exp(-torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance))
40 |
41 | # Make sure sum of values in gaussian kernel equals 1.
42 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
43 |
44 | # Reshape to 2d depthwise convolutional weight
45 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
46 | gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1)
47 |
48 | # create gaussian filter as convolutional layer
49 | self.gaussian_filter = nn.Conv2d(3, 3, kernel_size, stride=stride, padding=padding, groups=3, bias=False)
50 | self.gaussian_filter.weight.data = gaussian_kernel
51 | self.gaussian_filter.weight.requires_grad = False
52 |
53 | def forward(self, x):
54 | return self.gaussian_filter(x)
55 |
56 |
57 | class FilterLow(nn.Module):
58 | def __init__(self, recursions=1, kernel_size=5, stride=1, padding=True, include_pad=True, gaussian=False):
59 | super(FilterLow, self).__init__()
60 | if padding:
61 | pad = int((kernel_size - 1) / 2)
62 | else:
63 | pad = 0
64 | if gaussian:
65 | self.filter = GaussianFilter(kernel_size=kernel_size, stride=stride, padding=pad)
66 | else:
67 | self.filter = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=pad, count_include_pad=include_pad)
68 | self.recursions = recursions
69 |
70 | def forward(self, img):
71 | for i in range(self.recursions):
72 | img = self.filter(img)
73 | return img
74 |
75 |
76 | class FilterHigh(nn.Module):
77 | def __init__(self, recursions=1, kernel_size=5, stride=1, include_pad=True, normalize=True, gaussian=False):
78 | super(FilterHigh, self).__init__()
79 | self.filter_low = FilterLow(recursions=1, kernel_size=kernel_size, stride=stride, include_pad=include_pad,
80 | gaussian=gaussian)
81 | self.recursions = recursions
82 | self.normalize = normalize
83 |
84 | def forward(self, img):
85 | if self.recursions > 1:
86 | for i in range(self.recursions - 1):
87 | img = self.filter_low(img)
88 | img = img - self.filter_low(img)
89 | if self.normalize:
90 | return 0.5 + img * 0.5
91 | else:
92 | return img
93 |
94 | class FSLoss(nn.Module):
95 | def __init__(self, recursions=1, stride=1, kernel_size=5, gaussian=False):
96 | super(FSLoss, self).__init__()
97 | self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size, include_pad=False,
98 | gaussian=gaussian)
99 | def forward(self, x, y):
100 | x_ = self.filter(x)
101 | y_ = self.filter(y)
102 | loss = F.l1_loss(x_, y_)
103 | return loss
104 |
105 | class CharbonnierLoss(nn.Module):
106 | """Charbonnier Loss (L1)"""
107 |
108 | def __init__(self, eps=1e-6):
109 | super(CharbonnierLoss, self).__init__()
110 | self.eps = eps
111 |
112 | def forward(self, x, y):
113 | diff = x - y
114 | loss = torch.sum(torch.sqrt(diff * diff + self.eps))
115 | return loss
116 |
117 |
118 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
119 | class GANLoss(nn.Module):
120 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
121 | super(GANLoss, self).__init__()
122 | self.gan_type = gan_type.lower()
123 | self.real_label_val = real_label_val
124 | self.fake_label_val = fake_label_val
125 |
126 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
127 | self.loss = nn.BCEWithLogitsLoss()
128 | elif self.gan_type == 'lsgan':
129 | self.loss = nn.MSELoss()
130 | elif self.gan_type == 'wgan-gp':
131 |
132 | def wgan_loss(input, target):
133 | # target is boolean
134 | return -1 * input.mean() if target else input.mean()
135 |
136 | self.loss = wgan_loss
137 | else:
138 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
139 |
140 | def get_target_label(self, input, target_is_real):
141 | if self.gan_type == 'wgan-gp':
142 | return target_is_real
143 | if target_is_real:
144 | return torch.empty_like(input).fill_(self.real_label_val)
145 | else:
146 | return torch.empty_like(input).fill_(self.fake_label_val)
147 |
148 | def forward(self, input, target_is_real):
149 | target_label = self.get_target_label(input, target_is_real)
150 | loss = self.loss(input, target_label)
151 | return loss
152 |
153 |
154 | class GradientPenaltyLoss(nn.Module):
155 | def __init__(self, device=torch.device('cpu')):
156 | super(GradientPenaltyLoss, self).__init__()
157 | self.register_buffer('grad_outputs', torch.Tensor())
158 | self.grad_outputs = self.grad_outputs.to(device)
159 |
160 | def get_grad_outputs(self, input):
161 | if self.grad_outputs.size() != input.size():
162 | self.grad_outputs.resize_(input.size()).fill_(1.0)
163 | return self.grad_outputs
164 |
165 | def forward(self, interp, interp_crit):
166 | grad_outputs = self.get_grad_outputs(interp_crit)
167 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
168 | grad_outputs=grad_outputs, create_graph=True,
169 | retain_graph=True, only_inputs=True)[0]
170 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
171 | grad_interp_norm = grad_interp.norm(2, dim=1)
172 |
173 | loss = ((grad_interp_norm - 1)**2).mean()
174 | return loss
175 |
--------------------------------------------------------------------------------
/codes/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restarts = [v + 1 for v in self.restarts]
16 | self.restart_weights = weights if weights else [1]
17 | assert len(self.restarts) == len(
18 | self.restart_weights), 'restarts and their weights do not match.'
19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
20 |
21 | def get_lr(self):
22 | if self.last_epoch in self.restarts:
23 | if self.clear_state:
24 | self.optimizer.state = defaultdict(dict)
25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
27 | if self.last_epoch not in self.milestones:
28 | return [group['lr'] for group in self.optimizer.param_groups]
29 | return [
30 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
31 | for group in self.optimizer.param_groups
32 | ]
33 |
34 |
35 | class CosineAnnealingLR_Restart(_LRScheduler):
36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
37 | self.T_period = T_period
38 | self.T_max = self.T_period[0] # current T period
39 | self.eta_min = eta_min
40 | self.restarts = restarts if restarts else [0]
41 | self.restarts = [v + 1 for v in self.restarts]
42 | self.restart_weights = weights if weights else [1]
43 | self.last_restart = 0
44 | assert len(self.restarts) == len(
45 | self.restart_weights), 'restarts and their weights do not match.'
46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
47 |
48 | def get_lr(self):
49 | if self.last_epoch == 0:
50 | return self.base_lrs
51 | elif self.last_epoch in self.restarts:
52 | self.last_restart = self.last_epoch
53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
57 | return [
58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
60 | ]
61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
63 | (group['lr'] - self.eta_min) + self.eta_min
64 | for group in self.optimizer.param_groups]
65 |
66 |
67 | if __name__ == "__main__":
68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
69 | betas=(0.9, 0.99))
70 | ##############################
71 | # MultiStepLR_Restart
72 | ##############################
73 | ## Original
74 | lr_steps = [200000, 400000, 600000, 800000]
75 | restarts = None
76 | restart_weights = None
77 |
78 | ## two
79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
80 | restarts = [500000]
81 | restart_weights = [1]
82 |
83 | ## four
84 | lr_steps = [
85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
87 | ]
88 | restarts = [250000, 500000, 750000]
89 | restart_weights = [1, 1, 1]
90 |
91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
92 | clear_state=False)
93 |
94 | ##############################
95 | # Cosine Annealing Restart
96 | ##############################
97 | ## two
98 | T_period = [500000, 500000]
99 | restarts = [500000]
100 | restart_weights = [1]
101 |
102 | ## four
103 | T_period = [250000, 250000, 250000, 250000]
104 | restarts = [250000, 500000, 750000]
105 | restart_weights = [1, 1, 1]
106 |
107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
108 | weights=restart_weights)
109 |
110 | ##############################
111 | # Draw figure
112 | ##############################
113 | N_iter = 1000000
114 | lr_l = list(range(N_iter))
115 | for i in range(N_iter):
116 | scheduler.step()
117 | current_lr = optimizer.param_groups[0]['lr']
118 | lr_l[i] = current_lr
119 |
120 | import matplotlib as mpl
121 | from matplotlib import pyplot as plt
122 | import matplotlib.ticker as mtick
123 | mpl.style.use('default')
124 | import seaborn
125 | seaborn.set(style='whitegrid')
126 | seaborn.set_context('paper')
127 |
128 | plt.figure(1)
129 | plt.subplot(111)
130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
131 | plt.title('Title', fontsize=16, color='k')
132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
133 | legend = plt.legend(loc='upper right', shadow=False)
134 | ax = plt.gca()
135 | labels = ax.get_xticks().tolist()
136 | for k, v in enumerate(labels):
137 | labels[k] = str(int(v / 1000)) + 'K'
138 | ax.set_xticklabels(labels)
139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
140 |
141 | ax.set_ylabel('Learning rate')
142 | ax.set_xlabel('Iteration')
143 | fig = plt.gcf()
144 | plt.show()
145 |
--------------------------------------------------------------------------------
/codes/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import models.archs.PAN_arch as PAN_arch
3 | import models.archs.SRResNet_arch as SRResNet_arch
4 | import models.archs.RCAN_arch as RCAN_arch
5 |
6 |
7 | # Generator
8 | def define_G(opt):
9 | opt_net = opt['network_G']
10 | which_model = opt_net['which_model_G']
11 |
12 | # image restoration
13 | if which_model == 'PAN':
14 | netG = PAN_arch.PAN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
15 | nf=opt_net['nf'], unf=opt_net['unf'], nb=opt_net['nb'], scale=opt_net['scale'])
16 | elif which_model == 'MSRResNet_PA':
17 | netG = SRResNet_arch.MSRResNet_PA(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
18 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
19 | elif which_model == 'RCAN_PA':
20 | netG = RCAN_arch.RCAN_PA(n_resgroups=opt_net['n_resgroups'], n_resblocks=opt_net['n_resblocks'], n_feats=opt_net['n_feats'], res_scale=opt_net['res_scale'], n_colors=opt_net['n_colors'], rgb_range=opt_net['rgb_range'], scale=opt_net['scale'])
21 | else:
22 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
23 |
24 | return netG
25 |
--------------------------------------------------------------------------------
/codes/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__init__.py
--------------------------------------------------------------------------------
/codes/options/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/options/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/options/__pycache__/options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/options.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/options/__pycache__/options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/options.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 | Loader, Dumper = OrderedYaml()
7 |
8 |
9 | def parse(opt_path, is_train=True):
10 | with open(opt_path, mode='r') as f:
11 | opt = yaml.load(f, Loader=Loader)
12 | # export CUDA_VISIBLE_DEVICES
13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
16 |
17 | opt['is_train'] = is_train
18 | if opt['distortion'] == 'sr':
19 | scale = opt['scale']
20 |
21 | # datasets
22 | for phase, dataset in opt['datasets'].items():
23 | phase = phase.split('_')[0]
24 | dataset['phase'] = phase
25 | if opt['distortion'] == 'sr':
26 | dataset['scale'] = scale
27 | is_lmdb = False
28 | if dataset.get('dataroot_GT', None) is not None:
29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
30 | if dataset['dataroot_GT'].endswith('lmdb'):
31 | is_lmdb = True
32 | if dataset.get('dataroot_LQ', None) is not None:
33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
34 | if dataset['dataroot_LQ'].endswith('lmdb'):
35 | is_lmdb = True
36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
37 | if dataset['mode'].endswith('mc'): # for memcached
38 | dataset['data_type'] = 'mc'
39 | dataset['mode'] = dataset['mode'].replace('_mc', '')
40 |
41 | # path
42 | for key, path in opt['path'].items():
43 | if path and key in opt['path'] and key != 'strict_load':
44 | opt['path'][key] = osp.expanduser(path)
45 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
46 | if is_train:
47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
48 | opt['path']['experiments_root'] = experiments_root
49 | opt['path']['models'] = osp.join(experiments_root, 'models')
50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
51 | opt['path']['log'] = experiments_root
52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
53 |
54 | # change some options for debug mode
55 | if 'debug' in opt['name']:
56 | opt['train']['val_freq'] = 8
57 | opt['logger']['print_freq'] = 1
58 | opt['logger']['save_checkpoint_freq'] = 8
59 | else: # test
60 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
61 | opt['path']['results_root'] = results_root
62 | opt['path']['log'] = results_root
63 |
64 | # network
65 | if opt['distortion'] == 'sr':
66 | opt['network_G']['scale'] = scale
67 |
68 | return opt
69 |
70 |
71 | def dict2str(opt, indent_l=1):
72 | '''dict to string for logger'''
73 | msg = ''
74 | for k, v in opt.items():
75 | if isinstance(v, dict):
76 | msg += ' ' * (indent_l * 2) + k + ':[\n'
77 | msg += dict2str(v, indent_l + 1)
78 | msg += ' ' * (indent_l * 2) + ']\n'
79 | else:
80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
81 | return msg
82 |
83 |
84 | class NoneDict(dict):
85 | def __missing__(self, key):
86 | return None
87 |
88 |
89 | # convert to NoneDict, which return None for missing key.
90 | def dict_to_nonedict(opt):
91 | if isinstance(opt, dict):
92 | new_opt = dict()
93 | for key, sub_opt in opt.items():
94 | new_opt[key] = dict_to_nonedict(sub_opt)
95 | return NoneDict(**new_opt)
96 | elif isinstance(opt, list):
97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
98 | else:
99 | return opt
100 |
101 |
102 | def check_resume(opt, resume_iter):
103 | '''Check resume states and pretrain_model paths'''
104 | logger = logging.getLogger('base')
105 | if opt['path']['resume_state']:
106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
107 | 'pretrain_model_D', None) is not None:
108 | logger.warning('pretrain_model path will be ignored when resuming training.')
109 |
110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
111 | '{}_G.pth'.format(resume_iter))
112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
113 | if 'gan' in opt['model']:
114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
115 | '{}_D.pth'.format(resume_iter))
116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
117 |
--------------------------------------------------------------------------------
/codes/options/test/test_PANx2.yml:
--------------------------------------------------------------------------------
1 | name: PANx2_DF2K
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 2
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | save_img: True
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test1:
12 | name: Set5
13 | mode: LQGT
14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR
15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X2
16 | test2:
17 | name: Set14
18 | mode: LQGT
19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR
20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X2
21 | test3:
22 | name: B100
23 | mode: LQGT
24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR
25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X2
26 | test4:
27 | name: Urban100
28 | mode: LQGT
29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR
30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X2
31 | test5:
32 | name: Manga109
33 | mode: LQGT
34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR
35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X2
36 |
37 | #### network structures
38 | network_G:
39 | which_model_G: PAN
40 | in_nc: 3
41 | out_nc: 3
42 | nf: 40
43 | unf: 24
44 | nb: 16
45 | scale: 2
46 |
47 | #### path
48 | path:
49 | pretrain_model_G: ../experiments/pretrained_models/PANx2_DF2K.pth
50 |
--------------------------------------------------------------------------------
/codes/options/test/test_PANx3.yml:
--------------------------------------------------------------------------------
1 | name: PANx3_DF2K
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 3
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | save_img: True
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test1:
12 | name: Set5
13 | mode: LQGT
14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR
15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X3
16 | test2:
17 | name: Set14
18 | mode: LQGT
19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR
20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X3
21 | test3:
22 | name: B100
23 | mode: LQGT
24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR
25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X3
26 | test4:
27 | name: Urban100
28 | mode: LQGT
29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR
30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X3
31 | test5:
32 | name: Manga109
33 | mode: LQGT
34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR
35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X3
36 |
37 | #### network structures
38 | network_G:
39 | which_model_G: PAN
40 | in_nc: 3
41 | out_nc: 3
42 | nf: 40
43 | unf: 24
44 | nb: 16
45 | scale: 3
46 |
47 | #### path
48 | path:
49 | pretrain_model_G: ../experiments/pretrained_models/PANx3_DF2K.pth
50 |
--------------------------------------------------------------------------------
/codes/options/test/test_PANx4.yml:
--------------------------------------------------------------------------------
1 | name: PANx4_DF2K
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 4
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | save_img: True
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test1:
12 | name: Set5
13 | mode: LQGT
14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR
15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X4
16 | test2:
17 | name: Set14
18 | mode: LQGT
19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR
20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X4
21 | test3:
22 | name: B100
23 | mode: LQGT
24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR
25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X4
26 | test4:
27 | name: Urban100
28 | mode: LQGT
29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR
30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X4
31 | test5:
32 | name: Manga109
33 | mode: LQGT
34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR
35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X4
36 |
37 | #### network structures
38 | network_G:
39 | which_model_G: PAN
40 | in_nc: 3
41 | out_nc: 3
42 | nf: 40
43 | unf: 24
44 | nb: 16
45 | scale: 4
46 |
47 | #### path
48 | path:
49 | pretrain_model_G: ../experiments/pretrained_models/PANx4_DF2K.pth
50 |
--------------------------------------------------------------------------------
/codes/options/test/test_RCAN.yml:
--------------------------------------------------------------------------------
1 | name: RCANPAx4_DIV2K
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 4
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | save_img: False
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test1:
12 | name: Set5
13 | mode: LQGT
14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR
15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X4
16 | test2:
17 | name: Set14
18 | mode: LQGT
19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR
20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X4
21 | test3:
22 | name: B100
23 | mode: LQGT
24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR
25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X4
26 | test4:
27 | name: Urban100
28 | mode: LQGT
29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR
30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X4
31 | test5:
32 | name: Manga109
33 | mode: LQGT
34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR
35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X4
36 |
37 | #### network structures
38 | network_G:
39 | which_model_G: RCAN_PA
40 | n_resgroups: 10
41 | n_resblocks: 20
42 | n_feats: 64
43 | res_scale: 1
44 | n_colors: 3
45 | rgb_range: 1
46 | scale: 4
47 |
48 | #### path
49 | path:
50 | pretrain_model_G: ../experiments/pretrained_models/RCAN_PA_DIV2K.pth
--------------------------------------------------------------------------------
/codes/options/test/test_SRResNet.yml:
--------------------------------------------------------------------------------
1 | name: MSSResNet_PA_DIV2K
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 4
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | save_img: True
8 | gpu_ids: [0]
9 |
10 | datasets:
11 | test1:
12 | name: Set5
13 | mode: LQGT
14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR
15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X4
16 | test2:
17 | name: Set14
18 | mode: LQGT
19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR
20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X4
21 | test3:
22 | name: B100
23 | mode: LQGT
24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR
25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X4
26 | test4:
27 | name: Urban100
28 | mode: LQGT
29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR
30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X4
31 | test5:
32 | name: Manga109
33 | mode: LQGT
34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR
35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X4
36 | #### network structures
37 | network_G:
38 | which_model_G: MSRResNet_PA
39 | in_nc: 3
40 | out_nc: 3
41 | nf: 64
42 | nb: 16
43 | upscale: 4
44 |
45 | #### path
46 | path:
47 | pretrain_model_G: ../experiments/pretrained_models/SRResNet_PA_DIV2K.pth
--------------------------------------------------------------------------------
/codes/options/train/train_PANx2.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: PANx2_DF2K
3 | use_tb_logger: True
4 | model: sr
5 | distortion: sr
6 | scale: 2
7 | save_img: False
8 | gpu_ids: [0]
9 |
10 | #### datasets
11 | datasets:
12 | train:
13 | name: DF2K
14 | mode: LQGT
15 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DF2K_train/HRx2_sub360
16 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DF2K_train/LRx2_sub180
17 |
18 | use_shuffle: true
19 | n_workers: 6 # per GPU
20 | batch_size: 32
21 | GT_size: 128
22 | use_flip: true
23 | use_rot: true
24 | color: RGB
25 | val:
26 | name: Set5
27 | mode: LQGT
28 | dataroot_GT: ../datasets/Set5/HR
29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X2
30 |
31 | #### network structures
32 | network_G:
33 | which_model_G: PAN
34 | in_nc: 3
35 | out_nc: 3
36 | nf: 40
37 | unf: 24
38 | nb: 16
39 | scale: 2
40 |
41 | #### path
42 | path:
43 | pretrain_model_G: ~
44 | strict_load: true
45 | resume_state: ~
46 |
47 | #### training settings: learning rate scheme, loss
48 | train:
49 | lr_G: !!float 7e-4
50 | lr_scheme: CosineAnnealingLR_Restart
51 | beta1: 0.9
52 | beta2: 0.99
53 | niter: 1000000
54 | warmup_iter: -1 # no warm up
55 | T_period: [250000, 250000, 250000, 250000]
56 | restarts: [250000, 500000, 750000]
57 | restart_weights: [1, 1, 1]
58 | eta_min: !!float 1e-7
59 |
60 | pixel_criterion: l1
61 | pixel_weight: 1.0
62 |
63 | manual_seed: 10
64 | val_freq: !!float 5e3
65 |
66 | #### logger
67 | logger:
68 | print_freq: 100
69 | save_checkpoint_freq: !!float 5e3
--------------------------------------------------------------------------------
/codes/options/train/train_PANx3.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: PANx3_DF2K
3 | use_tb_logger: True
4 | model: sr
5 | distortion: sr
6 | scale: 3
7 | save_img: False
8 | gpu_ids: [0]
9 |
10 | #### datasets
11 | datasets:
12 | train:
13 | name: DF2K
14 | mode: LQGT
15 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DF2K_train/HRx3_sub360
16 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DF2K_train/LRx3_sub120
17 |
18 | use_shuffle: true
19 | n_workers: 6 # per GPU
20 | batch_size: 32
21 | GT_size: 192
22 | use_flip: true
23 | use_rot: true
24 | color: RGB
25 | val:
26 | name: Set5
27 | mode: LQGT
28 | dataroot_GT: ../datasets/Set5/HR
29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X3
30 |
31 | #### network structures
32 | network_G:
33 | which_model_G: PAN
34 | in_nc: 3
35 | out_nc: 3
36 | nf: 40
37 | unf: 24
38 | nb: 16
39 | scale: 3
40 |
41 | #### path
42 | path:
43 | pretrain_model_G: ~
44 | strict_load: true
45 | resume_state: ~
46 |
47 | #### training settings: learning rate scheme, loss
48 | train:
49 | lr_G: !!float 7e-4
50 | lr_scheme: CosineAnnealingLR_Restart
51 | beta1: 0.9
52 | beta2: 0.99
53 | niter: 1000000
54 | warmup_iter: -1 # no warm up
55 | T_period: [250000, 250000, 250000, 250000]
56 | restarts: [250000, 500000, 750000]
57 | restart_weights: [1, 1, 1]
58 | eta_min: !!float 1e-7
59 |
60 | pixel_criterion: l1
61 | pixel_weight: 1.0
62 |
63 | manual_seed: 10
64 | val_freq: !!float 5e3
65 |
66 | #### logger
67 | logger:
68 | print_freq: 100
69 | save_checkpoint_freq: !!float 5e3
--------------------------------------------------------------------------------
/codes/options/train/train_PANx4.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: PANx4_DF2K
3 | use_tb_logger: True
4 | model: sr
5 | distortion: sr
6 | scale: 4
7 | save_img: False
8 | gpu_ids: [0]
9 |
10 | #### datasets
11 | datasets:
12 | train:
13 | name: DF2K
14 | mode: LQGT
15 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DF2K_train/HR_sub360
16 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DF2K_train/LR_sub90
17 |
18 | use_shuffle: true
19 | n_workers: 6 # per GPU
20 | batch_size: 32
21 | GT_size: 256
22 | use_flip: true
23 | use_rot: true
24 | color: RGB
25 | val:
26 | name: Set5
27 | mode: LQGT
28 | dataroot_GT: ../datasets/Set5/HR
29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4
30 |
31 | #### network structures
32 | network_G:
33 | which_model_G: PAN
34 | in_nc: 3
35 | out_nc: 3
36 | nf: 40
37 | unf: 24
38 | nb: 16
39 | scale: 4
40 |
41 | #### path
42 | path:
43 | pretrain_model_G: ~
44 | strict_load: true
45 | resume_state: ~
46 |
47 | #### training settings: learning rate scheme, loss
48 | train:
49 | lr_G: !!float 7e-4
50 | lr_scheme: CosineAnnealingLR_Restart
51 | beta1: 0.9
52 | beta2: 0.99
53 | niter: 1000000
54 | warmup_iter: -1 # no warm up
55 | T_period: [250000, 250000, 250000, 250000]
56 | restarts: [250000, 500000, 750000]
57 | restart_weights: [1, 1, 1]
58 | eta_min: !!float 1e-7
59 |
60 | pixel_criterion: l1
61 | pixel_weight: 1.0
62 |
63 | manual_seed: 10
64 | val_freq: !!float 5e3
65 |
66 | #### logger
67 | logger:
68 | print_freq: 100
69 | save_checkpoint_freq: !!float 5e3
--------------------------------------------------------------------------------
/codes/options/train/train_RCAN.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: RCANPAx4_DIV2K
3 | use_tb_logger: True
4 | model: sr
5 | distortion: sr
6 | scale: 4
7 | gpu_ids: [0]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: DIV2K
13 | mode: LQGT
14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DIV2K_train800/HR_sub480
15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DIV2K_train800/LR_sub120
16 |
17 | use_shuffle: true
18 | n_workers: 6 # per GPU
19 | batch_size: 16
20 | GT_size: 192
21 | use_flip: true
22 | use_rot: true
23 | color: RGB
24 | val:
25 | name: Set5
26 | mode: LQGT
27 | dataroot_GT: ../datasets/Set5/HR
28 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4
29 |
30 | #### network structures
31 | network_G:
32 | which_model_G: RCAN_PA
33 | n_resgroups: 10
34 | n_resblocks: 20
35 | n_feats: 64
36 | res_scale: 1
37 | n_colors: 3
38 | rgb_range: 1
39 | scale: 4
40 | reduction: 16
41 |
42 | #### path
43 | path:
44 | pretrain_model_G: ~
45 | strict_load: true
46 | resume_state: ~
47 |
48 | #### training settings: learning rate scheme, loss
49 | train:
50 | lr_G: !!float 2e-5
51 | lr_scheme: CosineAnnealingLR_Restart
52 | beta1: 0.9
53 | beta2: 0.99
54 | niter: 200000
55 | warmup_iter: -1 # no warm up
56 | T_period: [200000, 200000, 200000, 200000]
57 | restarts: [200000, 400000, 600000]
58 | restart_weights: [1, 1, 1]
59 | eta_min: !!float 1e-7
60 |
61 | pixel_criterion: l1
62 | pixel_weight: 1.0
63 |
64 | manual_seed: 10
65 | val_freq: !!float 2e3
66 |
67 | #### logger
68 | logger:
69 | print_freq: 100
70 | save_checkpoint_freq: !!float 2e3
71 |
--------------------------------------------------------------------------------
/codes/options/train/train_SRResNet.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: MSSResNet_PA_DIV2K
3 | use_tb_logger: true
4 | model: sr
5 | distortion: sr
6 | scale: 4
7 | gpu_ids: [0]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: DIV2K
13 | mode: LQGT
14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DIV2K_train800/HR_sub480
15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DIV2K_train800/LR_sub120
16 |
17 | use_shuffle: true
18 | n_workers: 6 # per GPU
19 | batch_size: 16
20 | GT_size: 256
21 | use_flip: true
22 | use_rot: true
23 | color: RGB
24 | val:
25 | name: Set5
26 | mode: LQGT
27 | dataroot_GT: ../datasets/Set5/HR
28 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4
29 |
30 | #### network structures
31 | network_G:
32 | which_model_G: MSRResNet_PA
33 | in_nc: 3
34 | out_nc: 3
35 | nf: 64
36 | nb: 16
37 | upscale: 4
38 |
39 | #### path
40 | path:
41 | pretrain_model_G: ~
42 | strict_load: true
43 | resume_state: ~
44 |
45 | #### training settings: learning rate scheme, loss
46 | train:
47 | lr_G: !!float 2e-4
48 | lr_scheme: CosineAnnealingLR_Restart
49 | beta1: 0.9
50 | beta2: 0.99
51 | niter: 500000
52 | warmup_iter: -1 # no warm up
53 | T_period: [250000, 250000, 250000, 250000]
54 | restarts: [250000, 500000, 750000]
55 | restart_weights: [1, 1, 1]
56 | eta_min: !!float 1e-7
57 |
58 | pixel_criterion: l1
59 | pixel_weight: 1.0
60 |
61 | manual_seed: 10
62 | val_freq: !!float 5e3
63 |
64 | #### logger
65 | logger:
66 | print_freq: 100
67 | save_checkpoint_freq: !!float 5e3
68 |
--------------------------------------------------------------------------------
/codes/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 | torchsummaryX
4 | lmdb
5 | pyyaml
6 | tb-nightly
7 | future
8 |
--------------------------------------------------------------------------------
/codes/run_scripts.sh:
--------------------------------------------------------------------------------
1 | # Testing x2, x3, x4
2 | python test.py -opt options/test/test_PANx2.yml
3 | python test.py -opt options/test/test_PANx3.yml
4 | python test.py -opt options/test/test_PANx4.yml
5 |
6 |
7 | # Training x2, x3, x4
8 | python train.py -opt options/train/train_PANx2.yml
9 | python train.py -opt options/train/train_PANx3.yml
10 | python train.py -opt options/train/train_PANx4.yml
11 |
12 |
13 | # Training SRResNet_PA or RCAN_PA
14 | python train.py -opt options/train/train_SRResNet.yml
15 | python train.py -opt options/train/train_RCAN.yml
16 |
--------------------------------------------------------------------------------
/codes/scripts/back_projection/backprojection.m:
--------------------------------------------------------------------------------
1 | function [im_h] = backprojection(im_h, im_l, maxIter)
2 |
3 | [row_l, col_l,~] = size(im_l);
4 | [row_h, col_h,~] = size(im_h);
5 |
6 | p = fspecial('gaussian', 5, 1);
7 | p = p.^2;
8 | p = p./sum(p(:));
9 |
10 | im_l = double(im_l);
11 | im_h = double(im_h);
12 |
13 | for ii = 1:maxIter
14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
15 | im_diff = im_l - im_l_s;
16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
20 | end
21 |
--------------------------------------------------------------------------------
/codes/scripts/back_projection/main_bp.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 |
3 | LR_folder = './LR'; % LR
4 | preout_folder = './results'; % pre output
5 | save_folder = './results_20bp';
6 | filepaths = dir(fullfile(preout_folder, '*.png'));
7 | max_iter = 20;
8 |
9 | if ~ exist(save_folder, 'dir')
10 | mkdir(save_folder);
11 | end
12 |
13 | for idx_im = 1:length(filepaths)
14 | fprintf([num2str(idx_im) '\n']);
15 | im_name = filepaths(idx_im).name;
16 | im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17 | im_out = im2double(imread(fullfile(preout_folder, im_name)));
18 | %tic
19 | im_out = backprojection(im_out, im_LR, max_iter);
20 | %toc
21 | imwrite(im_out, fullfile(save_folder, im_name));
22 | end
23 |
--------------------------------------------------------------------------------
/codes/scripts/back_projection/main_reverse_filter.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 |
3 | LR_folder = './LR'; % LR
4 | preout_folder = './results'; % pre output
5 | save_folder = './results_20if';
6 | filepaths = dir(fullfile(preout_folder, '*.png'));
7 | max_iter = 20;
8 |
9 | if ~ exist(save_folder, 'dir')
10 | mkdir(save_folder);
11 | end
12 |
13 | for idx_im = 1:length(filepaths)
14 | fprintf([num2str(idx_im) '\n']);
15 | im_name = filepaths(idx_im).name;
16 | im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17 | im_out = im2double(imread(fullfile(preout_folder, im_name)));
18 | J = imresize(im_LR,4,'bicubic');
19 | %tic
20 | for m = 1:max_iter
21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
22 | end
23 | %toc
24 | imwrite(im_out, fullfile(save_folder, im_name));
25 | end
26 |
--------------------------------------------------------------------------------
/codes/scripts/transfer_params_MSRResNet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 | import torch
4 | try:
5 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
6 | import models.archs.SRResNet_arch as SRResNet_arch
7 | except ImportError:
8 | pass
9 |
10 | pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
11 | crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
12 | crt_net = crt_model.state_dict()
13 |
14 | for k, v in crt_net.items():
15 | if k in pretrained_net and 'upconv1' not in k:
16 | crt_net[k] = pretrained_net[k]
17 | print('replace ... ', k)
18 |
19 | # x4 -> x3
20 | crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
21 | crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
22 | crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
23 | crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
24 | crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
25 | crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
26 |
27 | torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
28 |
--------------------------------------------------------------------------------
/codes/test.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import logging
3 | import time
4 | import argparse
5 | from collections import OrderedDict
6 |
7 | import options.options as option
8 | import utils.util as util
9 | from data.util import bgr2ycbcr
10 | from data import create_dataset, create_dataloader
11 | from models import create_model
12 |
13 | #### options
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
16 | opt = option.parse(parser.parse_args().opt, is_train=False)
17 | opt = option.dict_to_nonedict(opt)
18 |
19 | util.mkdirs(
20 | (path for key, path in opt['path'].items()
21 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
22 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
23 | screen=True, tofile=True)
24 | logger = logging.getLogger('base')
25 | logger.info(option.dict2str(opt))
26 |
27 | #### Create test dataset and dataloader
28 | test_loaders = []
29 | for phase, dataset_opt in sorted(opt['datasets'].items()):
30 | test_set = create_dataset(dataset_opt)
31 | test_loader = create_dataloader(test_set, dataset_opt)
32 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
33 | test_loaders.append(test_loader)
34 |
35 | model = create_model(opt)
36 | for test_loader in test_loaders:
37 | test_set_name = test_loader.dataset.opt['name']
38 | logger.info('\nTesting [{:s}]...'.format(test_set_name))
39 | test_start_time = time.time()
40 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
41 | util.mkdir(dataset_dir)
42 |
43 | test_results = OrderedDict()
44 | test_results['psnr'] = []
45 | test_results['ssim'] = []
46 | test_results['psnr_y'] = []
47 | test_results['ssim_y'] = []
48 |
49 | for data in test_loader:
50 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
51 | model.feed_data(data, need_GT=need_GT)
52 | img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
53 | img_name = osp.splitext(osp.basename(img_path))[0]
54 |
55 | model.test()
56 | visuals = model.get_current_visuals(need_GT=need_GT)
57 |
58 | sr_img = util.tensor2img(visuals['rlt']) # uint8
59 |
60 | # save images
61 | suffix = opt['suffix']
62 | if suffix:
63 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
64 | else:
65 | save_img_path = osp.join(dataset_dir, img_name + '.png')
66 | if opt['save_img']:
67 | util.save_img(sr_img, save_img_path)
68 |
69 | # calculate PSNR and SSIM
70 | if need_GT:
71 | gt_img = util.tensor2img(visuals['GT'])
72 | sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
73 | psnr = util.calculate_psnr(sr_img, gt_img)
74 | ssim = util.calculate_ssim(sr_img, gt_img)
75 | test_results['psnr'].append(psnr)
76 | test_results['ssim'].append(ssim)
77 |
78 | if gt_img.shape[2] == 3: # RGB image
79 | sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
80 | gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)
81 |
82 | psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
83 | ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
84 | test_results['psnr_y'].append(psnr_y)
85 | test_results['ssim_y'].append(ssim_y)
86 | logger.info(
87 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
88 | format(img_name, psnr, ssim, psnr_y, ssim_y))
89 | else:
90 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
91 | else:
92 | logger.info(img_name)
93 |
94 | if need_GT: # metrics
95 | # Average PSNR/SSIM results
96 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
97 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
98 | logger.info(
99 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format(
100 | test_set_name, ave_psnr, ave_ssim))
101 | if test_results['psnr_y'] and test_results['ssim_y']:
102 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
103 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
104 | logger.info(
105 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'.
106 | format(ave_psnr_y, ave_ssim_y))
107 |
--------------------------------------------------------------------------------
/codes/test_running_time.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import time
4 | import torch
5 | import cv2
6 | import numpy as np
7 | import models.archs.PAN_arch as PAN_arch
8 | import utils.util as util
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = "7"
11 |
12 | def main():
13 | ## test dataset
14 | test_d = sorted(glob.glob('/mnt/hyzhao/Documents/datasets/DIV2K_test/*.png'))
15 |
16 | torch.cuda.current_device()
17 | torch.cuda.empty_cache()
18 | torch.backends.cudnn.benchmark = False
19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 |
21 | ## some functions
22 | def readimg(path):
23 | im = cv2.imread(path, cv2.IMREAD_UNCHANGED)
24 | im = im.astype(np.float32) / 255.
25 | im = im[:, :, [2, 1, 0]]
26 | return im
27 |
28 | def img2tensor(img):
29 | imgt = torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float()[None, ...]
30 | return imgt
31 |
32 | ## load model
33 | scale = 4
34 | model = PAN_arch.PAN(in_nc=3, out_nc=3, nf=40, unf=24, nb=16, scale=scale)
35 | model_weight = torch.load('../experiments/pretrained_models/PANx%d_DF2K.pth'%(scale))
36 | model.load_state_dict(model_weight, strict=True)
37 | model.eval()
38 | for k, v in model.named_parameters():
39 | v.requires_grad = False
40 | model = model.to(device)
41 |
42 | number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
43 |
44 | ## runnning
45 | print('-----------------Start Running-----------------')
46 | psnrs = []
47 | times = []
48 |
49 | start = torch.cuda.Event(enable_timing=True)
50 | end = torch.cuda.Event(enable_timing=True)
51 |
52 | for i in range(len(test_d)):
53 | im = readimg(test_d[i])
54 | img_LR = img2tensor(im)
55 | img_LR = img_LR.to(device)
56 |
57 | start.record()
58 | img_SR = model(img_LR)
59 | end.record()
60 |
61 | torch.cuda.synchronize()
62 | times.append(start.elapsed_time(end))
63 |
64 | sr_img = util.tensor2img(img_SR.detach())
65 |
66 | print('Image: %03d, Time: %.10f'%(i+1, times[-1]))
67 | print('Paramters: %d, Mean Time: %.10f'%(number_parameters, np.mean(times)/1000.))
68 |
69 | if __name__ == '__main__':
70 |
71 | main()
--------------------------------------------------------------------------------
/codes/test_summary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchsummaryX import summary
3 |
4 | import models.archs.PAN_arch as PAN_arch
5 |
6 | model = PAN_arch.PAN(in_nc=3, out_nc=3, nf=40, unf=24, nb=16, scale=4)
7 |
8 | summary(model, torch.zeros((1, 3, 32, 32)))
9 |
10 | # input LR x2, HR size is 720p
11 | # summary(model, torch.zeros((1, 3, 640, 360)))
12 |
13 | # input LR x3, HR size is 720p
14 | # summary(model, torch.zeros((1, 3, 426, 240)))
15 |
16 | # input LR x4, HR size is 720p
17 | # summary(model, torch.zeros((1, 3, 320, 180)))
18 |
--------------------------------------------------------------------------------
/codes/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__init__.py
--------------------------------------------------------------------------------
/codes/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/codes/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | import torch.nn.functional as F
6 | from datetime import datetime
7 | import random
8 | import logging
9 | from collections import OrderedDict
10 | import numpy as np
11 | import cv2
12 | import torch
13 | from torchvision.utils import make_grid
14 | from shutil import get_terminal_size
15 |
16 | import yaml
17 | try:
18 | from yaml import CLoader as Loader, CDumper as Dumper
19 | except ImportError:
20 | from yaml import Loader, Dumper
21 |
22 |
23 | def OrderedYaml():
24 | '''yaml orderedDict support'''
25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
26 |
27 | def dict_representer(dumper, data):
28 | return dumper.represent_dict(data.items())
29 |
30 | def dict_constructor(loader, node):
31 | return OrderedDict(loader.construct_pairs(node))
32 |
33 | Dumper.add_representer(OrderedDict, dict_representer)
34 | Loader.add_constructor(_mapping_tag, dict_constructor)
35 | return Loader, Dumper
36 |
37 |
38 | ####################
39 | # miscellaneous
40 | ####################
41 |
42 |
43 | def get_timestamp():
44 | return datetime.now().strftime('%y%m%d-%H%M%S')
45 |
46 |
47 | def mkdir(path):
48 | if not os.path.exists(path):
49 | os.makedirs(path)
50 |
51 |
52 | def mkdirs(paths):
53 | if isinstance(paths, str):
54 | mkdir(paths)
55 | else:
56 | for path in paths:
57 | mkdir(path)
58 |
59 |
60 | def mkdir_and_rename(path):
61 | if os.path.exists(path):
62 | new_name = path + '_archived_' + get_timestamp()
63 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
64 | logger = logging.getLogger('base')
65 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
66 | os.rename(path, new_name)
67 | os.makedirs(path)
68 |
69 |
70 | def set_random_seed(seed):
71 | random.seed(seed)
72 | np.random.seed(seed)
73 | torch.manual_seed(seed)
74 | torch.cuda.manual_seed_all(seed)
75 |
76 |
77 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
78 | '''set up logger'''
79 | lg = logging.getLogger(logger_name)
80 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
81 | datefmt='%y-%m-%d %H:%M:%S')
82 | lg.setLevel(level)
83 | if tofile:
84 | log_file = os.path.join(root, phase + 'train_{}.log'.format(get_timestamp()))
85 | fh = logging.FileHandler(log_file, mode='w')
86 | fh.setFormatter(formatter)
87 | lg.addHandler(fh)
88 | if screen:
89 | sh = logging.StreamHandler()
90 | sh.setFormatter(formatter)
91 | lg.addHandler(sh)
92 |
93 |
94 | ####################
95 | # image convert
96 | ####################
97 | def crop_border(img_list, crop_border):
98 | """Crop borders of images
99 | Args:
100 | img_list (list [Numpy]): HWC
101 | crop_border (int): crop border for each end of height and weight
102 |
103 | Returns:
104 | (list [Numpy]): cropped image list
105 | """
106 | if crop_border == 0:
107 | return img_list
108 | else:
109 | return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list]
110 |
111 |
112 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
113 | '''
114 | Converts a torch Tensor into an image Numpy array
115 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
116 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
117 | '''
118 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
119 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
120 | n_dim = tensor.dim()
121 | if n_dim == 4:
122 | n_img = len(tensor)
123 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
124 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
125 | elif n_dim == 3:
126 | img_np = tensor.numpy()
127 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
128 | elif n_dim == 2:
129 | img_np = tensor.numpy()
130 | else:
131 | raise TypeError(
132 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
133 | if out_type == np.uint8:
134 | img_np = (img_np * 255.0).round()
135 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
136 | return img_np.astype(out_type)
137 |
138 |
139 | def save_img(img, img_path, mode='RGB'):
140 | cv2.imwrite(img_path, img)
141 |
142 |
143 | def DUF_downsample(x, scale=4):
144 | """Downsamping with Gaussian kernel used in the DUF official code
145 |
146 | Args:
147 | x (Tensor, [B, T, C, H, W]): frames to be downsampled.
148 | scale (int): downsampling factor: 2 | 3 | 4.
149 | """
150 |
151 | assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale)
152 |
153 | def gkern(kernlen=13, nsig=1.6):
154 | import scipy.ndimage.filters as fi
155 | inp = np.zeros((kernlen, kernlen))
156 | # set element at the middle to one, a dirac delta
157 | inp[kernlen // 2, kernlen // 2] = 1
158 | # gaussian-smooth the dirac, resulting in a gaussian filter mask
159 | return fi.gaussian_filter(inp, nsig)
160 |
161 | B, T, C, H, W = x.size()
162 | x = x.view(-1, 1, H, W)
163 | pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter
164 | r_h, r_w = 0, 0
165 | if scale == 3:
166 | r_h = 3 - (H % 3)
167 | r_w = 3 - (W % 3)
168 | x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect')
169 |
170 | gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0)
171 | x = F.conv2d(x, gaussian_filter, stride=scale)
172 | x = x[:, :, 2:-2, 2:-2]
173 | x = x.view(B, T, C, x.size(2), x.size(3))
174 | return x
175 |
176 |
177 | def single_forward(model, inp):
178 | """PyTorch model forward (single test), it is just a simple warpper
179 | Args:
180 | model (PyTorch model)
181 | inp (Tensor): inputs defined by the model
182 |
183 | Returns:
184 | output (Tensor): outputs of the model. float, in CPU
185 | """
186 | with torch.no_grad():
187 | model_output = model(inp)
188 | if isinstance(model_output, list) or isinstance(model_output, tuple):
189 | output = model_output[0]
190 | else:
191 | output = model_output
192 | output = output.data.float().cpu()
193 | return output
194 |
195 |
196 | def flipx4_forward(model, inp):
197 | """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W
198 | Args:
199 | model (PyTorch model)
200 | inp (Tensor): inputs defined by the model
201 |
202 | Returns:
203 | output (Tensor): outputs of the model. float, in CPU
204 | """
205 | # normal
206 | output_f = single_forward(model, inp)
207 |
208 | # flip W
209 | output = single_forward(model, torch.flip(inp, (-1, )))
210 | output_f = output_f + torch.flip(output, (-1, ))
211 | # flip H
212 | output = single_forward(model, torch.flip(inp, (-2, )))
213 | output_f = output_f + torch.flip(output, (-2, ))
214 | # flip both H and W
215 | output = single_forward(model, torch.flip(inp, (-2, -1)))
216 | output_f = output_f + torch.flip(output, (-2, -1))
217 |
218 | return output_f / 4
219 |
220 | def flipxrot_forward(model, inp):
221 | # normal
222 | output_f = single_forward(model, inp)
223 |
224 | # flip W
225 | output = single_forward(model, torch.flip(inp, (-1, )))
226 | output_f = output_f + torch.flip(output, (-1, ))
227 |
228 | # flip H
229 | output = single_forward(model, torch.flip(inp, (-2, )))
230 | output_f = output_f + torch.flip(output, (-2, ))
231 |
232 | # flip both H and W
233 | output = single_forward(model, torch.flip(inp, (-2, -1)))
234 | output_f = output_f + torch.flip(output, (-2, -1))
235 |
236 | # rot90
237 | output = single_forward(model, torch.rot90(inp, 1, (-1, -2)))
238 | output_f = output_f + torch.rot90(output, 3, (-1, -2))
239 |
240 | # rot270
241 | output = single_forward(model, torch.rot90(inp, 3, (-1, -2)))
242 | output_f = output_f + torch.rot90(output, 1, (-1, -2))
243 |
244 | # flip W rot90
245 | output = single_forward(model, torch.rot90(torch.flip(inp, (-1, )), 1, (-1, -2)))
246 | output_f = output_f + torch.flip(torch.rot90(output, 3, (-1, -2)), (-1, ))
247 |
248 | # flip H rot90
249 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, )), 1, (-1, -2)))
250 | output_f = output_f + torch.flip(torch.rot90(output, 3, (-1, -2)), (-2, ))
251 |
252 | # flip both H and W rot90
253 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, -1)), 1, (-1, -2)))
254 | output_f = output_f + torch.flip(torch.rot90(output, 3, (-1, -2)), (-2, -1))
255 |
256 | # flip W rot270
257 | output = single_forward(model, torch.rot90(torch.flip(inp, (-1, )), 3, (-1, -2)))
258 | output_f = output_f + torch.flip(torch.rot90(output, 1, (-1, -2)), (-1, ))
259 |
260 | # flip H rot270
261 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, )), 3, (-1, -2)))
262 | output_f = output_f + torch.flip(torch.rot90(output, 1, (-1, -2)), (-2, ))
263 |
264 | # flip both H and W rot270
265 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, -1)), 3, (-1, -2)))
266 | output_f = output_f + torch.flip(torch.rot90(output, 1, (-1, -2)), (-2, -1))
267 |
268 | return output_f / 12
269 |
270 | ####################
271 | # metric
272 | ####################
273 |
274 |
275 | def calculate_psnr(img1, img2):
276 | # img1 and img2 have range [0, 255]
277 | img1 = img1.astype(np.float64)
278 | img2 = img2.astype(np.float64)
279 | mse = np.mean((img1 - img2)**2)
280 | if mse == 0:
281 | return float('inf')
282 | return 20 * math.log10(255.0 / math.sqrt(mse))
283 |
284 |
285 | def ssim(img1, img2):
286 | C1 = (0.01 * 255)**2
287 | C2 = (0.03 * 255)**2
288 |
289 | img1 = img1.astype(np.float64)
290 | img2 = img2.astype(np.float64)
291 | kernel = cv2.getGaussianKernel(11, 1.5)
292 | window = np.outer(kernel, kernel.transpose())
293 |
294 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
295 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
296 | mu1_sq = mu1**2
297 | mu2_sq = mu2**2
298 | mu1_mu2 = mu1 * mu2
299 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
300 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
301 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
302 |
303 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
304 | (sigma1_sq + sigma2_sq + C2))
305 | return ssim_map.mean()
306 |
307 |
308 | def calculate_ssim(img1, img2):
309 | '''calculate SSIM
310 | the same outputs as MATLAB's
311 | img1, img2: [0, 255]
312 | '''
313 | if not img1.shape == img2.shape:
314 | raise ValueError('Input images must have the same dimensions.')
315 | if img1.ndim == 2:
316 | return ssim(img1, img2)
317 | elif img1.ndim == 3:
318 | if img1.shape[2] == 3:
319 | ssims = []
320 | for i in range(3):
321 | ssims.append(ssim(img1, img2))
322 | return np.array(ssims).mean()
323 | elif img1.shape[2] == 1:
324 | return ssim(np.squeeze(img1), np.squeeze(img2))
325 | else:
326 | raise ValueError('Wrong input image dimensions.')
327 |
328 |
329 | class ProgressBar(object):
330 | '''A progress bar which can print the progress
331 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
332 | '''
333 |
334 | def __init__(self, task_num=0, bar_width=50, start=True):
335 | self.task_num = task_num
336 | max_bar_width = self._get_max_bar_width()
337 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
338 | self.completed = 0
339 | if start:
340 | self.start()
341 |
342 | def _get_max_bar_width(self):
343 | terminal_width, _ = get_terminal_size()
344 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
345 | if max_bar_width < 10:
346 | print('terminal width is too small ({}), please consider widen the terminal for better '
347 | 'progressbar visualization'.format(terminal_width))
348 | max_bar_width = 10
349 | return max_bar_width
350 |
351 | def start(self):
352 | if self.task_num > 0:
353 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
354 | ' ' * self.bar_width, self.task_num, 'Start...'))
355 | else:
356 | sys.stdout.write('completed: 0, elapsed: 0s')
357 | sys.stdout.flush()
358 | self.start_time = time.time()
359 |
360 | def update(self, msg='In progress...'):
361 | self.completed += 1
362 | elapsed = time.time() - self.start_time
363 | fps = self.completed / elapsed
364 | if self.task_num > 0:
365 | percentage = self.completed / float(self.task_num)
366 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
367 | mark_width = int(self.bar_width * percentage)
368 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
369 | sys.stdout.write('\033[2F') # cursor up 2 lines
370 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
371 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
372 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
373 | else:
374 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
375 | self.completed, int(elapsed + 0.5), fps))
376 | sys.stdout.flush()
377 |
--------------------------------------------------------------------------------
/datasets/Set14/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/.DS_Store
--------------------------------------------------------------------------------
/datasets/Set14/HR/baboon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/baboon.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/barbara.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/barbara.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/bridge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/bridge.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/coastguard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/coastguard.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/comic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/comic.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/face.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/flowers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/flowers.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/foreman.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/foreman.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/lenna.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/lenna.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/man.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/man.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/monarch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/monarch.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/pepper.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/pepper.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/ppt3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/ppt3.png
--------------------------------------------------------------------------------
/datasets/Set14/HR/zebra.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/zebra.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/baboonx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/baboonx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/barbarax2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/barbarax2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/bridgex2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/bridgex2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/coastguardx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/coastguardx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/comicx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/comicx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/facex2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/facex2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/flowersx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/flowersx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/foremanx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/foremanx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/lennax2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/lennax2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/manx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/manx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/monarchx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/monarchx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/pepperx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/pepperx2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/ppt3x2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/ppt3x2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X2/zebrax2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/zebrax2.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/baboonx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/baboonx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/barbarax3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/barbarax3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/bridgex3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/bridgex3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/coastguardx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/coastguardx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/comicx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/comicx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/facex3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/facex3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/flowersx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/flowersx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/foremanx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/foremanx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/lennax3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/lennax3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/manx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/manx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/monarchx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/monarchx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/pepperx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/pepperx3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/ppt3x3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/ppt3x3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X3/zebrax3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/zebrax3.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/baboonx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/baboonx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/barbarax4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/barbarax4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/bridgex4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/bridgex4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/coastguardx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/coastguardx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/comicx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/comicx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/facex4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/facex4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/flowersx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/flowersx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/foremanx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/foremanx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/lennax4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/lennax4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/manx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/manx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/monarchx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/monarchx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/pepperx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/pepperx4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/ppt3x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/ppt3x4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X4/zebrax4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/zebrax4.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/baboonx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/baboonx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/barbarax8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/barbarax8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/bridgex8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/bridgex8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/coastguardx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/coastguardx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/comicx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/comicx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/facex8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/facex8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/flowersx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/flowersx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/foremanx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/foremanx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/lennax8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/lennax8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/manx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/manx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/monarchx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/monarchx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/pepperx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/pepperx8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/ppt3x8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/ppt3x8.png
--------------------------------------------------------------------------------
/datasets/Set14/LR_bicubic/X8/zebrax8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/zebrax8.png
--------------------------------------------------------------------------------
/datasets/Set5/HR/baby.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/baby.png
--------------------------------------------------------------------------------
/datasets/Set5/HR/bird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/bird.png
--------------------------------------------------------------------------------
/datasets/Set5/HR/butterfly.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/butterfly.png
--------------------------------------------------------------------------------
/datasets/Set5/HR/head.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/head.png
--------------------------------------------------------------------------------
/datasets/Set5/HR/woman.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/woman.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X2/babyx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/babyx2.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X2/birdx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/birdx2.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X2/butterflyx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/butterflyx2.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X2/headx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/headx2.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X2/womanx2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/womanx2.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X3/babyx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/babyx3.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X3/birdx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/birdx3.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X3/butterflyx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/butterflyx3.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X3/headx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/headx3.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X3/womanx3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/womanx3.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X4/babyx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/babyx4.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X4/birdx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/birdx4.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X4/butterflyx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/butterflyx4.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X4/headx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/headx4.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X4/womanx4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/womanx4.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X8/babyx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/babyx8.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X8/birdx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/birdx8.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X8/butterflyx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/butterflyx8.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X8/headx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/headx8.png
--------------------------------------------------------------------------------
/datasets/Set5/LR_bicubic/X8/womanx8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/womanx8.png
--------------------------------------------------------------------------------
/experiments/pretrained_models/PANx2_DF2K.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/experiments/pretrained_models/PANx2_DF2K.pth
--------------------------------------------------------------------------------
/experiments/pretrained_models/PANx3_DF2K.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/experiments/pretrained_models/PANx3_DF2K.pth
--------------------------------------------------------------------------------
/experiments/pretrained_models/PANx4_DF2K.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/experiments/pretrained_models/PANx4_DF2K.pth
--------------------------------------------------------------------------------
/show_figs/main.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/show_figs/main.jpg
--------------------------------------------------------------------------------
/show_figs/main.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/show_figs/main.pdf
--------------------------------------------------------------------------------