├── .gitattributes ├── .gitignore ├── README.md ├── codes ├── data │ ├── LQGT_dataset.py │ ├── LQ_dataset.py │ ├── README.md │ ├── Rank_IMIM_Pair_dataset.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── LQGT_dataset.cpython-36.pyc │ │ ├── LRHR_dataset.cpython-36.pyc │ │ ├── LR_dataset.cpython-36.pyc │ │ ├── Rank_IMIM_Pair_dataset.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── data_sampler.cpython-36.pyc │ │ └── util.cpython-36.pyc │ ├── data_sampler.py │ └── util.py ├── data_scripts │ ├── create_lmdb.py │ ├── extract_subimages.py │ ├── generate_LR_Vimeo90K.m │ ├── generate_mod_LR_bic.m │ ├── generate_mod_LR_bic.py │ ├── prepare_DIV2K_x4_dataset.sh │ ├── rename.py │ └── test_dataloader.py ├── metrics │ ├── calculate_PSNR_SSIM.m │ └── calculate_PSNR_SSIM.py ├── models │ ├── RankSRGAN_model.py │ ├── Ranker_model.py │ ├── SRGAN_model.py │ ├── SR_model.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── RankSRGAN.cpython-36.pyc │ │ ├── RankSRGAN_model.cpython-36.pyc │ │ ├── Ranker_model.cpython-36.pyc │ │ ├── SRGAN_model.cpython-36.pyc │ │ ├── SRGAN_rank_model.cpython-36.pyc │ │ ├── SR_model.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── loss.cpython-36.pyc │ │ ├── lr_scheduler.cpython-36.pyc │ │ └── networks.cpython-36.pyc │ ├── archs │ │ ├── RRDBNet_arch.py │ │ ├── RankSRGAN_arch.py │ │ ├── SRResNet_arch.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── RRDBNet_arch.cpython-36.pyc │ │ │ ├── RankSRGAN_arch.cpython-36.pyc │ │ │ ├── SRResNet_arch.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── arch_util.cpython-36.pyc │ │ │ └── discriminator_vgg_arch.cpython-36.pyc │ │ ├── arch_util.py │ │ └── discriminator_vgg_arch.py │ ├── base_model.py │ ├── loss.py │ ├── lr_scheduler.py │ └── networks.py ├── options │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── options.cpython-36.pyc │ ├── options.py │ ├── test │ │ └── test_RankSRGAN.yml │ └── train │ │ ├── train_RankSRGAN.yml │ │ ├── train_Ranker.yml │ │ ├── train_SRGAN.yml │ │ └── train_SRResNet.yml ├── run_scripts.sh ├── scripts │ ├── README.md │ ├── __pycache__ │ │ ├── arch_util.cpython-36.pyc │ │ └── block.cpython-36.pyc │ ├── arch_util.py │ ├── back_projection │ │ ├── backprojection.m │ │ ├── main_bp.m │ │ └── main_reverse_filter.m │ ├── block.py │ ├── calparameters.py │ ├── create_lmdb.py │ ├── extract_subimgs_single.py │ ├── generate_mod_LR_bic.m │ ├── transfer.py │ └── transfer_params_MSRResNet.py ├── test.py ├── train.py ├── train_niqe.py ├── train_rank.py └── utils │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── rank_test.cpython-36.pyc │ └── util.cpython-36.pyc │ ├── perceptualmetric │ ├── calc_NIQE.m │ ├── computefeature.m │ ├── computemean.m │ ├── computequality.m │ ├── convert_shave_image.m │ ├── estimateaggdparam.m │ ├── estimatemodelparam.m │ └── modelparameters.mat │ ├── rank_test.py │ └── util.py ├── datasets ├── README.md └── generate_rankdataset │ ├── README.md │ ├── generate_rankdataset.m │ ├── generate_train_ranklabel.m │ ├── generate_valid_ranklabel.m │ ├── move_valid.py │ └── utils │ ├── calc_NIQE.m │ ├── compute_sharpness.m │ ├── computemean.m │ ├── convert_shave_image.m │ ├── get_sigle_patch.m │ ├── niqe_release │ ├── computefeature.m │ ├── computemean.m │ ├── computequality.m │ ├── estimateaggdparam.m │ ├── estimatemodelparam.m │ ├── modelparameters.mat │ └── readme.txt │ ├── parcal_niqe.m │ └── save_patch_img.m ├── experiments ├── pretrained_models │ └── readme.md └── readme.md ├── figures ├── method.png ├── readme.md └── visual_results1.png └── results └── readme.md /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | results/* 3 | tb_logger/* 4 | experiments/* 5 | !data/readme.md 6 | !experiments/pretrained_models/readme.md 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RankSRGAN 2 | ### [Paper](https://arxiv.org/abs/1908.06382) | [Supplementary file](https://arxiv.org/abs/1908.06382) | [Project Page](https://wenlongzhang0724.github.io/Projects/RankSRGAN) 3 | ### RankSRGAN: Generative Adversarial Networks with Ranker for Image Super-Resolution 4 | 5 | By [Wenlong Zhang](https://wenlongzhang0724.github.io/), [Yihao Liu](http://xpixel.group/2010/03/29/yihaoliu.html), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=en), [Yu Qiao](http://mmlab.siat.ac.cn/yuqiao/) 6 | 7 | 8 | --- 9 | 10 | 11 |

12 | 13 |

14 | 15 |

16 | 17 |

18 | 19 | ### Dependencies 20 | 21 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 22 | - [PyTorch >= 1.0.0](https://pytorch.org/) 23 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 24 | - Python packages: `pip install numpy opencv-python lmdb` 25 | - [option] Python packages: [`pip install tensorboardX`](https://github.com/lanpa/tensorboardX), for visualizing curves. 26 | 27 | # Codes 28 | - We update the codes version based on [mmsr](https://github.com/open-mmlab/mmsr). 29 | The old version can be downloaded from [Google Drive](https://drive.google.com/drive/folders/13ZOwv0HIa_hrtnYAOM9cTzA-AVKqiF8c?usp=sharing) 30 | - This version is under testing. We will provide more details of RankSRGAN in near future. 31 | ## How to Test 32 | 1. Clone this github repo. 33 | ``` 34 | git clone https://github.com/WenlongZhang0724/RankSRGAN.git 35 | cd RankSRGAN 36 | ``` 37 | 2. Place your own **low-resolution images** in `./LR` folder. 38 | 3. Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing). Place the models in `./experiments/pretrained_models/`. We provide three Ranker models and three RankSRGAN models (see [model list](experiments/pretrained_models)). 39 | 4. Run test. We provide RankSRGAN (NIQE, Ma, PI) model and you can config in the `test.py`. 40 | ``` 41 | python test.py -opt options/test/test_RankSRGAN.yml 42 | ``` 43 | 5. The results are in `./results` folder. 44 | 45 | ## How to Train 46 | ### Train Ranker 47 | 1. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](https://github.com/LimBee/NTIRE2017) from [Google Drive](https://drive.google.com/drive/folders/1B-uaxvV9qeuQ-t7MFiN1oEdA6dKnj2vW?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1CFIML6KfQVYGZSNFrhMXmA) 48 | 2. Generate rank dataset [./datasets/generate_rankdataset/](datasets/generate_rankdataset) 49 | 3. Run command: 50 | ```c++ 51 | python train_rank.py -opt options/train/train_Ranker.yml 52 | ``` 53 | 54 | ### Train RankSRGAN 55 | We use a PSNR-oriented pretrained SR model to initialize the parameters for better quality. 56 | 57 | 1. Prepare datasets, usually the DIV2K dataset. 58 | 2. Prerapre the PSNR-oriented pretrained model. You can use the `mmsr_SRResNet_pretrain.pth` as the pretrained model that can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing). 59 | 3. Modify the configuration file `options/train/train_RankSRGAN.json` 60 | 4. Run command: 61 | ```c++ 62 | python train.py -opt options/train/train_RankSRGAN.yml 63 | ``` 64 | or 65 | 66 | ```c++ 67 | python train_niqe.py -opt options/train/train_RankSRGAN.yml 68 | ``` 69 | Using the train.py can output the convergence curves with PSNR; Using the train_niqe.py can output the convergence curves with NIQE and PSNR. 70 | 71 | ## Acknowledgement 72 | - Part of this codes was done by [Yihao Liu](http://xpixel.group/2010/03/29/yihaoliu.html). 73 | - This codes are based on [BasicSR](https://github.com/xinntao/BasicSR). 74 | -------------------------------------------------------------------------------- /codes/data/LQGT_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | 9 | class LQGTDataset(data.Dataset): 10 | """ 11 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. 12 | If only GT images are provided, generate LQ images on-the-fly. 13 | """ 14 | 15 | def __init__(self, opt): 16 | super(LQGTDataset, self).__init__() 17 | self.opt = opt 18 | self.data_type = self.opt['data_type'] 19 | self.paths_LQ, self.paths_GT = None, None 20 | self.sizes_LQ, self.sizes_GT = None, None 21 | self.LQ_env, self.GT_env = None, None # environments for lmdb 22 | 23 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) 24 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 25 | assert self.paths_GT, 'Error: GT path is empty.' 26 | if self.paths_LQ and self.paths_GT: 27 | assert len(self.paths_LQ) == len( 28 | self.paths_GT 29 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format( 30 | len(self.paths_LQ), len(self.paths_GT)) 31 | self.random_scale_list = [1] 32 | 33 | def _init_lmdb(self): 34 | # https://github.com/chainer/chainermn/issues/129 35 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 36 | meminit=False) 37 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 38 | meminit=False) 39 | 40 | def __getitem__(self, index): 41 | if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): 42 | self._init_lmdb() 43 | GT_path, LQ_path = None, None 44 | scale = self.opt['scale'] 45 | GT_size = self.opt['GT_size'] 46 | 47 | # get GT image 48 | GT_path = self.paths_GT[index] 49 | resolution = [int(s) for s in self.sizes_GT[index].split('_') 50 | ] if self.data_type == 'lmdb' else None 51 | img_GT = util.read_img(self.GT_env, GT_path, resolution) 52 | if self.opt['phase'] != 'train': # modcrop in the validation / test phase 53 | img_GT = util.modcrop(img_GT, scale) 54 | if self.opt['color']: # change color space if necessary 55 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 56 | 57 | # get LQ image 58 | if self.paths_LQ: 59 | LQ_path = self.paths_LQ[index] 60 | resolution = [int(s) for s in self.sizes_LQ[index].split('_') 61 | ] if self.data_type == 'lmdb' else None 62 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 63 | else: # down-sampling on-the-fly 64 | # randomly scale during training 65 | if self.opt['phase'] == 'train': 66 | random_scale = random.choice(self.random_scale_list) 67 | H_s, W_s, _ = img_GT.shape 68 | 69 | def _mod(n, random_scale, scale, thres): 70 | rlt = int(n * random_scale) 71 | rlt = (rlt // scale) * scale 72 | return thres if rlt < thres else rlt 73 | 74 | H_s = _mod(H_s, random_scale, scale, GT_size) 75 | W_s = _mod(W_s, random_scale, scale, GT_size) 76 | img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) 77 | if img_GT.ndim == 2: 78 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) 79 | 80 | H, W, _ = img_GT.shape 81 | # using matlab imresize 82 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 83 | if img_LQ.ndim == 2: 84 | img_LQ = np.expand_dims(img_LQ, axis=2) 85 | 86 | if self.opt['phase'] == 'train': 87 | # if the image size is too small 88 | H, W, _ = img_GT.shape 89 | if H < GT_size or W < GT_size: 90 | img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) 91 | # using matlab imresize 92 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 93 | if img_LQ.ndim == 2: 94 | img_LQ = np.expand_dims(img_LQ, axis=2) 95 | 96 | H, W, C = img_LQ.shape 97 | LQ_size = GT_size // scale 98 | 99 | # randomly crop 100 | rnd_h = random.randint(0, max(0, H - LQ_size)) 101 | rnd_w = random.randint(0, max(0, W - LQ_size)) 102 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] 103 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) 104 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] 105 | 106 | # augmentation - flip, rotate 107 | img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], 108 | self.opt['use_rot']) 109 | 110 | if self.opt['color']: # change color space if necessary 111 | img_LQ = util.channel_convert(C, self.opt['color'], 112 | [img_LQ])[0] # TODO during val no definition 113 | 114 | # BGR to RGB, HWC to CHW, numpy to tensor 115 | if img_GT.shape[2] == 3: 116 | img_GT = img_GT[:, :, [2, 1, 0]] 117 | img_LQ = img_LQ[:, :, [2, 1, 0]] 118 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 119 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 120 | 121 | if LQ_path is None: 122 | LQ_path = GT_path 123 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} 124 | 125 | def __len__(self): 126 | return len(self.paths_GT) 127 | -------------------------------------------------------------------------------- /codes/data/LQ_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import lmdb 3 | import torch 4 | import torch.utils.data as data 5 | import data.util as util 6 | 7 | 8 | class LQDataset(data.Dataset): 9 | '''Read LQ images only in the test phase.''' 10 | 11 | def __init__(self, opt): 12 | super(LQDataset, self).__init__() 13 | self.opt = opt 14 | self.data_type = self.opt['data_type'] 15 | self.paths_LQ, self.paths_GT = None, None 16 | self.LQ_env = None # environment for lmdb 17 | 18 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 19 | assert self.paths_LQ, 'Error: LQ paths are empty.' 20 | 21 | def _init_lmdb(self): 22 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 23 | meminit=False) 24 | 25 | def __getitem__(self, index): 26 | if self.data_type == 'lmdb' and self.LQ_env is None: 27 | self._init_lmdb() 28 | LQ_path = None 29 | 30 | # get LQ image 31 | LQ_path = self.paths_LQ[index] 32 | resolution = [int(s) for s in self.sizes_LQ[index].split('_') 33 | ] if self.data_type == 'lmdb' else None 34 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 35 | H, W, C = img_LQ.shape 36 | 37 | if self.opt['color']: # change color space if necessary 38 | img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0] 39 | 40 | # BGR to RGB, HWC to CHW, numpy to tensor 41 | if img_LQ.shape[2] == 3: 42 | img_LQ = img_LQ[:, :, [2, 1, 0]] 43 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 44 | 45 | return {'LQ': img_LQ, 'LQ_path': LQ_path} 46 | 47 | def __len__(self): 48 | return len(self.paths_LQ) 49 | -------------------------------------------------------------------------------- /codes/data/README.md: -------------------------------------------------------------------------------- 1 | 2 | Dataloader 3 | 4 | - use opencv (`cv2`) to read and process images. 5 | 6 | - read from **image** files OR from **.lmdb** for fast IO speed. 7 | - How to create .lmdb file? Please see [`codes/scripts/create_lmdb.py`](../scripts/create_lmdb.py). 8 | 9 | - can downsample images using `matlab bicubic` function. However, the speed is a bit slow. Implemented in [`util.py`](util.py). More about [`matlab bicubic` function](https://github.com/xinntao/BasicSR/wiki/Matlab-bicubic-imresize). 10 | 11 | 12 | ## Contents 13 | 14 | - `LR_dataset`: only reads LR images in test phase where there is no GT images. 15 | - `LRHR_dataset`: reads LR and HR pairs from image folder or lmdb files. If only HR images are provided, downsample the images on-the-fly. Used in SR and SRGAN training and validation phase. 16 | - `Rank_IMIM_Pair_dataset`: reads rank data pairs from image folder. 17 | ## How To Prepare Data 18 | ### SR, SRGAN 19 | 1. Prepare the images. You can download **classical SR** datasets (including BSD200, T91, General100; Set5, Set14, urban100, BSD100, manga109; historical) from [Google Drive](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/18fJzAHIg8Zpkc-2seGRW4Q). DIV2K dataset can be downloaded from [DIV2K offical page](https://data.vision.ee.ethz.ch/cvl/DIV2K/), or from [Baidu Drive](https://pan.baidu.com/s/1LUj90_skqlVw4rjRVeEoiw). 20 | 21 | 1. For faster IO speed, you can make lmdb files for training dataset. Please see [`codes/scripts/create_lmdb.py`](../scripts/create_lmdb.py). 22 | 23 | 1. We use DIV2K dataset for training the SR and SRGAN models. 24 | 1. since DIV2K images are large, we first crop them to sub images using [`codes/scripts/extract_subimgs_single.py`](../scripts/extract_subimgs_single.py). 25 | 1. generate LR images using matlab with [`codes/scripts/generate_mod_LR_bic.m`](../scripts/generate_mod_LR_bic.m). If you already have LR images, you can skip this step. Please make sure the LR and HR folders have the same number of images. 26 | 1. generate .lmdb file if needed using [`codes/scripts/create_lmdb.py`](../scripts/create_lmdb.py). 27 | 1. modify configurations in `options/train/xxx.json` when training, e.g., `dataroot_HR`, `dataroot_LR`. 28 | 29 | ### data augmentation 30 | 31 | We use random crop, random flip/rotation, (random scale) for data augmentation. 32 | 33 | ### wiki 34 | 35 | [Color-conversion-in-SR](https://github.com/xinntao/BasicSR/wiki/Color-conversion-in-SR) 36 | 37 | 38 | 42 | -------------------------------------------------------------------------------- /codes/data/Rank_IMIM_Pair_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | from itertools import combinations 9 | from scipy.special import comb 10 | 11 | 12 | class RANK_IMIM_Pair_Dataset(data.Dataset): 13 | ''' 14 | Read LR and HR image pair. 15 | If only HR image is provided, generate LR image on-the-fly. 16 | The pair is ensured by 'sorted' function, so please check the name convention. 17 | ''' 18 | 19 | def name(self): 20 | return 'RANK_IMIM_Pair_Dataset' 21 | 22 | def __init__(self, opt, is_train): 23 | super(RANK_IMIM_Pair_Dataset, self).__init__() 24 | self.opt = opt 25 | 26 | self.is_train = is_train 27 | 28 | # read image list from lmdb or image files 29 | 30 | self.paths_img1, self.sizes_GT = util.get_image_paths(opt['data_type'], opt['dataroot_img1']) 31 | self.paths_img2, self.sizes_GT = util.get_image_paths(opt['data_type'], opt['dataroot_img2']) 32 | self.paths_img3, self.sizes_GT = util.get_image_paths(opt['data_type'], opt['dataroot_img3']) 33 | 34 | self.img_env1 = None 35 | self.img_env2 = None 36 | self.img_env3 = None 37 | 38 | self.label_path = opt['dataroot_label_file'] 39 | 40 | # get image label scores 41 | self.label = {} 42 | f = open(self.label_path, 'r') 43 | for line in f.readlines(): 44 | line = line.strip().split() 45 | self.label[line[0]] = line[1] 46 | f.close() 47 | 48 | assert self.paths_img1, 'Error: img1 paths are empty.' 49 | 50 | # self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5] 51 | self.random_scale_list = None 52 | 53 | def __getitem__(self, index): 54 | 55 | if self.is_train: 56 | # get img1 and img1 label score 57 | # choice = random.choice(['img1_img2','img1_img2','img1_img2','img1_img3','img2_img3']) #Oversampling for hard sample 58 | choice = random.choice(['img1_img2', 'img1_img3', 'img2_img3']) 59 | 60 | # print(choice) 61 | 62 | if choice == 'img1_img2': 63 | img1_path = self.paths_img1[index] 64 | img1 = util.read_img(self.img_env1, img1_path) 65 | img2_path = self.paths_img2[index] 66 | img2 = util.read_img(self.img_env2, img2_path) 67 | elif choice == 'img1_img3': 68 | img1_path = self.paths_img1[index] 69 | img1 = util.read_img(self.img_env1, img1_path) 70 | img2_path = self.paths_img3[index] 71 | img2 = util.read_img(self.img_env3, img2_path) 72 | 73 | elif choice == 'img2_img3': 74 | img1_path = self.paths_img2[index] 75 | img1 = util.read_img(self.img_env2, img1_path) 76 | img2_path = self.paths_img3[index] 77 | img2 = util.read_img(self.img_env3, img2_path) 78 | 79 | 80 | img1_name = img1_path.split('/')[-1] 81 | img1_score = np.array(float(self.label[img1_name]), dtype='float') 82 | img1_score = img1_score.reshape(1) 83 | 84 | img2_name = img2_path.split('/')[-1] 85 | img2_score = np.array(float(self.label[img2_name]), dtype='float') 86 | img2_score = img2_score.reshape(1) 87 | 88 | if img1.shape[2] == 3: 89 | img1 = img1[:, :, [2, 1, 0]] 90 | img1 = torch.from_numpy(np.ascontiguousarray(np.transpose(img1, (2, 0, 1)))).float() 91 | img1_score = torch.from_numpy(img1_score).float() 92 | 93 | if img2.shape[2] == 3: 94 | img2 = img2[:, :, [2, 1, 0]] 95 | img2 = torch.from_numpy(np.ascontiguousarray(np.transpose(img2, (2, 0, 1)))).float() 96 | img2_score = torch.from_numpy(img2_score).float() 97 | 98 | # print('img1:'+img1_name,' & ','img2:'+img2_name) 99 | 100 | else: 101 | # get img1 102 | img1_path = self.paths_img1[index] 103 | img1 = util.read_img(self.img_env1, img1_path) 104 | 105 | img1_name = img1_path.split('/')[-1] 106 | img1_score = np.array(float(self.label[img1_name]), dtype='float') 107 | img1_score = img1_score.reshape(1) 108 | 109 | if img1.shape[2] == 3: 110 | img1 = img1[:, :, [2, 1, 0]] 111 | img1 = torch.from_numpy(np.ascontiguousarray(np.transpose(img1, (2, 0, 1)))).float() 112 | img1_score = torch.from_numpy(img1_score).float() 113 | # print('img1:'+img1_name) 114 | 115 | # not useful 116 | img2_path = img1_path 117 | img2 = img1 118 | img2_score = img1_score 119 | 120 | # exit() 121 | 122 | return {'img1': img1, 'img2': img2, 'img1_path': img1_path, 'img2_path': img2_path, 'score1': img1_score, 123 | 'score2': img2_score} 124 | 125 | def __len__(self): 126 | return len(self.paths_img1) 127 | -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | """create dataset and dataloader""" 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=False) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 25 | pin_memory=False) 26 | 27 | def create_dataset(dataset_opt, is_train = True): 28 | mode = dataset_opt['mode'] 29 | # datasets for image restoration 30 | if mode == 'LQ': 31 | from data.LQ_dataset import LQDataset as D 32 | elif mode == 'LQGT': 33 | from data.LQGT_dataset import LQGTDataset as D 34 | elif mode == 'RANK_IMIM_Pair': 35 | from data.Rank_IMIM_Pair_dataset import RANK_IMIM_Pair_Dataset as D 36 | else: 37 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 38 | if 'RANK_IMIM_Pair' in mode: 39 | dataset = D(dataset_opt, is_train = is_train) 40 | else: 41 | dataset = D(dataset_opt) 42 | logger = logging.getLogger('base') 43 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 44 | dataset_opt['name'])) 45 | return dataset 46 | -------------------------------------------------------------------------------- /codes/data/__pycache__/LQGT_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/LQGT_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/LRHR_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/LRHR_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/LR_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/LR_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/Rank_IMIM_Pair_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/Rank_IMIM_Pair_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/data_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/data_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /codes/data_scripts/extract_subimages.py: -------------------------------------------------------------------------------- 1 | """A multi-thread tool to crop large images to sub-images for faster IO.""" 2 | import os 3 | import os.path as osp 4 | import sys 5 | from multiprocessing import Pool 6 | import numpy as np 7 | import cv2 8 | from PIL import Image 9 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 10 | from utils.util import ProgressBar # noqa: E402 11 | import data.util as data_util # noqa: E402 12 | 13 | 14 | def main(): 15 | mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs) 16 | opt = {} 17 | opt['n_thread'] = 20 18 | opt['compression_level'] = 3 # 3 is the default value in cv2 19 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer 20 | # compression time. If read raw images during training, use 0 for faster IO speed. 21 | if mode == 'single': 22 | opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR' 23 | opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub' 24 | opt['crop_sz'] = 480 # the size of each sub-image 25 | opt['step'] = 240 # step of the sliding crop window 26 | opt['thres_sz'] = 48 # size threshold 27 | extract_signle(opt) 28 | elif mode == 'pair': 29 | GT_folder = '../../datasets/DIV2K/DIV2K_train_HR' 30 | LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' 31 | save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub' 32 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4' 33 | scale_ratio = 4 34 | crop_sz = 480 # the size of each sub-image (GT) 35 | step = 240 # step of the sliding crop window (GT) 36 | thres_sz = 48 # size threshold 37 | ######################################################################## 38 | # check that all the GT and LR images have correct scale ratio 39 | img_GT_list = data_util._get_paths_from_images(GT_folder) 40 | img_LR_list = data_util._get_paths_from_images(LR_folder) 41 | assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.' 42 | for path_GT, path_LR in zip(img_GT_list, img_LR_list): 43 | img_GT = Image.open(path_GT) 44 | img_LR = Image.open(path_LR) 45 | w_GT, h_GT = img_GT.size 46 | w_LR, h_LR = img_LR.size 47 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 48 | w_GT, scale_ratio, w_LR, path_GT) 49 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 50 | w_GT, scale_ratio, w_LR, path_GT) 51 | # check crop size, step and threshold size 52 | assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format( 53 | scale_ratio) 54 | assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio) 55 | assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format( 56 | scale_ratio) 57 | print('process GT...') 58 | opt['input_folder'] = GT_folder 59 | opt['save_folder'] = save_GT_folder 60 | opt['crop_sz'] = crop_sz 61 | opt['step'] = step 62 | opt['thres_sz'] = thres_sz 63 | extract_signle(opt) 64 | print('process LR...') 65 | opt['input_folder'] = LR_folder 66 | opt['save_folder'] = save_LR_folder 67 | opt['crop_sz'] = crop_sz // scale_ratio 68 | opt['step'] = step // scale_ratio 69 | opt['thres_sz'] = thres_sz // scale_ratio 70 | extract_signle(opt) 71 | assert len(data_util._get_paths_from_images(save_GT_folder)) == len( 72 | data_util._get_paths_from_images( 73 | save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.' 74 | else: 75 | raise ValueError('Wrong mode.') 76 | 77 | 78 | def extract_signle(opt): 79 | input_folder = opt['input_folder'] 80 | save_folder = opt['save_folder'] 81 | if not osp.exists(save_folder): 82 | os.makedirs(save_folder) 83 | print('mkdir [{:s}] ...'.format(save_folder)) 84 | else: 85 | print('Folder [{:s}] already exists. Exit...'.format(save_folder)) 86 | sys.exit(1) 87 | img_list = data_util._get_paths_from_images(input_folder) 88 | 89 | def update(arg): 90 | pbar.update(arg) 91 | 92 | pbar = ProgressBar(len(img_list)) 93 | 94 | pool = Pool(opt['n_thread']) 95 | for path in img_list: 96 | pool.apply_async(worker, args=(path, opt), callback=update) 97 | pool.close() 98 | pool.join() 99 | print('All subprocesses done.') 100 | 101 | 102 | def worker(path, opt): 103 | crop_sz = opt['crop_sz'] 104 | step = opt['step'] 105 | thres_sz = opt['thres_sz'] 106 | img_name = osp.basename(path) 107 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 108 | 109 | n_channels = len(img.shape) 110 | if n_channels == 2: 111 | h, w = img.shape 112 | elif n_channels == 3: 113 | h, w, c = img.shape 114 | else: 115 | raise ValueError('Wrong image shape - {}'.format(n_channels)) 116 | 117 | h_space = np.arange(0, h - crop_sz + 1, step) 118 | if h - (h_space[-1] + crop_sz) > thres_sz: 119 | h_space = np.append(h_space, h - crop_sz) 120 | w_space = np.arange(0, w - crop_sz + 1, step) 121 | if w - (w_space[-1] + crop_sz) > thres_sz: 122 | w_space = np.append(w_space, w - crop_sz) 123 | 124 | index = 0 125 | for x in h_space: 126 | for y in w_space: 127 | index += 1 128 | if n_channels == 2: 129 | crop_img = img[x:x + crop_sz, y:y + crop_sz] 130 | else: 131 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :] 132 | crop_img = np.ascontiguousarray(crop_img) 133 | cv2.imwrite( 134 | osp.join(opt['save_folder'], 135 | img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img, 136 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 137 | return 'Processing {:s} ...'.format(img_name) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_LR_Vimeo90K.m: -------------------------------------------------------------------------------- 1 | function generate_LR_Vimeo90K() 2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset 3 | 4 | up_scale = 4; 5 | mod_scale = 4; 6 | idx = 0; 7 | filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png'); 8 | for i = 1 : length(filepaths) 9 | [~,imname,ext] = fileparts(filepaths(i).name); 10 | folder_path = filepaths(i).folder; 11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4'); 12 | if ~exist(save_LR_folder, 'dir') 13 | mkdir(save_LR_folder); 14 | end 15 | if isempty(imname) 16 | disp('Ignore . folder.'); 17 | elseif strcmp(imname, '.') 18 | disp('Ignore .. folder.'); 19 | else 20 | idx = idx + 1; 21 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 22 | fprintf(str_rlt); 23 | % read image 24 | img = imread(fullfile(folder_path, [imname, ext])); 25 | img = im2double(img); 26 | % modcrop 27 | img = modcrop(img, mod_scale); 28 | % LR 29 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 30 | if exist('save_LR_folder', 'var') 31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 32 | end 33 | end 34 | end 35 | end 36 | 37 | %% modcrop 38 | function img = modcrop(img, modulo) 39 | if size(img,3) == 1 40 | sz = size(img); 41 | sz = sz - mod(sz, modulo); 42 | img = img(1:sz(1), 1:sz(2)); 43 | else 44 | tmpsz = size(img); 45 | sz = tmpsz(1:2); 46 | sz = sz - mod(sz, modulo); 47 | img = img(1:sz(1), 1:sz(2),:); 48 | end 49 | end 50 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_mod_LR_bic.m: -------------------------------------------------------------------------------- 1 | function generate_mod_LR_bic() 2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. 3 | 4 | %% set parameters 5 | % comment the unnecessary line 6 | input_folder = '../../datasets/DIV2K/DIV2K800'; 7 | % save_mod_folder = ''; 8 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4'; 9 | % save_bic_folder = ''; 10 | 11 | up_scale = 4; 12 | mod_scale = 4; 13 | 14 | if exist('save_mod_folder', 'var') 15 | if exist(save_mod_folder, 'dir') 16 | disp(['It will cover ', save_mod_folder]); 17 | else 18 | mkdir(save_mod_folder); 19 | end 20 | end 21 | if exist('save_LR_folder', 'var') 22 | if exist(save_LR_folder, 'dir') 23 | disp(['It will cover ', save_LR_folder]); 24 | else 25 | mkdir(save_LR_folder); 26 | end 27 | end 28 | if exist('save_bic_folder', 'var') 29 | if exist(save_bic_folder, 'dir') 30 | disp(['It will cover ', save_bic_folder]); 31 | else 32 | mkdir(save_bic_folder); 33 | end 34 | end 35 | 36 | idx = 0; 37 | filepaths = dir(fullfile(input_folder,'*.*')); 38 | for i = 1 : length(filepaths) 39 | [paths,imname,ext] = fileparts(filepaths(i).name); 40 | if isempty(imname) 41 | disp('Ignore . folder.'); 42 | elseif strcmp(imname, '.') 43 | disp('Ignore .. folder.'); 44 | else 45 | idx = idx + 1; 46 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 47 | fprintf(str_rlt); 48 | % read image 49 | img = imread(fullfile(input_folder, [imname, ext])); 50 | img = im2double(img); 51 | % modcrop 52 | img = modcrop(img, mod_scale); 53 | if exist('save_mod_folder', 'var') 54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); 55 | end 56 | % LR 57 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 58 | if exist('save_LR_folder', 'var') 59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 60 | end 61 | % Bicubic 62 | if exist('save_bic_folder', 'var') 63 | im_B = imresize(im_LR, up_scale, 'bicubic'); 64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png'])); 65 | end 66 | end 67 | end 68 | end 69 | 70 | %% modcrop 71 | function img = modcrop(img, modulo) 72 | if size(img,3) == 1 73 | sz = size(img); 74 | sz = sz - mod(sz, modulo); 75 | img = img(1:sz(1), 1:sz(2)); 76 | else 77 | tmpsz = size(img); 78 | sz = tmpsz(1:2); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2),:); 81 | end 82 | end 83 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_mod_LR_bic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | 6 | try: 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from data.util import imresize_np 9 | except ImportError: 10 | pass 11 | 12 | 13 | def generate_mod_LR_bic(): 14 | # set parameters 15 | up_scale = 4 16 | mod_scale = 4 17 | # set data dir 18 | sourcedir = '/data/datasets/img' 19 | savedir = '/data/datasets/mod' 20 | 21 | saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale)) 22 | saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale)) 23 | saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale)) 24 | 25 | if not os.path.isdir(sourcedir): 26 | print('Error: No source data found') 27 | exit(0) 28 | if not os.path.isdir(savedir): 29 | os.mkdir(savedir) 30 | 31 | if not os.path.isdir(os.path.join(savedir, 'HR')): 32 | os.mkdir(os.path.join(savedir, 'HR')) 33 | if not os.path.isdir(os.path.join(savedir, 'LR')): 34 | os.mkdir(os.path.join(savedir, 'LR')) 35 | if not os.path.isdir(os.path.join(savedir, 'Bic')): 36 | os.mkdir(os.path.join(savedir, 'Bic')) 37 | 38 | if not os.path.isdir(saveHRpath): 39 | os.mkdir(saveHRpath) 40 | else: 41 | print('It will cover ' + str(saveHRpath)) 42 | 43 | if not os.path.isdir(saveLRpath): 44 | os.mkdir(saveLRpath) 45 | else: 46 | print('It will cover ' + str(saveLRpath)) 47 | 48 | if not os.path.isdir(saveBicpath): 49 | os.mkdir(saveBicpath) 50 | else: 51 | print('It will cover ' + str(saveBicpath)) 52 | 53 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')] 54 | num_files = len(filepaths) 55 | 56 | # prepare data with augementation 57 | for i in range(num_files): 58 | filename = filepaths[i] 59 | print('No.{} -- Processing {}'.format(i, filename)) 60 | # read image 61 | image = cv2.imread(os.path.join(sourcedir, filename)) 62 | 63 | width = int(np.floor(image.shape[1] / mod_scale)) 64 | height = int(np.floor(image.shape[0] / mod_scale)) 65 | # modcrop 66 | if len(image.shape) == 3: 67 | image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] 68 | else: 69 | image_HR = image[0:mod_scale * height, 0:mod_scale * width] 70 | # LR 71 | image_LR = imresize_np(image_HR, 1 / up_scale, True) 72 | # bic 73 | image_Bic = imresize_np(image_LR, up_scale, True) 74 | 75 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) 76 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) 77 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) 78 | 79 | 80 | if __name__ == "__main__": 81 | generate_mod_LR_bic() 82 | -------------------------------------------------------------------------------- /codes/data_scripts/prepare_DIV2K_x4_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | echo "Prepare DIV2K X4 datasets..." 4 | cd ../../datasets 5 | mkdir DIV2K 6 | cd DIV2K 7 | 8 | #### Step 1 9 | echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..." 10 | # GT 11 | FOLDER=DIV2K_train_HR 12 | FILE=DIV2K_train_HR.zip 13 | if [ ! -d "$FOLDER" ]; then 14 | if [ ! -f "$FILE" ]; then 15 | echo "Downloading $FILE..." 16 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE 17 | fi 18 | unzip $FILE 19 | fi 20 | # LR 21 | FOLDER=DIV2K_train_LR_bicubic 22 | FILE=DIV2K_train_LR_bicubic_X4.zip 23 | if [ ! -d "$FOLDER" ]; then 24 | if [ ! -f "$FILE" ]; then 25 | echo "Downloading $FILE..." 26 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE 27 | fi 28 | unzip $FILE 29 | fi 30 | 31 | #### Step 2 32 | echo "Step 2: Rename the LR images..." 33 | cd ../../codes/data_scripts 34 | python rename.py 35 | 36 | #### Step 4 37 | echo "Step 4: Crop to sub-images..." 38 | python extract_subimages.py 39 | 40 | #### Step 5 41 | echo "Step5: Create LMDB files..." 42 | python create_lmdb.py 43 | -------------------------------------------------------------------------------- /codes/data_scripts/rename.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | def main(): 6 | folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' 7 | DIV2K(folder) 8 | print('Finished.') 9 | 10 | 11 | def DIV2K(path): 12 | img_path_l = glob.glob(os.path.join(path, '*')) 13 | for img_path in img_path_l: 14 | new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') 15 | os.rename(img_path, new_path) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() -------------------------------------------------------------------------------- /codes/data_scripts/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | import math 4 | import torchvision.utils 5 | 6 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 7 | from data import create_dataloader, create_dataset # noqa: E402 8 | from utils import util # noqa: E402 9 | 10 | 11 | def main(): 12 | dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub 13 | opt = {} 14 | opt['dist'] = False 15 | opt['gpu_ids'] = [0] 16 | if dataset == 'REDS': 17 | opt['name'] = 'test_REDS' 18 | opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb' 19 | opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' 20 | opt['mode'] = 'REDS' 21 | opt['N_frames'] = 5 22 | opt['phase'] = 'train' 23 | opt['use_shuffle'] = True 24 | opt['n_workers'] = 8 25 | opt['batch_size'] = 16 26 | opt['GT_size'] = 256 27 | opt['LQ_size'] = 64 28 | opt['scale'] = 4 29 | opt['use_flip'] = True 30 | opt['use_rot'] = True 31 | opt['interval_list'] = [1] 32 | opt['random_reverse'] = False 33 | opt['border_mode'] = False 34 | opt['cache_keys'] = None 35 | opt['data_type'] = 'lmdb' # img | lmdb | mc 36 | elif dataset == 'Vimeo90K': 37 | opt['name'] = 'test_Vimeo90K' 38 | opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' 39 | opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' 40 | opt['mode'] = 'Vimeo90K' 41 | opt['N_frames'] = 7 42 | opt['phase'] = 'train' 43 | opt['use_shuffle'] = True 44 | opt['n_workers'] = 8 45 | opt['batch_size'] = 16 46 | opt['GT_size'] = 256 47 | opt['LQ_size'] = 64 48 | opt['scale'] = 4 49 | opt['use_flip'] = True 50 | opt['use_rot'] = True 51 | opt['interval_list'] = [1] 52 | opt['random_reverse'] = False 53 | opt['border_mode'] = False 54 | opt['cache_keys'] = None 55 | opt['data_type'] = 'lmdb' # img | lmdb | mc 56 | elif dataset == 'DIV2K800_sub': 57 | opt['name'] = 'DIV2K800' 58 | opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' 59 | opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' 60 | opt['mode'] = 'LQGT' 61 | opt['phase'] = 'train' 62 | opt['use_shuffle'] = True 63 | opt['n_workers'] = 8 64 | opt['batch_size'] = 16 65 | opt['GT_size'] = 128 66 | opt['scale'] = 4 67 | opt['use_flip'] = True 68 | opt['use_rot'] = True 69 | opt['color'] = 'RGB' 70 | opt['data_type'] = 'lmdb' # img | lmdb 71 | else: 72 | raise ValueError('Please implement by yourself.') 73 | 74 | util.mkdir('tmp') 75 | train_set = create_dataset(opt) 76 | train_loader = create_dataloader(train_set, opt, opt, None) 77 | nrow = int(math.sqrt(opt['batch_size'])) 78 | padding = 2 if opt['phase'] == 'train' else 0 79 | 80 | print('start...') 81 | for i, data in enumerate(train_loader): 82 | if i > 5: 83 | break 84 | print(i) 85 | if dataset == 'REDS' or dataset == 'Vimeo90K': 86 | LQs = data['LQs'] 87 | else: 88 | LQ = data['LQ'] 89 | GT = data['GT'] 90 | 91 | if dataset == 'REDS' or dataset == 'Vimeo90K': 92 | for j in range(LQs.size(1)): 93 | torchvision.utils.save_image(LQs[:, j, :, :, :], 94 | 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow, 95 | padding=padding, normalize=False) 96 | else: 97 | torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow, 98 | padding=padding, normalize=False) 99 | torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding, 100 | normalize=False) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /codes/metrics/calculate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | function calculate_PSNR_SSIM() 2 | 3 | % GT and SR folder 4 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'; 5 | folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'; 6 | scale = 4; 7 | suffix = ''; % suffix for SR images 8 | test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels 9 | if test_Y 10 | fprintf('Tesing Y channel.\n'); 11 | else 12 | fprintf('Tesing RGB channels.\n'); 13 | end 14 | filepaths = dir(fullfile(folder_GT, '*.png')); 15 | PSNR_all = zeros(1, length(filepaths)); 16 | SSIM_all = zeros(1, length(filepaths)); 17 | 18 | for idx_im = 1:length(filepaths) 19 | im_name = filepaths(idx_im).name; 20 | im_GT = imread(fullfile(folder_GT, im_name)); 21 | im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png'])); 22 | 23 | if test_Y % evaluate on Y channel in YCbCr color space 24 | if size(im_GT, 3) == 3 25 | im_GT_YCbCr = rgb2ycbcr(im2double(im_GT)); 26 | im_GT_in = im_GT_YCbCr(:,:,1); 27 | im_SR_YCbCr = rgb2ycbcr(im2double(im_SR)); 28 | im_SR_in = im_SR_YCbCr(:,:,1); 29 | else 30 | im_GT_in = im2double(im_GT); 31 | im_SR_in = im2double(im_SR); 32 | end 33 | else % evaluate on RGB channels 34 | im_GT_in = im2double(im_GT); 35 | im_SR_in = im2double(im_SR); 36 | end 37 | 38 | % calculate PSNR and SSIM 39 | PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale); 40 | SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale); 41 | fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im)); 42 | end 43 | 44 | fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all)); 45 | end 46 | 47 | function res = calculate_PSNR(GT, SR, border) 48 | % remove border 49 | GT = GT(border+1:end-border, border+1:end-border, :); 50 | SR = SR(border+1:end-border, border+1:end-border, :); 51 | % calculate PNSR (assume in [0,255]) 52 | error = GT(:) - SR(:); 53 | mse = mean(error.^2); 54 | res = 10 * log10(255^2/mse); 55 | end 56 | 57 | function res = calculate_SSIM(GT, SR, border) 58 | GT = GT(border+1:end-border, border+1:end-border, :); 59 | SR = SR(border+1:end-border, border+1:end-border, :); 60 | % calculate SSIM 61 | mssim = zeros(1, size(SR, 3)); 62 | for i = 1:size(SR,3) 63 | [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i)); 64 | end 65 | res = mean(mssim); 66 | end 67 | 68 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L) 69 | 70 | %======================================================================== 71 | %SSIM Index, Version 1.0 72 | %Copyright(c) 2003 Zhou Wang 73 | %All Rights Reserved. 74 | % 75 | %The author is with Howard Hughes Medical Institute, and Laboratory 76 | %for Computational Vision at Center for Neural Science and Courant 77 | %Institute of Mathematical Sciences, New York University. 78 | % 79 | %---------------------------------------------------------------------- 80 | %Permission to use, copy, or modify this software and its documentation 81 | %for educational and research purposes only and without fee is hereby 82 | %granted, provided that this copyright notice and the original authors' 83 | %names appear on all copies and supporting documentation. This program 84 | %shall not be used, rewritten, or adapted as the basis of a commercial 85 | %software or hardware product without first obtaining permission of the 86 | %authors. The authors make no representations about the suitability of 87 | %this software for any purpose. It is provided "as is" without express 88 | %or implied warranty. 89 | %---------------------------------------------------------------------- 90 | % 91 | %This is an implementation of the algorithm for calculating the 92 | %Structural SIMilarity (SSIM) index between two images. Please refer 93 | %to the following paper: 94 | % 95 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 96 | %quality assessment: From error measurement to structural similarity" 97 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 98 | % 99 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 100 | % 101 | %---------------------------------------------------------------------- 102 | % 103 | %Input : (1) img1: the first image being compared 104 | % (2) img2: the second image being compared 105 | % (3) K: constants in the SSIM index formula (see the above 106 | % reference). defualt value: K = [0.01 0.03] 107 | % (4) window: local window for statistics (see the above 108 | % reference). default widnow is Gaussian given by 109 | % window = fspecial('gaussian', 11, 1.5); 110 | % (5) L: dynamic range of the images. default: L = 255 111 | % 112 | %Output: (1) mssim: the mean SSIM index value between 2 images. 113 | % If one of the images being compared is regarded as 114 | % perfect quality, then mssim can be considered as the 115 | % quality measure of the other image. 116 | % If img1 = img2, then mssim = 1. 117 | % (2) ssim_map: the SSIM index map of the test image. The map 118 | % has a smaller size than the input images. The actual size: 119 | % size(img1) - size(window) + 1. 120 | % 121 | %Default Usage: 122 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 123 | % 124 | % [mssim ssim_map] = ssim_index(img1, img2); 125 | % 126 | %Advanced Usage: 127 | % User defined parameters. For example 128 | % 129 | % K = [0.05 0.05]; 130 | % window = ones(8); 131 | % L = 100; 132 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 133 | % 134 | %See the results: 135 | % 136 | % mssim %Gives the mssim value 137 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 138 | % 139 | %======================================================================== 140 | 141 | 142 | if (nargin < 2 || nargin > 5) 143 | ssim_index = -Inf; 144 | ssim_map = -Inf; 145 | return; 146 | end 147 | 148 | if (size(img1) ~= size(img2)) 149 | ssim_index = -Inf; 150 | ssim_map = -Inf; 151 | return; 152 | end 153 | 154 | [M, N] = size(img1); 155 | 156 | if (nargin == 2) 157 | if ((M < 11) || (N < 11)) 158 | ssim_index = -Inf; 159 | ssim_map = -Inf; 160 | return 161 | end 162 | window = fspecial('gaussian', 11, 1.5); % 163 | K(1) = 0.01; % default settings 164 | K(2) = 0.03; % 165 | L = 255; % 166 | end 167 | 168 | if (nargin == 3) 169 | if ((M < 11) || (N < 11)) 170 | ssim_index = -Inf; 171 | ssim_map = -Inf; 172 | return 173 | end 174 | window = fspecial('gaussian', 11, 1.5); 175 | L = 255; 176 | if (length(K) == 2) 177 | if (K(1) < 0 || K(2) < 0) 178 | ssim_index = -Inf; 179 | ssim_map = -Inf; 180 | return; 181 | end 182 | else 183 | ssim_index = -Inf; 184 | ssim_map = -Inf; 185 | return; 186 | end 187 | end 188 | 189 | if (nargin == 4) 190 | [H, W] = size(window); 191 | if ((H*W) < 4 || (H > M) || (W > N)) 192 | ssim_index = -Inf; 193 | ssim_map = -Inf; 194 | return 195 | end 196 | L = 255; 197 | if (length(K) == 2) 198 | if (K(1) < 0 || K(2) < 0) 199 | ssim_index = -Inf; 200 | ssim_map = -Inf; 201 | return; 202 | end 203 | else 204 | ssim_index = -Inf; 205 | ssim_map = -Inf; 206 | return; 207 | end 208 | end 209 | 210 | if (nargin == 5) 211 | [H, W] = size(window); 212 | if ((H*W) < 4 || (H > M) || (W > N)) 213 | ssim_index = -Inf; 214 | ssim_map = -Inf; 215 | return 216 | end 217 | if (length(K) == 2) 218 | if (K(1) < 0 || K(2) < 0) 219 | ssim_index = -Inf; 220 | ssim_map = -Inf; 221 | return; 222 | end 223 | else 224 | ssim_index = -Inf; 225 | ssim_map = -Inf; 226 | return; 227 | end 228 | end 229 | 230 | C1 = (K(1)*L)^2; 231 | C2 = (K(2)*L)^2; 232 | window = window/sum(sum(window)); 233 | img1 = double(img1); 234 | img2 = double(img2); 235 | 236 | mu1 = filter2(window, img1, 'valid'); 237 | mu2 = filter2(window, img2, 'valid'); 238 | mu1_sq = mu1.*mu1; 239 | mu2_sq = mu2.*mu2; 240 | mu1_mu2 = mu1.*mu2; 241 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 242 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 243 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 244 | 245 | if (C1 > 0 && C2 > 0) 246 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 247 | else 248 | numerator1 = 2*mu1_mu2 + C1; 249 | numerator2 = 2*sigma12 + C2; 250 | denominator1 = mu1_sq + mu2_sq + C1; 251 | denominator2 = sigma1_sq + sigma2_sq + C2; 252 | ssim_map = ones(size(mu1)); 253 | index = (denominator1.*denominator2 > 0); 254 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 255 | index = (denominator1 ~= 0) & (denominator2 == 0); 256 | ssim_map(index) = numerator1(index)./denominator1(index); 257 | end 258 | 259 | mssim = mean2(ssim_map); 260 | 261 | end 262 | -------------------------------------------------------------------------------- /codes/metrics/calculate_PSNR_SSIM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | calculate the PSNR and SSIM. 3 | same as MATLAB's results 4 | ''' 5 | import os 6 | import math 7 | import numpy as np 8 | import cv2 9 | import glob 10 | 11 | 12 | def main(): 13 | # Configurations 14 | 15 | # GT - Ground-truth; 16 | # Gen: Generated / Restored / Recovered images 17 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5' 18 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5' 19 | 20 | crop_border = 4 21 | suffix = '' # suffix for Gen images 22 | test_Y = False # True: test Y channel only; False: test RGB channels 23 | 24 | PSNR_all = [] 25 | SSIM_all = [] 26 | img_list = sorted(glob.glob(folder_GT + '/*')) 27 | 28 | if test_Y: 29 | print('Testing Y channel.') 30 | else: 31 | print('Testing RGB channels.') 32 | 33 | for i, img_path in enumerate(img_list): 34 | base_name = os.path.splitext(os.path.basename(img_path))[0] 35 | im_GT = cv2.imread(img_path) / 255. 36 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255. 37 | 38 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 39 | im_GT_in = bgr2ycbcr(im_GT) 40 | im_Gen_in = bgr2ycbcr(im_Gen) 41 | else: 42 | im_GT_in = im_GT 43 | im_Gen_in = im_Gen 44 | 45 | # crop borders 46 | if im_GT_in.ndim == 3: 47 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 48 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 49 | elif im_GT_in.ndim == 2: 50 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 51 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 52 | else: 53 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 54 | 55 | # calculate PSNR and SSIM 56 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 57 | 58 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 59 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format( 60 | i + 1, base_name, PSNR, SSIM)) 61 | PSNR_all.append(PSNR) 62 | SSIM_all.append(SSIM) 63 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format( 64 | sum(PSNR_all) / len(PSNR_all), 65 | sum(SSIM_all) / len(SSIM_all))) 66 | 67 | 68 | def calculate_psnr(img1, img2): 69 | # img1 and img2 have range [0, 255] 70 | img1 = img1.astype(np.float64) 71 | img2 = img2.astype(np.float64) 72 | mse = np.mean((img1 - img2)**2) 73 | if mse == 0: 74 | return float('inf') 75 | return 20 * math.log10(255.0 / math.sqrt(mse)) 76 | 77 | 78 | def ssim(img1, img2): 79 | C1 = (0.01 * 255)**2 80 | C2 = (0.03 * 255)**2 81 | 82 | img1 = img1.astype(np.float64) 83 | img2 = img2.astype(np.float64) 84 | kernel = cv2.getGaussianKernel(11, 1.5) 85 | window = np.outer(kernel, kernel.transpose()) 86 | 87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 89 | mu1_sq = mu1**2 90 | mu2_sq = mu2**2 91 | mu1_mu2 = mu1 * mu2 92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 95 | 96 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 97 | (sigma1_sq + sigma2_sq + C2)) 98 | return ssim_map.mean() 99 | 100 | 101 | def calculate_ssim(img1, img2): 102 | '''calculate SSIM 103 | the same outputs as MATLAB's 104 | img1, img2: [0, 255] 105 | ''' 106 | if not img1.shape == img2.shape: 107 | raise ValueError('Input images must have the same dimensions.') 108 | if img1.ndim == 2: 109 | return ssim(img1, img2) 110 | elif img1.ndim == 3: 111 | if img1.shape[2] == 3: 112 | ssims = [] 113 | for i in range(3): 114 | ssims.append(ssim(img1, img2)) 115 | return np.array(ssims).mean() 116 | elif img1.shape[2] == 1: 117 | return ssim(np.squeeze(img1), np.squeeze(img2)) 118 | else: 119 | raise ValueError('Wrong input image dimensions.') 120 | 121 | 122 | def bgr2ycbcr(img, only_y=True): 123 | '''same as matlab rgb2ycbcr 124 | only_y: only return Y channel 125 | Input: 126 | uint8, [0, 255] 127 | float, [0, 1] 128 | ''' 129 | in_img_type = img.dtype 130 | img.astype(np.float32) 131 | if in_img_type != np.uint8: 132 | img *= 255. 133 | # convert 134 | if only_y: 135 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 136 | else: 137 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 138 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 139 | if in_img_type == np.uint8: 140 | rlt = rlt.round() 141 | else: 142 | rlt /= 255. 143 | return rlt.astype(in_img_type) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /codes/models/Ranker_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | from torch.nn.parallel import DataParallel, DistributedDataParallel 9 | 10 | import models.networks as networks 11 | from .base_model import BaseModel 12 | 13 | logger = logging.getLogger('base') 14 | 15 | 16 | class Ranker_Model(BaseModel): 17 | def name(self): 18 | return 'Ranker_Model' 19 | 20 | def __init__(self, opt): 21 | super(Ranker_Model, self).__init__(opt) 22 | 23 | if opt['dist']: 24 | self.rank = torch.distributed.get_rank() 25 | else: 26 | self.rank = -1 # non dist training 27 | train_opt = opt['train'] 28 | 29 | # define networks and load pretrained models 30 | self.netR = networks.define_R(opt).to(self.device) 31 | if opt['dist']: 32 | self.netR = DistributedDataParallel(self.netR, device_ids=[torch.cuda.current_device()]) 33 | else: 34 | self.netR = DataParallel(self.netR) 35 | self.load() 36 | 37 | if self.is_train: 38 | self.netR.train() 39 | 40 | # loss 41 | self.RankLoss = nn.MarginRankingLoss(margin=0.5) 42 | self.RankLoss.to(self.device) 43 | self.L2Loss = nn.L1Loss() 44 | self.L2Loss.to(self.device) 45 | # optimizers 46 | self.optimizers = [] 47 | wd_R = train_opt['weight_decay_R'] if train_opt['weight_decay_R'] else 0 48 | optim_params = [] 49 | for k, v in self.netR.named_parameters(): # can optimize for a part of the model 50 | if v.requires_grad: 51 | optim_params.append(v) 52 | else: 53 | print('WARNING: params [%s] will not optimize.' % k) 54 | self.optimizer_R = torch.optim.Adam(optim_params, lr=train_opt['lr_R'], weight_decay=wd_R) 55 | print('Weight_decay:%f' % wd_R) 56 | self.optimizers.append(self.optimizer_R) 57 | 58 | # schedulers 59 | self.schedulers = [] 60 | if train_opt['lr_scheme'] == 'MultiStepLR': 61 | for optimizer in self.optimizers: 62 | self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \ 63 | train_opt['lr_steps'], train_opt['lr_gamma'])) 64 | else: 65 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 66 | 67 | self.log_dict = OrderedDict() 68 | 69 | print('---------- Model initialized ------------------') 70 | self.print_network() 71 | print('-----------------------------------------------') 72 | 73 | def feed_data(self, data, need_img2=True): 74 | # input img1 75 | self.input_img1 = data['img1'].to(self.device) 76 | 77 | # label score1 78 | self.label_score1 = data['score1'].to(self.device) 79 | 80 | if need_img2: 81 | # input img2 82 | self.input_img2 = data['img2'].to(self.device) 83 | 84 | # label score2 85 | self.label_score2 = data['score2'].to(self.device) 86 | 87 | # rank label 88 | self.label = self.label_score1 >= self.label_score2 # get a ByteTensor 89 | # transfer into FloatTensor 90 | self.label = self.label.float() 91 | self.label = (self.label - 0.5) * 2 92 | 93 | 94 | def optimize_parameters(self, step): 95 | self.optimizer_R.zero_grad() 96 | self.predict_score1 = self.netR(self.input_img1) 97 | self.predict_score2 = self.netR(self.input_img2) 98 | 99 | 100 | self.predict_score1 = torch.clamp(self.predict_score1, min=-5, max=5) 101 | self.predict_score2 = torch.clamp(self.predict_score2, min=-5, max=5) 102 | 103 | l_rank = self.RankLoss(self.predict_score1, self.predict_score2, self.label) 104 | 105 | l_rank.backward() 106 | self.optimizer_R.step() 107 | 108 | # set log 109 | self.log_dict['l_rank'] = l_rank.item() 110 | 111 | def test(self): 112 | self.netR.eval() 113 | self.predict_score1 = self.netR(self.input_img1) 114 | self.netR.train() 115 | 116 | def get_current_log(self): 117 | return self.log_dict 118 | 119 | def get_current_visuals(self, need_HR=True): 120 | out_dict = OrderedDict() # ............................ 121 | out_dict['predict_score1'] = self.predict_score1.data[0].float().cpu() 122 | 123 | return out_dict 124 | 125 | def print_network(self): 126 | s, n = self.get_network_description(self.netR) 127 | if isinstance(self.netR, nn.DataParallel): 128 | net_struc_str = '{} - {}'.format(self.netR.__class__.__name__, 129 | self.netR.module.__class__.__name__) 130 | else: 131 | net_struc_str = '{}'.format(self.netR.__class__.__name__) 132 | logger.info('Network R structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 133 | logger.info(s) 134 | 135 | def load(self): 136 | 137 | load_path_R = self.opt['path']['pretrain_model_R'] 138 | if load_path_R is not None: 139 | logger.info('Loading pretrained model for R [{:s}] ...'.format(load_path_R)) 140 | self.load_network(load_path_R, self.netR) 141 | def save(self, iter_step): 142 | self.save_network(self.netR, 'R', iter_step) 143 | -------------------------------------------------------------------------------- /codes/models/SR_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | import models.networks as networks 8 | import models.lr_scheduler as lr_scheduler 9 | from .base_model import BaseModel 10 | from models.loss import CharbonnierLoss 11 | 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class SRModel(BaseModel): 16 | def __init__(self, opt): 17 | super(SRModel, self).__init__(opt) 18 | 19 | if opt['dist']: 20 | self.rank = torch.distributed.get_rank() 21 | else: 22 | self.rank = -1 # non dist training 23 | train_opt = opt['train'] 24 | 25 | # define network and load pretrained models 26 | self.netG = networks.define_G(opt).to(self.device) 27 | if opt['dist']: 28 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 29 | else: 30 | self.netG = DataParallel(self.netG) 31 | # print network 32 | self.print_network() 33 | self.load() 34 | 35 | if self.is_train: 36 | self.netG.train() 37 | 38 | # loss 39 | loss_type = train_opt['pixel_criterion'] 40 | if loss_type == 'l1': 41 | self.cri_pix = nn.L1Loss().to(self.device) 42 | elif loss_type == 'l2': 43 | self.cri_pix = nn.MSELoss().to(self.device) 44 | elif loss_type == 'cb': 45 | self.cri_pix = CharbonnierLoss().to(self.device) 46 | else: 47 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 48 | self.l_pix_w = train_opt['pixel_weight'] 49 | 50 | # optimizers 51 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 52 | optim_params = [] 53 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 54 | if v.requires_grad: 55 | optim_params.append(v) 56 | else: 57 | if self.rank <= 0: 58 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 59 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 60 | weight_decay=wd_G, 61 | betas=(train_opt['beta1'], train_opt['beta2'])) 62 | self.optimizers.append(self.optimizer_G) 63 | 64 | # schedulers 65 | if train_opt['lr_scheme'] == 'MultiStepLR': 66 | for optimizer in self.optimizers: 67 | self.schedulers.append( 68 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 69 | restarts=train_opt['restarts'], 70 | weights=train_opt['restart_weights'], 71 | gamma=train_opt['lr_gamma'], 72 | clear_state=train_opt['clear_state'])) 73 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 74 | for optimizer in self.optimizers: 75 | self.schedulers.append( 76 | lr_scheduler.CosineAnnealingLR_Restart( 77 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 78 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 79 | else: 80 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 81 | 82 | self.log_dict = OrderedDict() 83 | 84 | def feed_data(self, data, need_GT=True): 85 | self.var_L = data['LQ'].to(self.device) # LQ 86 | if need_GT: 87 | self.real_H = data['GT'].to(self.device) # GT 88 | 89 | def optimize_parameters(self, step): 90 | self.optimizer_G.zero_grad() 91 | self.fake_H = self.netG(self.var_L) 92 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) 93 | l_pix.backward() 94 | self.optimizer_G.step() 95 | 96 | # set log 97 | self.log_dict['l_pix'] = l_pix.item() 98 | 99 | def test(self): 100 | self.netG.eval() 101 | with torch.no_grad(): 102 | self.fake_H = self.netG(self.var_L) 103 | self.netG.train() 104 | 105 | def test_x8(self): 106 | # from https://github.com/thstkdgus35/EDSR-PyTorch 107 | self.netG.eval() 108 | 109 | def _transform(v, op): 110 | # if self.precision != 'single': v = v.float() 111 | v2np = v.data.cpu().numpy() 112 | if op == 'v': 113 | tfnp = v2np[:, :, :, ::-1].copy() 114 | elif op == 'h': 115 | tfnp = v2np[:, :, ::-1, :].copy() 116 | elif op == 't': 117 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 118 | 119 | ret = torch.Tensor(tfnp).to(self.device) 120 | # if self.precision == 'half': ret = ret.half() 121 | 122 | return ret 123 | 124 | lr_list = [self.var_L] 125 | for tf in 'v', 'h', 't': 126 | lr_list.extend([_transform(t, tf) for t in lr_list]) 127 | with torch.no_grad(): 128 | sr_list = [self.netG(aug) for aug in lr_list] 129 | for i in range(len(sr_list)): 130 | if i > 3: 131 | sr_list[i] = _transform(sr_list[i], 't') 132 | if i % 4 > 1: 133 | sr_list[i] = _transform(sr_list[i], 'h') 134 | if (i % 4) % 2 == 1: 135 | sr_list[i] = _transform(sr_list[i], 'v') 136 | 137 | output_cat = torch.cat(sr_list, dim=0) 138 | self.fake_H = output_cat.mean(dim=0, keepdim=True) 139 | self.netG.train() 140 | 141 | def get_current_log(self): 142 | return self.log_dict 143 | 144 | def get_current_visuals(self, need_GT=True): 145 | out_dict = OrderedDict() 146 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 147 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() 148 | if need_GT: 149 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 150 | return out_dict 151 | 152 | def print_network(self): 153 | s, n = self.get_network_description(self.netG) 154 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 155 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 156 | self.netG.module.__class__.__name__) 157 | else: 158 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 159 | if self.rank <= 0: 160 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 161 | logger.info(s) 162 | 163 | def load(self): 164 | load_path_G = self.opt['path']['pretrain_model_G'] 165 | if load_path_G is not None: 166 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 167 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 168 | 169 | def save(self, iter_label): 170 | self.save_network(self.netG, 'G', iter_label) 171 | -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | model = opt['model'] 7 | # image restoration 8 | if model == 'sr': # PSNR-oriented super resolution 9 | from .SR_model import SRModel as M 10 | elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN 11 | from .SRGAN_model import SRGANModel as M 12 | elif model == 'ranksrgan': # GAN-based super resolution 13 | from .RankSRGAN_model import SRGANModel as M 14 | # Ranker 15 | elif model == 'rank': 16 | from .Ranker_model import Ranker_Model as M 17 | else: 18 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 19 | m = M(opt) 20 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 21 | return m 22 | -------------------------------------------------------------------------------- /codes/models/__pycache__/RankSRGAN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/RankSRGAN.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/RankSRGAN_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/RankSRGAN_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/Ranker_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/Ranker_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/SRGAN_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/SRGAN_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/SRGAN_rank_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/SRGAN_rank_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/SR_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/SR_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/RRDBNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import models.archs.arch_util as arch_util 6 | 7 | 8 | class ResidualDenseBlock_5C(nn.Module): 9 | def __init__(self, nf=64, gc=32, bias=True): 10 | super(ResidualDenseBlock_5C, self).__init__() 11 | # gc: growth channel, i.e. intermediate channels 12 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 13 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 14 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 15 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 16 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 17 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 18 | 19 | # initialization 20 | arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 21 | 0.1) 22 | 23 | def forward(self, x): 24 | x1 = self.lrelu(self.conv1(x)) 25 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 26 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 27 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 28 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 29 | return x5 * 0.2 + x 30 | 31 | 32 | class RRDB(nn.Module): 33 | '''Residual in Residual Dense Block''' 34 | 35 | def __init__(self, nf, gc=32): 36 | super(RRDB, self).__init__() 37 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 38 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 39 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 40 | 41 | def forward(self, x): 42 | out = self.RDB1(x) 43 | out = self.RDB2(out) 44 | out = self.RDB3(out) 45 | return out * 0.2 + x 46 | 47 | 48 | class RRDBNet(nn.Module): 49 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 50 | super(RRDBNet, self).__init__() 51 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 52 | 53 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 54 | self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb) 55 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 56 | #### upsampling 57 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 58 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 59 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 60 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 61 | 62 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 63 | 64 | def forward(self, x): 65 | fea = self.conv_first(x) 66 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 67 | fea = fea + trunk 68 | 69 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 70 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 71 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 72 | 73 | return out 74 | -------------------------------------------------------------------------------- /codes/models/archs/RankSRGAN_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | import models.archs.arch_util as arch_util 4 | 5 | #################### 6 | # Generator for RankSRGAN 7 | #################### 8 | class SRResNet(nn.Module): 9 | 10 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 11 | super(SRResNet, self).__init__() 12 | self.upscale = upscale 13 | 14 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 15 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 16 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 17 | self.LRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 18 | 19 | # upsampling 20 | if self.upscale == 2: 21 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 22 | self.pixel_shuffle = nn.PixelShuffle(2) 23 | elif self.upscale == 3: 24 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 25 | self.pixel_shuffle = nn.PixelShuffle(3) 26 | elif self.upscale == 4: 27 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 28 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 29 | self.pixel_shuffle = nn.PixelShuffle(2) 30 | 31 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 32 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 33 | 34 | # activation function 35 | self.relu = nn.ReLU(inplace=True) 36 | 37 | # initialization 38 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 39 | 0.1) 40 | if self.upscale == 4: 41 | arch_util.initialize_weights(self.upconv2, 0.1) 42 | 43 | def forward(self, x): 44 | fea = self.conv_first(x) 45 | out = self.recon_trunk(fea) 46 | out = self.LRconv(out) 47 | 48 | if self.upscale == 4: 49 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea))) 50 | out = self.relu(self.pixel_shuffle(self.upconv2(out))) 51 | elif self.upscale == 3 or self.upscale == 2: 52 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea))) 53 | 54 | out = self.conv_last(self.relu(self.HRconv(out))) 55 | 56 | return out 57 | 58 | #################### 59 | # Discriminator with patchsize 296 for RankSRGAN 60 | #################### 61 | class Discriminator_VGG_296(nn.Module): 62 | def __init__(self, in_nc, nf): 63 | super(Discriminator_VGG_296, self).__init__() 64 | # [64, 128, 128] 65 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 66 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 67 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 68 | # [64, 64, 64] 69 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 70 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 71 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 72 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 73 | # [128, 32, 32] 74 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 75 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 76 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 77 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 78 | # [256, 16, 16] 79 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 80 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 81 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 82 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 83 | # [512, 8, 8] 84 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 85 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 86 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 87 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 88 | 89 | self.linear1 = nn.Linear(512 * 9 * 9, 100) 90 | self.linear2 = nn.Linear(100, 1) 91 | 92 | # activation function 93 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 94 | 95 | def forward(self, x): 96 | fea = self.lrelu(self.conv0_0(x)) 97 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 98 | 99 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 100 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 101 | 102 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 103 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 104 | 105 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 106 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 107 | 108 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 109 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 110 | 111 | fea = fea.view(fea.size(0), -1) 112 | fea = self.lrelu(self.linear1(fea)) 113 | out = self.linear2(fea) 114 | return out 115 | 116 | 117 | #################### 118 | # Ranker 119 | #################### 120 | 121 | class Ranker_VGG12_296(nn.Module): 122 | def __init__(self, in_nc, nf): 123 | super(Ranker_VGG12_296, self).__init__() 124 | # [64, 128, 128] 125 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 126 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=True) 127 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 128 | # [64, 64, 64] 129 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=True) 130 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 131 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=True) 132 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 133 | # [128, 32, 32] 134 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=True) 135 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 136 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=True) 137 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 138 | # [256, 16, 16] 139 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=True) 140 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 141 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=True) 142 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 143 | # [512, 8, 8] 144 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True) 145 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 146 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=True) 147 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 148 | 149 | # classifier 150 | self.classifier = nn.Sequential( 151 | nn.Linear(512, 100), 152 | nn.LeakyReLU(0.2, True), 153 | nn.Linear(100, 1) 154 | ) 155 | 156 | # activation function 157 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 158 | 159 | def forward(self, x): 160 | fea = self.lrelu(self.conv0_0(x)) 161 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 162 | 163 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 164 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 165 | 166 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 167 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 168 | 169 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 170 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 171 | 172 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 173 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 174 | 175 | fea = nn.AvgPool2d(fea.size()[2])(fea) 176 | fea = fea.view(fea.size(0), -1) 177 | out = self.classifier(fea) 178 | return out 179 | 180 | -------------------------------------------------------------------------------- /codes/models/archs/SRResNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.archs.arch_util as arch_util 5 | 6 | 7 | class MSRResNet(nn.Module): 8 | ''' modified SRResNet''' 9 | 10 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 11 | super(MSRResNet, self).__init__() 12 | self.upscale = upscale 13 | 14 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 15 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 16 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 17 | 18 | # upsampling 19 | if self.upscale == 2: 20 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 21 | self.pixel_shuffle = nn.PixelShuffle(2) 22 | elif self.upscale == 3: 23 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 24 | self.pixel_shuffle = nn.PixelShuffle(3) 25 | elif self.upscale == 4: 26 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 27 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 28 | self.pixel_shuffle = nn.PixelShuffle(2) 29 | 30 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 31 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 32 | 33 | # activation function 34 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 35 | 36 | # initialization 37 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 38 | 0.1) 39 | if self.upscale == 4: 40 | arch_util.initialize_weights(self.upconv2, 0.1) 41 | 42 | def forward(self, x): 43 | fea = self.lrelu(self.conv_first(x)) 44 | out = self.recon_trunk(fea) 45 | 46 | if self.upscale == 4: 47 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 48 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 49 | elif self.upscale == 3 or self.upscale == 2: 50 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 51 | 52 | out = self.conv_last(self.lrelu(self.HRconv(out))) 53 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 54 | out += base 55 | return out 56 | -------------------------------------------------------------------------------- /codes/models/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__init__.py -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/RankSRGAN_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/RankSRGAN_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/arch_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/arch_util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/archs/arch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def make_layer(block, n_layers): 28 | layers = [] 29 | for _ in range(n_layers): 30 | layers.append(block()) 31 | return nn.Sequential(*layers) 32 | 33 | 34 | class ResidualBlock_noBN(nn.Module): 35 | '''Residual block w/o BN 36 | ---Conv-ReLU-Conv-+- 37 | |________________| 38 | ''' 39 | 40 | def __init__(self, nf=64): 41 | super(ResidualBlock_noBN, self).__init__() 42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 44 | 45 | # initialization 46 | initialize_weights([self.conv1, self.conv2], 0.1) 47 | 48 | def forward(self, x): 49 | identity = x 50 | out = F.relu(self.conv1(x), inplace=True) 51 | out = self.conv2(out) 52 | return identity + out 53 | 54 | -------------------------------------------------------------------------------- /codes/models/archs/discriminator_vgg_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class Discriminator_VGG_128(nn.Module): 7 | def __init__(self, in_nc, nf): 8 | super(Discriminator_VGG_128, self).__init__() 9 | # [64, 128, 128] 10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 13 | # [64, 64, 64] 14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 18 | # [128, 32, 32] 19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 23 | # [256, 16, 16] 24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 28 | # [512, 8, 8] 29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 33 | 34 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 35 | self.linear2 = nn.Linear(100, 1) 36 | 37 | # activation function 38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 39 | 40 | def forward(self, x): 41 | fea = self.lrelu(self.conv0_0(x)) 42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 43 | 44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 46 | 47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 49 | 50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 52 | 53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 55 | 56 | fea = fea.view(fea.size(0), -1) 57 | fea = self.lrelu(self.linear1(fea)) 58 | out = self.linear2(fea) 59 | return out 60 | 61 | 62 | class VGGFeatureExtractor(nn.Module): 63 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, 64 | device=torch.device('cpu')): 65 | super(VGGFeatureExtractor, self).__init__() 66 | self.use_input_norm = use_input_norm 67 | if use_bn: 68 | model = torchvision.models.vgg19_bn(pretrained=True) 69 | else: 70 | model = torchvision.models.vgg19(pretrained=True) 71 | if self.use_input_norm: 72 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 73 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 75 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 76 | self.register_buffer('mean', mean) 77 | self.register_buffer('std', std) 78 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 79 | # No need to BP to variable 80 | for k, v in self.features.named_parameters(): 81 | v.requires_grad = False 82 | 83 | def forward(self, x): 84 | # Assume input range is [0, 1] 85 | if self.use_input_norm: 86 | x = (x - self.mean) / self.std 87 | output = self.features(x) 88 | return output 89 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | """Set learning rate for warmup 39 | lr_groups_l: list for lr_groups. each for a optimizer""" 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | """Get the initial lr, which is set by the scheduler""" 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | # set up warm-up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups] 67 | 68 | def get_network_description(self, network): 69 | """Get the string and total parameters of the network""" 70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 71 | network = network.module 72 | return str(network), sum(map(lambda x: x.numel(), network.parameters())) 73 | 74 | def save_network(self, network, network_label, iter_label): 75 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 76 | save_path = os.path.join(self.opt['path']['models'], save_filename) 77 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 78 | network = network.module 79 | state_dict = network.state_dict() 80 | for key, param in state_dict.items(): 81 | state_dict[key] = param.cpu() 82 | torch.save(state_dict, save_path) 83 | 84 | def load_network(self, load_path, network, strict=True): 85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 86 | network = network.module 87 | load_net = torch.load(load_path) 88 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 89 | for k, v in load_net.items(): 90 | if k.startswith('module.'): 91 | load_net_clean[k[7:]] = v 92 | else: 93 | load_net_clean[k] = v 94 | network.load_state_dict(load_net_clean, strict=strict) 95 | 96 | def save_training_state(self, epoch, iter_step): 97 | """Save training state during training, which will be used for resuming""" 98 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 99 | for s in self.schedulers: 100 | state['schedulers'].append(s.state_dict()) 101 | for o in self.optimizers: 102 | state['optimizers'].append(o.state_dict()) 103 | save_filename = '{}.state'.format(iter_step) 104 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 105 | torch.save(state, save_path) 106 | 107 | def resume_training(self, resume_state): 108 | """Resume the optimizers and schedulers for training""" 109 | resume_optimizers = resume_state['optimizers'] 110 | resume_schedulers = resume_state['schedulers'] 111 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 112 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 113 | for i, o in enumerate(resume_optimizers): 114 | self.optimizers[i].load_state_dict(o) 115 | for i, s in enumerate(resume_schedulers): 116 | self.schedulers[i].load_state_dict(s) 117 | -------------------------------------------------------------------------------- /codes/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | 17 | 18 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 19 | class GANLoss(nn.Module): 20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 21 | super(GANLoss, self).__init__() 22 | self.gan_type = gan_type.lower() 23 | self.real_label_val = real_label_val 24 | self.fake_label_val = fake_label_val 25 | 26 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 27 | self.loss = nn.BCEWithLogitsLoss() 28 | elif self.gan_type == 'lsgan': 29 | self.loss = nn.MSELoss() 30 | elif self.gan_type == 'wgan-gp': 31 | 32 | def wgan_loss(input, target): 33 | # target is boolean 34 | return -1 * input.mean() if target else input.mean() 35 | 36 | self.loss = wgan_loss 37 | else: 38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 39 | 40 | def get_target_label(self, input, target_is_real): 41 | if self.gan_type == 'wgan-gp': 42 | return target_is_real 43 | if target_is_real: 44 | return torch.empty_like(input).fill_(self.real_label_val) 45 | else: 46 | return torch.empty_like(input).fill_(self.fake_label_val) 47 | 48 | def forward(self, input, target_is_real): 49 | target_label = self.get_target_label(input, target_is_real) 50 | loss = self.loss(input, target_label) 51 | return loss 52 | 53 | 54 | class GradientPenaltyLoss(nn.Module): 55 | def __init__(self, device=torch.device('cpu')): 56 | super(GradientPenaltyLoss, self).__init__() 57 | self.register_buffer('grad_outputs', torch.Tensor()) 58 | self.grad_outputs = self.grad_outputs.to(device) 59 | 60 | def get_grad_outputs(self, input): 61 | if self.grad_outputs.size() != input.size(): 62 | self.grad_outputs.resize_(input.size()).fill_(1.0) 63 | return self.grad_outputs 64 | 65 | def forward(self, interp, interp_crit): 66 | grad_outputs = self.get_grad_outputs(interp_crit) 67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 68 | grad_outputs=grad_outputs, create_graph=True, 69 | retain_graph=True, only_inputs=True)[0] 70 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 71 | grad_interp_norm = grad_interp.norm(2, dim=1) 72 | 73 | loss = ((grad_interp_norm - 1)**2).mean() 74 | return loss 75 | -------------------------------------------------------------------------------- /codes/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restarts = [v + 1 for v in self.restarts] 16 | self.restart_weights = weights if weights else [1] 17 | assert len(self.restarts) == len( 18 | self.restart_weights), 'restarts and their weights do not match.' 19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch in self.restarts: 23 | if self.clear_state: 24 | self.optimizer.state = defaultdict(dict) 25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 27 | if self.last_epoch not in self.milestones: 28 | return [group['lr'] for group in self.optimizer.param_groups] 29 | return [ 30 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 31 | for group in self.optimizer.param_groups 32 | ] 33 | 34 | 35 | class CosineAnnealingLR_Restart(_LRScheduler): 36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 37 | self.T_period = T_period 38 | self.T_max = self.T_period[0] # current T period 39 | self.eta_min = eta_min 40 | self.restarts = restarts if restarts else [0] 41 | self.restarts = [v + 1 for v in self.restarts] 42 | self.restart_weights = weights if weights else [1] 43 | self.last_restart = 0 44 | assert len(self.restarts) == len( 45 | self.restart_weights), 'restarts and their weights do not match.' 46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | if self.last_epoch == 0: 50 | return self.base_lrs 51 | elif self.last_epoch in self.restarts: 52 | self.last_restart = self.last_epoch 53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 57 | return [ 58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 60 | ] 61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 63 | (group['lr'] - self.eta_min) + self.eta_min 64 | for group in self.optimizer.param_groups] 65 | 66 | 67 | if __name__ == "__main__": 68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 69 | betas=(0.9, 0.99)) 70 | ############################## 71 | # MultiStepLR_Restart 72 | ############################## 73 | ## Original 74 | lr_steps = [200000, 400000, 600000, 800000] 75 | restarts = None 76 | restart_weights = None 77 | 78 | ## two 79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 80 | restarts = [500000] 81 | restart_weights = [1] 82 | 83 | ## four 84 | lr_steps = [ 85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 87 | ] 88 | restarts = [250000, 500000, 750000] 89 | restart_weights = [1, 1, 1] 90 | 91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 92 | clear_state=False) 93 | 94 | ############################## 95 | # Cosine Annealing Restart 96 | ############################## 97 | ## two 98 | T_period = [500000, 500000] 99 | restarts = [500000] 100 | restart_weights = [1] 101 | 102 | ## four 103 | T_period = [250000, 250000, 250000, 250000] 104 | restarts = [250000, 500000, 750000] 105 | restart_weights = [1, 1, 1] 106 | 107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 108 | weights=restart_weights) 109 | 110 | ############################## 111 | # Draw figure 112 | ############################## 113 | N_iter = 1000000 114 | lr_l = list(range(N_iter)) 115 | for i in range(N_iter): 116 | scheduler.step() 117 | current_lr = optimizer.param_groups[0]['lr'] 118 | lr_l[i] = current_lr 119 | 120 | import matplotlib as mpl 121 | from matplotlib import pyplot as plt 122 | import matplotlib.ticker as mtick 123 | mpl.style.use('default') 124 | import seaborn 125 | seaborn.set(style='whitegrid') 126 | seaborn.set_context('paper') 127 | 128 | plt.figure(1) 129 | plt.subplot(111) 130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 131 | plt.title('Title', fontsize=16, color='k') 132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 133 | legend = plt.legend(loc='upper right', shadow=False) 134 | ax = plt.gca() 135 | labels = ax.get_xticks().tolist() 136 | for k, v in enumerate(labels): 137 | labels[k] = str(int(v / 1000)) + 'K' 138 | ax.set_xticklabels(labels) 139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 140 | 141 | ax.set_ylabel('Learning rate') 142 | ax.set_xlabel('Iteration') 143 | fig = plt.gcf() 144 | plt.show() 145 | -------------------------------------------------------------------------------- /codes/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import models.archs.SRResNet_arch as SRResNet_arch 3 | import models.archs.discriminator_vgg_arch as SRGAN_arch 4 | import models.archs.RRDBNet_arch as RRDBNet_arch 5 | import models.archs.RankSRGAN_arch as RankSRGAN_arch 6 | 7 | 8 | # Generator 9 | def define_G(opt): 10 | opt_net = opt['network_G'] 11 | which_model = opt_net['which_model_G'] 12 | 13 | # image restoration 14 | if which_model == 'MSRResNet': 15 | netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 16 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) 17 | elif which_model == 'RRDBNet': 18 | netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 19 | nf=opt_net['nf'], nb=opt_net['nb']) 20 | elif which_model == 'SRResNet': 21 | netG = RankSRGAN_arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 22 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) 23 | else: 24 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) 25 | 26 | return netG 27 | 28 | 29 | # Discriminator 30 | def define_D(opt): 31 | opt_net = opt['network_D'] 32 | which_model = opt_net['which_model_D'] 33 | 34 | if which_model == 'discriminator_vgg_128': 35 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 36 | elif which_model == 'discriminator_vgg_296': 37 | netD = RankSRGAN_arch.Discriminator_VGG_296(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 38 | else: 39 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 40 | return netD 41 | 42 | # Define network used for perceptual loss 43 | def define_F(opt, use_bn=False): 44 | gpu_ids = opt['gpu_ids'] 45 | device = torch.device('cuda' if gpu_ids else 'cpu') 46 | # PyTorch pretrained VGG19-54, before ReLU. 47 | if use_bn: 48 | feature_layer = 49 49 | else: 50 | feature_layer = 34 51 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, 52 | use_input_norm=True, device=device) 53 | netF.eval() # No need to train 54 | return netF 55 | 56 | # Define network used for rank-content loss 57 | def define_R(opt): 58 | opt_net = opt['network_R'] 59 | which_model = opt_net['which_model_R'] 60 | 61 | if which_model == 'Ranker_VGG12': 62 | netR = RankSRGAN_arch.Ranker_VGG12_296(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 63 | else: 64 | raise NotImplementedError('Ranker model [{:s}] is not recognized'.format(which_model)) 65 | 66 | return netR 67 | -------------------------------------------------------------------------------- /codes/options/README.md: -------------------------------------------------------------------------------- 1 | # Configurations 2 | - Use **json** files to configure options. 3 | - Convert the json file to python dict. 4 | - Support `//` comments and use `null` for `None`. 5 | 6 | ## Table 7 | Click for detailed explanations for each json file. 8 | 9 | 1. [RankSRGAN_NIQE.json](#RankSRGAN_NIQE_json) 10 | 2. [Ranker.json](#Ranker_json) 11 | 12 | ## RankSRGAN_NIQE_json 13 | ```c++ 14 | { 15 | "name": "RankSRGANx4_NIQE" 16 | ,"model":"ranksrgan" // use tensorboard_logger 17 | ,"scale": 4 18 | ,"gpu_ids": [2] // specify GPUs, actually it sets the `CUDA_VISIBLE_DEVICES` 19 | 20 | ,"datasets": { // configure the training and validation datasets 21 | "train": { // training dataset configurations 22 | "name": "DIV2K" 23 | ,"mode": "LRHR" 24 | ,"dataroot_HR": "/home/wlzhang/BasicSR12/data/DIV2K800_sub.lmdb" // HR data root 25 | ,"dataroot_LR": "/home/wlzhang/BasicSR12/data/DIV2K800_sub_bicLRx4.lmdb" // LR data root 26 | ,"subset_file": null 27 | ,"use_shuffle": true 28 | ,"n_workers": 8 // number of data load workers 29 | ,"batch_size": 8 30 | ,"HR_size": 296 // 128 for SRGAN | 296 for RankSRGAN, cropped HR patch size 31 | ,"use_flip": true 32 | ,"use_rot": true 33 | , "random_flip": false // whether use horizontal and vertical flips 34 | , "random_scale": false // whether use rotations: 90, 190, 270 degrees 35 | } 36 | , "val": { // validation dataset configurations 37 | "name": "val_PIRM" 38 | ,"mode": "LRHR" 39 | ,"dataroot_HR": "/home/wlzhang/BasicSR12/data/val/PIRMtestHR" 40 | ,"dataroot_LR": "/home/wlzhang/BasicSR12/data/val/PIRMtest" 41 | } 42 | } 43 | 44 | ,"path": { 45 | "root": "/home/wlzhang/RankSRGAN", // root path 46 | // "resume_state": "../experiments/RankSRGANx4_NIQE/training_state/152000.state", // Resume the training from 152000 iteration 47 | "pretrain_model_G": "/home/wlzhang/RankSRGAN/experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth", // G network pretrain model 48 | "pretrain_model_R": "/home/wlzhang/RankSRGAN/experiments/pretrained_models/Ranker_NIQE.pth", // R network pretrain model 49 | 50 | "experiments_root": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE", 51 | "models": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE/models", 52 | "log": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE", 53 | "val_images": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE/val_images" 54 | } 55 | 56 | ,"network_G": { // configurations for the network G 57 | "which_model_G": "sr_resnet" 58 | ,"norm_type": null // null | "batch", norm type 59 | ,"mode": "CNA" // Convolution mode: CNA for Conv-Norm_Activation 60 | ,"nf": 64 // number of features for each layer 61 | ,"nb": 16 // number of blocks 62 | ,"in_nc": 3 // input channels 63 | ,"out_nc": 3 // output channels 64 | ,"group": 1 65 | } 66 | ,"network_D": { // configurations for the network D 67 | "which_model_D": "discriminator_vgg_128" 68 | ,"norm_type": "batch" 69 | ,"act_type": "leakyrelu" 70 | ,"mode": "CNA" 71 | ,"nf": 64 72 | ,"in_nc": 3 73 | }, 74 | "network_R": { 75 | "which_model_R": "Ranker_VGG12", 76 | "norm_type": "batch", 77 | "act_type": "leakyrelu", 78 | "mode": "CNA", 79 | "nf": 64, 80 | "in_nc": 3 81 | }, 82 | "train": { // training strategies 83 | "lr_G": 0.0001, // initialized learning rate for G 84 | "train_D": 1, 85 | "weight_decay_G": 0, 86 | "beta1_G": 0.9, 87 | "lr_D": 0.0001, // initialized learning rate for D 88 | "weight_decay_D": 0, 89 | "beta1_D": 0.9, 90 | "lr_scheme": "MultiStepLR", // learning rate decay scheme 91 | "lr_steps": [ 92 | 50000, 93 | 100000, 94 | 200000, 95 | 300000 96 | ], 97 | "lr_gamma": 0.5, 98 | "pixel_criterion": "l1", // "l1" | "l2", pixel criterion 99 | "pixel_weight": 0, 100 | "feature_criterion": "l1", // perceptual criterion (VGG loss) 101 | "feature_weight": 1, 102 | "gan_type": "vanilla", // GAN type 103 | "gan_weight": 0.005, 104 | "D_update_ratio": 1, 105 | "D_init_iters": 0, 106 | "R_weight": 0.03, // Ranker-content loss 107 | "R_bias": 0, 108 | "manual_seed": 0, 109 | "niter": 500000.0, // total training iteration 110 | "val_freq": 5000 // validation frequency 111 | }, 112 | "logger": { // logger configurations 113 | "print_freq": 200 114 | ,"save_checkpoint_freq": 5000 115 | }, 116 | "timestamp": "180804-004247", 117 | "is_train": true, 118 | "fine_tune": false 119 | } 120 | 121 | ``` 122 | ## Ranker_json 123 | 124 | ```c++ 125 | { 126 | "name": "Ranker_NIQE" // 127 | ,"use_tb_logger": true // use tensorboard_logger 128 | ,"model":"rank" 129 | ,"scale": 4 130 | ,"gpu_ids": [2,5] // specify GPUs, actually it sets the `CUDA_VISIBLE_DEVICES` 131 | ,"datasets": { // configure the training and validation rank datasets 132 | "train": { // training dataset configurations 133 | "name": "DF2K_train_rankdataset" 134 | ,"mode": "RANK_IMIM_Pair" 135 | ,"dataroot_HR": null 136 | ,"dataroot_LR":null 137 | ,"dataroot_img1": "/home/wlzhang/data/rankdataset/DF2K_train_patch_esrgan/" // Rankdataset: Perceptual level1 data root 138 | ,"dataroot_img2": "/home/wlzhang/data/rankdataset/DF2K_train_patch_srgan/" // Rankdataset: Perceptual level2 data root 139 | ,"dataroot_img3": "/home/wlzhang/data/rankdataset/DF2K_train_patch_srres/" // Rankdataset: Perceptual level3 data root 140 | ,"dataroot_label_file": "/home/wlzhang/data/rankdataset/DF2K_train_patch_label.txt" // Rankdataset: Perceptual rank label root 141 | ,"subset_file": null 142 | ,"use_shuffle": true 143 | ,"n_workers": 8 // number of data load workers 144 | ,"batch_size": 32 145 | ,"HR_size": 128 146 | ,"use_flip": true 147 | ,"use_rot": true 148 | } 149 | , "val": { 150 | "name": "DF2K_valid_rankdataset" // validation dataset configurations 151 | ,"mode": "RANK_IMIM_Pair" 152 | ,"dataroot_HR": null 153 | ,"dataroot_LR":null 154 | ,"dataroot_img1": "/home/wlzhang/data/rankdataset/DF2K_test_patch_all/" 155 | ,"dataroot_label_file": "/home/wlzhang/data/rankdataset/DF2K_test_patch_label.txt" 156 | } 157 | } 158 | 159 | ,"path": { // root path 160 | "root": "/home/wlzhang/RankSRGAN", 161 | "experiments_root": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE", 162 | "models": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE/models", 163 | "log": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE", 164 | "val_images": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE/val_images" 165 | } 166 | 167 | ,"network_G": { 168 | "which_model_G": "sr_resnet" 169 | ,"norm_type": null 170 | ,"mode": "CNA" 171 | ,"nf": 64 172 | ,"nb": 16 173 | ,"in_nc": 3 174 | ,"out_nc": 3 175 | ,"group": 1 176 | } 177 | ,"network_R": { // configurations for the network Ranker 178 | "which_model_R": "Ranker_VGG12" 179 | ,"norm_type": "batch" // null | "batch", norm type 180 | ,"act_type": "leakyrelu" 181 | ,"mode": "CNA" 182 | ,"nf": 64 183 | ,"nb": 16 184 | ,"in_nc": 3 185 | ,"out_nc": 3 186 | ,"in_nc": 3 187 | } 188 | 189 | ,"train": { // training strategies 190 | "lr_R": 1e-3 // initialized learning rate for R 191 | ,"weight_decay_R": 1e-4 192 | ,"beta1_G": 0.9 193 | ,"lr_D": 1e-4 194 | ,"weight_decay_D": 0 195 | ,"beta1_D": 0.9 196 | ,"lr_scheme": "MultiStepLR" 197 | ,"lr_steps": [100000, 200000] // learning rate decay scheme 198 | 199 | ,"lr_gamma": 0.5 200 | 201 | // ,"pixel_criterion": "l1" 202 | // ,"pixel_weight": 1 203 | // ,"feature_criterion": "l1" 204 | // ,"feature_weight": 1 205 | // ,"gan_type": "vanilla" 206 | // ,"gan_weight": 5e-3 207 | 208 | ,"D_update_ratio": 1 209 | ,"D_init_iters": 0 210 | 211 | ,"manual_seed": 0 212 | ,"niter": 400000 // total training iteration 213 | ,"val_freq": 5000 // validation frequency 214 | } 215 | 216 | ,"logger": { // logger configurations 217 | "print_freq": 200 218 | ,"save_checkpoint_freq": 5000 219 | } 220 | } 221 | ''' 222 | -------------------------------------------------------------------------------- /codes/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/options/__init__.py -------------------------------------------------------------------------------- /codes/options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/options/__pycache__/options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/options/__pycache__/options.cpython-36.pyc -------------------------------------------------------------------------------- /codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale = opt['scale'] 20 | 21 | # datasets 22 | for phase, dataset in opt['datasets'].items(): 23 | phase = phase.split('_')[0] 24 | dataset['phase'] = phase 25 | if opt['distortion'] == 'sr': 26 | dataset['scale'] = scale 27 | is_lmdb = False 28 | if dataset.get('dataroot_GT', None) is not None: 29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 30 | if dataset['dataroot_GT'].endswith('lmdb'): 31 | is_lmdb = True 32 | if dataset.get('dataroot_LQ', None) is not None: 33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 34 | if dataset['dataroot_LQ'].endswith('lmdb'): 35 | is_lmdb = True 36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 37 | if dataset['mode'].endswith('mc'): # for memcached 38 | dataset['data_type'] = 'mc' 39 | dataset['mode'] = dataset['mode'].replace('_mc', '') 40 | 41 | # path 42 | for key, path in opt['path'].items(): 43 | if path and key in opt['path'] and key != 'strict_load': 44 | opt['path'][key] = osp.expanduser(path) 45 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 46 | if is_train: 47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 48 | opt['path']['experiments_root'] = experiments_root 49 | opt['path']['models'] = osp.join(experiments_root, 'models') 50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 51 | opt['path']['log'] = experiments_root 52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 53 | 54 | # change some options for debug mode 55 | if 'debug' in opt['name']: 56 | opt['train']['val_freq'] = 8 57 | opt['logger']['print_freq'] = 1 58 | opt['logger']['save_checkpoint_freq'] = 8 59 | else: # test 60 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 61 | opt['path']['results_root'] = results_root 62 | opt['path']['log'] = results_root 63 | 64 | # network 65 | if opt['distortion'] == 'sr': 66 | opt['network_G']['scale'] = scale 67 | 68 | return opt 69 | 70 | 71 | def dict2str(opt, indent_l=1): 72 | '''dict to string for logger''' 73 | msg = '' 74 | for k, v in opt.items(): 75 | if isinstance(v, dict): 76 | msg += ' ' * (indent_l * 2) + k + ':[\n' 77 | msg += dict2str(v, indent_l + 1) 78 | msg += ' ' * (indent_l * 2) + ']\n' 79 | else: 80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 81 | return msg 82 | 83 | 84 | class NoneDict(dict): 85 | def __missing__(self, key): 86 | return None 87 | 88 | 89 | # convert to NoneDict, which return None for missing key. 90 | def dict_to_nonedict(opt): 91 | if isinstance(opt, dict): 92 | new_opt = dict() 93 | for key, sub_opt in opt.items(): 94 | new_opt[key] = dict_to_nonedict(sub_opt) 95 | return NoneDict(**new_opt) 96 | elif isinstance(opt, list): 97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 98 | else: 99 | return opt 100 | 101 | 102 | def check_resume(opt, resume_iter): 103 | '''Check resume states and pretrain_model paths''' 104 | logger = logging.getLogger('base') 105 | if opt['path']['resume_state']: 106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 107 | 'pretrain_model_D', None) is not None: 108 | logger.warning('pretrain_model path will be ignored when resuming training.') 109 | 110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 111 | '{}_G.pth'.format(resume_iter)) 112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 113 | if 'gan' in opt['model']: 114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 115 | '{}_D.pth'.format(resume_iter)) 116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 117 | -------------------------------------------------------------------------------- /codes/options/test/test_RankSRGAN.yml: -------------------------------------------------------------------------------- 1 | name: RankSRGANx4 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [3] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: set14 12 | mode: LQGT 13 | dataroot_GT: /home/wlzhang/BasicSR12/data/val/Set14_mod 14 | dataroot_LQ: 15 | test_2: # the 2st test dataset 16 | name: PIRMtest 17 | mode: LQGT 18 | dataroot_GT: /home/wlzhang/RankSRGAN/data/val/PIRMtestHR 19 | dataroot_LQ: /home/wlzhang/RankSRGAN/data/val/PIRMtest 20 | 21 | #### network structures 22 | network_G: 23 | which_model_G: SRResNet # SRResNet for RankSRGAN 24 | in_nc: 3 25 | out_nc: 3 26 | nf: 64 27 | nb: 16 28 | upscale: 4 29 | 30 | #### path 31 | # Download pretrained models from https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing 32 | path: 33 | pretrain_model_G: ../experiments/pretrained_models/mmsr_RankSRGAN_NIQE.pth 34 | -------------------------------------------------------------------------------- /codes/options/train/train_RankSRGAN.yml: -------------------------------------------------------------------------------- 1 | # Not exactly the same as SRGAN in 2 | # With 16 Residual blocks w/o BN 3 | 4 | #### general settings 5 | name: 002_RankSRGANx4_SRResNetx4_DIV2K 6 | use_tb_logger: true 7 | model: ranksrgan 8 | distortion: sr 9 | scale: 4 10 | gpu_ids: [0,1] 11 | 12 | #### datasets 13 | datasets: 14 | train: 15 | name: DIV2K 16 | mode: LQGT 17 | dataroot_GT: /home/wlzhang/data/DIV2K/DIV2K_train_HR 18 | dataroot_LQ: 19 | 20 | use_shuffle: true 21 | n_workers: 6 # per GPU 22 | batch_size: 8 23 | GT_size: 296 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | val: 28 | name: val_Pirm_test100 29 | mode: LQGT 30 | dataroot_GT: /home/wlzhang/BasicSR12/data/val/PIRMtestHR 31 | dataroot_LQ: /home/wlzhang/BasicSR12/data/val/PIRMtest 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: SRResNet # SRResNet for RankSRGAN 36 | in_nc: 3 37 | out_nc: 3 38 | nf: 64 39 | nb: 16 40 | upscale: 4 41 | network_D: 42 | which_model_D: discriminator_vgg_296 43 | in_nc: 3 44 | nf: 64 45 | 46 | network_R: 47 | which_model_R: Ranker_VGG12 48 | in_nc: 3 49 | nf: 64 50 | 51 | #### path 52 | # Download pretrained models from https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing 53 | path: 54 | pretrain_model_G: ../experiments/pretrained_models/mmsr_SRResNet_pretrain.pth 55 | pretrain_model_R: ../experiments/pretrained_models/mmsr_Ranker_NIQE.pth 56 | 57 | strict_load: true 58 | resume_state: ~ 59 | 60 | #### training settings: learning rate scheme, loss 61 | train: 62 | lr_G: !!float 1e-4 63 | weight_decay_G: 0 64 | beta1_G: 0.9 65 | beta2_G: 0.99 66 | lr_D: !!float 1e-4 67 | weight_decay_D: 0 68 | beta1_D: 0.9 69 | beta2_D: 0.99 70 | lr_scheme: MultiStepLR 71 | 72 | niter: 500000 73 | warmup_iter: -1 # no warm up 74 | lr_steps: [50000, 100000, 200000, 300000] 75 | lr_gamma: 0.5 76 | 77 | pixel_criterion: l1 78 | pixel_weight: 0 79 | feature_criterion: l1 80 | feature_weight: 1 81 | R_weight: !!float 3e-2 # rank-content loss 82 | R_bias: 0 83 | gan_type: gan # gan | ragan 84 | gan_weight: !!float 5e-3 85 | 86 | D_update_ratio: 1 87 | D_init_iters: 0 88 | 89 | manual_seed: 10 90 | val_freq: !!float 2e3 91 | 92 | #### logger 93 | logger: 94 | print_freq: 200 95 | save_checkpoint_freq: !!float 2e3 96 | -------------------------------------------------------------------------------- /codes/options/train/train_Ranker.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_VGG_Ranker_DF2K 3 | use_tb_logger: true 4 | model: rank 5 | distortion: sr 6 | scale: 4 7 | gpu_ids: [0,1] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: Rankdataset 13 | mode: RANK_IMIM_Pair 14 | dataroot_img1: /home/wlzhang/data/rankdataset/DF2K_train_patch_esrgan/ 15 | dataroot_img2: /home/wlzhang/data/rankdataset/DF2K_train_patch_srgan/ 16 | dataroot_img3: /home/wlzhang/data/rankdataset/DF2K_train_patch_srres/ 17 | dataroot_label_file: /home/wlzhang/data/rankdataset/DF2K_train_patch_label.txt # Rankdataset: Perceptual rank label root 18 | 19 | use_shuffle: true 20 | n_workers: 6 # per GPU 21 | batch_size: 16 22 | GT_size: 128 23 | use_flip: true 24 | use_rot: true 25 | color: RGB 26 | val: 27 | name: DF2K_valid_rankdataset 28 | mode: RANK_IMIM_Pair 29 | dataroot_img1: /home/wlzhang/data/rankdataset/DF2K_test_patch_all/ 30 | dataroot_label_file: /home/wlzhang/data/rankdataset/DF2K_test_patch_label.txt 31 | 32 | #### network structures 33 | network_G: 34 | which_model_G: RRDBNet 35 | in_nc: 3 36 | out_nc: 3 37 | nf: 64 38 | nb: 23 39 | network_R: # configurations for the network Ranker 40 | which_model_R: Ranker_VGG12 41 | in_nc: 3 42 | nf: 64 43 | 44 | #### path 45 | path: 46 | strict_load: true 47 | resume_state: ~ 48 | 49 | #### training settings: learning rate scheme, loss 50 | train: 51 | lr_R: !!float 1e-4 52 | weight_decay_R: !!float 1e-4 53 | beta1_G: 0.9 54 | lr_scheme: MultiStepLR 55 | 56 | niter: 400000 57 | warmup_iter: -1 # no warm up 58 | lr_steps: [100000, 200000] 59 | lr_gamma: 0.5 60 | 61 | manual_seed: 10 62 | val_freq: !!float 5e3 63 | 64 | #### logger 65 | logger: 66 | print_freq: 200 67 | save_checkpoint_freq: !!float 5e3 68 | -------------------------------------------------------------------------------- /codes/options/train/train_SRGAN.yml: -------------------------------------------------------------------------------- 1 | # Not exactly the same as SRGAN in 2 | # With 16 Residual blocks w/o BN 3 | 4 | #### general settings 5 | name: 003_SRGANx4_MSRResNetx4Ini_DIV2K 6 | use_tb_logger: true 7 | model: srgan 8 | distortion: sr 9 | scale: 4 10 | gpu_ids: [1] 11 | 12 | #### datasets 13 | datasets: 14 | train: 15 | name: DIV2K 16 | mode: LQGT 17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb 18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb 19 | 20 | use_shuffle: true 21 | n_workers: 6 # per GPU 22 | batch_size: 16 23 | GT_size: 128 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | val: 28 | name: val_set14 29 | mode: LQGT 30 | dataroot_GT: ../datasets/val_set14/Set14 31 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: MSRResNet 36 | in_nc: 3 37 | out_nc: 3 38 | nf: 64 39 | nb: 16 40 | upscale: 4 41 | network_D: 42 | which_model_D: discriminator_vgg_128 43 | in_nc: 3 44 | nf: 64 45 | 46 | #### path 47 | path: 48 | pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth 49 | strict_load: true 50 | resume_state: ~ 51 | 52 | #### training settings: learning rate scheme, loss 53 | train: 54 | lr_G: !!float 1e-4 55 | weight_decay_G: 0 56 | beta1_G: 0.9 57 | beta2_G: 0.99 58 | lr_D: !!float 1e-4 59 | weight_decay_D: 0 60 | beta1_D: 0.9 61 | beta2_D: 0.99 62 | lr_scheme: MultiStepLR 63 | 64 | niter: 400000 65 | warmup_iter: -1 # no warm up 66 | lr_steps: [50000, 100000, 200000, 300000] 67 | lr_gamma: 0.5 68 | 69 | pixel_criterion: l1 70 | pixel_weight: !!float 1e-2 71 | feature_criterion: l1 72 | feature_weight: 1 73 | gan_type: gan # gan | ragan 74 | gan_weight: !!float 5e-3 75 | 76 | D_update_ratio: 1 77 | D_init_iters: 0 78 | 79 | manual_seed: 10 80 | val_freq: !!float 5e3 81 | 82 | #### logger 83 | logger: 84 | print_freq: 100 85 | save_checkpoint_freq: !!float 5e3 86 | -------------------------------------------------------------------------------- /codes/options/train/train_SRResNet.yml: -------------------------------------------------------------------------------- 1 | # Not exactly the same as SRResNet in 2 | # With 16 Residual blocks w/o BN 3 | 4 | #### general settings 5 | name: 004_MSRResNetx4_scratch_DIV2K 6 | use_tb_logger: true 7 | model: sr 8 | distortion: sr 9 | scale: 4 10 | gpu_ids: [0] 11 | 12 | #### datasets 13 | datasets: 14 | train: 15 | name: DIV2K 16 | mode: LQGT 17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb 18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb 19 | 20 | use_shuffle: true 21 | n_workers: 6 # per GPU 22 | batch_size: 16 23 | GT_size: 128 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | val: 28 | name: val_set5 29 | mode: LQGT 30 | dataroot_GT: ../datasets/val_set5/Set5 31 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: MSRResNet 36 | in_nc: 3 37 | out_nc: 3 38 | nf: 64 39 | nb: 16 40 | upscale: 4 41 | 42 | #### path 43 | path: 44 | pretrain_model_G: ~ 45 | strict_load: true 46 | resume_state: ~ 47 | 48 | #### training settings: learning rate scheme, loss 49 | train: 50 | lr_G: !!float 2e-4 51 | lr_scheme: CosineAnnealingLR_Restart 52 | beta1: 0.9 53 | beta2: 0.99 54 | niter: 1000000 55 | warmup_iter: -1 # no warm up 56 | T_period: [250000, 250000, 250000, 250000] 57 | restarts: [250000, 500000, 750000] 58 | restart_weights: [1, 1, 1] 59 | eta_min: !!float 1e-7 60 | 61 | pixel_criterion: l1 62 | pixel_weight: 1.0 63 | 64 | manual_seed: 10 65 | val_freq: !!float 5e3 66 | 67 | #### logger 68 | logger: 69 | print_freq: 100 70 | save_checkpoint_freq: !!float 5e3 71 | -------------------------------------------------------------------------------- /codes/run_scripts.sh: -------------------------------------------------------------------------------- 1 | # image SR training 2 | 3 | python train.py -opt options/train/train_RankSRGAN.yml # Validation with PSNR 4 | python train_niqe.py -opt options/train/train_RankSRGAN.yml # Validation with PSNR and NIQE 5 | 6 | # Ranker training 7 | 8 | python train_rank.py -opt options/train/train_Ranker.yml # 9 | -------------------------------------------------------------------------------- /codes/scripts/README.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | We provide some useful scripts here. 3 | 4 | ## List 5 | 6 | | Name | Description | 7 | |:---:|:---:| 8 | | back projection | `Matlab` codes for back projection | 9 | -------------------------------------------------------------------------------- /codes/scripts/__pycache__/arch_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/scripts/__pycache__/arch_util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/scripts/__pycache__/block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/scripts/__pycache__/block.cpython-36.pyc -------------------------------------------------------------------------------- /codes/scripts/arch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def make_layer(block, n_layers): 28 | layers = [] 29 | for _ in range(n_layers): 30 | layers.append(block()) 31 | return nn.Sequential(*layers) 32 | 33 | 34 | class ResidualBlock_noBN(nn.Module): 35 | '''Residual block w/o BN 36 | ---Conv-ReLU-Conv-+- 37 | |________________| 38 | ''' 39 | 40 | def __init__(self, nf=64): 41 | super(ResidualBlock_noBN, self).__init__() 42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 44 | 45 | # initialization 46 | initialize_weights([self.conv1, self.conv2], 0.1) 47 | 48 | def forward(self, x): 49 | identity = x 50 | out = F.relu(self.conv1(x), inplace=True) 51 | out = self.conv2(out) 52 | return identity + out 53 | 54 | 55 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 56 | """Warp an image or feature map with optical flow 57 | Args: 58 | x (Tensor): size (N, C, H, W) 59 | flow (Tensor): size (N, H, W, 2), normal value 60 | interp_mode (str): 'nearest' or 'bilinear' 61 | padding_mode (str): 'zeros' or 'border' or 'reflection' 62 | 63 | Returns: 64 | Tensor: warped image or feature map 65 | """ 66 | assert x.size()[-2:] == flow.size()[1:3] 67 | B, C, H, W = x.size() 68 | # mesh grid 69 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 70 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 71 | grid.requires_grad = False 72 | grid = grid.type_as(x) 73 | vgrid = grid + flow 74 | # scale grid to [-1,1] 75 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 76 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 77 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 78 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 79 | return output 80 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/backprojection.m: -------------------------------------------------------------------------------- 1 | function [im_h] = backprojection(im_h, im_l, maxIter) 2 | 3 | [row_l, col_l,~] = size(im_l); 4 | [row_h, col_h,~] = size(im_h); 5 | 6 | p = fspecial('gaussian', 5, 1); 7 | p = p.^2; 8 | p = p./sum(p(:)); 9 | 10 | im_l = double(im_l); 11 | im_h = double(im_h); 12 | 13 | for ii = 1:maxIter 14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic'); 15 | im_diff = im_l - im_l_s; 16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic'); 17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same'); 18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same'); 19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same'); 20 | end 21 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/main_bp.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20bp'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | %tic 19 | im_out = backprojection(im_out, im_LR, max_iter); 20 | %toc 21 | imwrite(im_out, fullfile(save_folder, im_name)); 22 | end 23 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/main_reverse_filter.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20if'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | J = imresize(im_LR,4,'bicubic'); 19 | %tic 20 | for m = 1:max_iter 21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic')); 22 | end 23 | %toc 24 | imwrite(im_out, fullfile(save_folder, im_name)); 25 | end 26 | -------------------------------------------------------------------------------- /codes/scripts/calparameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchsummary import summary 5 | from collections import OrderedDict 6 | import torch 7 | import torch.nn as nn 8 | from . import block as B 9 | 10 | class Discriminator_VGG_128(nn.Module): 11 | def __init__(self, in_nc = 3, base_nf = 64, norm_type='batch', act_type='leakyrelu', mode='CNA'): 12 | super(Discriminator_VGG_128, self).__init__() 13 | # features 14 | # hxw, c 15 | # 128, 64 16 | 17 | conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \ 18 | mode=mode) 19 | conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \ 20 | act_type=act_type, mode=mode) 21 | # 64, 64 22 | conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \ 23 | act_type=act_type, mode=mode) 24 | conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \ 25 | act_type=act_type, mode=mode) 26 | # 32, 128 27 | conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \ 28 | act_type=act_type, mode=mode) 29 | conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \ 30 | act_type=act_type, mode=mode) 31 | # 16, 256 32 | conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ 33 | act_type=act_type, mode=mode) 34 | conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ 35 | act_type=act_type, mode=mode) 36 | # 8, 512 37 | conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \ 38 | act_type=act_type, mode=mode) 39 | conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \ 40 | act_type=act_type, mode=mode) 41 | # 4, 512 42 | self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\ 43 | conv9) 44 | 45 | # classifier 46 | self.classifier = nn.Sequential( 47 | nn.Linear(512 * 9 * 9, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) # patch400 12, patch 256 8 200 6 296 9 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = x.view(x.size(0), -1) 52 | x = self.classifier(x) 53 | return x 54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0 55 | model = Discriminator_VGG_128().to(device) 56 | 57 | summary(model, (1, 28, 28)) -------------------------------------------------------------------------------- /codes/scripts/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | import glob 4 | import pickle 5 | import lmdb 6 | import cv2 7 | 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | from utils.progress_bar import ProgressBar 10 | 11 | # configurations 12 | img_folder = '/sdd/BasicSR_datasets/DIV2K800/DIV2K800/*' # glob matching pattern 13 | lmdb_save_path = '/sdd/BasicSR_datasets/DIV2K800/DIV2K800.lmdb' # must end with .lmdb 14 | 15 | img_list = sorted(glob.glob(img_folder)) 16 | dataset = [] 17 | data_size = 0 18 | 19 | print('Read images...') 20 | pbar = ProgressBar(len(img_list)) 21 | for i, v in enumerate(img_list): 22 | pbar.update('Read {}'.format(v)) 23 | img = cv2.imread(v, cv2.IMREAD_UNCHANGED) 24 | dataset.append(img) 25 | data_size += img.nbytes 26 | env = lmdb.open(lmdb_save_path, map_size=data_size * 10) 27 | print('Finish reading {} images.\nWrite lmdb...'.format(len(img_list))) 28 | 29 | pbar = ProgressBar(len(img_list)) 30 | with env.begin(write=True) as txn: # txn is a Transaction object 31 | for i, v in enumerate(img_list): 32 | pbar.update('Write {}'.format(v)) 33 | base_name = os.path.splitext(os.path.basename(v))[0] 34 | key = base_name.encode('ascii') 35 | data = dataset[i] 36 | if dataset[i].ndim == 2: 37 | H, W = dataset[i].shape 38 | C = 1 39 | else: 40 | H, W, C = dataset[i].shape 41 | meta_key = (base_name + '.meta').encode('ascii') 42 | meta = '{:d}, {:d}, {:d}'.format(H, W, C) 43 | # The encode is only essential in Python 3 44 | txn.put(key, data) 45 | txn.put(meta_key, meta.encode('ascii')) 46 | print('Finish writing lmdb.') 47 | 48 | # create keys cache 49 | keys_cache_file = os.path.join(lmdb_save_path, '_keys_cache.p') 50 | env = lmdb.open(lmdb_save_path, readonly=True, lock=False, readahead=False, meminit=False) 51 | with env.begin(write=False) as txn: 52 | print('Create lmdb keys cache: {}'.format(keys_cache_file)) 53 | keys = [key.decode('ascii') for key, _ in txn.cursor()] 54 | pickle.dump(keys, open(keys_cache_file, "wb")) 55 | print('Finish creating lmdb keys cache.') 56 | -------------------------------------------------------------------------------- /codes/scripts/extract_subimgs_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from multiprocessing import Pool 4 | import time 5 | import numpy as np 6 | import cv2 7 | 8 | 9 | def main(): 10 | GT_dir = '/media/sdc/wlzhang/data/DIV2K_train_HR' 11 | save_GT_dir = '/media/sdc/wlzhang/data/DIV2K800_sub2' 12 | n_thread = 20 13 | 14 | print('Parent process %s.' % os.getpid()) 15 | start = time.time() 16 | 17 | p = Pool(n_thread) 18 | # read all files to a list 19 | all_files = [] 20 | for root, _, fnames in sorted(os.walk(GT_dir)): 21 | full_path = [os.path.join(root, x) for x in fnames] 22 | all_files.extend(full_path) 23 | # cut into subtasks 24 | def chunkify(lst, n): # for non-continuous chunks 25 | return [lst[i::n] for i in range(n)] 26 | 27 | sub_lists = chunkify(all_files, n_thread) 28 | # call workers 29 | for i in range(n_thread): 30 | p.apply_async(worker, args=(sub_lists[i], save_GT_dir)) 31 | print('Waiting for all subprocesses done...') 32 | p.close() 33 | p.join() 34 | end = time.time() 35 | print('All subprocesses done. Using time {} sec.'.format(end - start)) 36 | 37 | 38 | def worker(GT_paths, save_GT_dir): 39 | crop_sz = 480 40 | step = 240 41 | thres_sz = 48 42 | 43 | for GT_path in GT_paths: 44 | base_name = os.path.basename(GT_path) 45 | print(base_name, os.getpid()) 46 | img_GT = cv2.imread(GT_path, cv2.IMREAD_UNCHANGED) 47 | 48 | n_channels = len(img_GT.shape) 49 | if n_channels == 2: 50 | h, w = img_GT.shape 51 | elif n_channels == 3: 52 | h, w, c = img_GT.shape 53 | else: 54 | raise ValueError('Wrong image shape - {}'.format(n_channels)) 55 | 56 | h_space = np.arange(0, h - crop_sz + 1, step) 57 | if h - (h_space[-1] + crop_sz) > thres_sz: 58 | h_space = np.append(h_space, h - crop_sz) 59 | w_space = np.arange(0, w - crop_sz + 1, step) 60 | if w - (w_space[-1] + crop_sz) > thres_sz: 61 | w_space = np.append(w_space, w - crop_sz) 62 | index = 0 63 | for x in h_space: 64 | for y in w_space: 65 | index += 1 66 | if n_channels == 2: 67 | crop_img = img_GT[x:x + crop_sz, y:y + crop_sz] 68 | else: 69 | crop_img = img_GT[x:x + crop_sz, y:y + crop_sz, :] 70 | 71 | crop_img = np.ascontiguousarray(crop_img) 72 | index_str = '{:03d}'.format(index) 73 | # var = np.var(crop_img / 255) 74 | # if var > 0.008: 75 | # print(index_str, var) 76 | cv2.imwrite(os.path.join(save_GT_dir, base_name.replace('.png', \ 77 | '_s'+index_str+'.png')), crop_img, [cv2.IMWRITE_PNG_COMPRESSION, 0]) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /codes/scripts/generate_mod_LR_bic.m: -------------------------------------------------------------------------------- 1 | function generate_mod_LR_bic() 2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. 3 | 4 | %% set parameters 5 | % comment the unnecessary line 6 | input_folder = '/sdd/wlzhang/data/DF2K_HR'; 7 | %save_mod_folder = '/home/wlzhang/BasicSR/data_samples/urban100_mod'; 8 | save_LR_folder = '/sdd/wlzhang/data/DF2K800_bicLRx4'; 9 | % save_bic_folder = ''; 10 | 11 | up_scale = 4; 12 | mod_scale = 4; 13 | 14 | if exist('save_mod_folder', 'var') 15 | if exist(save_mod_folder, 'dir') 16 | disp(['It will cover ', save_mod_folder]); 17 | else 18 | mkdir(save_mod_folder); 19 | end 20 | end 21 | if exist('save_LR_folder', 'var') 22 | if exist(save_LR_folder, 'dir') 23 | disp(['It will cover ', save_LR_folder]); 24 | else 25 | mkdir(save_LR_folder); 26 | end 27 | end 28 | if exist('save_bic_folder', 'var') 29 | if exist(save_bic_folder, 'dir') 30 | disp(['It will cover ', save_bic_folder]); 31 | else 32 | mkdir(save_bic_folder); 33 | end 34 | end 35 | 36 | idx = 0; 37 | filepaths = dir(fullfile(input_folder,'*.*')); 38 | for i = 1 : length(filepaths) 39 | [paths,imname,ext] = fileparts(filepaths(i).name); 40 | if isempty(imname) 41 | disp('Ignore . folder.'); 42 | elseif strcmp(imname, '.') 43 | disp('Ignore .. folder.'); 44 | else 45 | idx = idx + 1; 46 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 47 | fprintf(str_rlt); 48 | % read image 49 | img = imread(fullfile(input_folder, [imname, ext])); 50 | img = im2double(img); 51 | % modcrop 52 | img = modcrop(img, mod_scale); 53 | if exist('save_mod_folder', 'var') 54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); 55 | end 56 | 57 | %sigma = 1.8; 58 | %image_kernel = fspecial('gaussian',21, sigma); % 7x7 Gaussian kernel k_d with width 1.6 59 | %img = imfilter(img,double(image_kernel),'replicate'); % blur 60 | 61 | % LR 62 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 63 | if exist('save_LR_folder', 'var') 64 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 65 | end 66 | % Bicubic 67 | if exist('save_bic_folder', 'var') 68 | im_B = imresize(im_LR, up_scale, 'bicubic'); 69 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png'])); 70 | end 71 | end 72 | end 73 | end 74 | 75 | %% modcrop 76 | function img = modcrop(img, modulo) 77 | if size(img,3) == 1 78 | sz = size(img); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2)); 81 | else 82 | tmpsz = size(img); 83 | sz = tmpsz(1:2); 84 | sz = sz - mod(sz, modulo); 85 | img = img(1:sz(1), 1:sz(2),:); 86 | end 87 | end 88 | -------------------------------------------------------------------------------- /codes/scripts/transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchsummary import summary 5 | from collections import OrderedDict 6 | import torch 7 | import torch.nn as nn 8 | import block as B 9 | import os 10 | import math 11 | import functools 12 | import arch_util as arch_util 13 | 14 | class SRResNet(nn.Module): 15 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4, norm_type= None , act_type='relu', \ 16 | mode='CNA', res_scale=1, upsample_mode='pixelshuffle'): 17 | super(SRResNet, self).__init__() 18 | n_upscale = int(math.log(upscale, 2)) 19 | if upscale == 3: 20 | n_upscale = 1 21 | 22 | fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None) 23 | resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\ 24 | mode=mode, res_scale=res_scale) for _ in range(nb)] 25 | LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode) 26 | 27 | if upsample_mode == 'upconv': 28 | upsample_block = B.upconv_blcok 29 | elif upsample_mode == 'pixelshuffle': 30 | upsample_block = B.pixelshuffle_block 31 | else: 32 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 33 | if upscale == 3: 34 | upsampler = upsample_block(nf, nf, 3, act_type=act_type) 35 | else: 36 | upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)] 37 | HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type) 38 | HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None) 39 | 40 | self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\ 41 | *upsampler, HR_conv0, HR_conv1) 42 | 43 | def forward(self, x): 44 | x = self.model(x) 45 | return x 46 | class mmsrSRResNet(nn.Module): 47 | ''' modified SRResNet''' 48 | 49 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 50 | super(mmsrSRResNet, self).__init__() 51 | self.upscale = upscale 52 | 53 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 54 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 55 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 56 | self.LRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 57 | 58 | # upsampling 59 | if self.upscale == 2: 60 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 61 | self.pixel_shuffle = nn.PixelShuffle(2) 62 | elif self.upscale == 3: 63 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 64 | self.pixel_shuffle = nn.PixelShuffle(3) 65 | elif self.upscale == 4: 66 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 67 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 68 | self.pixel_shuffle = nn.PixelShuffle(2) 69 | 70 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 71 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 72 | 73 | # activation function 74 | self.relu = nn.ReLU(inplace=True) 75 | # self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 76 | 77 | # initialization 78 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 79 | 0.1) 80 | if self.upscale == 4: 81 | arch_util.initialize_weights(self.upconv2, 0.1) 82 | 83 | def forward(self, x): 84 | fea = self.conv_first(x) 85 | out = self.recon_trunk(fea) 86 | out = self.LRconv(out) 87 | 88 | if self.upscale == 4: 89 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea))) 90 | out = self.relu(self.pixel_shuffle(self.upconv2(out))) 91 | elif self.upscale == 3 or self.upscale == 2: 92 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea))) 93 | 94 | out = self.conv_last(self.relu(self.HRconv(out))) 95 | 96 | return out 97 | 98 | 99 | def transfer_network(load_path, network,ordereddict, strict=True): 100 | load_net = torch.load(load_path) 101 | load_net_dict = OrderedDict() # remove unnecessary 'module.' 102 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 103 | load_model_key = [] 104 | for k, v in load_net.items(): 105 | load_net_dict[k] = v 106 | load_model_key.append(k) 107 | 108 | i = 0 109 | for param_tensor in model2.state_dict(): 110 | load_net_clean[param_tensor] = load_net_dict[load_model_key[i]] 111 | print('-------') 112 | print(param_tensor) 113 | print(load_model_key[i]) 114 | i=i+1 115 | print(i) 116 | 117 | torch.save(load_net_clean, '/home/wlzhang/mmsr/experiments/pretrained_models/mmsr_SRResNet_pretrain.pth') 118 | network.load_state_dict(load_net_clean, strict=strict) 119 | 120 | net_old = SRResNet() 121 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0 122 | model = net_old.to(device) 123 | 124 | # print("Model's state_dict:") 125 | # for param_tensor in model.state_dict(): 126 | # print(param_tensor, "\t", model.state_dict()[param_tensor].size()) 127 | 128 | net_new = mmsrSRResNet() 129 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0 130 | model2 = net_new.to(device) 131 | 132 | print("Model2's state_dict:") 133 | ordereddict = [] 134 | for param_tensor in model2.state_dict(): 135 | ordereddict.append(param_tensor) 136 | # print(param_tensor, "\t", model2.state_dict()[param_tensor].size()) 137 | 138 | # print("key state_dict:") 139 | # print(ordereddict) 140 | 141 | transfer_network('/home/wlzhang/mmsr/experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth', net_new, ordereddict) 142 | 143 | 144 | # print("key:") 145 | # print(ordereddict) 146 | # summary(model, (3, 296, 296)) -------------------------------------------------------------------------------- /codes/scripts/transfer_params_MSRResNet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import torch 4 | try: 5 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 6 | import models.archs.SRResNet_arch as SRResNet_arch 7 | except ImportError: 8 | pass 9 | 10 | pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth') 11 | crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3) 12 | crt_net = crt_model.state_dict() 13 | 14 | for k, v in crt_net.items(): 15 | if k in pretrained_net and 'upconv1' not in k: 16 | crt_net[k] = pretrained_net[k] 17 | print('replace ... ', k) 18 | 19 | # x4 -> x3 20 | crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2 21 | crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2 22 | crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2 23 | crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2 24 | crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2 25 | crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2 26 | 27 | torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth') 28 | -------------------------------------------------------------------------------- /codes/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import logging 3 | import time 4 | import argparse 5 | from collections import OrderedDict 6 | 7 | import options.options as option 8 | import utils.util as util 9 | from data.util import bgr2ycbcr 10 | from data import create_dataset, create_dataloader 11 | from models import create_model 12 | 13 | #### options 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 16 | opt = option.parse(parser.parse_args().opt, is_train=False) 17 | opt = option.dict_to_nonedict(opt) 18 | 19 | util.mkdirs( 20 | (path for key, path in opt['path'].items() 21 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 22 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 23 | screen=True, tofile=True) 24 | logger = logging.getLogger('base') 25 | logger.info(option.dict2str(opt)) 26 | 27 | #### Create test dataset and dataloader 28 | test_loaders = [] 29 | for phase, dataset_opt in sorted(opt['datasets'].items()): 30 | test_set = create_dataset(dataset_opt) 31 | test_loader = create_dataloader(test_set, dataset_opt) 32 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 33 | test_loaders.append(test_loader) 34 | 35 | model = create_model(opt) 36 | for test_loader in test_loaders: 37 | test_set_name = test_loader.dataset.opt['name'] 38 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 39 | test_start_time = time.time() 40 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 41 | util.mkdir(dataset_dir) 42 | 43 | test_results = OrderedDict() 44 | test_results['psnr'] = [] 45 | test_results['ssim'] = [] 46 | test_results['psnr_y'] = [] 47 | test_results['ssim_y'] = [] 48 | 49 | for data in test_loader: 50 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True 51 | model.feed_data(data, need_GT=need_GT) 52 | img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] 53 | img_name = osp.splitext(osp.basename(img_path))[0] 54 | 55 | model.test() 56 | visuals = model.get_current_visuals(need_GT=need_GT) 57 | 58 | sr_img = util.tensor2img(visuals['rlt']) # uint8 59 | 60 | # save images 61 | suffix = opt['suffix'] 62 | if suffix: 63 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') 64 | else: 65 | save_img_path = osp.join(dataset_dir, img_name + '.png') 66 | util.save_img(sr_img, save_img_path) 67 | 68 | # calculate PSNR and SSIM 69 | if need_GT: 70 | gt_img = util.tensor2img(visuals['GT']) 71 | sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) 72 | psnr = util.calculate_psnr(sr_img, gt_img) 73 | ssim = util.calculate_ssim(sr_img, gt_img) 74 | test_results['psnr'].append(psnr) 75 | test_results['ssim'].append(ssim) 76 | 77 | if gt_img.shape[2] == 3: # RGB image 78 | sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) 79 | gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) 80 | 81 | psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) 82 | ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) 83 | test_results['psnr_y'].append(psnr_y) 84 | test_results['ssim_y'].append(ssim_y) 85 | logger.info( 86 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'. 87 | format(img_name, psnr, ssim, psnr_y, ssim_y)) 88 | else: 89 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) 90 | else: 91 | logger.info(img_name) 92 | 93 | if need_GT: # metrics 94 | # Average PSNR/SSIM results 95 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 96 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 97 | logger.info( 98 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( 99 | test_set_name, ave_psnr, ave_ssim)) 100 | if test_results['psnr_y'] and test_results['ssim_y']: 101 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) 102 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) 103 | logger.info( 104 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'. 105 | format(ave_psnr_y, ave_ssim_y)) 106 | -------------------------------------------------------------------------------- /codes/utils/README.md: -------------------------------------------------------------------------------- 1 | # Utils 2 | 3 | ## Tensorboard Logger (tb_logger) 4 | 5 | [tensorboard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) is a nice visualization tool for visualizing/comparing training loss, validation PSNR and etc. 6 | 7 | You can turn it on/off in json option file with the key: `use_tb_logger`. 8 | 9 | ### Install 10 | 1. `pip install tensorflow` - Maybe it is the easiest way to install tensorboard, though we will install tensorflow at the same time. 11 | 1. `pip install tensorboard_logger` - install [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger) 12 | 13 | ### Run 14 | 1. In terminal: `tensorboard --logdir xxx/xxx`. 15 | 1. Open TensorBoard UI at http://localhost:6006 in your browser 16 | -------------------------------------------------------------------------------- /codes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__init__.py -------------------------------------------------------------------------------- /codes/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/rank_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__pycache__/rank_test.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utils/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/calc_NIQE.m: -------------------------------------------------------------------------------- 1 | function NIQE = calc_scores(input_image_path,shave_width) 2 | 3 | %% Loading model 4 | load modelparameters.mat 5 | blocksizerow = 96; 6 | blocksizecol = 96; 7 | blockrowoverlap = 0; 8 | blockcoloverlap = 0; 9 | %% Calculating scores 10 | NIQE = []; 11 | input_image = convert_shave_image(imread(input_image_path),shave_width); 12 | 13 | % Calculating scores 14 | NIQE = computequality(input_image,blocksizerow,blocksizecol,... 15 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam); 16 | 17 | end 18 | 19 | 20 | -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/computefeature.m: -------------------------------------------------------------------------------- 1 | function feat = computefeature(structdis) 2 | 3 | % Input - MSCn coefficients 4 | % Output - Compute the 18 dimensional feature vector 5 | 6 | feat = []; 7 | 8 | 9 | 10 | [alpha betal betar] = estimateaggdparam(structdis(:)); 11 | 12 | feat = [feat;alpha;(betal+betar)/2]; 13 | 14 | shifts = [ 0 1;1 0 ;1 1;1 -1]; 15 | 16 | for itr_shift =1:4 17 | 18 | shifted_structdis = circshift(structdis,shifts(itr_shift,:)); 19 | pair = structdis(:).*shifted_structdis(:); 20 | [alpha betal betar] = estimateaggdparam(pair); 21 | meanparam = (betar-betal)*(gamma(2/alpha)/gamma(1/alpha)); 22 | feat = [feat;alpha;meanparam;betal;betar]; 23 | 24 | end 25 | -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/computemean.m: -------------------------------------------------------------------------------- 1 | function val = computemean(patch) 2 | 3 | val = mean2(patch); -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/computequality.m: -------------------------------------------------------------------------------- 1 | function quality = computequality(im,blocksizerow,blocksizecol,... 2 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam) 3 | 4 | % Input 5 | % im - Image whose quality needs to be computed 6 | % blocksizerow - Height of the blocks in to which image is divided 7 | % blocksizecol - Width of the blocks in to which image is divided 8 | % blockrowoverlap - Amount of vertical overlap between blocks 9 | % blockcoloverlap - Amount of horizontal overlap between blocks 10 | % mu_prisparam - mean of multivariate Gaussian model 11 | % cov_prisparam - covariance of multivariate Gaussian model 12 | 13 | % For good performance, it is advisable to use make the multivariate Gaussian model 14 | % using same size patches as the distorted image is divided in to 15 | 16 | % Output 17 | %quality - Quality of the input distorted image 18 | 19 | % Example call 20 | %quality = computequality(im,96,96,0,0,mu_prisparam,cov_prisparam) 21 | 22 | % --------------------------------------------------------------- 23 | %Number of features 24 | % 18 features at each scale 25 | featnum = 18; 26 | %---------------------------------------------------------------- 27 | %Compute features 28 | if(size(im,3)==3) 29 | im = rgb2gray(im); 30 | end 31 | im = double(im); 32 | [row col] = size(im); 33 | block_rownum = floor(row/blocksizerow); 34 | block_colnum = floor(col/blocksizecol); 35 | 36 | im = im(1:block_rownum*blocksizerow,1:block_colnum*blocksizecol); 37 | [row col] = size(im); 38 | block_rownum = floor(row/blocksizerow); 39 | block_colnum = floor(col/blocksizecol); 40 | im = im(1:block_rownum*blocksizerow, ... 41 | 1:block_colnum*blocksizecol); 42 | window = fspecial('gaussian',7,7/6); 43 | window = window/sum(sum(window)); 44 | scalenum = 2; 45 | warning('off') 46 | 47 | feat = []; 48 | 49 | 50 | for itr_scale = 1:scalenum 51 | 52 | 53 | mu = imfilter(im,window,'replicate'); 54 | mu_sq = mu.*mu; 55 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq)); 56 | structdis = (im-mu)./(sigma+1); 57 | 58 | 59 | 60 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ... 61 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ... 62 | @computefeature); 63 | feat_scale = reshape(feat_scale,[featnum .... 64 | size(feat_scale,1)*size(feat_scale,2)/featnum]); 65 | feat_scale = feat_scale'; 66 | 67 | 68 | if(itr_scale == 1) 69 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ... 70 | [blockrowoverlap blockcoloverlap],@computemean); 71 | sharpness = sharpness(:); 72 | end 73 | 74 | 75 | feat = [feat feat_scale]; 76 | 77 | im =imresize(im,0.5); 78 | 79 | end 80 | 81 | 82 | % Fit a MVG model to distorted patch features 83 | distparam = feat; 84 | mu_distparam = nanmean(distparam); 85 | cov_distparam = nancov(distparam); 86 | 87 | % Compute quality 88 | invcov_param = pinv((cov_prisparam+cov_distparam)/2); 89 | quality = sqrt((mu_prisparam-mu_distparam)* ... 90 | invcov_param*(mu_prisparam-mu_distparam)'); 91 | 92 | -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/convert_shave_image.m: -------------------------------------------------------------------------------- 1 | function shaved = convert_shave_image(input_image,shave_width) 2 | 3 | % Converting to y channel only 4 | image_ychannel = rgb2ycbcr(input_image); 5 | image_ychannel = image_ychannel(:,:,1); 6 | 7 | % Shaving image 8 | shaved = image_ychannel(1+shave_width:end-shave_width,... 9 | 1+shave_width:end-shave_width); 10 | 11 | end -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/estimateaggdparam.m: -------------------------------------------------------------------------------- 1 | function [alpha betal betar] = estimateaggdparam(vec) 2 | 3 | 4 | gam = 0.2:0.001:10; 5 | r_gam = ((gamma(2./gam)).^2)./(gamma(1./gam).*gamma(3./gam)); 6 | 7 | 8 | leftstd = sqrt(mean((vec(vec<0)).^2)); 9 | rightstd = sqrt(mean((vec(vec>0)).^2)); 10 | 11 | gammahat = leftstd/rightstd; 12 | rhat = (mean(abs(vec)))^2/mean((vec).^2); 13 | rhatnorm = (rhat*(gammahat^3 +1)*(gammahat+1))/((gammahat^2 +1)^2); 14 | [min_difference, array_position] = min((r_gam - rhatnorm).^2); 15 | alpha = gam(array_position); 16 | 17 | betal = leftstd *sqrt(gamma(1/alpha)/gamma(3/alpha)); 18 | betar = rightstd*sqrt(gamma(1/alpha)/gamma(3/alpha)); 19 | 20 | 21 | -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/estimatemodelparam.m: -------------------------------------------------------------------------------- 1 | function [mu_prisparam cov_prisparam] = estimatemodelparam(folderpath,... 2 | blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,sh_th) 3 | 4 | % Input 5 | % folderpath - Folder containing the pristine images 6 | % blocksizerow - Height of the blocks in to which image is divided 7 | % blocksizecol - Width of the blocks in to which image is divided 8 | % blockrowoverlap - Amount of vertical overlap between blocks 9 | % blockcoloverlap - Amount of horizontal overlap between blocks 10 | % sh_th - The sharpness threshold level 11 | %Output 12 | %mu_prisparam - mean of multivariate Gaussian model 13 | %cov_prisparam - covariance of multivariate Gaussian model 14 | 15 | % Example call 16 | 17 | %[mu_prisparam cov_prisparam] = estimatemodelparam('pristine',96,96,0,0,0.75); 18 | 19 | 20 | %---------------------------------------------------------------- 21 | % Find the names of images in the folder 22 | current = pwd; 23 | cd(sprintf('%s',folderpath)) 24 | names = ls; 25 | names = names(3:end,:); 26 | cd(current) 27 | % --------------------------------------------------------------- 28 | %Number of features 29 | % 18 features at each scale 30 | featnum = 18; 31 | % --------------------------------------------------------------- 32 | % Make the directory for storing the features 33 | mkdir(sprintf('local_risquee_prisfeatures')) 34 | % --------------------------------------------------------------- 35 | % Compute pristine image features 36 | for itr = 1:size(names,1) 37 | itr 38 | im = imread(sprintf('%s\\%s',folderpath,names(itr,:))); 39 | if(size(im,3)==3) 40 | im = rgb2gray(im); 41 | end 42 | im = double(im); 43 | [row col] = size(im); 44 | block_rownum = floor(row/blocksizerow); 45 | block_colnum = floor(col/blocksizecol); 46 | im = im(1:block_rownum*blocksizerow, ... 47 | 1:block_colnum*blocksizecol); 48 | window = fspecial('gaussian',7,7/6); 49 | window = window/sum(sum(window)); 50 | scalenum = 2; 51 | warning('off') 52 | 53 | feat = []; 54 | 55 | 56 | for itr_scale = 1:scalenum 57 | 58 | 59 | mu = imfilter(im,window,'replicate'); 60 | mu_sq = mu.*mu; 61 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq)); 62 | structdis = (im-mu)./(sigma+1); 63 | 64 | 65 | 66 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ... 67 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ... 68 | @computefeature); 69 | feat_scale = reshape(feat_scale,[featnum .... 70 | size(feat_scale,1)*size(feat_scale,2)/featnum]); 71 | feat_scale = feat_scale'; 72 | 73 | 74 | if(itr_scale == 1) 75 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ... 76 | [blockrowoverlap blockcoloverlap],@computemean); 77 | sharpness = sharpness(:); 78 | end 79 | 80 | 81 | feat = [feat feat_scale]; 82 | 83 | im =imresize(im,0.5); 84 | 85 | end 86 | 87 | save(sprintf('local_risquee_prisfeatures\\prisfeatures_local%d.mat',... 88 | itr),'feat','sharpness'); 89 | end 90 | 91 | 92 | 93 | %---------------------------------------------- 94 | % Load pristine image features 95 | prisparam = []; 96 | current = pwd; 97 | cd(sprintf('%s','local_risquee_prisfeatures')) 98 | names = ls; 99 | names = names(3:end,:); 100 | cd(current) 101 | for itr = 1:size(names,1) 102 | % Load the features and select the only features 103 | load(sprintf('local_risquee_prisfeatures\\%s',strtrim(names(itr,:)))); 104 | IX = find(sharpness(:) >sh_th*max(sharpness(:))); 105 | feat = feat(IX,:); 106 | prisparam = [prisparam; feat]; 107 | 108 | end 109 | %---------------------------------------------- 110 | % Compute model parameters 111 | mu_prisparam = nanmean(prisparam); 112 | cov_prisparam = nancov(prisparam); 113 | %---------------------------------------------- 114 | % Save features in the mat file 115 | save('modelparameters_new.mat','mu_prisparam','cov_prisparam'); 116 | %---------------------------------------------- 117 | -------------------------------------------------------------------------------- /codes/utils/perceptualmetric/modelparameters.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/perceptualmetric/modelparameters.mat -------------------------------------------------------------------------------- /codes/utils/rank_test.py: -------------------------------------------------------------------------------- 1 | 2 | def rank_pair_test(predict_file ,label_file): 3 | predict_score = {} 4 | label_score = {} 5 | f1 = open(predict_file ,'r') 6 | f2 = open(label_file ,'r') 7 | 8 | for line in f1.readlines(): 9 | line = line.strip().split() 10 | img_name = line[0] 11 | img_score = line[1] 12 | predict_score[img_name] = float(img_score) 13 | 14 | for line in f2.readlines(): 15 | line = line.strip().split() 16 | img_name = line[0] 17 | img_score = line[1] 18 | label_score[img_name] = float(img_score) 19 | 20 | keys_list = list(predict_score.keys()) 21 | keys_list.sort() 22 | 23 | cursor = keys_list[0].split('_')[0] 24 | class_num = 0 25 | for key in keys_list: 26 | if cursor == key.split('_')[0]: 27 | class_num += 1 28 | else: 29 | break 30 | count = 0 31 | positive = 0 32 | for idx in range(0 ,len(keys_list) ,class_num): 33 | for i in range(idx ,idx +class_num): 34 | for j in range( i +1 ,idx +class_num): 35 | 36 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1 37 | 38 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1 39 | 40 | count += 1 41 | if real_rank == predict_rank: 42 | positive += 1 43 | 44 | # print('%d/%d ' %(positive ,count)) 45 | accuracy = positive /count 46 | # print('Aligned Pair Accuracy: %f ' %accuracy) 47 | 48 | count1 = 1 49 | count2 = 1 50 | positive1 = 0 51 | positive2 = 0 52 | 53 | for idx in range(0 ,len(keys_list) ,class_num): 54 | 55 | i = idx 56 | j = i+ 1 57 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1 58 | 59 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1 60 | 61 | count += 1 62 | if real_rank == 1: 63 | count1 += 1 64 | if real_rank == predict_rank: 65 | positive1 += 1 66 | if real_rank == -1: 67 | count2 += 1 68 | if real_rank == predict_rank: 69 | positive2 += 1 70 | 71 | # print('%d/%d' % (positive1, count1)) 72 | accuracy_esrganbig1 = positive1 / count1 73 | # print('accuracy_esrganbig: %f' % accuracy_esrganbig1) 74 | # 75 | # print('%d/%d' % (positive2, count2)) 76 | accuracy_srganbig1 = positive2 / count2 77 | # print('accuracy_srganbig: %f' % accuracy_srganbig1) 78 | count1 = 1 79 | count2 = 1 80 | positive1 = 0 81 | positive2 = 0 82 | for idx in range(0, len(keys_list), class_num): 83 | 84 | i = idx 85 | j = i + 2 86 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1 87 | 88 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1 89 | 90 | count += 1 91 | if real_rank == 1: 92 | count1 += 1 93 | if real_rank == predict_rank: 94 | positive1 += 1 95 | if real_rank == -1: 96 | count2 += 1 97 | if real_rank == predict_rank: 98 | positive2 += 1 99 | 100 | # print('%d/%d' % (positive1, count1)) 101 | # accuracy_esrganbig = positive1 / count1 102 | # print('accuracy2: %f' % accuracy_esrganbig) 103 | # 104 | # print('%d/%d' % (positive2, count2)) 105 | # accuracy_srganbig = positive2 / count2 106 | # print('accuracy2: %f' % accuracy_srganbig) 107 | 108 | count1 = 1 109 | count2 = 1 110 | positive1 = 0 111 | positive2 = 0 112 | 113 | for idx in range(0, len(keys_list), class_num): 114 | 115 | i = idx + 1 116 | j = i + 1 117 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1 118 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1 119 | 120 | count += 1 121 | if real_rank == 1: 122 | count1 += 1 123 | if real_rank == predict_rank: 124 | positive1 += 1 125 | if real_rank == -1: 126 | count2 += 1 127 | if real_rank == predict_rank: 128 | positive2 += 1 129 | 130 | # print('%d/%d' % (positive1, count1)) 131 | # accuracy_esrganbig = positive1 / count1 132 | # print('accuracy3: %f' % accuracy_esrganbig) 133 | 134 | # print('%d/%d' % (positive2, count2)) 135 | # accuracy_srganbig = positive2 / count2 136 | # print('accuracy3: %f' % accuracy_srganbig) 137 | 138 | return accuracy, accuracy_esrganbig1, accuracy_srganbig1 -------------------------------------------------------------------------------- /datasets/generate_rankdataset/README.md: -------------------------------------------------------------------------------- 1 | ## How To Prepare Rank Dataset 2 | ### Prepare perceptual data 3 | 1. Prepare Three levels SR Models. You can download the [SRResNet (SRResNet_bicx4_in3nf64nb16.pth), 4 | SRGAN (SRGAN.pth), ESRGAN (ESRGAN_SuperSR.pth)] from 5 | [Google Drive](https://drive.google.com/drive/folders/16DkwrBa4cIqAoTbGU_bKMYoATcXC4IT6?usp=sharing) 6 | or [Baidu Drive](https://pan.baidu.com/s/1HFZokeAWne9oUkmJBnGr-A). You could place them in [`./experiments/pretrained_models/`](../../master/experiments/pretrained_models/). 7 | 8 | 2. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](https://github.com/LimBee/NTIRE2017) 9 | from [Google Drive](https://drive.google.com/drive/folders/1B-uaxvV9qeuQ-t7MFiN1oEdA6dKnj2vW?usp=sharing) or 10 | [Baidu Drive](https://pan.baidu.com/s/1CFIML6KfQVYGZSNFrhMXmA) 11 | 3. Generate Three level images using 'How to test' with [`codes/options/test/test_RankSRGAN.yml`](../../master/codes/options/test/test_RankSRGAN.yml) 12 | ### Generate rank dataset 13 | 4. **Training dataset:** Use [`./datasets/generate_rankdataset/generate_rankdataset.m`](../../master/datasets/generate_rankdataset/generate_rankdataset.m) 14 | to generate three level training patchs. 15 | 5. **Validation dataset:** Use [`./datasets/generate_rankdataset/move_valid.py`](../../master/datasets/generate_rankdataset/move_valid.py) 16 | to generate three level patchs. 17 | 6. **Rank label:** Use [`./datasets/generate_rankdataset/generate_train_ranklabel.m`](../../master/datasets/generate_rankdataset/generate_train_ranklabel.m) 18 | to generate Training Rank label (NIQE). 19 | Use [`./datasets/generate_rankdataset/generate_valid_ranklabel.m`](../../master/datasets/generate_rankdataset/generate_valid_ranklabel.m) 20 | to generate Validation Rank label (NIQE). 21 | 22 | 23 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/generate_rankdataset.m: -------------------------------------------------------------------------------- 1 | 2 | Level1_folder = '/home/wlzhang/data/DIV2K/DIV2K_train_ESRGAN/' 3 | Level2_folder = '/home/wlzhang/data/DIV2K/DIV2K_train_srgan/' 4 | Level3_folder = '/home/wlzhang/data/DIV2K/DIV2K_train_srres/' 5 | 6 | Level1_filepaths = dir(fullfile(Level1_folder,'*.png')); 7 | Level2_filepaths = dir(fullfile(Level2_folder,'*.png')); 8 | Level3_filepaths = dir(fullfile(Level3_folder,'*.png')); 9 | 10 | Level1_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_esrgan'; 11 | Level2_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srgan'; 12 | Level3_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srres'; 13 | 14 | mkdir(Level3_patchsave_path) 15 | mkdir(Level2_patchsave_path) 16 | mkdir(Level1_patchsave_path) 17 | 18 | addpath utils 19 | 20 | patch_sz = 296; 21 | stride = 148; %300 148 200 22 | blocksizerow = 96; 23 | blocksizecol = 96; 24 | blockrowoverlap = 0; 25 | blockcoloverlap = 0; 26 | 27 | total_count = 0; 28 | selected_count = 0; 29 | display_flag = 0; 30 | save_count = 0; 31 | count_class = 0; 32 | fprintf('-------- Strating -----------'); 33 | for k = 1 : length(Level1_filepaths) 34 | tic; 35 | fprintf('Processing img: %s\n',Level1_filepaths(k).name); 36 | fprintf('Processing srganimg: %s\n',Level2_filepaths(k).name); 37 | 38 | level1_img = imread(fullfile(Level1_folder,Level1_filepaths(k).name)); 39 | level2_img = imread(fullfile(Level2_folder,Level2_filepaths(k).name)); 40 | %srres_img = imread(fullfile(Level3_folder,srres_filepaths(k).name)); 41 | 42 | if display_flag == 1 43 | subplot(1,2,1); 44 | imshow(level1_img); 45 | end 46 | img_height = size(level1_img,1); 47 | img_width = size(level1_img,2); 48 | 49 | h_num = ceil((img_height-patch_sz)/stride); 50 | w_num = ceil((img_width-patch_sz)/stride); 51 | count = 0; 52 | patch_list = zeros(patch_sz,patch_sz,3,1); 53 | 54 | % esrgan_patch_list1 = zeros(patch_sz,patch_sz,3,200); 55 | % srres_patch_list1 = zeros(patch_sz,patch_sz,3,200); 56 | %srgan_patch_list1 = zeros(patch_sz,patch_sz,3,200); 57 | 58 | location_list = zeros(1,2,15); 59 | i = 0; 60 | for h = 1:stride:img_height-patch_sz 61 | for w = 1:stride:img_width-patch_sz 62 | count = count +1; 63 | total_count = total_count + 1; 64 | level1_img_patch = level1_img(h:h+patch_sz-1,w:w+patch_sz-1,:); 65 | if(size(level1_img_patch,3)==3) 66 | im = rgb2gray(level1_img_patch); 67 | end 68 | im = double(im); 69 | 70 | sharpness = compute_sharpness(im,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap); 71 | 72 | if sharpness >= 40 73 | i = i+1; 74 | if display_flag == 1 75 | subplot(1,2,1); 76 | rectangle('Position',[w,h,patch_sz,patch_sz],'edgecolor','r'); 77 | end 78 | selected_count = selected_count + 1; 79 | level1_patch_list(:,:,:,i) = level1_img_patch; 80 | location_list(:,:,i) = [h,w]; 81 | 82 | else 83 | if display_flag == 1 84 | subplot(1,2,1); 85 | rectangle('Position',[w,h,patch_sz,patch_sz],'edgecolor','k'); 86 | end 87 | end 88 | 89 | 90 | end 91 | end 92 | fprintf('patch numbel:%d ',i); 93 | count1 = 0; 94 | count2 = 0; 95 | count3 = 0; 96 | for j = 1:i 97 | 98 | level1_save_patch = uint8(level1_patch_list(:,:,:,j)); 99 | level1_NIQE = calc_NIQE(level1_save_patch); 100 | 101 | patch_h = location_list(1,1,j); 102 | patch_w = location_list(1,2,j); 103 | Level3_img_patch=get_sigle_patch(Level3_folder,Level1_filepaths(k).name,... 104 | Level3_patchsave_path,... 105 | [num2str(save_count) '_srres.png'],... 106 | patch_h,patch_w,patch_sz); 107 | level2_img_patch=get_sigle_patch(Level2_folder,Level1_filepaths(k).name,... 108 | Level2_patchsave_path,... 109 | [num2str(save_count) '_srgan.png'],... 110 | patch_h,patch_w,patch_sz); 111 | level2_NIQE = calc_NIQE(level2_img_patch); 112 | Level3_NIQE = calc_NIQE(Level3_img_patch); 113 | 114 | if abs(level2_NIQE - level1_NIQE) > 0.1 115 | count1 = count1+1; 116 | level1_patch_list1(:,:,:,count1) = level1_save_patch; 117 | level2_patch_list1(:,:,:,count1) = level2_img_patch; 118 | Level3_patch_list1(:,:,:,count1) = Level3_img_patch; 119 | end 120 | 121 | end 122 | fprintf('distance good:%d ',count1); 123 | if count1 < 200 %200 124 | for idx= 1:count1 125 | save_count = save_count + 1; 126 | 127 | save_name = [num2str(save_count) '_esrgan.png']; 128 | level1_patch = uint8(level1_patch_list1(:,:,:,idx)); 129 | imwrite(level1_patch,fullfile(Level1_patchsave_path,save_name)); 130 | 131 | save_name = [num2str(save_count) '_srgan.png']; 132 | level2_patch = uint8(level2_patch_list1(:,:,:,idx)); 133 | imwrite(level2_patch,fullfile(Level2_patchsave_path,save_name)); 134 | 135 | save_name = [num2str(save_count) '_srres.png']; 136 | Level3_patch = uint8(Level3_patch_list1(:,:,:,idx)); 137 | imwrite(Level3_patch,fullfile(Level3_patchsave_path,save_name)); 138 | 139 | end 140 | else 141 | rand_order = randperm(count1); 142 | for idx= 1:200 143 | save_count = save_count + 1; 144 | 145 | save_name = [num2str(save_count) '_esrgan.png']; 146 | level1_patch = uint8(level1_patch_list1(:,:,:,rand_order(idx))); 147 | imwrite(level1_patch,fullfile(Level1_patchsave_path,save_name)); 148 | 149 | save_name = [num2str(save_count) '_srgan.png']; 150 | level2_patch = uint8(level2_patch_list1(:,:,:,rand_order(idx))); 151 | imwrite(level2_patch,fullfile(Level2_patchsave_path,save_name)); 152 | 153 | save_name = [num2str(save_count) '_srres.png']; 154 | Level3_patch = uint8(Level3_patch_list1(:,:,:,rand_order(idx))); 155 | imwrite(Level3_patch,fullfile(Level3_patchsave_path,save_name)); 156 | 157 | end 158 | end 159 | fprintf('Current save patches:%d\n',save_count); 160 | 161 | end 162 | toc; 163 | fprintf('Total image patch: %d\n',total_count); 164 | fprintf('Selected image patch: %d\n',selected_count); 165 | fprintf('Generated image patch: %d\n',save_count); 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/generate_train_ranklabel.m: -------------------------------------------------------------------------------- 1 | 2 | %% set path 3 | start = clock; 4 | data_path = '../../data/Rank_dataset_test/'; 5 | 6 | level1_path = [data_path,'DF2K_train_patch_esrgan']; 7 | level2_path = [data_path,'DF2K_train_patch_srgan']; 8 | level3_path = [data_path,'DF2K_train_patch_srres']; 9 | 10 | ranklabel_path = [data_path,'/DF2K_train_NIQE.txt']; 11 | 12 | level1_dir = fullfile(pwd,level1_path); 13 | level2_dir = fullfile(pwd,level2_path); 14 | level3_dir = fullfile(pwd,level3_path); 15 | 16 | % Number of pixels to shave off image borders when calcualting scores 17 | shave_width = 4; 18 | 19 | % Set verbose option 20 | verbose = true; 21 | %% Calculate scores and save 22 | addpath utils 23 | addpath(genpath(fullfile(pwd,'utils'))); 24 | 25 | %% Reading file list 26 | level1_file_list = dir([level1_dir,'/*.png']); 27 | level2_file_list = dir([level2_path,'/*.png']); 28 | level3_file_list = dir([level3_path,'/*.png']); 29 | 30 | im_num = length(level1_file_list) 31 | %fprintf(' %f\n',im_num); 32 | 33 | %% Calculating scores 34 | txtfp = fopen(ranklabel_path,'w'); 35 | tic; 36 | pp = parpool('local',28); 37 | pp.IdleTimeout = 9800 38 | disp('Already initialized'); %Strating 39 | fprintf('-------- Strating -----------'); 40 | parfor ii=(1:im_num) 41 | [scoresname,scoresniqe] = parcal_niqe(ii,level3_dir,level3_file_list,im_num) 42 | level3_name{ii} = scoresname; 43 | level3_niqe(ii) = scoresniqe; 44 | 45 | end 46 | parfor ii=(1:im_num) 47 | [scoresname,scoresniqe] = parcal_niqe(ii,level2_dir,level2_file_list,im_num) 48 | level2_name{ii} = scoresname; 49 | level2_niqe(ii) = scoresniqe; 50 | end 51 | parfor ii=(1:im_num) 52 | [scoresname,scoresniqe] = parcal_niqe(ii,level1_dir,level1_file_list,im_num) 53 | level1_name{ii} = scoresname; 54 | level1_niqe(ii) = scoresniqe; 55 | end 56 | 57 | toc; 58 | delete(pp) 59 | txtfp = fopen(ranklabel_path,'w'); 60 | for ii=(1:im_num) 61 | fprintf(txtfp,level3_name{ii}); 62 | fprintf(txtfp,' %f\n',level3_niqe(ii)); 63 | fprintf(txtfp,level2_name{ii}); 64 | fprintf(txtfp,' %f\n',level2_niqe(ii)); 65 | fprintf(txtfp,level1_name{ii}); 66 | fprintf(txtfp,' %f\n',level1_niqe(ii)); 67 | end 68 | 69 | fclose(txtfp); 70 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/generate_valid_ranklabel.m: -------------------------------------------------------------------------------- 1 | 2 | %% set path 3 | start = clock; 4 | data_path = '../../data/Rank_dataset_test/'; 5 | 6 | level1_path = [data_path,'DF2K_valid_patch_esrgan']; 7 | level2_path = [data_path,'DF2K_valid_patch_srgan']; 8 | level3_path = [data_path,'DF2K_valid_patch_srres']; 9 | 10 | ranklabel_path = [data_path,'/DF2K_train_NIQE.txt']; 11 | 12 | level1_dir = fullfile(pwd,level1_path); 13 | level2_dir = fullfile(pwd,level2_path); 14 | level3_dir = fullfile(pwd,level3_path); 15 | 16 | % Number of pixels to shave off image borders when calcualting scores 17 | shave_width = 4; 18 | 19 | % Set verbose option 20 | verbose = true; 21 | %% Calculate scores and save 22 | addpath utils 23 | addpath(genpath(fullfile(pwd,'utils'))); 24 | 25 | %% Reading file list 26 | level1_file_list = dir([level1_dir,'/*.png']); 27 | level2_file_list = dir([level2_path,'/*.png']); 28 | level3_file_list = dir([level3_path,'/*.png']); 29 | 30 | im_num = length(level1_file_list) 31 | %fprintf(' %f\n',im_num); 32 | 33 | %% Calculating scores 34 | txtfp = fopen(ranklabel_path,'w'); 35 | tic; 36 | pp = parpool('local',28); 37 | pp.IdleTimeout = 9800 38 | disp('Already initialized'); %Strating 39 | fprintf('-------- Strating -----------'); 40 | parfor ii=(1:im_num) 41 | [scoresname,scoresniqe] = parcal_niqe(ii,level3_dir,level3_file_list,im_num) 42 | level3_name{ii} = scoresname; 43 | level3_niqe(ii) = scoresniqe; 44 | 45 | end 46 | parfor ii=(1:im_num) 47 | [scoresname,scoresniqe] = parcal_niqe(ii,level2_dir,level2_file_list,im_num) 48 | level2_name{ii} = scoresname; 49 | level2_niqe(ii) = scoresniqe; 50 | end 51 | parfor ii=(1:im_num) 52 | [scoresname,scoresniqe] = parcal_niqe(ii,level1_dir,level1_file_list,im_num) 53 | level1_name{ii} = scoresname; 54 | level1_niqe(ii) = scoresniqe; 55 | end 56 | 57 | toc; 58 | delete(pp) 59 | txtfp = fopen(ranklabel_path,'w'); 60 | for ii=(1:im_num) 61 | fprintf(txtfp,level3_name{ii}); 62 | fprintf(txtfp,' %f\n',level3_niqe(ii)); 63 | fprintf(txtfp,level2_name{ii}); 64 | fprintf(txtfp,' %f\n',level2_niqe(ii)); 65 | fprintf(txtfp,level1_name{ii}); 66 | fprintf(txtfp,' %f\n',level1_niqe(ii)); 67 | end 68 | 69 | fclose(txtfp); 70 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/move_valid.py: -------------------------------------------------------------------------------- 1 | import os, random, shutil 2 | Level1_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_esrgan/'; 3 | Level2_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srgan/'; 4 | Level3_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srres/'; 5 | 6 | Level1_valid_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_valid_patch_esrgan/'; 7 | Level2_valid_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_valid_patch_srgan/'; 8 | Level3_valid_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_valid_patch_srres/'; 9 | 10 | if not os.path.exists(Level1_valid_patchsave_path): 11 | os.makedirs(Level1_valid_patchsave_path) 12 | else: 13 | print('exists') 14 | 15 | if not os.path.exists(Level2_valid_patchsave_path): 16 | os.makedirs(Level2_valid_patchsave_path) 17 | else: 18 | print('exists') 19 | 20 | if not os.path.exists(Level3_valid_patchsave_path): 21 | os.makedirs(Level3_valid_patchsave_path) 22 | else: 23 | print('exists') 24 | 25 | pathDir = os.listdir(Level1_patchsave_path) #取图片的原始路径 26 | filenumber=len(pathDir) 27 | rate=0.1 28 | picknumber=int(filenumber*rate) 29 | 30 | sample = random.sample(pathDir, picknumber) 31 | 32 | 33 | for name in sample: 34 | 35 | name = "".join(name) 36 | name = name.split('_') 37 | print(name[0]) 38 | 39 | shutil.move(Level1_patchsave_path+name[0]+'_esrgan.png', Level1_valid_patchsave_path) 40 | shutil.move(Level2_patchsave_path+name[0]+'_srgan.png', Level2_valid_patchsave_path) 41 | shutil.move(Level3_patchsave_path+name[0]+'_srres.png', Level3_valid_patchsave_path) 42 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/calc_NIQE.m: -------------------------------------------------------------------------------- 1 | function NIQE = calc_NIQE(input_image) 2 | 3 | addpath(genpath(fullfile(pwd,'utils'))); 4 | 5 | %% Loading model 6 | load modelparameters.mat 7 | blocksizerow = 96; 8 | blocksizecol = 96; 9 | blockrowoverlap = 0; 10 | blockcoloverlap = 0; 11 | 12 | %% Reading file list 13 | %scores = struct([]); 14 | 15 | 16 | % Calculating scores 17 | NIQE = computequality(input_image,blocksizerow,blocksizecol,... 18 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam); 19 | % perceptual_score = ([scores(ii).NIQE] + (10 - [scores(ii).Ma])) / 2; 20 | % perceptual_score = scores(ii).NIQE; 21 | % fprintf([' perceptual scores is: ',num2str(perceptual_score),' ',scores(ii).name,'']); 22 | 23 | end 24 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/compute_sharpness.m: -------------------------------------------------------------------------------- 1 | function sharpness = compute_sharpness(im,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap) 2 | 3 | window = fspecial('gaussian',7,7/6); 4 | window = window/sum(sum(window)); 5 | 6 | sigma = sqrt(imfilter((im-imfilter(im,window,'replicate')).*(im-imfilter(im,window,'replicate')),window,'replicate')); 7 | paper_sharpness = blkproc(sigma,[blocksizerow blocksizecol],[blockrowoverlap blockcoloverlap],@computemean); 8 | paper_sharpness = paper_sharpness(:); 9 | sharpness = sum(paper_sharpness); 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | end 20 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/computemean.m: -------------------------------------------------------------------------------- 1 | function val = computemean(patch) 2 | 3 | val = mean2(patch); -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/convert_shave_image.m: -------------------------------------------------------------------------------- 1 | function shaved = convert_shave_image(input_image,shave_width) 2 | 3 | % Converting to y channel only 4 | image_ychannel = rgb2ycbcr(input_image); 5 | image_ychannel = image_ychannel(:,:,1); 6 | 7 | % Shaving image 8 | shaved = image_ychannel(1+shave_width:end-shave_width,... 9 | 1+shave_width:end-shave_width); 10 | 11 | end -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/get_sigle_patch.m: -------------------------------------------------------------------------------- 1 | function img_patch=get_sigle_patch(file_path,file_name,... 2 | save_dir,save_name,h,w,patch_sz) 3 | 4 | img = imread(fullfile(file_path,file_name)); 5 | img_patch = img(h:h+patch_sz-1,w:w+patch_sz-1,:); 6 | 7 | 8 | end 9 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/niqe_release/computefeature.m: -------------------------------------------------------------------------------- 1 | function feat = computefeature(structdis) 2 | 3 | % Input - MSCn coefficients 4 | % Output - Compute the 18 dimensional feature vector 5 | 6 | feat = []; 7 | 8 | 9 | 10 | [alpha betal betar] = estimateaggdparam(structdis(:)); 11 | 12 | feat = [feat;alpha;(betal+betar)/2]; 13 | 14 | shifts = [ 0 1;1 0 ;1 1;1 -1]; 15 | 16 | for itr_shift =1:4 17 | 18 | shifted_structdis = circshift(structdis,shifts(itr_shift,:)); 19 | pair = structdis(:).*shifted_structdis(:); 20 | [alpha betal betar] = estimateaggdparam(pair); 21 | meanparam = (betar-betal)*(gamma(2/alpha)/gamma(1/alpha)); 22 | feat = [feat;alpha;meanparam;betal;betar]; 23 | 24 | end 25 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/niqe_release/computemean.m: -------------------------------------------------------------------------------- 1 | function val = computemean(patch) 2 | 3 | val = mean2(patch); -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/niqe_release/computequality.m: -------------------------------------------------------------------------------- 1 | function quality = computequality(im,blocksizerow,blocksizecol,... 2 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam) 3 | 4 | % Input 5 | % im - Image whose quality needs to be computed 6 | % blocksizerow - Height of the blocks in to which image is divided 7 | % blocksizecol - Width of the blocks in to which image is divided 8 | % blockrowoverlap - Amount of vertical overlap between blocks 9 | % blockcoloverlap - Amount of horizontal overlap between blocks 10 | % mu_prisparam - mean of multivariate Gaussian model 11 | % cov_prisparam - covariance of multivariate Gaussian model 12 | 13 | % For good performance, it is advisable to use make the multivariate Gaussian model 14 | % using same size patches as the distorted image is divided in to 15 | 16 | % Output 17 | %quality - Quality of the input distorted image 18 | 19 | % Example call 20 | %quality = computequality(im,96,96,0,0,mu_prisparam,cov_prisparam) 21 | 22 | % --------------------------------------------------------------- 23 | %Number of features 24 | % 18 features at each scale 25 | featnum = 18; 26 | %---------------------------------------------------------------- 27 | %Compute features 28 | if(size(im,3)==3) 29 | im = rgb2gray(im); 30 | end 31 | im = double(im); 32 | [row col] = size(im); 33 | block_rownum = floor(row/blocksizerow); 34 | block_colnum = floor(col/blocksizecol); 35 | 36 | im = im(1:block_rownum*blocksizerow,1:block_colnum*blocksizecol); 37 | [row col] = size(im); 38 | block_rownum = floor(row/blocksizerow); 39 | block_colnum = floor(col/blocksizecol); 40 | im = im(1:block_rownum*blocksizerow, ... 41 | 1:block_colnum*blocksizecol); 42 | window = fspecial('gaussian',7,7/6); 43 | window = window/sum(sum(window)); 44 | scalenum = 2; 45 | warning('off') 46 | 47 | feat = []; 48 | 49 | 50 | for itr_scale = 1:scalenum 51 | 52 | 53 | mu = imfilter(im,window,'replicate'); 54 | mu_sq = mu.*mu; 55 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq)); 56 | structdis = (im-mu)./(sigma+1); 57 | 58 | 59 | 60 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ... 61 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ... 62 | @computefeature); 63 | feat_scale = reshape(feat_scale,[featnum .... 64 | size(feat_scale,1)*size(feat_scale,2)/featnum]); 65 | feat_scale = feat_scale'; 66 | 67 | 68 | if(itr_scale == 1) 69 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ... 70 | [blockrowoverlap blockcoloverlap],@computemean); 71 | sharpness = sharpness(:); 72 | end 73 | 74 | 75 | feat = [feat feat_scale]; 76 | 77 | im =imresize(im,0.5); 78 | 79 | end 80 | 81 | 82 | % Fit a MVG model to distorted patch features 83 | distparam = feat; 84 | mu_distparam = nanmean(distparam); 85 | cov_distparam = nancov(distparam); 86 | 87 | % Compute quality 88 | invcov_param = pinv((cov_prisparam+cov_distparam)/2); 89 | quality = sqrt((mu_prisparam-mu_distparam)* ... 90 | invcov_param*(mu_prisparam-mu_distparam)'); 91 | 92 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/niqe_release/estimateaggdparam.m: -------------------------------------------------------------------------------- 1 | function [alpha betal betar] = estimateaggdparam(vec) 2 | 3 | 4 | gam = 0.2:0.001:10; 5 | r_gam = ((gamma(2./gam)).^2)./(gamma(1./gam).*gamma(3./gam)); 6 | 7 | 8 | leftstd = sqrt(mean((vec(vec<0)).^2)); 9 | rightstd = sqrt(mean((vec(vec>0)).^2)); 10 | 11 | gammahat = leftstd/rightstd; 12 | rhat = (mean(abs(vec)))^2/mean((vec).^2); 13 | rhatnorm = (rhat*(gammahat^3 +1)*(gammahat+1))/((gammahat^2 +1)^2); 14 | [min_difference, array_position] = min((r_gam - rhatnorm).^2); 15 | alpha = gam(array_position); 16 | 17 | betal = leftstd *sqrt(gamma(1/alpha)/gamma(3/alpha)); 18 | betar = rightstd*sqrt(gamma(1/alpha)/gamma(3/alpha)); 19 | 20 | 21 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/niqe_release/estimatemodelparam.m: -------------------------------------------------------------------------------- 1 | function [mu_prisparam cov_prisparam] = estimatemodelparam(folderpath,... 2 | blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,sh_th) 3 | 4 | % Input 5 | % folderpath - Folder containing the pristine images 6 | % blocksizerow - Height of the blocks in to which image is divided 7 | % blocksizecol - Width of the blocks in to which image is divided 8 | % blockrowoverlap - Amount of vertical overlap between blocks 9 | % blockcoloverlap - Amount of horizontal overlap between blocks 10 | % sh_th - The sharpness threshold level 11 | %Output 12 | %mu_prisparam - mean of multivariate Gaussian model 13 | %cov_prisparam - covariance of multivariate Gaussian model 14 | 15 | % Example call 16 | 17 | %[mu_prisparam cov_prisparam] = estimatemodelparam('pristine',96,96,0,0,0.75); 18 | 19 | 20 | %---------------------------------------------------------------- 21 | % Find the names of images in the folder 22 | current = pwd; 23 | cd(sprintf('%s',folderpath)) 24 | names = ls; 25 | names = names(3:end,:); 26 | cd(current) 27 | % --------------------------------------------------------------- 28 | %Number of features 29 | % 18 features at each scale 30 | featnum = 18; 31 | % --------------------------------------------------------------- 32 | % Make the directory for storing the features 33 | mkdir(sprintf('local_risquee_prisfeatures')) 34 | % --------------------------------------------------------------- 35 | % Compute pristine image features 36 | for itr = 1:size(names,1) 37 | itr 38 | im = imread(sprintf('%s\\%s',folderpath,names(itr,:))); 39 | if(size(im,3)==3) 40 | im = rgb2gray(im); 41 | end 42 | im = double(im); 43 | [row col] = size(im); 44 | block_rownum = floor(row/blocksizerow); 45 | block_colnum = floor(col/blocksizecol); 46 | im = im(1:block_rownum*blocksizerow, ... 47 | 1:block_colnum*blocksizecol); 48 | window = fspecial('gaussian',7,7/6); 49 | window = window/sum(sum(window)); 50 | scalenum = 2; 51 | warning('off') 52 | 53 | feat = []; 54 | 55 | 56 | for itr_scale = 1:scalenum 57 | 58 | 59 | mu = imfilter(im,window,'replicate'); 60 | mu_sq = mu.*mu; 61 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq)); 62 | structdis = (im-mu)./(sigma+1); 63 | 64 | 65 | 66 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ... 67 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ... 68 | @computefeature); 69 | feat_scale = reshape(feat_scale,[featnum .... 70 | size(feat_scale,1)*size(feat_scale,2)/featnum]); 71 | feat_scale = feat_scale'; 72 | 73 | 74 | if(itr_scale == 1) 75 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ... 76 | [blockrowoverlap blockcoloverlap],@computemean); 77 | sharpness = sharpness(:); 78 | end 79 | 80 | 81 | feat = [feat feat_scale]; 82 | 83 | im =imresize(im,0.5); 84 | 85 | end 86 | 87 | save(sprintf('local_risquee_prisfeatures\\prisfeatures_local%d.mat',... 88 | itr),'feat','sharpness'); 89 | end 90 | 91 | 92 | 93 | %---------------------------------------------- 94 | % Load pristine image features 95 | prisparam = []; 96 | current = pwd; 97 | cd(sprintf('%s','local_risquee_prisfeatures')) 98 | names = ls; 99 | names = names(3:end,:); 100 | cd(current) 101 | for itr = 1:size(names,1) 102 | % Load the features and select the only features 103 | load(sprintf('local_risquee_prisfeatures\\%s',strtrim(names(itr,:)))); 104 | IX = find(sharpness(:) >sh_th*max(sharpness(:))); 105 | feat = feat(IX,:); 106 | prisparam = [prisparam; feat]; 107 | 108 | end 109 | %---------------------------------------------- 110 | % Compute model parameters 111 | mu_prisparam = nanmean(prisparam); 112 | cov_prisparam = nancov(prisparam); 113 | %---------------------------------------------- 114 | % Save features in the mat file 115 | save('modelparameters_new.mat','mu_prisparam','cov_prisparam'); 116 | %---------------------------------------------- 117 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/niqe_release/modelparameters.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/datasets/generate_rankdataset/utils/niqe_release/modelparameters.mat -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/niqe_release/readme.txt: -------------------------------------------------------------------------------- 1 | NIQE Software release. 2 | 3 | ======================================================================= 4 | -----------COPYRIGHT NOTICE STARTS WITH THIS LINE------------ 5 | Copyright (c) 2011 The University of Texas at Austin 6 | All rights reserved. 7 | 8 | Permission is hereby granted, without written agreement and without license or royalty fees, to use, copy, 9 | modify, and distribute this code (the source files) and its documentation for 10 | any purpose, provided that the copyright notice in its entirety appear in all copies of this code, and the 11 | original source of this code, Laboratory for Image and Video Engineering (LIVE, http://live.ece.utexas.edu) 12 | and Center for Perceptual Systems (CPS, http://www.cps.utexas.edu) at the University of Texas at Austin (UT Austin, 13 | http://www.utexas.edu), is acknowledged in any publication that reports research using this code. The research 14 | is to be cited in the bibliography as: 15 | 16 | 1) A. Mittal, R. Soundararajan and A. C. Bovik, "NIQE Software Release", 17 | URL: http://live.ece.utexas.edu/research/quality/niqe.zip, 2012. 18 | 19 | 2) A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", submitted to IEEE Signal Processing Letters, 2012. 20 | 21 | IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT AUSTIN BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, 22 | OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF THIS DATABASE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF TEXAS 23 | AT AUSTIN HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | THE UNIVERSITY OF TEXAS AT AUSTIN SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 26 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE DATABASE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, 27 | AND THE UNIVERSITY OF TEXAS AT AUSTIN HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 28 | 29 | -----------COPYRIGHT NOTICE ENDS WITH THIS LINE------------% 30 | 31 | Author : Anish Mittal 32 | Version : 1.0 33 | 34 | The authors are with the Laboratory for Image and Video Engineering 35 | (LIVE), Department of Electrical and Computer Engineering, The 36 | University of Texas at Austin, Austin, TX. 37 | 38 | Kindly report any suggestions or corrections to mittal.anish@gmail.com 39 | 40 | ======================================================================= 41 | 42 | This is a demonstration of the Naturalness Image Quality Evaluator(NIQE) index. The algorithm is described in: 43 | 44 | A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", submitted to IEEE Signal Processing Letters, 2012. 45 | 46 | You can change this program as you like and use it anywhere, but please 47 | refer to its original source (cite our paper and our web page at 48 | http://live.ece.utexas.edu/research/quality/niqe_release.zip). 49 | 50 | ======================================================================= 51 | Running on Matlab 52 | 53 | Input : A test image loaded in an array 54 | 55 | Output: A quality score of the image. Higher value represents a lower quality. 56 | 57 | Usage: 58 | 59 | 1. Load the image, for example 60 | 61 | image = imread('testimage1.bmp'); 62 | 63 | 2. Load the parameters of pristine multivariate Gaussian model. 64 | 65 | 66 | load modelparameters.mat; 67 | 68 | 69 | The images used for making the current model may be viewed at http://live.ece.utexas.edu/research/quality/pristinedata.zip 70 | 71 | 72 | 3. Initialize different parameters 73 | 74 | Height of the block 75 | blocksizerow = 96; 76 | Width of the block 77 | blocksizecol = 96; 78 | Verical overlap between blocks 79 | blocksizerow = 0; 80 | Horizontal overlap between blocks 81 | blocksizecol = 0; 82 | 83 | For good performance, it is advisable to divide the distorted image in to same size patched as used for the construction of multivariate Gaussian model. 84 | 85 | 3. Call this function to calculate the quality score: 86 | 87 | 88 | qualityscore = computequality(im,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam) 89 | 90 | Sample execution is also shown through example.m 91 | 92 | 93 | ======================================================================= 94 | 95 | MATLAB files: (provided with release): example.m, computefeature.m, computemean.m, computequality.m, estimateaggdparam.m and estimatemodelparam.m 96 | 97 | Image Files: image1.bmp, image2.bmp, image3.bmp and image4.bmp 98 | 99 | Dependencies: Mat file: modelparameters.mat provided with release 100 | 101 | ======================================================================= 102 | 103 | Note on training: 104 | This release version of NIQE was trained on 125 pristine images with patch size set to 96X96 and sharpness threshold of 0.75. 105 | 106 | Training the model 107 | 108 | If the user wants to retrain the model using different set of pristine image or set the patch sizes to different values, he/she can do so 109 | use the following function. The images used for making the current model may be viewed at http://live.ece.utexas.edu/research/quality/pristinedata.zip 110 | 111 | Folder containing the pristine images 112 | folderpath = 'pristine' 113 | Height of the block 114 | blocksizerow = 96; 115 | Width of the block 116 | blocksizecol = 96; 117 | Verical overlap between blocks 118 | blocksizerow = 0; 119 | Horizontal overlap between blocks 120 | blocksizecol = 0; 121 | The sharpness threshold level 122 | sh_th = 0.75; 123 | 124 | 125 | [mu_prisparam cov_prisparam] = estimatemodelparam(folderpath,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,sh_th) 126 | ======================================================================= 127 | -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/parcal_niqe.m: -------------------------------------------------------------------------------- 1 | function [scoresname,scoresNIQE] = parcal_niqemse(ii,input_dir,file_list,im_num) 2 | 3 | load modelparameters.mat 4 | fprintf(['\nCalculating scores for image ',num2str(ii),' / ',num2str(im_num)]); 5 | % Reading and converting images 6 | input_image_path = fullfile(input_dir,file_list(ii).name); 7 | input_image = convert_shave_image(imread(input_image_path),4); 8 | 9 | % Calculating scores 10 | scoresname = file_list(ii).name; 11 | 12 | scoresNIQE = computequality(input_image,96,96,... 13 | 0,0,mu_prisparam,cov_prisparam); 14 | 15 | fprintf([' perceptual scores is: ',num2str(scoresNIQE),' ',scoresname,'']); 16 | end -------------------------------------------------------------------------------- /datasets/generate_rankdataset/utils/save_patch_img.m: -------------------------------------------------------------------------------- 1 | function img_patch=save_patch_img(file_path,file_name,... 2 | save_dir,save_name,h,w,patch_sz) 3 | 4 | img = imread(fullfile(file_path,file_name)); 5 | img_patch = img(h:h+patch_sz-1,w:w+patch_sz-1,:); 6 | 7 | imwrite(img_patch,fullfile(save_dir,save_name)); 8 | 9 | end 10 | -------------------------------------------------------------------------------- /experiments/pretrained_models/readme.md: -------------------------------------------------------------------------------- 1 | ## Place pretrained models here. 2 | 3 | ### Pretrained models 4 | 5 | 1. `SRResNet_bicx4_in3nf64nb16.pth`: the well-trained SRResNet model in PSNR orientation. 6 | 2. `SRGAN.pth`: the pretrained model SRGAN implemented by [BasicSR](https://github.com/xinntao/BasicSR). 7 | 3. `ESRGAN_SuperSR.pth`: the pretrained model [ESRGAN_SuperSR](https://github.com/xinntao/ESRGAN). 8 | 9 | ### Three pretrained Ranker models : 10 | 1. `Ranker_NIQE.pth`: the well-trained Ranker with **NIQE** metric. 11 | 2. `Ranker_Ma.pth`: the well-trained Ranker with **Ma** metric. 12 | 3. `Ranker_PI.pth`: the well-trained Ranker with **PI** metric. 13 | 14 | ### RankSRGAN models 15 | 1. `RankSRGAN_NIQE.pth`: the RankSRGAN in **NIQE** orientation. 16 | 2. `RankSRGAN_Ma.pth`: the RankSRGAN in **Ma** orientation. 17 | 3. `RankSRGAN_PI.pth`: the RankSRGAN in **PI** orientation. 18 | 19 | 20 | 21 | *Note that* the pretrained models are trained under the `MATLAB bicubic` kernel. 22 | If the downsampled kernel is different from that, the results may have artifacts. 23 | -------------------------------------------------------------------------------- /experiments/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figures/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/figures/method.png -------------------------------------------------------------------------------- /figures/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figures/visual_results1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/figures/visual_results1.png -------------------------------------------------------------------------------- /results/readme.md: -------------------------------------------------------------------------------- 1 | 2 | --------------------------------------------------------------------------------