├── .gitignore ├── README.md ├── codes ├── .ipynb_checkpoints │ └── Test_demo-checkpoint.ipynb ├── Test_demo.ipynb ├── data │ ├── LQGT_dataset.py │ ├── LQ_dataset.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── Color_dataset.cpython-37.pyc │ │ ├── ContinueLQGT_dataset.cpython-37.pyc │ │ ├── LQGT_dataset.cpython-36.pyc │ │ ├── LQGT_dataset.cpython-37.pyc │ │ ├── LQ_dataset.cpython-37.pyc │ │ ├── Vimeo90K_dataset.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── data_sampler.cpython-36.pyc │ │ ├── data_sampler.cpython-37.pyc │ │ ├── util.cpython-36.pyc │ │ ├── util.cpython-37.pyc │ │ └── video_test_dataset.cpython-37.pyc │ ├── data_sampler.py │ └── util.py ├── data_scripts │ ├── __pycache__ │ │ └── generate_mod_LR_bic.cpython-36.pyc │ ├── create_lmdb.py │ ├── extract_subimages.py │ ├── generate_mod_LR_bic.m │ ├── generate_mod_LR_bic.py │ ├── prepare_DIV2K_x4_dataset.sh │ ├── rename.py │ └── test_dataloader.py ├── metrics │ ├── calculate_PSNR_SSIM.m │ └── calculate_PSNR_SSIM.py ├── models │ ├── SR_model.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── SRGAN_model.cpython-37.pyc │ │ ├── SR_model.cpython-36.pyc │ │ ├── SR_model.cpython-37.pyc │ │ ├── Video_base_model.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── base_model.cpython-37.pyc │ │ ├── loss.cpython-36.pyc │ │ ├── loss.cpython-37.pyc │ │ ├── lr_scheduler.cpython-36.pyc │ │ ├── lr_scheduler.cpython-37.pyc │ │ ├── networks.cpython-36.pyc │ │ └── networks.cpython-37.pyc │ ├── archs │ │ ├── PAN_arch.py │ │ ├── RCAN_arch.py │ │ ├── SRResNet_arch.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── AWSRN_arch.cpython-37.pyc │ │ │ ├── DUF_arch.cpython-37.pyc │ │ │ ├── Dual_arch.cpython-37.pyc │ │ │ ├── EDVR_arch.cpython-36.pyc │ │ │ ├── EfficientSR_arch.cpython-37.pyc │ │ │ ├── EfficientSR_clean.cpython-37.pyc │ │ │ ├── FSRCNN_arch.cpython-37.pyc │ │ │ ├── KernelMD_arch.cpython-37.pyc │ │ │ ├── MSSResNet_deblur_arch.cpython-37.pyc │ │ │ ├── Octave_arch.cpython-37.pyc │ │ │ ├── PAN_arch.cpython-36.pyc │ │ │ ├── PAN_arch.cpython-37.pyc │ │ │ ├── PAN_arch_update.cpython-37.pyc │ │ │ ├── PANet_arch.cpython-37.pyc │ │ │ ├── PANv2_arch.cpython-37.pyc │ │ │ ├── RCAN_arch.cpython-36.pyc │ │ │ ├── RCAN_arch.cpython-37.pyc │ │ │ ├── RRDBNet_arch.cpython-36.pyc │ │ │ ├── RRDBNet_arch.cpython-37.pyc │ │ │ ├── SRResNet_arch.cpython-36.pyc │ │ │ ├── SRResNet_arch.cpython-37.pyc │ │ │ ├── UNet_arch.cpython-37.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── arch_util.cpython-36.pyc │ │ │ ├── arch_util.cpython-37.pyc │ │ │ ├── discriminator_vgg_arch.cpython-36.pyc │ │ │ ├── discriminator_vgg_arch.cpython-37.pyc │ │ │ ├── unet_arch.cpython-37.pyc │ │ │ └── unet_parts.cpython-37.pyc │ │ ├── arch_util.py │ │ └── dcn │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── deform_conv.cpython-36.pyc │ │ │ ├── deform_conv.egg-info │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ ├── not-zip-safe │ │ │ └── top_level.txt │ │ │ ├── deform_conv.py │ │ │ ├── setup.py │ │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ └── deform_conv_cuda_kernel.cu │ ├── base_model.py │ ├── loss.py │ ├── lr_scheduler.py │ └── networks.py ├── options │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── options.cpython-36.pyc │ │ └── options.cpython-37.pyc │ ├── options.py │ ├── test │ │ ├── test_PANx2.yml │ │ ├── test_PANx3.yml │ │ ├── test_PANx4.yml │ │ ├── test_RCAN.yml │ │ └── test_SRResNet.yml │ └── train │ │ ├── train_PANx2.yml │ │ ├── train_PANx3.yml │ │ ├── train_PANx4.yml │ │ ├── train_RCAN.yml │ │ └── train_SRResNet.yml ├── requirements.txt ├── run_scripts.sh ├── scripts │ ├── back_projection │ │ ├── backprojection.m │ │ ├── main_bp.m │ │ └── main_reverse_filter.m │ └── transfer_params_MSRResNet.py ├── test.py ├── test_running_time.py ├── test_summary.py ├── train.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── util.cpython-36.pyc │ └── util.cpython-37.pyc │ └── util.py ├── datasets ├── Set14 │ ├── .DS_Store │ ├── HR │ │ ├── baboon.png │ │ ├── barbara.png │ │ ├── bridge.png │ │ ├── coastguard.png │ │ ├── comic.png │ │ ├── face.png │ │ ├── flowers.png │ │ ├── foreman.png │ │ ├── lenna.png │ │ ├── man.png │ │ ├── monarch.png │ │ ├── pepper.png │ │ ├── ppt3.png │ │ └── zebra.png │ └── LR_bicubic │ │ ├── X2 │ │ ├── baboonx2.png │ │ ├── barbarax2.png │ │ ├── bridgex2.png │ │ ├── coastguardx2.png │ │ ├── comicx2.png │ │ ├── facex2.png │ │ ├── flowersx2.png │ │ ├── foremanx2.png │ │ ├── lennax2.png │ │ ├── manx2.png │ │ ├── monarchx2.png │ │ ├── pepperx2.png │ │ ├── ppt3x2.png │ │ └── zebrax2.png │ │ ├── X3 │ │ ├── baboonx3.png │ │ ├── barbarax3.png │ │ ├── bridgex3.png │ │ ├── coastguardx3.png │ │ ├── comicx3.png │ │ ├── facex3.png │ │ ├── flowersx3.png │ │ ├── foremanx3.png │ │ ├── lennax3.png │ │ ├── manx3.png │ │ ├── monarchx3.png │ │ ├── pepperx3.png │ │ ├── ppt3x3.png │ │ └── zebrax3.png │ │ ├── X4 │ │ ├── baboonx4.png │ │ ├── barbarax4.png │ │ ├── bridgex4.png │ │ ├── coastguardx4.png │ │ ├── comicx4.png │ │ ├── facex4.png │ │ ├── flowersx4.png │ │ ├── foremanx4.png │ │ ├── lennax4.png │ │ ├── manx4.png │ │ ├── monarchx4.png │ │ ├── pepperx4.png │ │ ├── ppt3x4.png │ │ └── zebrax4.png │ │ └── X8 │ │ ├── baboonx8.png │ │ ├── barbarax8.png │ │ ├── bridgex8.png │ │ ├── coastguardx8.png │ │ ├── comicx8.png │ │ ├── facex8.png │ │ ├── flowersx8.png │ │ ├── foremanx8.png │ │ ├── lennax8.png │ │ ├── manx8.png │ │ ├── monarchx8.png │ │ ├── pepperx8.png │ │ ├── ppt3x8.png │ │ └── zebrax8.png └── Set5 │ ├── HR │ ├── baby.png │ ├── bird.png │ ├── butterfly.png │ ├── head.png │ └── woman.png │ └── LR_bicubic │ ├── X2 │ ├── babyx2.png │ ├── birdx2.png │ ├── butterflyx2.png │ ├── headx2.png │ └── womanx2.png │ ├── X3 │ ├── babyx3.png │ ├── birdx3.png │ ├── butterflyx3.png │ ├── headx3.png │ └── womanx3.png │ ├── X4 │ ├── babyx4.png │ ├── birdx4.png │ ├── butterflyx4.png │ ├── headx4.png │ └── womanx4.png │ └── X8 │ ├── babyx8.png │ ├── birdx8.png │ ├── butterflyx8.png │ ├── headx8.png │ └── womanx8.png ├── experiments └── pretrained_models │ ├── PANx2_DF2K.pth │ ├── PANx3_DF2K.pth │ └── PANx4_DF2K.pth ├── results ├── PANx2_DF2K │ └── test_PANx2_DF2Ktrain_200912-232704.log ├── PANx3_DF2K │ └── test_PANx3_DF2Ktrain_200912-232635.log └── PANx4_DF2K │ └── test_PANx4_DF2Ktrain_200912-232209.log └── show_figs ├── main.jpg └── main.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PAN [🔥 272K parameters] 2 | ### Lowest parameters in AIM2020 Efficient Super Resolution. 3 | 4 | ## [Paper](https://arxiv.org/abs/2010.01073) | [Video](https://www.bilibili.com/video/BV1Qh411R7vZ/) 5 | ## Efficient Image Super-Resolution Using Pixel Attention 6 | Authors: Hengyuan Zhao, [Xiangtao Kong](https://github.com/Xiangtaokong), [Jingwen He](https://github.com/hejingwenhejingwen), [Yu Qiao](https://scholar.google.com/citations?user=gFtI-8QAAAAJ&hl=zh-CN), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=zh-CN) 7 | 8 | 9 |

10 | 11 |

12 | 13 | ## Dependencies 14 | 15 | - Python >= 3.6 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 16 | - [PyTorch >= 1.5.0](https://pytorch.org/) 17 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 18 | - Python packages: `pip install numpy opencv-python lmdb` 19 | - [option] Python packages: [`pip install tensorboardX`](https://github.com/lanpa/tensorboardX), for visualizing curves. 20 | 21 | # Codes 22 | - Our codes version based on [mmsr](https://github.com/open-mmlab/mmsr). 23 | - This codes provide the testing and training code. 24 | 25 | 26 | 27 | ## How to Test 28 | 1. Clone this github repo. 29 | ``` 30 | git clone https://github.com/zhaohengyuan1/PAN.git 31 | cd PAN 32 | ``` 33 | 2. Download the five test datasets (Set5, Set14, B100, Urban100, Manga109) from [Google Drive](https://drive.google.com/drive/folders/1lsoyAjsUEyp7gm1t6vZI9j7jr9YzKzcF?usp=sharing) 34 | 35 | 3. Pretrained models have be placed in `./experiments/pretrained_models/` folder. More models can be download from [Google Drive](https://drive.google.com/drive/folders/1_zZqTvvAb_ad4T4-uiIGF9CkNiPrBXGr?usp=sharing). 36 | 37 | 4. Run test. We provide `x2,x3,x4` pretrained models. 38 | ``` 39 | cd codes 40 | python test.py -opt option/test/test_PANx4.yml 41 | ``` 42 | More testing commonds can be found in `./codes/run_scripts.sh` file. 43 | 5. The output results will be sorted in `./results`. (We have been put our testing log file in `./results`) We also provide our testing results on five benchmark datasets on [Google Drive](https://drive.google.com/drive/folders/1F6unBkp6L1oJb_gOgSHYM5ZZbyLImDPH?usp=sharing). 44 | 45 | ## How to Train 46 | 47 | 1. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](https://github.com/LimBee/NTIRE2017) from [Google Drive](https://drive.google.com/drive/folders/1B-uaxvV9qeuQ-t7MFiN1oEdA6dKnj2vW?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1CFIML6KfQVYGZSNFrhMXmA) 48 | 49 | 2. Generate Training patches. Modified the path of your training datasets in `./codes/data_scripts/extract_subimages.py` file. 50 | 51 | 3. Run Training. 52 | 53 | ``` 54 | python train.py -opt options/train/train_PANx4.yml 55 | ``` 56 | 4. More training commond can be found in `./codes/run_scripts.sh` file. 57 | 58 | ## Testing the Parameters, Mult-Adds and Running Time 59 | 60 | 1. Testing the parameters and Mult-Adds. 61 | ``` 62 | python test_summary.py 63 | ``` 64 | 65 | 2. Testing the Running Time. 66 | 67 | ``` 68 | python test_running_time.py 69 | ``` 70 | 71 | ## Related Work on AIM2020 72 | Enhanced Quadratic Video Interpolation (winning solution of AIM2020 VTSR Challenge) 73 | [paper](https://arxiv.org/pdf/2009.04642.pdf) | [code](https://github.com/lyh-18/EQVI) 74 | 75 | ## Contact 76 | Email: hubylidayuan@gmail.com 77 | 78 | If you find our work is useful, please kindly cite it. 79 | ``` 80 | @inproceedings{zhao2020efficient, 81 | title={Efficient image super-resolution using pixel attention}, 82 | author={Zhao, Hengyuan and Kong, Xiangtao and He, Jingwen and Qiao, Yu and Dong, Chao}, 83 | booktitle={European Conference on Computer Vision}, 84 | pages={56--72}, 85 | year={2020}, 86 | organization={Springer} 87 | } 88 | ``` 89 | 90 | -------------------------------------------------------------------------------- /codes/data/LQGT_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | 9 | 10 | class LQGTDataset(data.Dataset): 11 | """ 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. 13 | If only GT images are provided, generate LQ images on-the-fly. 14 | """ 15 | 16 | def __init__(self, opt): 17 | super(LQGTDataset, self).__init__() 18 | self.opt = opt 19 | self.data_type = self.opt['data_type'] 20 | self.paths_LQ, self.paths_GT = None, None 21 | self.sizes_LQ, self.sizes_GT = None, None 22 | self.LQ_env, self.GT_env = None, None # environments for lmdb 23 | 24 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) 25 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 26 | 27 | assert self.paths_GT, 'Error: GT path is empty.' 28 | if self.paths_LQ and self.paths_GT: 29 | assert len(self.paths_LQ) == len( 30 | self.paths_GT 31 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format( 32 | len(self.paths_LQ), len(self.paths_GT)) 33 | self.random_scale_list = [1] 34 | 35 | def _init_lmdb(self): 36 | # https://github.com/chainer/chainermn/issues/129 37 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 38 | meminit=False) 39 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 40 | meminit=False) 41 | 42 | def __getitem__(self, index): 43 | if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): 44 | self._init_lmdb() 45 | GT_path, LQ_path = None, None 46 | scale = self.opt['scale'] 47 | GT_size = self.opt['GT_size'] 48 | 49 | # get GT image 50 | GT_path = self.paths_GT[index] 51 | resolution = [int(s) for s in self.sizes_GT[index].split('_') 52 | ] if self.data_type == 'lmdb' else None 53 | img_GT = util.read_img(self.GT_env, GT_path, resolution) 54 | if img_GT.shape[2] == 1: 55 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) 56 | if self.opt['phase'] != 'train': # modcrop in the validation / test phase 57 | img_GT = util.modcrop(img_GT, scale) 58 | if self.opt['color']: # change color space if necessary 59 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 60 | 61 | # get LQ image 62 | if self.paths_LQ: 63 | LQ_path = self.paths_LQ[index] 64 | resolution = [int(s) for s in self.sizes_LQ[index].split('_') 65 | ] if self.data_type == 'lmdb' else None 66 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 67 | if img_LQ.shape[2] == 1: 68 | # for test gray to color 69 | img_LQ = cv2.cvtColor(img_LQ, cv2.COLOR_GRAY2BGR) 70 | else: # down-sampling on-the-fly 71 | # randomly scale during training 72 | if self.opt['phase'] == 'train': 73 | random_scale = random.choice(self.random_scale_list) 74 | H_s, W_s, _ = img_GT.shape 75 | 76 | def _mod(n, random_scale, scale, thres): 77 | rlt = int(n * random_scale) 78 | rlt = (rlt // scale) * scale 79 | return thres if rlt < thres else rlt 80 | 81 | H_s = _mod(H_s, random_scale, scale, GT_size) 82 | W_s = _mod(W_s, random_scale, scale, GT_size) 83 | img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) 84 | if img_GT.ndim == 2: 85 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) 86 | 87 | H, W, _ = img_GT.shape 88 | # using matlab imresize 89 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 90 | if img_LQ.ndim == 2: 91 | img_LQ = np.expand_dims(img_LQ, axis=2) 92 | 93 | if self.opt['phase'] == 'train': 94 | # if the image size is too small 95 | H, W, _ = img_GT.shape 96 | if H < GT_size or W < GT_size: 97 | img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) 98 | # using matlab imresize 99 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 100 | if img_LQ.ndim == 2: 101 | img_LQ = np.expand_dims(img_LQ, axis=2) 102 | 103 | H, W, C = img_LQ.shape 104 | LQ_size = GT_size // scale 105 | # randomly crop 106 | rnd_h = random.randint(0, max(0, H - LQ_size)) 107 | rnd_w = random.randint(0, max(0, W - LQ_size)) 108 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] 109 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) 110 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] 111 | 112 | # print(img_GT.shape, img_LQ.shape) 113 | # augmentation - flip, rotate 114 | img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], 115 | self.opt['use_rot']) 116 | 117 | if self.opt['color']: # change color space if necessary 118 | img_LQ = util.channel_convert(C, self.opt['color'], 119 | [img_LQ])[0] # TODO during val no definition 120 | 121 | # BGR to RGB, HWC to CHW, numpy to tensor 122 | if img_GT.shape[2] == 3: 123 | img_GT = img_GT[:, :, [2, 1, 0]] 124 | img_LQ = img_LQ[:, :, [2, 1, 0]] 125 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 126 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 127 | 128 | if LQ_path is None: 129 | LQ_path = GT_path 130 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} 131 | 132 | def __len__(self): 133 | return len(self.paths_GT) 134 | -------------------------------------------------------------------------------- /codes/data/LQ_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import lmdb 3 | import torch 4 | import torch.utils.data as data 5 | import data.util as util 6 | 7 | 8 | class LQDataset(data.Dataset): 9 | '''Read LQ images only in the test phase.''' 10 | 11 | def __init__(self, opt): 12 | super(LQDataset, self).__init__() 13 | self.opt = opt 14 | self.data_type = self.opt['data_type'] 15 | self.paths_LQ, self.paths_GT = None, None 16 | self.LQ_env = None # environment for lmdb 17 | 18 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 19 | assert self.paths_LQ, 'Error: LQ paths are empty.' 20 | 21 | def _init_lmdb(self): 22 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 23 | meminit=False) 24 | 25 | def __getitem__(self, index): 26 | if self.data_type == 'lmdb' and self.LQ_env is None: 27 | self._init_lmdb() 28 | LQ_path = None 29 | 30 | # get LQ image 31 | LQ_path = self.paths_LQ[index] 32 | resolution = [int(s) for s in self.sizes_LQ[index].split('_') 33 | ] if self.data_type == 'lmdb' else None 34 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 35 | H, W, C = img_LQ.shape 36 | 37 | if self.opt['color']: # change color space if necessary 38 | img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0] 39 | 40 | # BGR to RGB, HWC to CHW, numpy to tensor 41 | if img_LQ.shape[2] == 3: 42 | img_LQ = img_LQ[:, :, [2, 1, 0]] 43 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 44 | 45 | return {'LQ': img_LQ, 'LQ_path': LQ_path} 46 | 47 | def __len__(self): 48 | return len(self.paths_LQ) 49 | -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | """create dataset and dataloader""" 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=False) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 25 | pin_memory=False) 26 | 27 | 28 | def create_dataset(dataset_opt): 29 | mode = dataset_opt['mode'] 30 | # datasets for image restoration 31 | if mode == 'LQ': 32 | from data.LQ_dataset import LQDataset as D 33 | elif mode == 'LQGT': 34 | from data.LQGT_dataset import LQGTDataset as D 35 | elif mode == 'Color': 36 | from data.Color_dataset import ColorDataset as D 37 | elif mode == 'ContinueLQGT': 38 | from data.ContinueLQGT_dataset import ContinueLQGTDataset as D 39 | # datasets for video restoration 40 | elif mode == 'REDS': 41 | from data.REDS_dataset import REDSDataset as D 42 | elif mode == 'Vimeo90K': 43 | from data.Vimeo90K_dataset import Vimeo90KDataset as D 44 | elif mode == 'video_test': 45 | from data.video_test_dataset import VideoTestDataset as D 46 | else: 47 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 48 | dataset = D(dataset_opt) 49 | 50 | logger = logging.getLogger('base') 51 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 52 | dataset_opt['name'])) 53 | return dataset 54 | -------------------------------------------------------------------------------- /codes/data/__pycache__/Color_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/Color_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/ContinueLQGT_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/ContinueLQGT_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/LQGT_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/LQGT_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/LQGT_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/LQGT_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/LQ_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/LQ_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/Vimeo90K_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/Vimeo90K_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/data_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/data_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/data_sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/data_sampler.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/video_test_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data/__pycache__/video_test_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /codes/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /codes/data_scripts/__pycache__/generate_mod_LR_bic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/data_scripts/__pycache__/generate_mod_LR_bic.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data_scripts/extract_subimages.py: -------------------------------------------------------------------------------- 1 | """A multi-thread tool to crop large images to sub-images for faster IO.""" 2 | import os 3 | import os.path as osp 4 | import sys 5 | from multiprocessing import Pool 6 | import numpy as np 7 | import cv2 8 | from PIL import Image 9 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 10 | from utils.util import ProgressBar # noqa: E402 11 | import data.util as data_util # noqa: E402 12 | 13 | 14 | def main(): 15 | mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs) 16 | opt = {} 17 | opt['n_thread'] = 20 18 | opt['compression_level'] = 3 # 3 is the default value in cv2 19 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer 20 | # compression time. If read raw images during training, use 0 for faster IO speed. 21 | if mode == 'single': 22 | # opt['input_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_LR_bicubic/X4_blur' 23 | # opt['save_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_LR_bicubic/X4_blur_sub' 24 | opt['input_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_Bic' 25 | opt['save_folder'] = '/mnt/hyzhao/Documents/datasets/DIV2K_train800/DIV2K_train_Bic/Bic_sub480' 26 | opt['crop_sz'] = 480 # the size of each sub-image 27 | opt['step'] = 120 # step of the sliding crop window 28 | opt['thres_sz'] = 48 # size threshold 29 | extract_signle(opt) 30 | elif mode == 'pair': 31 | # GT_folder = '../../datasets/DIV2K/DIV2K_train_HR' 32 | # LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' 33 | # save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub' 34 | # save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' 35 | 36 | GT_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/HR' 37 | LR_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/LR/X3' 38 | save_GT_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/HRx3_sub360' 39 | save_LR_folder = '/mnt/hyzhao/Documents/datasets/DF2K_train/LRx3_sub120' 40 | 41 | scale_ratio = 3 42 | crop_sz = 360 # the size of each sub-image (GT) 43 | step = 180 # step of the sliding crop window (GT) 44 | 45 | thres_sz = 48 # size threshold 46 | ######################################################################## 47 | # check that all the GT and LR images have correct scale ratio 48 | img_GT_list = data_util._get_paths_from_images(GT_folder) 49 | img_LR_list = data_util._get_paths_from_images(LR_folder) 50 | assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.' 51 | for path_GT, path_LR in zip(img_GT_list, img_LR_list): 52 | img_GT = Image.open(path_GT) 53 | img_LR = Image.open(path_LR) 54 | w_GT, h_GT = img_GT.size 55 | w_LR, h_LR = img_LR.size 56 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 57 | w_GT, scale_ratio, w_LR, path_GT) 58 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 59 | w_GT, scale_ratio, w_LR, path_GT) 60 | # check crop size, step and threshold size 61 | assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format( 62 | scale_ratio) 63 | assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio) 64 | assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format( 65 | scale_ratio) 66 | print('process GT...') 67 | opt['input_folder'] = GT_folder 68 | opt['save_folder'] = save_GT_folder 69 | opt['crop_sz'] = crop_sz 70 | opt['step'] = step 71 | opt['thres_sz'] = thres_sz 72 | extract_signle(opt) 73 | print('process LR...') 74 | opt['input_folder'] = LR_folder 75 | opt['save_folder'] = save_LR_folder 76 | opt['crop_sz'] = crop_sz // scale_ratio 77 | opt['step'] = step // scale_ratio 78 | opt['thres_sz'] = thres_sz // scale_ratio 79 | extract_signle(opt) 80 | assert len(data_util._get_paths_from_images(save_GT_folder)) == len( 81 | data_util._get_paths_from_images( 82 | save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.' 83 | else: 84 | raise ValueError('Wrong mode.') 85 | 86 | 87 | def extract_signle(opt): 88 | input_folder = opt['input_folder'] 89 | save_folder = opt['save_folder'] 90 | if not osp.exists(save_folder): 91 | os.makedirs(save_folder) 92 | print('mkdir [{:s}] ...'.format(save_folder)) 93 | else: 94 | print('Folder [{:s}] already exists. Exit...'.format(save_folder)) 95 | sys.exit(1) 96 | img_list = data_util._get_paths_from_images(input_folder) 97 | 98 | def update(arg): 99 | pbar.update(arg) 100 | 101 | pbar = ProgressBar(len(img_list)) 102 | 103 | pool = Pool(opt['n_thread']) 104 | for path in img_list: 105 | pool.apply_async(worker, args=(path, opt), callback=update) 106 | pool.close() 107 | pool.join() 108 | print('All subprocesses done.') 109 | 110 | 111 | def worker(path, opt): 112 | crop_sz = opt['crop_sz'] 113 | step = opt['step'] 114 | thres_sz = opt['thres_sz'] 115 | img_name = osp.basename(path) 116 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 117 | 118 | n_channels = len(img.shape) 119 | if n_channels == 2: 120 | h, w = img.shape 121 | elif n_channels == 3: 122 | h, w, c = img.shape 123 | else: 124 | raise ValueError('Wrong image shape - {}'.format(n_channels)) 125 | 126 | h_space = np.arange(0, h - crop_sz + 1, step) 127 | if h - (h_space[-1] + crop_sz) > thres_sz: 128 | h_space = np.append(h_space, h - crop_sz) 129 | w_space = np.arange(0, w - crop_sz + 1, step) 130 | if w - (w_space[-1] + crop_sz) > thres_sz: 131 | w_space = np.append(w_space, w - crop_sz) 132 | 133 | index = 0 134 | for x in h_space: 135 | for y in w_space: 136 | index += 1 137 | if n_channels == 2: 138 | crop_img = img[x:x + crop_sz, y:y + crop_sz] 139 | else: 140 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :] 141 | crop_img = np.ascontiguousarray(crop_img) 142 | cv2.imwrite( 143 | osp.join(opt['save_folder'], 144 | img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img, 145 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 146 | return 'Processing {:s} ...'.format(img_name) 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_mod_LR_bic.m: -------------------------------------------------------------------------------- 1 | function generate_mod_LR_bic() 2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. 3 | 4 | %% set parameters 5 | % comment the unnecessary line 6 | input_folder = '../../datasets/DIV2K/DIV2K800'; 7 | % save_mod_folder = ''; 8 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4'; 9 | % save_bic_folder = ''; 10 | 11 | up_scale = 4; 12 | mod_scale = 4; 13 | 14 | if exist('save_mod_folder', 'var') 15 | if exist(save_mod_folder, 'dir') 16 | disp(['It will cover ', save_mod_folder]); 17 | else 18 | mkdir(save_mod_folder); 19 | end 20 | end 21 | if exist('save_LR_folder', 'var') 22 | if exist(save_LR_folder, 'dir') 23 | disp(['It will cover ', save_LR_folder]); 24 | else 25 | mkdir(save_LR_folder); 26 | end 27 | end 28 | if exist('save_bic_folder', 'var') 29 | if exist(save_bic_folder, 'dir') 30 | disp(['It will cover ', save_bic_folder]); 31 | else 32 | mkdir(save_bic_folder); 33 | end 34 | end 35 | 36 | idx = 0; 37 | filepaths = dir(fullfile(input_folder,'*.*')); 38 | for i = 1 : length(filepaths) 39 | [paths,imname,ext] = fileparts(filepaths(i).name); 40 | if isempty(imname) 41 | disp('Ignore . folder.'); 42 | elseif strcmp(imname, '.') 43 | disp('Ignore .. folder.'); 44 | else 45 | idx = idx + 1; 46 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 47 | fprintf(str_rlt); 48 | % read image 49 | img = imread(fullfile(input_folder, [imname, ext])); 50 | img = im2double(img); 51 | % modcrop 52 | img = modcrop(img, mod_scale); 53 | if exist('save_mod_folder', 'var') 54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); 55 | end 56 | % LR 57 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 58 | if exist('save_LR_folder', 'var') 59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 60 | end 61 | % Bicubic 62 | if exist('save_bic_folder', 'var') 63 | im_B = imresize(im_LR, up_scale, 'bicubic'); 64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png'])); 65 | end 66 | end 67 | end 68 | end 69 | 70 | %% modcrop 71 | function img = modcrop(img, modulo) 72 | if size(img,3) == 1 73 | sz = size(img); 74 | sz = sz - mod(sz, modulo); 75 | img = img(1:sz(1), 1:sz(2)); 76 | else 77 | tmpsz = size(img); 78 | sz = tmpsz(1:2); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2),:); 81 | end 82 | end 83 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_mod_LR_bic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | 6 | try: 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from data.util import imresize_np 9 | except ImportError: 10 | pass 11 | 12 | 13 | def generate_mod_LR_bic(): 14 | # set parameters 15 | up_scale = 4 16 | mod_scale = 4 17 | # set data dir 18 | sourcedir = '/mnt/hyzhao/Documents/datasets/DF2K_train/HR' 19 | savedir = '/mnt/hyzhao/Documents/datasets/DF2K_train/' 20 | 21 | saveHRpath = os.path.join(savedir, 'HR', 'X' + str(mod_scale)) 22 | saveLRpath = os.path.join(savedir, 'LR', 'X' + str(up_scale)) 23 | saveBicpath = os.path.join(savedir, 'Bic', 'X' + str(up_scale)) 24 | 25 | if not os.path.isdir(sourcedir): 26 | print('Error: No source data found') 27 | exit(0) 28 | if not os.path.isdir(savedir): 29 | os.mkdir(savedir) 30 | 31 | if not os.path.isdir(os.path.join(savedir, 'HR')): 32 | os.mkdir(os.path.join(savedir, 'HR')) 33 | if not os.path.isdir(os.path.join(savedir, 'LR')): 34 | os.mkdir(os.path.join(savedir, 'LR')) 35 | if not os.path.isdir(os.path.join(savedir, 'Bic')): 36 | os.mkdir(os.path.join(savedir, 'Bic')) 37 | 38 | if not os.path.isdir(saveHRpath): 39 | os.mkdir(saveHRpath) 40 | else: 41 | print('It will cover ' + str(saveHRpath)) 42 | 43 | if not os.path.isdir(saveLRpath): 44 | os.mkdir(saveLRpath) 45 | else: 46 | print('It will cover ' + str(saveLRpath)) 47 | 48 | if not os.path.isdir(saveBicpath): 49 | os.mkdir(saveBicpath) 50 | else: 51 | print('It will cover ' + str(saveBicpath)) 52 | 53 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')] 54 | num_files = len(filepaths) 55 | 56 | # prepare data with augementation 57 | for i in range(num_files): 58 | filename = filepaths[i] 59 | print('No.{} -- Processing {}'.format(i, filename)) 60 | # read image 61 | image = cv2.imread(os.path.join(sourcedir, filename)) 62 | 63 | width = int(np.floor(image.shape[1] / mod_scale)) 64 | height = int(np.floor(image.shape[0] / mod_scale)) 65 | # modcrop 66 | if len(image.shape) == 3: 67 | image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] 68 | else: 69 | image_HR = image[0:mod_scale * height, 0:mod_scale * width] 70 | # LR 71 | image_LR = imresize_np(image_HR, 1 / up_scale, True) 72 | # bic 73 | image_Bic = imresize_np(image_LR, up_scale, True) 74 | 75 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) 76 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) 77 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) 78 | 79 | 80 | if __name__ == "__main__": 81 | generate_mod_LR_bic() -------------------------------------------------------------------------------- /codes/data_scripts/prepare_DIV2K_x4_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | echo "Prepare DIV2K X4 datasets..." 4 | cd ../../datasets 5 | mkdir DIV2K 6 | cd DIV2K 7 | 8 | #### Step 1 9 | echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..." 10 | # GT 11 | FOLDER=DIV2K_train_HR 12 | FILE=DIV2K_train_HR.zip 13 | if [ ! -d "$FOLDER" ]; then 14 | if [ ! -f "$FILE" ]; then 15 | echo "Downloading $FILE..." 16 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE 17 | fi 18 | unzip $FILE 19 | fi 20 | # LR 21 | FOLDER=DIV2K_train_LR_bicubic 22 | FILE=DIV2K_train_LR_bicubic_X4.zip 23 | if [ ! -d "$FOLDER" ]; then 24 | if [ ! -f "$FILE" ]; then 25 | echo "Downloading $FILE..." 26 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE 27 | fi 28 | unzip $FILE 29 | fi 30 | 31 | #### Step 2 32 | echo "Step 2: Rename the LR images..." 33 | cd ../../codes/data_scripts 34 | python rename.py 35 | 36 | #### Step 4 37 | echo "Step 4: Crop to sub-images..." 38 | python extract_subimages.py 39 | 40 | #### Step 5 41 | echo "Step5: Create LMDB files..." 42 | python create_lmdb.py 43 | -------------------------------------------------------------------------------- /codes/data_scripts/rename.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | def main(): 6 | folder = '../../results/006_RRDBNet_ILRx4_Flickr2K_100w+/DIV2K100' 7 | DIV2K(folder) 8 | print('Finished.') 9 | 10 | 11 | def DIV2K(path): 12 | img_path_l = glob.glob(os.path.join(path, '*')) 13 | for img_path in img_path_l: 14 | # new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') 15 | new_path = img_path.replace('.png', 'x4.png') 16 | os.rename(img_path, new_path) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() -------------------------------------------------------------------------------- /codes/data_scripts/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | import math 4 | import torchvision.utils 5 | 6 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 7 | from data import create_dataloader, create_dataset # noqa: E402 8 | from utils import util # noqa: E402 9 | 10 | 11 | def main(): 12 | dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub 13 | opt = {} 14 | opt['dist'] = False 15 | opt['gpu_ids'] = [0] 16 | if dataset == 'REDS': 17 | opt['name'] = 'test_REDS' 18 | opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb' 19 | opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' 20 | opt['mode'] = 'REDS' 21 | opt['N_frames'] = 5 22 | opt['phase'] = 'train' 23 | opt['use_shuffle'] = True 24 | opt['n_workers'] = 8 25 | opt['batch_size'] = 16 26 | opt['GT_size'] = 256 27 | opt['LQ_size'] = 64 28 | opt['scale'] = 4 29 | opt['use_flip'] = True 30 | opt['use_rot'] = True 31 | opt['interval_list'] = [1] 32 | opt['random_reverse'] = False 33 | opt['border_mode'] = False 34 | opt['cache_keys'] = None 35 | opt['data_type'] = 'lmdb' # img | lmdb | mc 36 | elif dataset == 'Vimeo90K': 37 | opt['name'] = 'test_Vimeo90K' 38 | opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' 39 | opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' 40 | opt['mode'] = 'Vimeo90K' 41 | opt['N_frames'] = 7 42 | opt['phase'] = 'train' 43 | opt['use_shuffle'] = True 44 | opt['n_workers'] = 8 45 | opt['batch_size'] = 16 46 | opt['GT_size'] = 256 47 | opt['LQ_size'] = 64 48 | opt['scale'] = 4 49 | opt['use_flip'] = True 50 | opt['use_rot'] = True 51 | opt['interval_list'] = [1] 52 | opt['random_reverse'] = False 53 | opt['border_mode'] = False 54 | opt['cache_keys'] = None 55 | opt['data_type'] = 'lmdb' # img | lmdb | mc 56 | elif dataset == 'DIV2K800_sub': 57 | opt['name'] = 'DIV2K800' 58 | opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' 59 | opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' 60 | opt['mode'] = 'LQGT' 61 | opt['phase'] = 'train' 62 | opt['use_shuffle'] = True 63 | opt['n_workers'] = 8 64 | opt['batch_size'] = 16 65 | opt['GT_size'] = 128 66 | opt['scale'] = 4 67 | opt['use_flip'] = True 68 | opt['use_rot'] = True 69 | opt['color'] = 'RGB' 70 | opt['data_type'] = 'lmdb' # img | lmdb 71 | else: 72 | raise ValueError('Please implement by yourself.') 73 | 74 | util.mkdir('tmp') 75 | train_set = create_dataset(opt) 76 | train_loader = create_dataloader(train_set, opt, opt, None) 77 | nrow = int(math.sqrt(opt['batch_size'])) 78 | padding = 2 if opt['phase'] == 'train' else 0 79 | 80 | print('start...') 81 | for i, data in enumerate(train_loader): 82 | if i > 5: 83 | break 84 | print(i) 85 | if dataset == 'REDS' or dataset == 'Vimeo90K': 86 | LQs = data['LQs'] 87 | else: 88 | LQ = data['LQ'] 89 | GT = data['GT'] 90 | 91 | if dataset == 'REDS' or dataset == 'Vimeo90K': 92 | for j in range(LQs.size(1)): 93 | torchvision.utils.save_image(LQs[:, j, :, :, :], 94 | 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow, 95 | padding=padding, normalize=False) 96 | else: 97 | torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow, 98 | padding=padding, normalize=False) 99 | torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding, 100 | normalize=False) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /codes/metrics/calculate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | function calculate_PSNR_SSIM() 2 | 3 | % GT and SR folder 4 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'; 5 | folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'; 6 | scale = 4; 7 | suffix = ''; % suffix for SR images 8 | test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels 9 | if test_Y 10 | fprintf('Tesing Y channel.\n'); 11 | else 12 | fprintf('Tesing RGB channels.\n'); 13 | end 14 | filepaths = dir(fullfile(folder_GT, '*.png')); 15 | PSNR_all = zeros(1, length(filepaths)); 16 | SSIM_all = zeros(1, length(filepaths)); 17 | 18 | for idx_im = 1:length(filepaths) 19 | im_name = filepaths(idx_im).name; 20 | im_GT = imread(fullfile(folder_GT, im_name)); 21 | im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png'])); 22 | 23 | if test_Y % evaluate on Y channel in YCbCr color space 24 | if size(im_GT, 3) == 3 25 | im_GT_YCbCr = rgb2ycbcr(im2double(im_GT)); 26 | im_GT_in = im_GT_YCbCr(:,:,1); 27 | im_SR_YCbCr = rgb2ycbcr(im2double(im_SR)); 28 | im_SR_in = im_SR_YCbCr(:,:,1); 29 | else 30 | im_GT_in = im2double(im_GT); 31 | im_SR_in = im2double(im_SR); 32 | end 33 | else % evaluate on RGB channels 34 | im_GT_in = im2double(im_GT); 35 | im_SR_in = im2double(im_SR); 36 | end 37 | 38 | % calculate PSNR and SSIM 39 | PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale); 40 | SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale); 41 | fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im)); 42 | end 43 | 44 | fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all)); 45 | end 46 | 47 | function res = calculate_PSNR(GT, SR, border) 48 | % remove border 49 | GT = GT(border+1:end-border, border+1:end-border, :); 50 | SR = SR(border+1:end-border, border+1:end-border, :); 51 | % calculate PNSR (assume in [0,255]) 52 | error = GT(:) - SR(:); 53 | mse = mean(error.^2); 54 | res = 10 * log10(255^2/mse); 55 | end 56 | 57 | function res = calculate_SSIM(GT, SR, border) 58 | GT = GT(border+1:end-border, border+1:end-border, :); 59 | SR = SR(border+1:end-border, border+1:end-border, :); 60 | % calculate SSIM 61 | mssim = zeros(1, size(SR, 3)); 62 | for i = 1:size(SR,3) 63 | [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i)); 64 | end 65 | res = mean(mssim); 66 | end 67 | 68 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L) 69 | 70 | %======================================================================== 71 | %SSIM Index, Version 1.0 72 | %Copyright(c) 2003 Zhou Wang 73 | %All Rights Reserved. 74 | % 75 | %The author is with Howard Hughes Medical Institute, and Laboratory 76 | %for Computational Vision at Center for Neural Science and Courant 77 | %Institute of Mathematical Sciences, New York University. 78 | % 79 | %---------------------------------------------------------------------- 80 | %Permission to use, copy, or modify this software and its documentation 81 | %for educational and research purposes only and without fee is hereby 82 | %granted, provided that this copyright notice and the original authors' 83 | %names appear on all copies and supporting documentation. This program 84 | %shall not be used, rewritten, or adapted as the basis of a commercial 85 | %software or hardware product without first obtaining permission of the 86 | %authors. The authors make no representations about the suitability of 87 | %this software for any purpose. It is provided "as is" without express 88 | %or implied warranty. 89 | %---------------------------------------------------------------------- 90 | % 91 | %This is an implementation of the algorithm for calculating the 92 | %Structural SIMilarity (SSIM) index between two images. Please refer 93 | %to the following paper: 94 | % 95 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 96 | %quality assessment: From error measurement to structural similarity" 97 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 98 | % 99 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 100 | % 101 | %---------------------------------------------------------------------- 102 | % 103 | %Input : (1) img1: the first image being compared 104 | % (2) img2: the second image being compared 105 | % (3) K: constants in the SSIM index formula (see the above 106 | % reference). defualt value: K = [0.01 0.03] 107 | % (4) window: local window for statistics (see the above 108 | % reference). default widnow is Gaussian given by 109 | % window = fspecial('gaussian', 11, 1.5); 110 | % (5) L: dynamic range of the images. default: L = 255 111 | % 112 | %Output: (1) mssim: the mean SSIM index value between 2 images. 113 | % If one of the images being compared is regarded as 114 | % perfect quality, then mssim can be considered as the 115 | % quality measure of the other image. 116 | % If img1 = img2, then mssim = 1. 117 | % (2) ssim_map: the SSIM index map of the test image. The map 118 | % has a smaller size than the input images. The actual size: 119 | % size(img1) - size(window) + 1. 120 | % 121 | %Default Usage: 122 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 123 | % 124 | % [mssim ssim_map] = ssim_index(img1, img2); 125 | % 126 | %Advanced Usage: 127 | % User defined parameters. For example 128 | % 129 | % K = [0.05 0.05]; 130 | % window = ones(8); 131 | % L = 100; 132 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 133 | % 134 | %See the results: 135 | % 136 | % mssim %Gives the mssim value 137 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 138 | % 139 | %======================================================================== 140 | 141 | 142 | if (nargin < 2 || nargin > 5) 143 | ssim_index = -Inf; 144 | ssim_map = -Inf; 145 | return; 146 | end 147 | 148 | if (size(img1) ~= size(img2)) 149 | ssim_index = -Inf; 150 | ssim_map = -Inf; 151 | return; 152 | end 153 | 154 | [M, N] = size(img1); 155 | 156 | if (nargin == 2) 157 | if ((M < 11) || (N < 11)) 158 | ssim_index = -Inf; 159 | ssim_map = -Inf; 160 | return 161 | end 162 | window = fspecial('gaussian', 11, 1.5); % 163 | K(1) = 0.01; % default settings 164 | K(2) = 0.03; % 165 | L = 255; % 166 | end 167 | 168 | if (nargin == 3) 169 | if ((M < 11) || (N < 11)) 170 | ssim_index = -Inf; 171 | ssim_map = -Inf; 172 | return 173 | end 174 | window = fspecial('gaussian', 11, 1.5); 175 | L = 255; 176 | if (length(K) == 2) 177 | if (K(1) < 0 || K(2) < 0) 178 | ssim_index = -Inf; 179 | ssim_map = -Inf; 180 | return; 181 | end 182 | else 183 | ssim_index = -Inf; 184 | ssim_map = -Inf; 185 | return; 186 | end 187 | end 188 | 189 | if (nargin == 4) 190 | [H, W] = size(window); 191 | if ((H*W) < 4 || (H > M) || (W > N)) 192 | ssim_index = -Inf; 193 | ssim_map = -Inf; 194 | return 195 | end 196 | L = 255; 197 | if (length(K) == 2) 198 | if (K(1) < 0 || K(2) < 0) 199 | ssim_index = -Inf; 200 | ssim_map = -Inf; 201 | return; 202 | end 203 | else 204 | ssim_index = -Inf; 205 | ssim_map = -Inf; 206 | return; 207 | end 208 | end 209 | 210 | if (nargin == 5) 211 | [H, W] = size(window); 212 | if ((H*W) < 4 || (H > M) || (W > N)) 213 | ssim_index = -Inf; 214 | ssim_map = -Inf; 215 | return 216 | end 217 | if (length(K) == 2) 218 | if (K(1) < 0 || K(2) < 0) 219 | ssim_index = -Inf; 220 | ssim_map = -Inf; 221 | return; 222 | end 223 | else 224 | ssim_index = -Inf; 225 | ssim_map = -Inf; 226 | return; 227 | end 228 | end 229 | 230 | C1 = (K(1)*L)^2; 231 | C2 = (K(2)*L)^2; 232 | window = window/sum(sum(window)); 233 | img1 = double(img1); 234 | img2 = double(img2); 235 | 236 | mu1 = filter2(window, img1, 'valid'); 237 | mu2 = filter2(window, img2, 'valid'); 238 | mu1_sq = mu1.*mu1; 239 | mu2_sq = mu2.*mu2; 240 | mu1_mu2 = mu1.*mu2; 241 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 242 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 243 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 244 | 245 | if (C1 > 0 && C2 > 0) 246 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 247 | else 248 | numerator1 = 2*mu1_mu2 + C1; 249 | numerator2 = 2*sigma12 + C2; 250 | denominator1 = mu1_sq + mu2_sq + C1; 251 | denominator2 = sigma1_sq + sigma2_sq + C2; 252 | ssim_map = ones(size(mu1)); 253 | index = (denominator1.*denominator2 > 0); 254 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 255 | index = (denominator1 ~= 0) & (denominator2 == 0); 256 | ssim_map(index) = numerator1(index)./denominator1(index); 257 | end 258 | 259 | mssim = mean2(ssim_map); 260 | 261 | end 262 | -------------------------------------------------------------------------------- /codes/metrics/calculate_PSNR_SSIM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | calculate the PSNR and SSIM. 3 | same as MATLAB's results 4 | ''' 5 | import os 6 | import math 7 | import numpy as np 8 | import cv2 9 | import glob 10 | 11 | 12 | def main(): 13 | # Configurations 14 | 15 | # GT - Ground-truth; 16 | # Gen: Generated / Restored / Recovered images 17 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5' 18 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5' 19 | 20 | crop_border = 4 21 | suffix = '' # suffix for Gen images 22 | test_Y = False # True: test Y channel only; False: test RGB channels 23 | 24 | PSNR_all = [] 25 | SSIM_all = [] 26 | img_list = sorted(glob.glob(folder_GT + '/*')) 27 | 28 | if test_Y: 29 | print('Testing Y channel.') 30 | else: 31 | print('Testing RGB channels.') 32 | 33 | for i, img_path in enumerate(img_list): 34 | base_name = os.path.splitext(os.path.basename(img_path))[0] 35 | im_GT = cv2.imread(img_path) / 255. 36 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255. 37 | 38 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 39 | im_GT_in = bgr2ycbcr(im_GT) 40 | im_Gen_in = bgr2ycbcr(im_Gen) 41 | else: 42 | im_GT_in = im_GT 43 | im_Gen_in = im_Gen 44 | 45 | # crop borders 46 | if im_GT_in.ndim == 3: 47 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 48 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 49 | elif im_GT_in.ndim == 2: 50 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 51 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 52 | else: 53 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 54 | 55 | # calculate PSNR and SSIM 56 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 57 | 58 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 59 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format( 60 | i + 1, base_name, PSNR, SSIM)) 61 | PSNR_all.append(PSNR) 62 | SSIM_all.append(SSIM) 63 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format( 64 | sum(PSNR_all) / len(PSNR_all), 65 | sum(SSIM_all) / len(SSIM_all))) 66 | 67 | 68 | def calculate_psnr(img1, img2): 69 | # img1 and img2 have range [0, 255] 70 | img1 = img1.astype(np.float64) 71 | img2 = img2.astype(np.float64) 72 | mse = np.mean((img1 - img2)**2) 73 | if mse == 0: 74 | return float('inf') 75 | return 20 * math.log10(255.0 / math.sqrt(mse)) 76 | 77 | 78 | def ssim(img1, img2): 79 | C1 = (0.01 * 255)**2 80 | C2 = (0.03 * 255)**2 81 | 82 | img1 = img1.astype(np.float64) 83 | img2 = img2.astype(np.float64) 84 | kernel = cv2.getGaussianKernel(11, 1.5) 85 | window = np.outer(kernel, kernel.transpose()) 86 | 87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 89 | mu1_sq = mu1**2 90 | mu2_sq = mu2**2 91 | mu1_mu2 = mu1 * mu2 92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 95 | 96 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 97 | (sigma1_sq + sigma2_sq + C2)) 98 | return ssim_map.mean() 99 | 100 | 101 | def calculate_ssim(img1, img2): 102 | '''calculate SSIM 103 | the same outputs as MATLAB's 104 | img1, img2: [0, 255] 105 | ''' 106 | if not img1.shape == img2.shape: 107 | raise ValueError('Input images must have the same dimensions.') 108 | if img1.ndim == 2: 109 | return ssim(img1, img2) 110 | elif img1.ndim == 3: 111 | if img1.shape[2] == 3: 112 | ssims = [] 113 | for i in range(3): 114 | ssims.append(ssim(img1, img2)) 115 | return np.array(ssims).mean() 116 | elif img1.shape[2] == 1: 117 | return ssim(np.squeeze(img1), np.squeeze(img2)) 118 | else: 119 | raise ValueError('Wrong input image dimensions.') 120 | 121 | 122 | def bgr2ycbcr(img, only_y=True): 123 | '''same as matlab rgb2ycbcr 124 | only_y: only return Y channel 125 | Input: 126 | uint8, [0, 255] 127 | float, [0, 1] 128 | ''' 129 | in_img_type = img.dtype 130 | img.astype(np.float32) 131 | if in_img_type != np.uint8: 132 | img *= 255. 133 | # convert 134 | if only_y: 135 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 136 | else: 137 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 138 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 139 | if in_img_type == np.uint8: 140 | rlt = rlt.round() 141 | else: 142 | rlt /= 255. 143 | return rlt.astype(in_img_type) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /codes/models/SR_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | import numpy as np 8 | import models.networks as networks 9 | import models.lr_scheduler as lr_scheduler 10 | from .base_model import BaseModel 11 | from models.loss import CharbonnierLoss, FSLoss, GradientLoss 12 | 13 | logger = logging.getLogger('base') 14 | 15 | 16 | class SRModel(BaseModel): 17 | def __init__(self, opt): 18 | super(SRModel, self).__init__(opt) 19 | 20 | if opt['dist']: 21 | self.rank = torch.distributed.get_rank() 22 | else: 23 | self.rank = -1 # non dist training 24 | train_opt = opt['train'] 25 | 26 | # define network and load pretrained models 27 | self.netG = networks.define_G(opt).to(self.device) 28 | if opt['dist']: 29 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 30 | else: 31 | self.netG = DataParallel(self.netG) 32 | # print network 33 | self.print_network() 34 | self.load() 35 | 36 | if self.is_train: 37 | self.netG.train() 38 | 39 | # loss 40 | loss_type = train_opt['pixel_criterion'] 41 | self.loss_type = loss_type 42 | if loss_type == 'l1': 43 | self.cri_pix = nn.L1Loss().to(self.device) 44 | elif loss_type == 'l2': 45 | self.cri_pix = nn.MSELoss().to(self.device) 46 | elif loss_type == 'cb': 47 | self.cri_pix = CharbonnierLoss().to(self.device) 48 | else: 49 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 50 | self.l_pix_w = train_opt['pixel_weight'] 51 | 52 | # optimizers 53 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 54 | optim_params = [] 55 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 56 | if v.requires_grad: 57 | optim_params.append(v) 58 | else: 59 | if self.rank <= 0: 60 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 61 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 62 | weight_decay=wd_G, 63 | betas=(train_opt['beta1'], train_opt['beta2'])) 64 | self.optimizers.append(self.optimizer_G) 65 | 66 | # schedulers 67 | if train_opt['lr_scheme'] == 'MultiStepLR': 68 | for optimizer in self.optimizers: 69 | self.schedulers.append( 70 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 71 | restarts=train_opt['restarts'], 72 | weights=train_opt['restart_weights'], 73 | gamma=train_opt['lr_gamma'], 74 | clear_state=train_opt['clear_state'])) 75 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 76 | for optimizer in self.optimizers: 77 | self.schedulers.append( 78 | lr_scheduler.CosineAnnealingLR_Restart( 79 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 80 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 81 | else: 82 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 83 | 84 | self.log_dict = OrderedDict() 85 | 86 | def feed_data(self, data, need_GT=True): 87 | self.var_L = data['LQ'].to(self.device) # LQ 88 | if need_GT: 89 | self.real_H = data['GT'].to(self.device) # GT 90 | 91 | def mixup_data(self, x, y, alpha=1.0, use_cuda=True): 92 | '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda''' 93 | batch_size = x.size()[0] 94 | lam = np.random.beta(alpha, alpha) if alpha > 0 else 1 95 | index = torch.randperm(batch_size).cuda() if use_cuda else torch.randperm(batch_size) 96 | mixed_x = lam * x + (1 - lam) * x[index,:] 97 | mixed_y = lam * y + (1 - lam) * y[index,:] 98 | return mixed_x, mixed_y 99 | 100 | def optimize_parameters(self, step): 101 | self.optimizer_G.zero_grad() 102 | 103 | '''add mixup operation''' 104 | # self.var_L, self.real_H = self.mixup_data(self.var_L, self.real_H) 105 | 106 | self.fake_H = self.netG(self.var_L) 107 | if self.loss_type == 'fs': 108 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) + self.l_fs_w * self.cri_fs(self.fake_H, self.real_H) 109 | elif self.loss_type == 'grad': 110 | l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) 111 | lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H) 112 | l_pix = l1 + lg 113 | elif self.loss_type == 'grad_fs': 114 | l1 = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) 115 | lg = self.l_grad_w * self.gradloss(self.fake_H, self.real_H) 116 | lfs = self.l_fs_w * self.cri_fs(self.fake_H, self.real_H) 117 | l_pix = l1 + lg + lfs 118 | else: 119 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) 120 | l_pix.backward() 121 | self.optimizer_G.step() 122 | 123 | # set log 124 | self.log_dict['l_pix'] = l_pix.item() 125 | if self.loss_type == 'grad': 126 | self.log_dict['l_1'] = l1.item() 127 | self.log_dict['l_grad'] = lg.item() 128 | if self.loss_type == 'grad_fs': 129 | self.log_dict['l_1'] = l1.item() 130 | self.log_dict['l_grad'] = lg.item() 131 | self.log_dict['l_fs'] = lfs.item() 132 | 133 | def test(self): 134 | self.netG.eval() 135 | 136 | with torch.no_grad(): 137 | self.fake_H = self.netG(self.var_L) 138 | self.netG.train() 139 | 140 | def test_x8(self): 141 | # from https://github.com/thstkdgus35/EDSR-PyTorch 142 | self.netG.eval() 143 | 144 | def _transform(v, op): 145 | # if self.precision != 'single': v = v.float() 146 | v2np = v.data.cpu().numpy() 147 | if op == 'v': 148 | tfnp = v2np[:, :, :, ::-1].copy() 149 | elif op == 'h': 150 | tfnp = v2np[:, :, ::-1, :].copy() 151 | elif op == 't': 152 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 153 | 154 | ret = torch.Tensor(tfnp).to(self.device) 155 | # if self.precision == 'half': ret = ret.half() 156 | 157 | return ret 158 | 159 | lr_list = [self.var_L] 160 | for tf in 'v', 'h', 't': 161 | lr_list.extend([_transform(t, tf) for t in lr_list]) 162 | with torch.no_grad(): 163 | sr_list = [self.netG(aug) for aug in lr_list] 164 | for i in range(len(sr_list)): 165 | if i > 3: 166 | sr_list[i] = _transform(sr_list[i], 't') 167 | if i % 4 > 1: 168 | sr_list[i] = _transform(sr_list[i], 'h') 169 | if (i % 4) % 2 == 1: 170 | sr_list[i] = _transform(sr_list[i], 'v') 171 | 172 | output_cat = torch.cat(sr_list, dim=0) 173 | self.fake_H = output_cat.mean(dim=0, keepdim=True) 174 | self.netG.train() 175 | 176 | def get_current_log(self): 177 | return self.log_dict 178 | 179 | def get_current_visuals(self, need_GT=True): 180 | out_dict = OrderedDict() 181 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 182 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() 183 | if need_GT: 184 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 185 | return out_dict 186 | 187 | def print_network(self): 188 | s, n = self.get_network_description(self.netG) 189 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 190 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 191 | self.netG.module.__class__.__name__) 192 | else: 193 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 194 | if self.rank <= 0: 195 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 196 | logger.info(s) 197 | 198 | def load(self): 199 | load_path_G = self.opt['path']['pretrain_model_G'] 200 | if load_path_G is not None: 201 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 202 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 203 | 204 | # def load(self): 205 | # load_path_G_1 = self.opt['path']['pretrain_model_G_1'] 206 | # load_path_G_2 = self.opt['path']['pretrain_model_G_2'] 207 | # load_path_Gs=[load_path_G_1, load_path_G_2] 208 | 209 | # load_path_G = self.opt['path']['pretrain_model_G'] 210 | # if load_path_G is not None: 211 | # logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 212 | # self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 213 | # if load_path_G_1 is not None: 214 | # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_1)) 215 | # logger.info('Loading model for 3net [{:s}] ...'.format(load_path_G_2)) 216 | # self.load_network_part(load_path_Gs, self.netG, self.opt['path']['strict_load']) 217 | 218 | def save(self, iter_label): 219 | self.save_network(self.netG, 'G', iter_label) 220 | -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | model = opt['model'] 7 | # image restoration 8 | if model == 'sr': # PSNR-oriented super resolution 9 | from .SR_model import SRModel as M 10 | elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN 11 | from .SRGAN_model import SRGANModel as M 12 | # video restoration 13 | elif model == 'video_base': 14 | from .Video_base_model import VideoBaseModel as M 15 | else: 16 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 17 | m = M(opt) 18 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 19 | return m 20 | -------------------------------------------------------------------------------- /codes/models/__pycache__/SRGAN_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/SRGAN_model.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/SR_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/SR_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/SR_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/SR_model.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/Video_base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/Video_base_model.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/base_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/base_model.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/PAN_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import models.archs.arch_util as arch_util 6 | 7 | class PA(nn.Module): 8 | '''PA is pixel attention''' 9 | def __init__(self, nf): 10 | 11 | super(PA, self).__init__() 12 | self.conv = nn.Conv2d(nf, nf, 1) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x): 16 | 17 | y = self.conv(x) 18 | y = self.sigmoid(y) 19 | out = torch.mul(x, y) 20 | 21 | return out 22 | 23 | class PAConv(nn.Module): 24 | 25 | def __init__(self, nf, k_size=3): 26 | 27 | super(PAConv, self).__init__() 28 | self.k2 = nn.Conv2d(nf, nf, 1) # 1x1 convolution nf->nf 29 | self.sigmoid = nn.Sigmoid() 30 | self.k3 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution 31 | self.k4 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution 32 | 33 | def forward(self, x): 34 | 35 | y = self.k2(x) 36 | y = self.sigmoid(y) 37 | 38 | out = torch.mul(self.k3(x), y) 39 | out = self.k4(out) 40 | 41 | return out 42 | 43 | class SCPA(nn.Module): 44 | 45 | """SCPA is modified from SCNet (Jiang-Jiang Liu et al. Improving Convolutional Networks with Self-Calibrated Convolutions. In CVPR, 2020) 46 | Github: https://github.com/MCG-NKU/SCNet 47 | """ 48 | 49 | def __init__(self, nf, reduction=2, stride=1, dilation=1): 50 | super(SCPA, self).__init__() 51 | group_width = nf // reduction 52 | 53 | self.conv1_a = nn.Conv2d(nf, group_width, kernel_size=1, bias=False) 54 | self.conv1_b = nn.Conv2d(nf, group_width, kernel_size=1, bias=False) 55 | 56 | self.k1 = nn.Sequential( 57 | nn.Conv2d( 58 | group_width, group_width, kernel_size=3, stride=stride, 59 | padding=dilation, dilation=dilation, 60 | bias=False) 61 | ) 62 | 63 | self.PAConv = PAConv(group_width) 64 | 65 | self.conv3 = nn.Conv2d( 66 | group_width * reduction, nf, kernel_size=1, bias=False) 67 | 68 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out_a= self.conv1_a(x) 74 | out_b = self.conv1_b(x) 75 | out_a = self.lrelu(out_a) 76 | out_b = self.lrelu(out_b) 77 | 78 | out_a = self.k1(out_a) 79 | out_b = self.PAConv(out_b) 80 | out_a = self.lrelu(out_a) 81 | out_b = self.lrelu(out_b) 82 | 83 | out = self.conv3(torch.cat([out_a, out_b], dim=1)) 84 | out += residual 85 | 86 | return out 87 | 88 | class PAN(nn.Module): 89 | 90 | def __init__(self, in_nc, out_nc, nf, unf, nb, scale=4): 91 | super(PAN, self).__init__() 92 | # SCPA 93 | SCPA_block_f = functools.partial(SCPA, nf=nf, reduction=2) 94 | self.scale = scale 95 | 96 | ### first convolution 97 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 98 | 99 | ### main blocks 100 | self.SCPA_trunk = arch_util.make_layer(SCPA_block_f, nb) 101 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 102 | 103 | #### upsampling 104 | self.upconv1 = nn.Conv2d(nf, unf, 3, 1, 1, bias=True) 105 | self.att1 = PA(unf) 106 | self.HRconv1 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True) 107 | 108 | if self.scale == 4: 109 | self.upconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True) 110 | self.att2 = PA(unf) 111 | self.HRconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True) 112 | 113 | self.conv_last = nn.Conv2d(unf, out_nc, 3, 1, 1, bias=True) 114 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 115 | 116 | def forward(self, x): 117 | 118 | fea = self.conv_first(x) 119 | trunk = self.trunk_conv(self.SCPA_trunk(fea)) 120 | fea = fea + trunk 121 | 122 | if self.scale == 2 or self.scale == 3: 123 | fea = self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode='nearest')) 124 | fea = self.lrelu(self.att1(fea)) 125 | fea = self.lrelu(self.HRconv1(fea)) 126 | elif self.scale == 4: 127 | fea = self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')) 128 | fea = self.lrelu(self.att1(fea)) 129 | fea = self.lrelu(self.HRconv1(fea)) 130 | fea = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')) 131 | fea = self.lrelu(self.att2(fea)) 132 | fea = self.lrelu(self.HRconv2(fea)) 133 | 134 | out = self.conv_last(fea) 135 | 136 | ILR = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) 137 | out = out + ILR 138 | return out 139 | -------------------------------------------------------------------------------- /codes/models/archs/RCAN_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import models.archs.arch_util as arch_util 5 | 6 | ## Channel Attention (CA) Layer 7 | class CALayer(nn.Module): 8 | def __init__(self, channel, reduction=16): 9 | super(CALayer, self).__init__() 10 | # global average pooling: feature --> point 11 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 12 | # feature channel downscale and upscale --> channel weight 13 | self.conv_du = nn.Sequential( 14 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 17 | nn.Sigmoid() 18 | ) 19 | 20 | def forward(self, x): 21 | y = self.avg_pool(x) 22 | y = self.conv_du(y) 23 | return x * y 24 | 25 | class PA(nn.Module): 26 | '''PA is pixel attention''' 27 | def __init__(self, nf): 28 | 29 | super(PA, self).__init__() 30 | self.conv = nn.Conv2d(nf, nf, 1) 31 | self.sigmoid = nn.Sigmoid() 32 | 33 | def forward(self, x): 34 | 35 | y = self.conv(x) 36 | y = self.sigmoid(y) 37 | out = torch.mul(x, y) 38 | 39 | return out 40 | 41 | 42 | ## Residual Channel Attention Block (RCAB) 43 | class RCAB(nn.Module): 44 | def __init__( 45 | self, conv, n_feat, kernel_size, reduction, 46 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 47 | 48 | super(RCAB, self).__init__() 49 | modules_body = [] 50 | for i in range(2): 51 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 52 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 53 | if i == 0: modules_body.append(act) 54 | modules_body.append(CALayer(n_feat, reduction)) 55 | self.body = nn.Sequential(*modules_body) 56 | self.res_scale = res_scale 57 | 58 | def forward(self, x): 59 | res = self.body(x) 60 | res += x 61 | return res 62 | 63 | class RPAB(nn.Module): 64 | '''Residual Block with PA''' 65 | def __init__( 66 | self, conv, n_feat, kernel_size, reduction, 67 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 68 | 69 | super(RPAB, self).__init__() 70 | modules_body = [] 71 | for i in range(2): 72 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 73 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 74 | if i == 0: modules_body.append(act) 75 | modules_body.append(PA(n_feat)) 76 | self.body = nn.Sequential(*modules_body) 77 | self.res_scale = res_scale 78 | 79 | def forward(self, x): 80 | res = self.body(x) 81 | res += x 82 | return res 83 | 84 | # Residual Group (RG) 85 | class ResidualGroup(nn.Module): 86 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 87 | super(ResidualGroup, self).__init__() 88 | modules_body = [] 89 | modules_body = [ 90 | RPAB( 91 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 92 | for _ in range(n_resblocks)] 93 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 94 | self.body = nn.Sequential(*modules_body) 95 | 96 | def forward(self, x): 97 | res = self.body(x) 98 | res += x 99 | return res 100 | 101 | # class ResidualGroup(nn.Module): 102 | # def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 103 | # super(ResidualGroup, self).__init__() 104 | # modules_body = [] 105 | # modules_body = [ 106 | # RCAB( 107 | # conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 108 | # for _ in range(n_resblocks)] 109 | # modules_body.append(conv(n_feat, n_feat, kernel_size)) 110 | # self.body = nn.Sequential(*modules_body) 111 | 112 | # def forward(self, x): 113 | # res = self.body(x) 114 | # res += x 115 | # return res 116 | 117 | class Upsampler(nn.Sequential): 118 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 119 | 120 | m = [] 121 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 122 | for _ in range(int(math.log(scale, 2))): 123 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 124 | m.append(nn.PixelShuffle(2)) 125 | if bn: m.append(nn.BatchNorm2d(n_feat)) 126 | if act: m.append(act()) 127 | elif scale == 3: 128 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 129 | m.append(nn.PixelShuffle(3)) 130 | if bn: m.append(nn.BatchNorm2d(n_feat)) 131 | if act: m.append(act()) 132 | else: 133 | raise NotImplementedError 134 | 135 | super(Upsampler, self).__init__(*m) 136 | 137 | ## Residual Channel Attention Network (RCAN) 138 | class MRCAN(nn.Module): 139 | ''' modified RCAN ''' 140 | 141 | def __init__(self, n_resgroups, n_resblocks, n_feats, res_scale, n_colors, rgb_range, scale, reduction, conv=arch_util.default_conv): 142 | super(MRCAN, self).__init__() 143 | 144 | n_resgroups = n_resgroups 145 | n_resblocks = n_resblocks 146 | n_feats = n_feats 147 | kernel_size = 3 148 | reduction = reduction 149 | scale = scale 150 | act = nn.ReLU(True) 151 | 152 | # RGB mean for DIV2K 153 | rgb_mean = (0.4488, 0.4371, 0.4040) 154 | rgb_std = (1.0, 1.0, 1.0) 155 | self.sub_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std) 156 | 157 | # define head module 158 | modules_head = [conv(n_colors, n_feats, kernel_size)] 159 | 160 | # define body module 161 | modules_body = [ 162 | ResidualGroup( 163 | conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ 164 | for _ in range(n_resgroups)] 165 | 166 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 167 | 168 | # define tail module 169 | modules_tail = [ 170 | Upsampler(conv, scale, n_feats, act=False), 171 | conv(n_feats, n_colors, kernel_size)] 172 | 173 | self.add_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std, 1) 174 | 175 | self.head = nn.Sequential(*modules_head) 176 | self.body = nn.Sequential(*modules_body) 177 | self.tail = nn.Sequential(*modules_tail) 178 | 179 | arch_util.initialize_weights([self.head, self.body, self.tail], 0.1) 180 | def forward(self, x): 181 | x = self.sub_mean(x) 182 | x = self.head(x) 183 | res = self.body(x) 184 | res += x 185 | 186 | x = self.tail(res) 187 | x = self.add_mean(x) 188 | 189 | return x 190 | 191 | class RCAN_PA(nn.Module): 192 | ''' RCAN + PA ''' 193 | 194 | def __init__(self, n_resgroups=10, n_resblocks=20, n_feats=64, res_scale=1, n_colors=3, rgb_range=1, scale=4, reduction=16, conv=arch_util.default_conv): 195 | super(RCAN_PA, self).__init__() 196 | 197 | n_resgroups = n_resgroups 198 | n_resblocks = n_resblocks 199 | n_feats = n_feats 200 | kernel_size = 3 201 | reduction = reduction 202 | scale = scale 203 | act = nn.ReLU(True) 204 | 205 | # RGB mean for DIV2K 206 | rgb_mean = (0.4488, 0.4371, 0.4040) 207 | rgb_std = (1.0, 1.0, 1.0) 208 | self.sub_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std) 209 | 210 | # define head module 211 | modules_head = [conv(n_colors, n_feats, kernel_size)] 212 | 213 | # define body module 214 | modules_body = [ 215 | ResidualGroup( 216 | conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale, n_resblocks=n_resblocks) \ 217 | for _ in range(n_resgroups)] 218 | 219 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 220 | 221 | # define tail module 222 | modules_tail = [ 223 | Upsampler(conv, scale, n_feats, act=False), 224 | conv(n_feats, n_colors, kernel_size)] 225 | 226 | self.add_mean = arch_util.MeanShift(rgb_range, rgb_mean, rgb_std, 1) 227 | 228 | self.head = nn.Sequential(*modules_head) 229 | self.body = nn.Sequential(*modules_body) 230 | self.tail = nn.Sequential(*modules_tail) 231 | 232 | arch_util.initialize_weights([self.head, self.body, self.tail], 0.1) 233 | def forward(self, x): 234 | x = self.sub_mean(x) 235 | x = self.head(x) 236 | 237 | res = self.body(x) 238 | res += x 239 | 240 | x = self.tail(res) 241 | x = self.add_mean(x) 242 | 243 | return x -------------------------------------------------------------------------------- /codes/models/archs/SRResNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import models.archs.arch_util as arch_util 6 | 7 | class PA(nn.Module): 8 | '''PA is pixel attention''' 9 | def __init__(self, nf): 10 | 11 | super(PA, self).__init__() 12 | self.conv = nn.Conv2d(nf, nf, 1) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x): 16 | 17 | y = self.conv(x) 18 | y = self.sigmoid(y) 19 | out = torch.mul(x, y) 20 | 21 | return out 22 | 23 | class ResidualBlock_noBN_PA(nn.Module): 24 | '''Residual block w/o BN 25 | ---Conv-ReLU-Conv-PA-+- 26 | |___________________| 27 | ''' 28 | 29 | def __init__(self, nf=64): 30 | super(ResidualBlock_noBN_PA, self).__init__() 31 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 32 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 33 | self.pa = PA(nf) 34 | 35 | # initialization 36 | arch_util.initialize_weights([self.conv1, self.conv2], 0.1) 37 | 38 | def forward(self, x): 39 | identity = x 40 | out = F.relu(self.conv1(x), inplace=True) 41 | out = self.conv2(out) 42 | out = self.pa(out) 43 | return identity + out 44 | 45 | class MSRResNet(nn.Module): 46 | ''' modified SRResNet''' 47 | 48 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 49 | super(MSRResNet, self).__init__() 50 | self.upscale = upscale 51 | 52 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 53 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 54 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 55 | 56 | # upsampling 57 | if self.upscale == 2: 58 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 59 | self.pixel_shuffle = nn.PixelShuffle(2) 60 | elif self.upscale == 3: 61 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 62 | self.pixel_shuffle = nn.PixelShuffle(3) 63 | elif self.upscale == 4: 64 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 65 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 66 | self.pixel_shuffle = nn.PixelShuffle(2) 67 | 68 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 69 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 70 | 71 | # activation function 72 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 73 | 74 | # initialization 75 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1) 76 | if self.upscale == 4: 77 | arch_util.initialize_weights(self.upconv2, 0.1) 78 | 79 | def forward(self, x): 80 | fea = self.lrelu(self.conv_first(x)) 81 | out = self.recon_trunk(fea) 82 | 83 | if self.upscale == 4: 84 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 85 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 86 | elif self.upscale == 3 or self.upscale == 2: 87 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 88 | 89 | out = self.conv_last(self.lrelu(self.HRconv(out))) 90 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 91 | out += base 92 | return out 93 | 94 | class MSRResNet_PA(nn.Module): 95 | ''' modified SRResNet + PA''' 96 | 97 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 98 | super(MSRResNet_PA, self).__init__() 99 | self.upscale = upscale 100 | 101 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 102 | basic_block = functools.partial(ResidualBlock_noBN_PA, nf=nf) 103 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 104 | 105 | # upsampling 106 | if self.upscale == 2: 107 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 108 | self.pixel_shuffle = nn.PixelShuffle(2) 109 | elif self.upscale == 3: 110 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 111 | self.pixel_shuffle = nn.PixelShuffle(3) 112 | elif self.upscale == 4: 113 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 114 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 115 | self.pixel_shuffle = nn.PixelShuffle(2) 116 | 117 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 118 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 119 | 120 | # activation function 121 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 122 | 123 | # initialization 124 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 125 | 0.1) 126 | if self.upscale == 4: 127 | arch_util.initialize_weights(self.upconv2, 0.1) 128 | 129 | def forward(self, x): 130 | fea = self.lrelu(self.conv_first(x)) 131 | out = self.recon_trunk(fea) 132 | 133 | if self.upscale == 4: 134 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 135 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 136 | elif self.upscale == 3 or self.upscale == 2: 137 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 138 | 139 | out = self.conv_last(self.lrelu(self.HRconv(out))) 140 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 141 | out += base 142 | return out -------------------------------------------------------------------------------- /codes/models/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__init__.py -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/AWSRN_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/AWSRN_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/DUF_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/DUF_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/Dual_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/Dual_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/EDVR_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/EDVR_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/EfficientSR_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/EfficientSR_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/EfficientSR_clean.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/EfficientSR_clean.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/FSRCNN_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/FSRCNN_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/KernelMD_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/KernelMD_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/MSSResNet_deblur_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/MSSResNet_deblur_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/Octave_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/Octave_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/PAN_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PAN_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/PAN_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PAN_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/PAN_arch_update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PAN_arch_update.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/PANet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PANet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/PANv2_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/PANv2_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/RCAN_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RCAN_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/RCAN_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RCAN_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/RRDBNet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/RRDBNet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/SRResNet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/SRResNet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/UNet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/UNet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/arch_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/arch_util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/arch_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/arch_util.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/unet_arch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/unet_arch.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/unet_parts.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/__pycache__/unet_parts.cpython-37.pyc -------------------------------------------------------------------------------- /codes/models/archs/arch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | import numpy as np 7 | 8 | # for RCAN 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size//2), bias=bias) 13 | 14 | class MeanShift(nn.Conv2d): 15 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 16 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 17 | std = torch.Tensor(rgb_std) 18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 19 | self.weight.data.div_(std.view(3, 1, 1, 1)) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 21 | self.bias.data.div_(std) 22 | self.requires_grad = False 23 | 24 | # for other networks 25 | def initialize_weights(net_l, scale=1): 26 | if not isinstance(net_l, list): 27 | net_l = [net_l] 28 | for net in net_l: 29 | for m in net.modules(): 30 | if isinstance(m, nn.Conv2d): 31 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 32 | m.weight.data *= scale # for residual block 33 | if m.bias is not None: 34 | m.bias.data.zero_() 35 | elif isinstance(m, nn.Linear): 36 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 37 | m.weight.data *= scale 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | init.constant_(m.weight, 1) 42 | init.constant_(m.bias.data, 0.0) 43 | 44 | 45 | def make_layer(block, n_layers): 46 | layers = [] 47 | for _ in range(n_layers): 48 | layers.append(block()) 49 | return nn.Sequential(*layers) 50 | 51 | 52 | class ResidualBlock_noBN(nn.Module): 53 | '''Residual block w/o BN 54 | ---Conv-ReLU-Conv-+- 55 | |________________| 56 | ''' 57 | 58 | def __init__(self, nf=64): 59 | super(ResidualBlock_noBN, self).__init__() 60 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 61 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 62 | 63 | # initialization 64 | initialize_weights([self.conv1, self.conv2], 0.1) 65 | 66 | def forward(self, x): 67 | identity = x 68 | out = F.relu(self.conv1(x), inplace=True) 69 | out = self.conv2(out) 70 | return identity + out 71 | 72 | 73 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 74 | """Warp an image or feature map with optical flow 75 | Args: 76 | x (Tensor): size (N, C, H, W) 77 | flow (Tensor): size (N, H, W, 2), normal value 78 | interp_mode (str): 'nearest' or 'bilinear' 79 | padding_mode (str): 'zeros' or 'border' or 'reflection' 80 | 81 | Returns: 82 | Tensor: warped image or feature map 83 | """ 84 | assert x.size()[-2:] == flow.size()[1:3] 85 | B, C, H, W = x.size() 86 | # mesh grid 87 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 88 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 89 | grid.requires_grad = False 90 | grid = grid.type_as(x) 91 | vgrid = grid + flow 92 | # scale grid to [-1,1] 93 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 94 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 95 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 96 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 97 | return output 98 | 99 | def scalex4(im): 100 | '''Nearest Upsampling by myself''' 101 | im1 = im[:, :1, ...].repeat(1, 16, 1, 1) 102 | im2 = im[:, 1:2, ...].repeat(1, 16, 1, 1) 103 | im3 = im[:, 2:, ...].repeat(1, 16, 1, 1) 104 | 105 | # b, c, h, w = im.shape 106 | # w = torch.randn(b,16,h,w).cuda() * (5e-2) 107 | 108 | # img1 = im1 + im1 * w 109 | # img2 = im2 + im2 * w 110 | # img3 = im3 + im3 * w 111 | 112 | imhr = torch.cat((im1, im2, im3), 1) 113 | imhr = F.pixel_shuffle(imhr, 4) 114 | return imhr -------------------------------------------------------------------------------- /codes/models/archs/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, 2 | deform_conv, modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/dcn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/dcn/__pycache__/deform_conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/models/archs/dcn/__pycache__/deform_conv.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/dcn/deform_conv.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: deform-conv 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/deform_conv.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | deform_conv.egg-info/PKG-INFO 3 | deform_conv.egg-info/SOURCES.txt 4 | deform_conv.egg-info/dependency_links.txt 5 | deform_conv.egg-info/not-zip-safe 6 | deform_conv.egg-info/top_level.txt 7 | src/deform_conv_cuda.cpp 8 | src/deform_conv_cuda_kernel.cu -------------------------------------------------------------------------------- /codes/models/archs/dcn/deform_conv.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/deform_conv.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/deform_conv.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | deform_conv_cuda 2 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.autograd.function import once_differentiable 8 | from torch.nn.modules.utils import _pair 9 | 10 | from . import deform_conv_cuda 11 | 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class DeformConvFunction(Function): 16 | @staticmethod 17 | def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1, 18 | deformable_groups=1, im2col_step=64): 19 | if input is not None and input.dim() != 4: 20 | raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format( 21 | input.dim())) 22 | ctx.stride = _pair(stride) 23 | ctx.padding = _pair(padding) 24 | ctx.dilation = _pair(dilation) 25 | ctx.groups = groups 26 | ctx.deformable_groups = deformable_groups 27 | ctx.im2col_step = im2col_step 28 | 29 | ctx.save_for_backward(input, offset, weight) 30 | 31 | output = input.new_empty( 32 | DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) 33 | 34 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones 35 | 36 | if not input.is_cuda: 37 | raise NotImplementedError 38 | else: 39 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 40 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 41 | deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output, 42 | ctx.bufs_[0], ctx.bufs_[1], weight.size(3), 43 | weight.size(2), ctx.stride[1], ctx.stride[0], 44 | ctx.padding[1], ctx.padding[0], 45 | ctx.dilation[1], ctx.dilation[0], ctx.groups, 46 | ctx.deformable_groups, cur_im2col_step) 47 | return output 48 | 49 | @staticmethod 50 | @once_differentiable 51 | def backward(ctx, grad_output): 52 | input, offset, weight = ctx.saved_tensors 53 | 54 | grad_input = grad_offset = grad_weight = None 55 | 56 | if not grad_output.is_cuda: 57 | raise NotImplementedError 58 | else: 59 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 60 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 61 | 62 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 63 | grad_input = torch.zeros_like(input) 64 | grad_offset = torch.zeros_like(offset) 65 | deform_conv_cuda.deform_conv_backward_input_cuda( 66 | input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0], 67 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], 68 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, 69 | ctx.deformable_groups, cur_im2col_step) 70 | 71 | if ctx.needs_input_grad[2]: 72 | grad_weight = torch.zeros_like(weight) 73 | deform_conv_cuda.deform_conv_backward_parameters_cuda( 74 | input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1], 75 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], 76 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, 77 | ctx.deformable_groups, 1, cur_im2col_step) 78 | 79 | return (grad_input, grad_offset, grad_weight, None, None, None, None, None) 80 | 81 | @staticmethod 82 | def _output_size(input, weight, padding, dilation, stride): 83 | channels = weight.size(0) 84 | output_size = (input.size(0), channels) 85 | for d in range(input.dim() - 2): 86 | in_size = input.size(d + 2) 87 | pad = padding[d] 88 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 89 | stride_ = stride[d] 90 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) 91 | if not all(map(lambda s: s > 0, output_size)): 92 | raise ValueError("convolution input is too small (output would be {})".format('x'.join( 93 | map(str, output_size)))) 94 | return output_size 95 | 96 | 97 | class ModulatedDeformConvFunction(Function): 98 | @staticmethod 99 | def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, 100 | groups=1, deformable_groups=1): 101 | ctx.stride = stride 102 | ctx.padding = padding 103 | ctx.dilation = dilation 104 | ctx.groups = groups 105 | ctx.deformable_groups = deformable_groups 106 | ctx.with_bias = bias is not None 107 | if not ctx.with_bias: 108 | bias = input.new_empty(1) # fake tensor 109 | if not input.is_cuda: 110 | raise NotImplementedError 111 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \ 112 | or input.requires_grad: 113 | ctx.save_for_backward(input, offset, mask, weight, bias) 114 | output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) 115 | ctx._bufs = [input.new_empty(0), input.new_empty(0)] 116 | deform_conv_cuda.modulated_deform_conv_cuda_forward( 117 | input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2], 118 | weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, 119 | ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias) 120 | return output 121 | 122 | @staticmethod 123 | @once_differentiable 124 | def backward(ctx, grad_output): 125 | if not grad_output.is_cuda: 126 | raise NotImplementedError 127 | input, offset, mask, weight, bias = ctx.saved_tensors 128 | grad_input = torch.zeros_like(input) 129 | grad_offset = torch.zeros_like(offset) 130 | grad_mask = torch.zeros_like(mask) 131 | grad_weight = torch.zeros_like(weight) 132 | grad_bias = torch.zeros_like(bias) 133 | deform_conv_cuda.modulated_deform_conv_cuda_backward( 134 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight, 135 | grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3], 136 | ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, 137 | ctx.groups, ctx.deformable_groups, ctx.with_bias) 138 | if not ctx.with_bias: 139 | grad_bias = None 140 | 141 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, 142 | None) 143 | 144 | @staticmethod 145 | def _infer_shape(ctx, input, weight): 146 | n = input.size(0) 147 | channels_out = weight.size(0) 148 | height, width = input.shape[2:4] 149 | kernel_h, kernel_w = weight.shape[2:4] 150 | height_out = (height + 2 * ctx.padding - (ctx.dilation * 151 | (kernel_h - 1) + 1)) // ctx.stride + 1 152 | width_out = (width + 2 * ctx.padding - (ctx.dilation * 153 | (kernel_w - 1) + 1)) // ctx.stride + 1 154 | return n, channels_out, height_out, width_out 155 | 156 | 157 | deform_conv = DeformConvFunction.apply 158 | modulated_deform_conv = ModulatedDeformConvFunction.apply 159 | 160 | 161 | class DeformConv(nn.Module): 162 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 163 | groups=1, deformable_groups=1, bias=False): 164 | super(DeformConv, self).__init__() 165 | 166 | assert not bias 167 | assert in_channels % groups == 0, \ 168 | 'in_channels {} cannot be divisible by groups {}'.format( 169 | in_channels, groups) 170 | assert out_channels % groups == 0, \ 171 | 'out_channels {} cannot be divisible by groups {}'.format( 172 | out_channels, groups) 173 | 174 | self.in_channels = in_channels 175 | self.out_channels = out_channels 176 | self.kernel_size = _pair(kernel_size) 177 | self.stride = _pair(stride) 178 | self.padding = _pair(padding) 179 | self.dilation = _pair(dilation) 180 | self.groups = groups 181 | self.deformable_groups = deformable_groups 182 | 183 | self.weight = nn.Parameter( 184 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) 185 | 186 | self.reset_parameters() 187 | 188 | def reset_parameters(self): 189 | n = self.in_channels 190 | for k in self.kernel_size: 191 | n *= k 192 | stdv = 1. / math.sqrt(n) 193 | self.weight.data.uniform_(-stdv, stdv) 194 | 195 | def forward(self, x, offset): 196 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, 197 | self.groups, self.deformable_groups) 198 | 199 | 200 | class DeformConvPack(DeformConv): 201 | def __init__(self, *args, **kwargs): 202 | super(DeformConvPack, self).__init__(*args, **kwargs) 203 | 204 | self.conv_offset = nn.Conv2d( 205 | self.in_channels, 206 | self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], 207 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 208 | bias=True) 209 | self.init_offset() 210 | 211 | def init_offset(self): 212 | self.conv_offset.weight.data.zero_() 213 | self.conv_offset.bias.data.zero_() 214 | 215 | def forward(self, x): 216 | offset = self.conv_offset(x) 217 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, 218 | self.groups, self.deformable_groups) 219 | 220 | 221 | class ModulatedDeformConv(nn.Module): 222 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 223 | groups=1, deformable_groups=1, bias=True): 224 | super(ModulatedDeformConv, self).__init__() 225 | self.in_channels = in_channels 226 | self.out_channels = out_channels 227 | self.kernel_size = _pair(kernel_size) 228 | self.stride = stride 229 | self.padding = padding 230 | self.dilation = dilation 231 | self.groups = groups 232 | self.deformable_groups = deformable_groups 233 | self.with_bias = bias 234 | 235 | self.weight = nn.Parameter( 236 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 237 | if bias: 238 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 239 | else: 240 | self.register_parameter('bias', None) 241 | self.reset_parameters() 242 | 243 | def reset_parameters(self): 244 | n = self.in_channels 245 | for k in self.kernel_size: 246 | n *= k 247 | stdv = 1. / math.sqrt(n) 248 | self.weight.data.uniform_(-stdv, stdv) 249 | if self.bias is not None: 250 | self.bias.data.zero_() 251 | 252 | def forward(self, x, offset, mask): 253 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, 254 | self.padding, self.dilation, self.groups, 255 | self.deformable_groups) 256 | 257 | 258 | class ModulatedDeformConvPack(ModulatedDeformConv): 259 | def __init__(self, *args, extra_offset_mask=False, **kwargs): 260 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) 261 | 262 | self.extra_offset_mask = extra_offset_mask 263 | self.conv_offset_mask = nn.Conv2d( 264 | self.in_channels, 265 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 266 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 267 | bias=True) 268 | self.init_offset() 269 | 270 | def init_offset(self): 271 | self.conv_offset_mask.weight.data.zero_() 272 | self.conv_offset_mask.bias.data.zero_() 273 | 274 | def forward(self, x): 275 | if self.extra_offset_mask: 276 | # x = [input, features] 277 | out = self.conv_offset_mask(x[1]) 278 | x = x[0] 279 | else: 280 | out = self.conv_offset_mask(x) 281 | o1, o2, mask = torch.chunk(out, 3, dim=1) 282 | offset = torch.cat((o1, o2), dim=1) 283 | mask = torch.sigmoid(mask) 284 | 285 | offset_mean = torch.mean(torch.abs(offset)) 286 | if offset_mean > 100: 287 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) 288 | 289 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, 290 | self.padding, self.dilation, self.groups, 291 | self.deformable_groups) 292 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | def make_cuda_ext(name, sources): 6 | 7 | return CUDAExtension( 8 | name='{}'.format(name), sources=[p for p in sources], extra_compile_args={ 9 | 'cxx': [], 10 | 'nvcc': [ 11 | '-D__CUDA_NO_HALF_OPERATORS__', 12 | '-D__CUDA_NO_HALF_CONVERSIONS__', 13 | '-D__CUDA_NO_HALF2_OPERATORS__', 14 | ] 15 | }) 16 | 17 | 18 | setup( 19 | name='deform_conv', ext_modules=[ 20 | make_cuda_ext(name='deform_conv_cuda', 21 | sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']) 22 | ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) 23 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | """Set learning rate for warmup 39 | lr_groups_l: list for lr_groups. each for a optimizer""" 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | """Get the initial lr, which is set by the scheduler""" 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | # set up warm-up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups] 67 | 68 | def get_network_description(self, network): 69 | """Get the string and total parameters of the network""" 70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 71 | network = network.module 72 | return str(network), sum(map(lambda x: x.numel(), network.parameters())) 73 | 74 | def save_network(self, network, network_label, iter_label): 75 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 76 | save_path = os.path.join(self.opt['path']['models'], save_filename) 77 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 78 | network = network.module 79 | state_dict = network.state_dict() 80 | for key, param in state_dict.items(): 81 | state_dict[key] = param.cpu() 82 | torch.save(state_dict, save_path) 83 | 84 | def load_network(self, load_path, network, strict=True): 85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 86 | network = network.module 87 | load_net = torch.load(load_path) 88 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 89 | for k, v in load_net.items(): 90 | if k.startswith('module.'): 91 | load_net_clean[k[7:]] = v 92 | else: 93 | load_net_clean[k] = v 94 | network.load_state_dict(load_net_clean, strict=strict) 95 | 96 | def load_network_part(self, load_path, network, strict=True): 97 | 98 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 99 | network1 = network.module.net1 100 | network2 = network.module.net2 101 | load_net = torch.load(load_path[0]) 102 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 103 | for k, v in load_net.items(): 104 | if k.startswith('module.'): 105 | load_net_clean[k[7:]] = v 106 | else: 107 | load_net_clean[k] = v 108 | network1.load_state_dict(load_net_clean, strict=strict) 109 | 110 | load_net = torch.load(load_path[1]) 111 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 112 | for k, v in load_net.items(): 113 | if k.startswith('module.'): 114 | load_net_clean[k[7:]] = v 115 | else: 116 | load_net_clean[k] = v 117 | network2.load_state_dict(load_net_clean, strict=strict) 118 | 119 | def save_training_state(self, epoch, iter_step): 120 | """Save training state during training, which will be used for resuming""" 121 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 122 | for s in self.schedulers: 123 | state['schedulers'].append(s.state_dict()) 124 | for o in self.optimizers: 125 | state['optimizers'].append(o.state_dict()) 126 | save_filename = '{}.state'.format(iter_step) 127 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 128 | torch.save(state, save_path) 129 | 130 | def resume_training(self, resume_state): 131 | """Resume the optimizers and schedulers for training""" 132 | resume_optimizers = resume_state['optimizers'] 133 | resume_schedulers = resume_state['schedulers'] 134 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 135 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 136 | for i, o in enumerate(resume_optimizers): 137 | self.optimizers[i].load_state_dict(o) 138 | for i, s in enumerate(resume_schedulers): 139 | self.schedulers[i].load_state_dict(s) 140 | -------------------------------------------------------------------------------- /codes/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class GradientLoss(nn.Module): 6 | def __init__(self): 7 | super(GradientLoss, self).__init__() 8 | filterx = torch.tensor([[-3., 0., 3.], [-10., 0., 10.], [-3., 0. , 3.]]) 9 | self.fx = filterx.expand(1,3,3,3).cuda() 10 | 11 | filtery = torch.tensor([[-3., -10, -3.], [0., 0., 0.], [3., 10. , 3.]]) 12 | self.fy = filtery.expand(1,3,3,3).cuda() 13 | 14 | def forward(self, x, y): 15 | schxx = F.conv2d(x, self.fx, stride=1, padding=1) 16 | schxy = F.conv2d(x, self.fy, stride=1, padding=1) 17 | gradx = torch.sqrt(torch.pow(schxx, 2) + torch.pow(schxy, 2) + 1e-6) 18 | 19 | schyx = F.conv2d(y, self.fx, stride=1, padding=1) 20 | schyy = F.conv2d(y, self.fy, stride=1, padding=1) 21 | grady = torch.sqrt(torch.pow(schyx, 2) + torch.pow(schyy, 2) + 1e-6) 22 | 23 | loss = F.l1_loss(gradx, grady) 24 | return loss 25 | 26 | class GaussianFilter(nn.Module): 27 | def __init__(self, kernel_size=5, stride=1, padding=4): 28 | super(GaussianFilter, self).__init__() 29 | # initialize guassian kernel 30 | mean = (kernel_size - 1) / 2.0 31 | variance = (kernel_size / 6.0) ** 2.0 32 | # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) 33 | x_coord = torch.arange(kernel_size) 34 | x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) 35 | y_grid = x_grid.t() 36 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 37 | 38 | # Calculate the 2-dimensional gaussian kernel 39 | gaussian_kernel = torch.exp(-torch.sum((xy_grid - mean) ** 2., dim=-1) / (2 * variance)) 40 | 41 | # Make sure sum of values in gaussian kernel equals 1. 42 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 43 | 44 | # Reshape to 2d depthwise convolutional weight 45 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 46 | gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1) 47 | 48 | # create gaussian filter as convolutional layer 49 | self.gaussian_filter = nn.Conv2d(3, 3, kernel_size, stride=stride, padding=padding, groups=3, bias=False) 50 | self.gaussian_filter.weight.data = gaussian_kernel 51 | self.gaussian_filter.weight.requires_grad = False 52 | 53 | def forward(self, x): 54 | return self.gaussian_filter(x) 55 | 56 | 57 | class FilterLow(nn.Module): 58 | def __init__(self, recursions=1, kernel_size=5, stride=1, padding=True, include_pad=True, gaussian=False): 59 | super(FilterLow, self).__init__() 60 | if padding: 61 | pad = int((kernel_size - 1) / 2) 62 | else: 63 | pad = 0 64 | if gaussian: 65 | self.filter = GaussianFilter(kernel_size=kernel_size, stride=stride, padding=pad) 66 | else: 67 | self.filter = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=pad, count_include_pad=include_pad) 68 | self.recursions = recursions 69 | 70 | def forward(self, img): 71 | for i in range(self.recursions): 72 | img = self.filter(img) 73 | return img 74 | 75 | 76 | class FilterHigh(nn.Module): 77 | def __init__(self, recursions=1, kernel_size=5, stride=1, include_pad=True, normalize=True, gaussian=False): 78 | super(FilterHigh, self).__init__() 79 | self.filter_low = FilterLow(recursions=1, kernel_size=kernel_size, stride=stride, include_pad=include_pad, 80 | gaussian=gaussian) 81 | self.recursions = recursions 82 | self.normalize = normalize 83 | 84 | def forward(self, img): 85 | if self.recursions > 1: 86 | for i in range(self.recursions - 1): 87 | img = self.filter_low(img) 88 | img = img - self.filter_low(img) 89 | if self.normalize: 90 | return 0.5 + img * 0.5 91 | else: 92 | return img 93 | 94 | class FSLoss(nn.Module): 95 | def __init__(self, recursions=1, stride=1, kernel_size=5, gaussian=False): 96 | super(FSLoss, self).__init__() 97 | self.filter = FilterHigh(recursions=recursions, stride=stride, kernel_size=kernel_size, include_pad=False, 98 | gaussian=gaussian) 99 | def forward(self, x, y): 100 | x_ = self.filter(x) 101 | y_ = self.filter(y) 102 | loss = F.l1_loss(x_, y_) 103 | return loss 104 | 105 | class CharbonnierLoss(nn.Module): 106 | """Charbonnier Loss (L1)""" 107 | 108 | def __init__(self, eps=1e-6): 109 | super(CharbonnierLoss, self).__init__() 110 | self.eps = eps 111 | 112 | def forward(self, x, y): 113 | diff = x - y 114 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 115 | return loss 116 | 117 | 118 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 119 | class GANLoss(nn.Module): 120 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 121 | super(GANLoss, self).__init__() 122 | self.gan_type = gan_type.lower() 123 | self.real_label_val = real_label_val 124 | self.fake_label_val = fake_label_val 125 | 126 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 127 | self.loss = nn.BCEWithLogitsLoss() 128 | elif self.gan_type == 'lsgan': 129 | self.loss = nn.MSELoss() 130 | elif self.gan_type == 'wgan-gp': 131 | 132 | def wgan_loss(input, target): 133 | # target is boolean 134 | return -1 * input.mean() if target else input.mean() 135 | 136 | self.loss = wgan_loss 137 | else: 138 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 139 | 140 | def get_target_label(self, input, target_is_real): 141 | if self.gan_type == 'wgan-gp': 142 | return target_is_real 143 | if target_is_real: 144 | return torch.empty_like(input).fill_(self.real_label_val) 145 | else: 146 | return torch.empty_like(input).fill_(self.fake_label_val) 147 | 148 | def forward(self, input, target_is_real): 149 | target_label = self.get_target_label(input, target_is_real) 150 | loss = self.loss(input, target_label) 151 | return loss 152 | 153 | 154 | class GradientPenaltyLoss(nn.Module): 155 | def __init__(self, device=torch.device('cpu')): 156 | super(GradientPenaltyLoss, self).__init__() 157 | self.register_buffer('grad_outputs', torch.Tensor()) 158 | self.grad_outputs = self.grad_outputs.to(device) 159 | 160 | def get_grad_outputs(self, input): 161 | if self.grad_outputs.size() != input.size(): 162 | self.grad_outputs.resize_(input.size()).fill_(1.0) 163 | return self.grad_outputs 164 | 165 | def forward(self, interp, interp_crit): 166 | grad_outputs = self.get_grad_outputs(interp_crit) 167 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 168 | grad_outputs=grad_outputs, create_graph=True, 169 | retain_graph=True, only_inputs=True)[0] 170 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 171 | grad_interp_norm = grad_interp.norm(2, dim=1) 172 | 173 | loss = ((grad_interp_norm - 1)**2).mean() 174 | return loss 175 | -------------------------------------------------------------------------------- /codes/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restarts = [v + 1 for v in self.restarts] 16 | self.restart_weights = weights if weights else [1] 17 | assert len(self.restarts) == len( 18 | self.restart_weights), 'restarts and their weights do not match.' 19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch in self.restarts: 23 | if self.clear_state: 24 | self.optimizer.state = defaultdict(dict) 25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 27 | if self.last_epoch not in self.milestones: 28 | return [group['lr'] for group in self.optimizer.param_groups] 29 | return [ 30 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 31 | for group in self.optimizer.param_groups 32 | ] 33 | 34 | 35 | class CosineAnnealingLR_Restart(_LRScheduler): 36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 37 | self.T_period = T_period 38 | self.T_max = self.T_period[0] # current T period 39 | self.eta_min = eta_min 40 | self.restarts = restarts if restarts else [0] 41 | self.restarts = [v + 1 for v in self.restarts] 42 | self.restart_weights = weights if weights else [1] 43 | self.last_restart = 0 44 | assert len(self.restarts) == len( 45 | self.restart_weights), 'restarts and their weights do not match.' 46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | if self.last_epoch == 0: 50 | return self.base_lrs 51 | elif self.last_epoch in self.restarts: 52 | self.last_restart = self.last_epoch 53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 57 | return [ 58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 60 | ] 61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 63 | (group['lr'] - self.eta_min) + self.eta_min 64 | for group in self.optimizer.param_groups] 65 | 66 | 67 | if __name__ == "__main__": 68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 69 | betas=(0.9, 0.99)) 70 | ############################## 71 | # MultiStepLR_Restart 72 | ############################## 73 | ## Original 74 | lr_steps = [200000, 400000, 600000, 800000] 75 | restarts = None 76 | restart_weights = None 77 | 78 | ## two 79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 80 | restarts = [500000] 81 | restart_weights = [1] 82 | 83 | ## four 84 | lr_steps = [ 85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 87 | ] 88 | restarts = [250000, 500000, 750000] 89 | restart_weights = [1, 1, 1] 90 | 91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 92 | clear_state=False) 93 | 94 | ############################## 95 | # Cosine Annealing Restart 96 | ############################## 97 | ## two 98 | T_period = [500000, 500000] 99 | restarts = [500000] 100 | restart_weights = [1] 101 | 102 | ## four 103 | T_period = [250000, 250000, 250000, 250000] 104 | restarts = [250000, 500000, 750000] 105 | restart_weights = [1, 1, 1] 106 | 107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 108 | weights=restart_weights) 109 | 110 | ############################## 111 | # Draw figure 112 | ############################## 113 | N_iter = 1000000 114 | lr_l = list(range(N_iter)) 115 | for i in range(N_iter): 116 | scheduler.step() 117 | current_lr = optimizer.param_groups[0]['lr'] 118 | lr_l[i] = current_lr 119 | 120 | import matplotlib as mpl 121 | from matplotlib import pyplot as plt 122 | import matplotlib.ticker as mtick 123 | mpl.style.use('default') 124 | import seaborn 125 | seaborn.set(style='whitegrid') 126 | seaborn.set_context('paper') 127 | 128 | plt.figure(1) 129 | plt.subplot(111) 130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 131 | plt.title('Title', fontsize=16, color='k') 132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 133 | legend = plt.legend(loc='upper right', shadow=False) 134 | ax = plt.gca() 135 | labels = ax.get_xticks().tolist() 136 | for k, v in enumerate(labels): 137 | labels[k] = str(int(v / 1000)) + 'K' 138 | ax.set_xticklabels(labels) 139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 140 | 141 | ax.set_ylabel('Learning rate') 142 | ax.set_xlabel('Iteration') 143 | fig = plt.gcf() 144 | plt.show() 145 | -------------------------------------------------------------------------------- /codes/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import models.archs.PAN_arch as PAN_arch 3 | import models.archs.SRResNet_arch as SRResNet_arch 4 | import models.archs.RCAN_arch as RCAN_arch 5 | 6 | 7 | # Generator 8 | def define_G(opt): 9 | opt_net = opt['network_G'] 10 | which_model = opt_net['which_model_G'] 11 | 12 | # image restoration 13 | if which_model == 'PAN': 14 | netG = PAN_arch.PAN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 15 | nf=opt_net['nf'], unf=opt_net['unf'], nb=opt_net['nb'], scale=opt_net['scale']) 16 | elif which_model == 'MSRResNet_PA': 17 | netG = SRResNet_arch.MSRResNet_PA(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 18 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) 19 | elif which_model == 'RCAN_PA': 20 | netG = RCAN_arch.RCAN_PA(n_resgroups=opt_net['n_resgroups'], n_resblocks=opt_net['n_resblocks'], n_feats=opt_net['n_feats'], res_scale=opt_net['res_scale'], n_colors=opt_net['n_colors'], rgb_range=opt_net['rgb_range'], scale=opt_net['scale']) 21 | else: 22 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) 23 | 24 | return netG 25 | -------------------------------------------------------------------------------- /codes/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__init__.py -------------------------------------------------------------------------------- /codes/options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /codes/options/__pycache__/options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/options.cpython-36.pyc -------------------------------------------------------------------------------- /codes/options/__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/options/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale = opt['scale'] 20 | 21 | # datasets 22 | for phase, dataset in opt['datasets'].items(): 23 | phase = phase.split('_')[0] 24 | dataset['phase'] = phase 25 | if opt['distortion'] == 'sr': 26 | dataset['scale'] = scale 27 | is_lmdb = False 28 | if dataset.get('dataroot_GT', None) is not None: 29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 30 | if dataset['dataroot_GT'].endswith('lmdb'): 31 | is_lmdb = True 32 | if dataset.get('dataroot_LQ', None) is not None: 33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 34 | if dataset['dataroot_LQ'].endswith('lmdb'): 35 | is_lmdb = True 36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 37 | if dataset['mode'].endswith('mc'): # for memcached 38 | dataset['data_type'] = 'mc' 39 | dataset['mode'] = dataset['mode'].replace('_mc', '') 40 | 41 | # path 42 | for key, path in opt['path'].items(): 43 | if path and key in opt['path'] and key != 'strict_load': 44 | opt['path'][key] = osp.expanduser(path) 45 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 46 | if is_train: 47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 48 | opt['path']['experiments_root'] = experiments_root 49 | opt['path']['models'] = osp.join(experiments_root, 'models') 50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 51 | opt['path']['log'] = experiments_root 52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 53 | 54 | # change some options for debug mode 55 | if 'debug' in opt['name']: 56 | opt['train']['val_freq'] = 8 57 | opt['logger']['print_freq'] = 1 58 | opt['logger']['save_checkpoint_freq'] = 8 59 | else: # test 60 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 61 | opt['path']['results_root'] = results_root 62 | opt['path']['log'] = results_root 63 | 64 | # network 65 | if opt['distortion'] == 'sr': 66 | opt['network_G']['scale'] = scale 67 | 68 | return opt 69 | 70 | 71 | def dict2str(opt, indent_l=1): 72 | '''dict to string for logger''' 73 | msg = '' 74 | for k, v in opt.items(): 75 | if isinstance(v, dict): 76 | msg += ' ' * (indent_l * 2) + k + ':[\n' 77 | msg += dict2str(v, indent_l + 1) 78 | msg += ' ' * (indent_l * 2) + ']\n' 79 | else: 80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 81 | return msg 82 | 83 | 84 | class NoneDict(dict): 85 | def __missing__(self, key): 86 | return None 87 | 88 | 89 | # convert to NoneDict, which return None for missing key. 90 | def dict_to_nonedict(opt): 91 | if isinstance(opt, dict): 92 | new_opt = dict() 93 | for key, sub_opt in opt.items(): 94 | new_opt[key] = dict_to_nonedict(sub_opt) 95 | return NoneDict(**new_opt) 96 | elif isinstance(opt, list): 97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 98 | else: 99 | return opt 100 | 101 | 102 | def check_resume(opt, resume_iter): 103 | '''Check resume states and pretrain_model paths''' 104 | logger = logging.getLogger('base') 105 | if opt['path']['resume_state']: 106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 107 | 'pretrain_model_D', None) is not None: 108 | logger.warning('pretrain_model path will be ignored when resuming training.') 109 | 110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 111 | '{}_G.pth'.format(resume_iter)) 112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 113 | if 'gan' in opt['model']: 114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 115 | '{}_D.pth'.format(resume_iter)) 116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 117 | -------------------------------------------------------------------------------- /codes/options/test/test_PANx2.yml: -------------------------------------------------------------------------------- 1 | name: PANx2_DF2K 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 2 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | save_img: True 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test1: 12 | name: Set5 13 | mode: LQGT 14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR 15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X2 16 | test2: 17 | name: Set14 18 | mode: LQGT 19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR 20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X2 21 | test3: 22 | name: B100 23 | mode: LQGT 24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR 25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X2 26 | test4: 27 | name: Urban100 28 | mode: LQGT 29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR 30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X2 31 | test5: 32 | name: Manga109 33 | mode: LQGT 34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR 35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X2 36 | 37 | #### network structures 38 | network_G: 39 | which_model_G: PAN 40 | in_nc: 3 41 | out_nc: 3 42 | nf: 40 43 | unf: 24 44 | nb: 16 45 | scale: 2 46 | 47 | #### path 48 | path: 49 | pretrain_model_G: ../experiments/pretrained_models/PANx2_DF2K.pth 50 | -------------------------------------------------------------------------------- /codes/options/test/test_PANx3.yml: -------------------------------------------------------------------------------- 1 | name: PANx3_DF2K 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 3 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | save_img: True 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test1: 12 | name: Set5 13 | mode: LQGT 14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR 15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X3 16 | test2: 17 | name: Set14 18 | mode: LQGT 19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR 20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X3 21 | test3: 22 | name: B100 23 | mode: LQGT 24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR 25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X3 26 | test4: 27 | name: Urban100 28 | mode: LQGT 29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR 30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X3 31 | test5: 32 | name: Manga109 33 | mode: LQGT 34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR 35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X3 36 | 37 | #### network structures 38 | network_G: 39 | which_model_G: PAN 40 | in_nc: 3 41 | out_nc: 3 42 | nf: 40 43 | unf: 24 44 | nb: 16 45 | scale: 3 46 | 47 | #### path 48 | path: 49 | pretrain_model_G: ../experiments/pretrained_models/PANx3_DF2K.pth 50 | -------------------------------------------------------------------------------- /codes/options/test/test_PANx4.yml: -------------------------------------------------------------------------------- 1 | name: PANx4_DF2K 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | save_img: True 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test1: 12 | name: Set5 13 | mode: LQGT 14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR 15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X4 16 | test2: 17 | name: Set14 18 | mode: LQGT 19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR 20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X4 21 | test3: 22 | name: B100 23 | mode: LQGT 24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR 25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X4 26 | test4: 27 | name: Urban100 28 | mode: LQGT 29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR 30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X4 31 | test5: 32 | name: Manga109 33 | mode: LQGT 34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR 35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X4 36 | 37 | #### network structures 38 | network_G: 39 | which_model_G: PAN 40 | in_nc: 3 41 | out_nc: 3 42 | nf: 40 43 | unf: 24 44 | nb: 16 45 | scale: 4 46 | 47 | #### path 48 | path: 49 | pretrain_model_G: ../experiments/pretrained_models/PANx4_DF2K.pth 50 | -------------------------------------------------------------------------------- /codes/options/test/test_RCAN.yml: -------------------------------------------------------------------------------- 1 | name: RCANPAx4_DIV2K 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | save_img: False 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test1: 12 | name: Set5 13 | mode: LQGT 14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR 15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X4 16 | test2: 17 | name: Set14 18 | mode: LQGT 19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR 20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X4 21 | test3: 22 | name: B100 23 | mode: LQGT 24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR 25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X4 26 | test4: 27 | name: Urban100 28 | mode: LQGT 29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR 30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X4 31 | test5: 32 | name: Manga109 33 | mode: LQGT 34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR 35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X4 36 | 37 | #### network structures 38 | network_G: 39 | which_model_G: RCAN_PA 40 | n_resgroups: 10 41 | n_resblocks: 20 42 | n_feats: 64 43 | res_scale: 1 44 | n_colors: 3 45 | rgb_range: 1 46 | scale: 4 47 | 48 | #### path 49 | path: 50 | pretrain_model_G: ../experiments/pretrained_models/RCAN_PA_DIV2K.pth -------------------------------------------------------------------------------- /codes/options/test/test_SRResNet.yml: -------------------------------------------------------------------------------- 1 | name: MSSResNet_PA_DIV2K 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | save_img: True 8 | gpu_ids: [0] 9 | 10 | datasets: 11 | test1: 12 | name: Set5 13 | mode: LQGT 14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set5/HR 15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set5/LR_bicubic/X4 16 | test2: 17 | name: Set14 18 | mode: LQGT 19 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Set14/HR 20 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Set14/LR_bicubic/X4 21 | test3: 22 | name: B100 23 | mode: LQGT 24 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/B100/HR 25 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/B100/LR_bicubic/X4 26 | test4: 27 | name: Urban100 28 | mode: LQGT 29 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/HR 30 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Urban100/LR_bicubic/X4 31 | test5: 32 | name: Manga109 33 | mode: LQGT 34 | dataroot_GT: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/HR 35 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/benchmark/Manga109/LR_bicubic/X4 36 | #### network structures 37 | network_G: 38 | which_model_G: MSRResNet_PA 39 | in_nc: 3 40 | out_nc: 3 41 | nf: 64 42 | nb: 16 43 | upscale: 4 44 | 45 | #### path 46 | path: 47 | pretrain_model_G: ../experiments/pretrained_models/SRResNet_PA_DIV2K.pth -------------------------------------------------------------------------------- /codes/options/train/train_PANx2.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: PANx2_DF2K 3 | use_tb_logger: True 4 | model: sr 5 | distortion: sr 6 | scale: 2 7 | save_img: False 8 | gpu_ids: [0] 9 | 10 | #### datasets 11 | datasets: 12 | train: 13 | name: DF2K 14 | mode: LQGT 15 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DF2K_train/HRx2_sub360 16 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DF2K_train/LRx2_sub180 17 | 18 | use_shuffle: true 19 | n_workers: 6 # per GPU 20 | batch_size: 32 21 | GT_size: 128 22 | use_flip: true 23 | use_rot: true 24 | color: RGB 25 | val: 26 | name: Set5 27 | mode: LQGT 28 | dataroot_GT: ../datasets/Set5/HR 29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X2 30 | 31 | #### network structures 32 | network_G: 33 | which_model_G: PAN 34 | in_nc: 3 35 | out_nc: 3 36 | nf: 40 37 | unf: 24 38 | nb: 16 39 | scale: 2 40 | 41 | #### path 42 | path: 43 | pretrain_model_G: ~ 44 | strict_load: true 45 | resume_state: ~ 46 | 47 | #### training settings: learning rate scheme, loss 48 | train: 49 | lr_G: !!float 7e-4 50 | lr_scheme: CosineAnnealingLR_Restart 51 | beta1: 0.9 52 | beta2: 0.99 53 | niter: 1000000 54 | warmup_iter: -1 # no warm up 55 | T_period: [250000, 250000, 250000, 250000] 56 | restarts: [250000, 500000, 750000] 57 | restart_weights: [1, 1, 1] 58 | eta_min: !!float 1e-7 59 | 60 | pixel_criterion: l1 61 | pixel_weight: 1.0 62 | 63 | manual_seed: 10 64 | val_freq: !!float 5e3 65 | 66 | #### logger 67 | logger: 68 | print_freq: 100 69 | save_checkpoint_freq: !!float 5e3 -------------------------------------------------------------------------------- /codes/options/train/train_PANx3.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: PANx3_DF2K 3 | use_tb_logger: True 4 | model: sr 5 | distortion: sr 6 | scale: 3 7 | save_img: False 8 | gpu_ids: [0] 9 | 10 | #### datasets 11 | datasets: 12 | train: 13 | name: DF2K 14 | mode: LQGT 15 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DF2K_train/HRx3_sub360 16 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DF2K_train/LRx3_sub120 17 | 18 | use_shuffle: true 19 | n_workers: 6 # per GPU 20 | batch_size: 32 21 | GT_size: 192 22 | use_flip: true 23 | use_rot: true 24 | color: RGB 25 | val: 26 | name: Set5 27 | mode: LQGT 28 | dataroot_GT: ../datasets/Set5/HR 29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X3 30 | 31 | #### network structures 32 | network_G: 33 | which_model_G: PAN 34 | in_nc: 3 35 | out_nc: 3 36 | nf: 40 37 | unf: 24 38 | nb: 16 39 | scale: 3 40 | 41 | #### path 42 | path: 43 | pretrain_model_G: ~ 44 | strict_load: true 45 | resume_state: ~ 46 | 47 | #### training settings: learning rate scheme, loss 48 | train: 49 | lr_G: !!float 7e-4 50 | lr_scheme: CosineAnnealingLR_Restart 51 | beta1: 0.9 52 | beta2: 0.99 53 | niter: 1000000 54 | warmup_iter: -1 # no warm up 55 | T_period: [250000, 250000, 250000, 250000] 56 | restarts: [250000, 500000, 750000] 57 | restart_weights: [1, 1, 1] 58 | eta_min: !!float 1e-7 59 | 60 | pixel_criterion: l1 61 | pixel_weight: 1.0 62 | 63 | manual_seed: 10 64 | val_freq: !!float 5e3 65 | 66 | #### logger 67 | logger: 68 | print_freq: 100 69 | save_checkpoint_freq: !!float 5e3 -------------------------------------------------------------------------------- /codes/options/train/train_PANx4.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: PANx4_DF2K 3 | use_tb_logger: True 4 | model: sr 5 | distortion: sr 6 | scale: 4 7 | save_img: False 8 | gpu_ids: [0] 9 | 10 | #### datasets 11 | datasets: 12 | train: 13 | name: DF2K 14 | mode: LQGT 15 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DF2K_train/HR_sub360 16 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DF2K_train/LR_sub90 17 | 18 | use_shuffle: true 19 | n_workers: 6 # per GPU 20 | batch_size: 32 21 | GT_size: 256 22 | use_flip: true 23 | use_rot: true 24 | color: RGB 25 | val: 26 | name: Set5 27 | mode: LQGT 28 | dataroot_GT: ../datasets/Set5/HR 29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 30 | 31 | #### network structures 32 | network_G: 33 | which_model_G: PAN 34 | in_nc: 3 35 | out_nc: 3 36 | nf: 40 37 | unf: 24 38 | nb: 16 39 | scale: 4 40 | 41 | #### path 42 | path: 43 | pretrain_model_G: ~ 44 | strict_load: true 45 | resume_state: ~ 46 | 47 | #### training settings: learning rate scheme, loss 48 | train: 49 | lr_G: !!float 7e-4 50 | lr_scheme: CosineAnnealingLR_Restart 51 | beta1: 0.9 52 | beta2: 0.99 53 | niter: 1000000 54 | warmup_iter: -1 # no warm up 55 | T_period: [250000, 250000, 250000, 250000] 56 | restarts: [250000, 500000, 750000] 57 | restart_weights: [1, 1, 1] 58 | eta_min: !!float 1e-7 59 | 60 | pixel_criterion: l1 61 | pixel_weight: 1.0 62 | 63 | manual_seed: 10 64 | val_freq: !!float 5e3 65 | 66 | #### logger 67 | logger: 68 | print_freq: 100 69 | save_checkpoint_freq: !!float 5e3 -------------------------------------------------------------------------------- /codes/options/train/train_RCAN.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: RCANPAx4_DIV2K 3 | use_tb_logger: True 4 | model: sr 5 | distortion: sr 6 | scale: 4 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: DIV2K 13 | mode: LQGT 14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DIV2K_train800/HR_sub480 15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DIV2K_train800/LR_sub120 16 | 17 | use_shuffle: true 18 | n_workers: 6 # per GPU 19 | batch_size: 16 20 | GT_size: 192 21 | use_flip: true 22 | use_rot: true 23 | color: RGB 24 | val: 25 | name: Set5 26 | mode: LQGT 27 | dataroot_GT: ../datasets/Set5/HR 28 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 29 | 30 | #### network structures 31 | network_G: 32 | which_model_G: RCAN_PA 33 | n_resgroups: 10 34 | n_resblocks: 20 35 | n_feats: 64 36 | res_scale: 1 37 | n_colors: 3 38 | rgb_range: 1 39 | scale: 4 40 | reduction: 16 41 | 42 | #### path 43 | path: 44 | pretrain_model_G: ~ 45 | strict_load: true 46 | resume_state: ~ 47 | 48 | #### training settings: learning rate scheme, loss 49 | train: 50 | lr_G: !!float 2e-5 51 | lr_scheme: CosineAnnealingLR_Restart 52 | beta1: 0.9 53 | beta2: 0.99 54 | niter: 200000 55 | warmup_iter: -1 # no warm up 56 | T_period: [200000, 200000, 200000, 200000] 57 | restarts: [200000, 400000, 600000] 58 | restart_weights: [1, 1, 1] 59 | eta_min: !!float 1e-7 60 | 61 | pixel_criterion: l1 62 | pixel_weight: 1.0 63 | 64 | manual_seed: 10 65 | val_freq: !!float 2e3 66 | 67 | #### logger 68 | logger: 69 | print_freq: 100 70 | save_checkpoint_freq: !!float 2e3 71 | -------------------------------------------------------------------------------- /codes/options/train/train_SRResNet.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: MSSResNet_PA_DIV2K 3 | use_tb_logger: true 4 | model: sr 5 | distortion: sr 6 | scale: 4 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: DIV2K 13 | mode: LQGT 14 | dataroot_GT: /mnt/hyzhao/Documents/datasets/DIV2K_train800/HR_sub480 15 | dataroot_LQ: /mnt/hyzhao/Documents/datasets/DIV2K_train800/LR_sub120 16 | 17 | use_shuffle: true 18 | n_workers: 6 # per GPU 19 | batch_size: 16 20 | GT_size: 256 21 | use_flip: true 22 | use_rot: true 23 | color: RGB 24 | val: 25 | name: Set5 26 | mode: LQGT 27 | dataroot_GT: ../datasets/Set5/HR 28 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 29 | 30 | #### network structures 31 | network_G: 32 | which_model_G: MSRResNet_PA 33 | in_nc: 3 34 | out_nc: 3 35 | nf: 64 36 | nb: 16 37 | upscale: 4 38 | 39 | #### path 40 | path: 41 | pretrain_model_G: ~ 42 | strict_load: true 43 | resume_state: ~ 44 | 45 | #### training settings: learning rate scheme, loss 46 | train: 47 | lr_G: !!float 2e-4 48 | lr_scheme: CosineAnnealingLR_Restart 49 | beta1: 0.9 50 | beta2: 0.99 51 | niter: 500000 52 | warmup_iter: -1 # no warm up 53 | T_period: [250000, 250000, 250000, 250000] 54 | restarts: [250000, 500000, 750000] 55 | restart_weights: [1, 1, 1] 56 | eta_min: !!float 1e-7 57 | 58 | pixel_criterion: l1 59 | pixel_weight: 1.0 60 | 61 | manual_seed: 10 62 | val_freq: !!float 5e3 63 | 64 | #### logger 65 | logger: 66 | print_freq: 100 67 | save_checkpoint_freq: !!float 5e3 68 | -------------------------------------------------------------------------------- /codes/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | torchsummaryX 4 | lmdb 5 | pyyaml 6 | tb-nightly 7 | future 8 | -------------------------------------------------------------------------------- /codes/run_scripts.sh: -------------------------------------------------------------------------------- 1 | # Testing x2, x3, x4 2 | python test.py -opt options/test/test_PANx2.yml 3 | python test.py -opt options/test/test_PANx3.yml 4 | python test.py -opt options/test/test_PANx4.yml 5 | 6 | 7 | # Training x2, x3, x4 8 | python train.py -opt options/train/train_PANx2.yml 9 | python train.py -opt options/train/train_PANx3.yml 10 | python train.py -opt options/train/train_PANx4.yml 11 | 12 | 13 | # Training SRResNet_PA or RCAN_PA 14 | python train.py -opt options/train/train_SRResNet.yml 15 | python train.py -opt options/train/train_RCAN.yml 16 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/backprojection.m: -------------------------------------------------------------------------------- 1 | function [im_h] = backprojection(im_h, im_l, maxIter) 2 | 3 | [row_l, col_l,~] = size(im_l); 4 | [row_h, col_h,~] = size(im_h); 5 | 6 | p = fspecial('gaussian', 5, 1); 7 | p = p.^2; 8 | p = p./sum(p(:)); 9 | 10 | im_l = double(im_l); 11 | im_h = double(im_h); 12 | 13 | for ii = 1:maxIter 14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic'); 15 | im_diff = im_l - im_l_s; 16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic'); 17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same'); 18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same'); 19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same'); 20 | end 21 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/main_bp.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20bp'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | %tic 19 | im_out = backprojection(im_out, im_LR, max_iter); 20 | %toc 21 | imwrite(im_out, fullfile(save_folder, im_name)); 22 | end 23 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/main_reverse_filter.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20if'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | J = imresize(im_LR,4,'bicubic'); 19 | %tic 20 | for m = 1:max_iter 21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic')); 22 | end 23 | %toc 24 | imwrite(im_out, fullfile(save_folder, im_name)); 25 | end 26 | -------------------------------------------------------------------------------- /codes/scripts/transfer_params_MSRResNet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import torch 4 | try: 5 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 6 | import models.archs.SRResNet_arch as SRResNet_arch 7 | except ImportError: 8 | pass 9 | 10 | pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth') 11 | crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3) 12 | crt_net = crt_model.state_dict() 13 | 14 | for k, v in crt_net.items(): 15 | if k in pretrained_net and 'upconv1' not in k: 16 | crt_net[k] = pretrained_net[k] 17 | print('replace ... ', k) 18 | 19 | # x4 -> x3 20 | crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2 21 | crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2 22 | crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2 23 | crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2 24 | crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2 25 | crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2 26 | 27 | torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth') 28 | -------------------------------------------------------------------------------- /codes/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import logging 3 | import time 4 | import argparse 5 | from collections import OrderedDict 6 | 7 | import options.options as option 8 | import utils.util as util 9 | from data.util import bgr2ycbcr 10 | from data import create_dataset, create_dataloader 11 | from models import create_model 12 | 13 | #### options 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 16 | opt = option.parse(parser.parse_args().opt, is_train=False) 17 | opt = option.dict_to_nonedict(opt) 18 | 19 | util.mkdirs( 20 | (path for key, path in opt['path'].items() 21 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 22 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 23 | screen=True, tofile=True) 24 | logger = logging.getLogger('base') 25 | logger.info(option.dict2str(opt)) 26 | 27 | #### Create test dataset and dataloader 28 | test_loaders = [] 29 | for phase, dataset_opt in sorted(opt['datasets'].items()): 30 | test_set = create_dataset(dataset_opt) 31 | test_loader = create_dataloader(test_set, dataset_opt) 32 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 33 | test_loaders.append(test_loader) 34 | 35 | model = create_model(opt) 36 | for test_loader in test_loaders: 37 | test_set_name = test_loader.dataset.opt['name'] 38 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 39 | test_start_time = time.time() 40 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 41 | util.mkdir(dataset_dir) 42 | 43 | test_results = OrderedDict() 44 | test_results['psnr'] = [] 45 | test_results['ssim'] = [] 46 | test_results['psnr_y'] = [] 47 | test_results['ssim_y'] = [] 48 | 49 | for data in test_loader: 50 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True 51 | model.feed_data(data, need_GT=need_GT) 52 | img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] 53 | img_name = osp.splitext(osp.basename(img_path))[0] 54 | 55 | model.test() 56 | visuals = model.get_current_visuals(need_GT=need_GT) 57 | 58 | sr_img = util.tensor2img(visuals['rlt']) # uint8 59 | 60 | # save images 61 | suffix = opt['suffix'] 62 | if suffix: 63 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') 64 | else: 65 | save_img_path = osp.join(dataset_dir, img_name + '.png') 66 | if opt['save_img']: 67 | util.save_img(sr_img, save_img_path) 68 | 69 | # calculate PSNR and SSIM 70 | if need_GT: 71 | gt_img = util.tensor2img(visuals['GT']) 72 | sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) 73 | psnr = util.calculate_psnr(sr_img, gt_img) 74 | ssim = util.calculate_ssim(sr_img, gt_img) 75 | test_results['psnr'].append(psnr) 76 | test_results['ssim'].append(ssim) 77 | 78 | if gt_img.shape[2] == 3: # RGB image 79 | sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) 80 | gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) 81 | 82 | psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) 83 | ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) 84 | test_results['psnr_y'].append(psnr_y) 85 | test_results['ssim_y'].append(ssim_y) 86 | logger.info( 87 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'. 88 | format(img_name, psnr, ssim, psnr_y, ssim_y)) 89 | else: 90 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) 91 | else: 92 | logger.info(img_name) 93 | 94 | if need_GT: # metrics 95 | # Average PSNR/SSIM results 96 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 97 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 98 | logger.info( 99 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( 100 | test_set_name, ave_psnr, ave_ssim)) 101 | if test_results['psnr_y'] and test_results['ssim_y']: 102 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) 103 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) 104 | logger.info( 105 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'. 106 | format(ave_psnr_y, ave_ssim_y)) 107 | -------------------------------------------------------------------------------- /codes/test_running_time.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import torch 5 | import cv2 6 | import numpy as np 7 | import models.archs.PAN_arch as PAN_arch 8 | import utils.util as util 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "7" 11 | 12 | def main(): 13 | ## test dataset 14 | test_d = sorted(glob.glob('/mnt/hyzhao/Documents/datasets/DIV2K_test/*.png')) 15 | 16 | torch.cuda.current_device() 17 | torch.cuda.empty_cache() 18 | torch.backends.cudnn.benchmark = False 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | ## some functions 22 | def readimg(path): 23 | im = cv2.imread(path, cv2.IMREAD_UNCHANGED) 24 | im = im.astype(np.float32) / 255. 25 | im = im[:, :, [2, 1, 0]] 26 | return im 27 | 28 | def img2tensor(img): 29 | imgt = torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float()[None, ...] 30 | return imgt 31 | 32 | ## load model 33 | scale = 4 34 | model = PAN_arch.PAN(in_nc=3, out_nc=3, nf=40, unf=24, nb=16, scale=scale) 35 | model_weight = torch.load('../experiments/pretrained_models/PANx%d_DF2K.pth'%(scale)) 36 | model.load_state_dict(model_weight, strict=True) 37 | model.eval() 38 | for k, v in model.named_parameters(): 39 | v.requires_grad = False 40 | model = model.to(device) 41 | 42 | number_parameters = sum(map(lambda x: x.numel(), model.parameters())) 43 | 44 | ## runnning 45 | print('-----------------Start Running-----------------') 46 | psnrs = [] 47 | times = [] 48 | 49 | start = torch.cuda.Event(enable_timing=True) 50 | end = torch.cuda.Event(enable_timing=True) 51 | 52 | for i in range(len(test_d)): 53 | im = readimg(test_d[i]) 54 | img_LR = img2tensor(im) 55 | img_LR = img_LR.to(device) 56 | 57 | start.record() 58 | img_SR = model(img_LR) 59 | end.record() 60 | 61 | torch.cuda.synchronize() 62 | times.append(start.elapsed_time(end)) 63 | 64 | sr_img = util.tensor2img(img_SR.detach()) 65 | 66 | print('Image: %03d, Time: %.10f'%(i+1, times[-1])) 67 | print('Paramters: %d, Mean Time: %.10f'%(number_parameters, np.mean(times)/1000.)) 68 | 69 | if __name__ == '__main__': 70 | 71 | main() -------------------------------------------------------------------------------- /codes/test_summary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchsummaryX import summary 3 | 4 | import models.archs.PAN_arch as PAN_arch 5 | 6 | model = PAN_arch.PAN(in_nc=3, out_nc=3, nf=40, unf=24, nb=16, scale=4) 7 | 8 | summary(model, torch.zeros((1, 3, 32, 32))) 9 | 10 | # input LR x2, HR size is 720p 11 | # summary(model, torch.zeros((1, 3, 640, 360))) 12 | 13 | # input LR x3, HR size is 720p 14 | # summary(model, torch.zeros((1, 3, 426, 240))) 15 | 16 | # input LR x4, HR size is 720p 17 | # summary(model, torch.zeros((1, 3, 320, 180))) 18 | -------------------------------------------------------------------------------- /codes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__init__.py -------------------------------------------------------------------------------- /codes/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/codes/utils/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /codes/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import torch.nn.functional as F 6 | from datetime import datetime 7 | import random 8 | import logging 9 | from collections import OrderedDict 10 | import numpy as np 11 | import cv2 12 | import torch 13 | from torchvision.utils import make_grid 14 | from shutil import get_terminal_size 15 | 16 | import yaml 17 | try: 18 | from yaml import CLoader as Loader, CDumper as Dumper 19 | except ImportError: 20 | from yaml import Loader, Dumper 21 | 22 | 23 | def OrderedYaml(): 24 | '''yaml orderedDict support''' 25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 26 | 27 | def dict_representer(dumper, data): 28 | return dumper.represent_dict(data.items()) 29 | 30 | def dict_constructor(loader, node): 31 | return OrderedDict(loader.construct_pairs(node)) 32 | 33 | Dumper.add_representer(OrderedDict, dict_representer) 34 | Loader.add_constructor(_mapping_tag, dict_constructor) 35 | return Loader, Dumper 36 | 37 | 38 | #################### 39 | # miscellaneous 40 | #################### 41 | 42 | 43 | def get_timestamp(): 44 | return datetime.now().strftime('%y%m%d-%H%M%S') 45 | 46 | 47 | def mkdir(path): 48 | if not os.path.exists(path): 49 | os.makedirs(path) 50 | 51 | 52 | def mkdirs(paths): 53 | if isinstance(paths, str): 54 | mkdir(paths) 55 | else: 56 | for path in paths: 57 | mkdir(path) 58 | 59 | 60 | def mkdir_and_rename(path): 61 | if os.path.exists(path): 62 | new_name = path + '_archived_' + get_timestamp() 63 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 64 | logger = logging.getLogger('base') 65 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 66 | os.rename(path, new_name) 67 | os.makedirs(path) 68 | 69 | 70 | def set_random_seed(seed): 71 | random.seed(seed) 72 | np.random.seed(seed) 73 | torch.manual_seed(seed) 74 | torch.cuda.manual_seed_all(seed) 75 | 76 | 77 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 78 | '''set up logger''' 79 | lg = logging.getLogger(logger_name) 80 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 81 | datefmt='%y-%m-%d %H:%M:%S') 82 | lg.setLevel(level) 83 | if tofile: 84 | log_file = os.path.join(root, phase + 'train_{}.log'.format(get_timestamp())) 85 | fh = logging.FileHandler(log_file, mode='w') 86 | fh.setFormatter(formatter) 87 | lg.addHandler(fh) 88 | if screen: 89 | sh = logging.StreamHandler() 90 | sh.setFormatter(formatter) 91 | lg.addHandler(sh) 92 | 93 | 94 | #################### 95 | # image convert 96 | #################### 97 | def crop_border(img_list, crop_border): 98 | """Crop borders of images 99 | Args: 100 | img_list (list [Numpy]): HWC 101 | crop_border (int): crop border for each end of height and weight 102 | 103 | Returns: 104 | (list [Numpy]): cropped image list 105 | """ 106 | if crop_border == 0: 107 | return img_list 108 | else: 109 | return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list] 110 | 111 | 112 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 113 | ''' 114 | Converts a torch Tensor into an image Numpy array 115 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 116 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 117 | ''' 118 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 119 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 120 | n_dim = tensor.dim() 121 | if n_dim == 4: 122 | n_img = len(tensor) 123 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 124 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 125 | elif n_dim == 3: 126 | img_np = tensor.numpy() 127 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 128 | elif n_dim == 2: 129 | img_np = tensor.numpy() 130 | else: 131 | raise TypeError( 132 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 133 | if out_type == np.uint8: 134 | img_np = (img_np * 255.0).round() 135 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 136 | return img_np.astype(out_type) 137 | 138 | 139 | def save_img(img, img_path, mode='RGB'): 140 | cv2.imwrite(img_path, img) 141 | 142 | 143 | def DUF_downsample(x, scale=4): 144 | """Downsamping with Gaussian kernel used in the DUF official code 145 | 146 | Args: 147 | x (Tensor, [B, T, C, H, W]): frames to be downsampled. 148 | scale (int): downsampling factor: 2 | 3 | 4. 149 | """ 150 | 151 | assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale) 152 | 153 | def gkern(kernlen=13, nsig=1.6): 154 | import scipy.ndimage.filters as fi 155 | inp = np.zeros((kernlen, kernlen)) 156 | # set element at the middle to one, a dirac delta 157 | inp[kernlen // 2, kernlen // 2] = 1 158 | # gaussian-smooth the dirac, resulting in a gaussian filter mask 159 | return fi.gaussian_filter(inp, nsig) 160 | 161 | B, T, C, H, W = x.size() 162 | x = x.view(-1, 1, H, W) 163 | pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter 164 | r_h, r_w = 0, 0 165 | if scale == 3: 166 | r_h = 3 - (H % 3) 167 | r_w = 3 - (W % 3) 168 | x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect') 169 | 170 | gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0) 171 | x = F.conv2d(x, gaussian_filter, stride=scale) 172 | x = x[:, :, 2:-2, 2:-2] 173 | x = x.view(B, T, C, x.size(2), x.size(3)) 174 | return x 175 | 176 | 177 | def single_forward(model, inp): 178 | """PyTorch model forward (single test), it is just a simple warpper 179 | Args: 180 | model (PyTorch model) 181 | inp (Tensor): inputs defined by the model 182 | 183 | Returns: 184 | output (Tensor): outputs of the model. float, in CPU 185 | """ 186 | with torch.no_grad(): 187 | model_output = model(inp) 188 | if isinstance(model_output, list) or isinstance(model_output, tuple): 189 | output = model_output[0] 190 | else: 191 | output = model_output 192 | output = output.data.float().cpu() 193 | return output 194 | 195 | 196 | def flipx4_forward(model, inp): 197 | """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W 198 | Args: 199 | model (PyTorch model) 200 | inp (Tensor): inputs defined by the model 201 | 202 | Returns: 203 | output (Tensor): outputs of the model. float, in CPU 204 | """ 205 | # normal 206 | output_f = single_forward(model, inp) 207 | 208 | # flip W 209 | output = single_forward(model, torch.flip(inp, (-1, ))) 210 | output_f = output_f + torch.flip(output, (-1, )) 211 | # flip H 212 | output = single_forward(model, torch.flip(inp, (-2, ))) 213 | output_f = output_f + torch.flip(output, (-2, )) 214 | # flip both H and W 215 | output = single_forward(model, torch.flip(inp, (-2, -1))) 216 | output_f = output_f + torch.flip(output, (-2, -1)) 217 | 218 | return output_f / 4 219 | 220 | def flipxrot_forward(model, inp): 221 | # normal 222 | output_f = single_forward(model, inp) 223 | 224 | # flip W 225 | output = single_forward(model, torch.flip(inp, (-1, ))) 226 | output_f = output_f + torch.flip(output, (-1, )) 227 | 228 | # flip H 229 | output = single_forward(model, torch.flip(inp, (-2, ))) 230 | output_f = output_f + torch.flip(output, (-2, )) 231 | 232 | # flip both H and W 233 | output = single_forward(model, torch.flip(inp, (-2, -1))) 234 | output_f = output_f + torch.flip(output, (-2, -1)) 235 | 236 | # rot90 237 | output = single_forward(model, torch.rot90(inp, 1, (-1, -2))) 238 | output_f = output_f + torch.rot90(output, 3, (-1, -2)) 239 | 240 | # rot270 241 | output = single_forward(model, torch.rot90(inp, 3, (-1, -2))) 242 | output_f = output_f + torch.rot90(output, 1, (-1, -2)) 243 | 244 | # flip W rot90 245 | output = single_forward(model, torch.rot90(torch.flip(inp, (-1, )), 1, (-1, -2))) 246 | output_f = output_f + torch.flip(torch.rot90(output, 3, (-1, -2)), (-1, )) 247 | 248 | # flip H rot90 249 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, )), 1, (-1, -2))) 250 | output_f = output_f + torch.flip(torch.rot90(output, 3, (-1, -2)), (-2, )) 251 | 252 | # flip both H and W rot90 253 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, -1)), 1, (-1, -2))) 254 | output_f = output_f + torch.flip(torch.rot90(output, 3, (-1, -2)), (-2, -1)) 255 | 256 | # flip W rot270 257 | output = single_forward(model, torch.rot90(torch.flip(inp, (-1, )), 3, (-1, -2))) 258 | output_f = output_f + torch.flip(torch.rot90(output, 1, (-1, -2)), (-1, )) 259 | 260 | # flip H rot270 261 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, )), 3, (-1, -2))) 262 | output_f = output_f + torch.flip(torch.rot90(output, 1, (-1, -2)), (-2, )) 263 | 264 | # flip both H and W rot270 265 | output = single_forward(model, torch.rot90(torch.flip(inp, (-2, -1)), 3, (-1, -2))) 266 | output_f = output_f + torch.flip(torch.rot90(output, 1, (-1, -2)), (-2, -1)) 267 | 268 | return output_f / 12 269 | 270 | #################### 271 | # metric 272 | #################### 273 | 274 | 275 | def calculate_psnr(img1, img2): 276 | # img1 and img2 have range [0, 255] 277 | img1 = img1.astype(np.float64) 278 | img2 = img2.astype(np.float64) 279 | mse = np.mean((img1 - img2)**2) 280 | if mse == 0: 281 | return float('inf') 282 | return 20 * math.log10(255.0 / math.sqrt(mse)) 283 | 284 | 285 | def ssim(img1, img2): 286 | C1 = (0.01 * 255)**2 287 | C2 = (0.03 * 255)**2 288 | 289 | img1 = img1.astype(np.float64) 290 | img2 = img2.astype(np.float64) 291 | kernel = cv2.getGaussianKernel(11, 1.5) 292 | window = np.outer(kernel, kernel.transpose()) 293 | 294 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 295 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 296 | mu1_sq = mu1**2 297 | mu2_sq = mu2**2 298 | mu1_mu2 = mu1 * mu2 299 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 300 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 301 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 302 | 303 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 304 | (sigma1_sq + sigma2_sq + C2)) 305 | return ssim_map.mean() 306 | 307 | 308 | def calculate_ssim(img1, img2): 309 | '''calculate SSIM 310 | the same outputs as MATLAB's 311 | img1, img2: [0, 255] 312 | ''' 313 | if not img1.shape == img2.shape: 314 | raise ValueError('Input images must have the same dimensions.') 315 | if img1.ndim == 2: 316 | return ssim(img1, img2) 317 | elif img1.ndim == 3: 318 | if img1.shape[2] == 3: 319 | ssims = [] 320 | for i in range(3): 321 | ssims.append(ssim(img1, img2)) 322 | return np.array(ssims).mean() 323 | elif img1.shape[2] == 1: 324 | return ssim(np.squeeze(img1), np.squeeze(img2)) 325 | else: 326 | raise ValueError('Wrong input image dimensions.') 327 | 328 | 329 | class ProgressBar(object): 330 | '''A progress bar which can print the progress 331 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 332 | ''' 333 | 334 | def __init__(self, task_num=0, bar_width=50, start=True): 335 | self.task_num = task_num 336 | max_bar_width = self._get_max_bar_width() 337 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 338 | self.completed = 0 339 | if start: 340 | self.start() 341 | 342 | def _get_max_bar_width(self): 343 | terminal_width, _ = get_terminal_size() 344 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 345 | if max_bar_width < 10: 346 | print('terminal width is too small ({}), please consider widen the terminal for better ' 347 | 'progressbar visualization'.format(terminal_width)) 348 | max_bar_width = 10 349 | return max_bar_width 350 | 351 | def start(self): 352 | if self.task_num > 0: 353 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 354 | ' ' * self.bar_width, self.task_num, 'Start...')) 355 | else: 356 | sys.stdout.write('completed: 0, elapsed: 0s') 357 | sys.stdout.flush() 358 | self.start_time = time.time() 359 | 360 | def update(self, msg='In progress...'): 361 | self.completed += 1 362 | elapsed = time.time() - self.start_time 363 | fps = self.completed / elapsed 364 | if self.task_num > 0: 365 | percentage = self.completed / float(self.task_num) 366 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 367 | mark_width = int(self.bar_width * percentage) 368 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 369 | sys.stdout.write('\033[2F') # cursor up 2 lines 370 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 371 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 372 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 373 | else: 374 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 375 | self.completed, int(elapsed + 0.5), fps)) 376 | sys.stdout.flush() 377 | -------------------------------------------------------------------------------- /datasets/Set14/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/.DS_Store -------------------------------------------------------------------------------- /datasets/Set14/HR/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/baboon.png -------------------------------------------------------------------------------- /datasets/Set14/HR/barbara.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/barbara.png -------------------------------------------------------------------------------- /datasets/Set14/HR/bridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/bridge.png -------------------------------------------------------------------------------- /datasets/Set14/HR/coastguard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/coastguard.png -------------------------------------------------------------------------------- /datasets/Set14/HR/comic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/comic.png -------------------------------------------------------------------------------- /datasets/Set14/HR/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/face.png -------------------------------------------------------------------------------- /datasets/Set14/HR/flowers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/flowers.png -------------------------------------------------------------------------------- /datasets/Set14/HR/foreman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/foreman.png -------------------------------------------------------------------------------- /datasets/Set14/HR/lenna.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/lenna.png -------------------------------------------------------------------------------- /datasets/Set14/HR/man.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/man.png -------------------------------------------------------------------------------- /datasets/Set14/HR/monarch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/monarch.png -------------------------------------------------------------------------------- /datasets/Set14/HR/pepper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/pepper.png -------------------------------------------------------------------------------- /datasets/Set14/HR/ppt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/ppt3.png -------------------------------------------------------------------------------- /datasets/Set14/HR/zebra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/HR/zebra.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/baboonx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/baboonx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/barbarax2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/barbarax2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/bridgex2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/bridgex2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/coastguardx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/coastguardx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/comicx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/comicx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/facex2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/facex2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/flowersx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/flowersx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/foremanx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/foremanx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/lennax2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/lennax2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/manx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/manx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/monarchx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/monarchx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/pepperx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/pepperx2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/ppt3x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/ppt3x2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X2/zebrax2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X2/zebrax2.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/baboonx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/baboonx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/barbarax3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/barbarax3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/bridgex3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/bridgex3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/coastguardx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/coastguardx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/comicx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/comicx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/facex3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/facex3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/flowersx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/flowersx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/foremanx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/foremanx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/lennax3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/lennax3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/manx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/manx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/monarchx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/monarchx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/pepperx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/pepperx3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/ppt3x3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/ppt3x3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X3/zebrax3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X3/zebrax3.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/baboonx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/baboonx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/barbarax4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/barbarax4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/bridgex4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/bridgex4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/coastguardx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/coastguardx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/comicx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/comicx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/facex4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/facex4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/flowersx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/flowersx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/foremanx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/foremanx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/lennax4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/lennax4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/manx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/manx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/monarchx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/monarchx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/pepperx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/pepperx4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/ppt3x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/ppt3x4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X4/zebrax4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X4/zebrax4.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/baboonx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/baboonx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/barbarax8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/barbarax8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/bridgex8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/bridgex8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/coastguardx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/coastguardx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/comicx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/comicx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/facex8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/facex8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/flowersx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/flowersx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/foremanx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/foremanx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/lennax8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/lennax8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/manx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/manx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/monarchx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/monarchx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/pepperx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/pepperx8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/ppt3x8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/ppt3x8.png -------------------------------------------------------------------------------- /datasets/Set14/LR_bicubic/X8/zebrax8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set14/LR_bicubic/X8/zebrax8.png -------------------------------------------------------------------------------- /datasets/Set5/HR/baby.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/baby.png -------------------------------------------------------------------------------- /datasets/Set5/HR/bird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/bird.png -------------------------------------------------------------------------------- /datasets/Set5/HR/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/butterfly.png -------------------------------------------------------------------------------- /datasets/Set5/HR/head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/head.png -------------------------------------------------------------------------------- /datasets/Set5/HR/woman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/HR/woman.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X2/babyx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/babyx2.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X2/birdx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/birdx2.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X2/butterflyx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/butterflyx2.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X2/headx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/headx2.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X2/womanx2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X2/womanx2.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X3/babyx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/babyx3.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X3/birdx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/birdx3.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X3/butterflyx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/butterflyx3.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X3/headx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/headx3.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X3/womanx3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X3/womanx3.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X4/babyx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/babyx4.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X4/birdx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/birdx4.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X4/butterflyx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/butterflyx4.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X4/headx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/headx4.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X4/womanx4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X4/womanx4.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X8/babyx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/babyx8.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X8/birdx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/birdx8.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X8/butterflyx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/butterflyx8.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X8/headx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/headx8.png -------------------------------------------------------------------------------- /datasets/Set5/LR_bicubic/X8/womanx8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/datasets/Set5/LR_bicubic/X8/womanx8.png -------------------------------------------------------------------------------- /experiments/pretrained_models/PANx2_DF2K.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/experiments/pretrained_models/PANx2_DF2K.pth -------------------------------------------------------------------------------- /experiments/pretrained_models/PANx3_DF2K.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/experiments/pretrained_models/PANx3_DF2K.pth -------------------------------------------------------------------------------- /experiments/pretrained_models/PANx4_DF2K.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/experiments/pretrained_models/PANx4_DF2K.pth -------------------------------------------------------------------------------- /show_figs/main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/show_figs/main.jpg -------------------------------------------------------------------------------- /show_figs/main.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhaohengyuan1/PAN/4d045d18592c5771cd3990da88f292812774e538/show_figs/main.pdf --------------------------------------------------------------------------------