├── .gitattributes
├── .gitignore
├── README.md
├── codes
├── data
│ ├── LQGT_dataset.py
│ ├── LQ_dataset.py
│ ├── README.md
│ ├── Rank_IMIM_Pair_dataset.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── LQGT_dataset.cpython-36.pyc
│ │ ├── LRHR_dataset.cpython-36.pyc
│ │ ├── LR_dataset.cpython-36.pyc
│ │ ├── Rank_IMIM_Pair_dataset.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── data_sampler.cpython-36.pyc
│ │ └── util.cpython-36.pyc
│ ├── data_sampler.py
│ └── util.py
├── data_scripts
│ ├── create_lmdb.py
│ ├── extract_subimages.py
│ ├── generate_LR_Vimeo90K.m
│ ├── generate_mod_LR_bic.m
│ ├── generate_mod_LR_bic.py
│ ├── prepare_DIV2K_x4_dataset.sh
│ ├── rename.py
│ └── test_dataloader.py
├── metrics
│ ├── calculate_PSNR_SSIM.m
│ └── calculate_PSNR_SSIM.py
├── models
│ ├── RankSRGAN_model.py
│ ├── Ranker_model.py
│ ├── SRGAN_model.py
│ ├── SR_model.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── RankSRGAN.cpython-36.pyc
│ │ ├── RankSRGAN_model.cpython-36.pyc
│ │ ├── Ranker_model.cpython-36.pyc
│ │ ├── SRGAN_model.cpython-36.pyc
│ │ ├── SRGAN_rank_model.cpython-36.pyc
│ │ ├── SR_model.cpython-36.pyc
│ │ ├── __init__.cpython-36.pyc
│ │ ├── base_model.cpython-36.pyc
│ │ ├── loss.cpython-36.pyc
│ │ ├── lr_scheduler.cpython-36.pyc
│ │ └── networks.cpython-36.pyc
│ ├── archs
│ │ ├── RRDBNet_arch.py
│ │ ├── RankSRGAN_arch.py
│ │ ├── SRResNet_arch.py
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── RRDBNet_arch.cpython-36.pyc
│ │ │ ├── RankSRGAN_arch.cpython-36.pyc
│ │ │ ├── SRResNet_arch.cpython-36.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── arch_util.cpython-36.pyc
│ │ │ └── discriminator_vgg_arch.cpython-36.pyc
│ │ ├── arch_util.py
│ │ └── discriminator_vgg_arch.py
│ ├── base_model.py
│ ├── loss.py
│ ├── lr_scheduler.py
│ └── networks.py
├── options
│ ├── README.md
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ └── options.cpython-36.pyc
│ ├── options.py
│ ├── test
│ │ └── test_RankSRGAN.yml
│ └── train
│ │ ├── train_RankSRGAN.yml
│ │ ├── train_Ranker.yml
│ │ ├── train_SRGAN.yml
│ │ └── train_SRResNet.yml
├── run_scripts.sh
├── scripts
│ ├── README.md
│ ├── __pycache__
│ │ ├── arch_util.cpython-36.pyc
│ │ └── block.cpython-36.pyc
│ ├── arch_util.py
│ ├── back_projection
│ │ ├── backprojection.m
│ │ ├── main_bp.m
│ │ └── main_reverse_filter.m
│ ├── block.py
│ ├── calparameters.py
│ ├── create_lmdb.py
│ ├── extract_subimgs_single.py
│ ├── generate_mod_LR_bic.m
│ ├── transfer.py
│ └── transfer_params_MSRResNet.py
├── test.py
├── train.py
├── train_niqe.py
├── train_rank.py
└── utils
│ ├── README.md
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── rank_test.cpython-36.pyc
│ └── util.cpython-36.pyc
│ ├── perceptualmetric
│ ├── calc_NIQE.m
│ ├── computefeature.m
│ ├── computemean.m
│ ├── computequality.m
│ ├── convert_shave_image.m
│ ├── estimateaggdparam.m
│ ├── estimatemodelparam.m
│ └── modelparameters.mat
│ ├── rank_test.py
│ └── util.py
├── datasets
├── README.md
└── generate_rankdataset
│ ├── README.md
│ ├── generate_rankdataset.m
│ ├── generate_train_ranklabel.m
│ ├── generate_valid_ranklabel.m
│ ├── move_valid.py
│ └── utils
│ ├── calc_NIQE.m
│ ├── compute_sharpness.m
│ ├── computemean.m
│ ├── convert_shave_image.m
│ ├── get_sigle_patch.m
│ ├── niqe_release
│ ├── computefeature.m
│ ├── computemean.m
│ ├── computequality.m
│ ├── estimateaggdparam.m
│ ├── estimatemodelparam.m
│ ├── modelparameters.mat
│ └── readme.txt
│ ├── parcal_niqe.m
│ └── save_patch_img.m
├── experiments
├── pretrained_models
│ └── readme.md
└── readme.md
├── figures
├── method.png
├── readme.md
└── visual_results1.png
└── results
└── readme.md
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | results/*
3 | tb_logger/*
4 | experiments/*
5 | !data/readme.md
6 | !experiments/pretrained_models/readme.md
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RankSRGAN
2 | ### [Paper](https://arxiv.org/abs/1908.06382) | [Supplementary file](https://arxiv.org/abs/1908.06382) | [Project Page](https://wenlongzhang0724.github.io/Projects/RankSRGAN)
3 | ### RankSRGAN: Generative Adversarial Networks with Ranker for Image Super-Resolution
4 |
5 | By [Wenlong Zhang](https://wenlongzhang0724.github.io/), [Yihao Liu](http://xpixel.group/2010/03/29/yihaoliu.html), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=en), [Yu Qiao](http://mmlab.siat.ac.cn/yuqiao/)
6 |
7 |
8 | ---
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ### Dependencies
20 |
21 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
22 | - [PyTorch >= 1.0.0](https://pytorch.org/)
23 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
24 | - Python packages: `pip install numpy opencv-python lmdb`
25 | - [option] Python packages: [`pip install tensorboardX`](https://github.com/lanpa/tensorboardX), for visualizing curves.
26 |
27 | # Codes
28 | - We update the codes version based on [mmsr](https://github.com/open-mmlab/mmsr).
29 | The old version can be downloaded from [Google Drive](https://drive.google.com/drive/folders/13ZOwv0HIa_hrtnYAOM9cTzA-AVKqiF8c?usp=sharing)
30 | - This version is under testing. We will provide more details of RankSRGAN in near future.
31 | ## How to Test
32 | 1. Clone this github repo.
33 | ```
34 | git clone https://github.com/WenlongZhang0724/RankSRGAN.git
35 | cd RankSRGAN
36 | ```
37 | 2. Place your own **low-resolution images** in `./LR` folder.
38 | 3. Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing). Place the models in `./experiments/pretrained_models/`. We provide three Ranker models and three RankSRGAN models (see [model list](experiments/pretrained_models)).
39 | 4. Run test. We provide RankSRGAN (NIQE, Ma, PI) model and you can config in the `test.py`.
40 | ```
41 | python test.py -opt options/test/test_RankSRGAN.yml
42 | ```
43 | 5. The results are in `./results` folder.
44 |
45 | ## How to Train
46 | ### Train Ranker
47 | 1. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](https://github.com/LimBee/NTIRE2017) from [Google Drive](https://drive.google.com/drive/folders/1B-uaxvV9qeuQ-t7MFiN1oEdA6dKnj2vW?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1CFIML6KfQVYGZSNFrhMXmA)
48 | 2. Generate rank dataset [./datasets/generate_rankdataset/](datasets/generate_rankdataset)
49 | 3. Run command:
50 | ```c++
51 | python train_rank.py -opt options/train/train_Ranker.yml
52 | ```
53 |
54 | ### Train RankSRGAN
55 | We use a PSNR-oriented pretrained SR model to initialize the parameters for better quality.
56 |
57 | 1. Prepare datasets, usually the DIV2K dataset.
58 | 2. Prerapre the PSNR-oriented pretrained model. You can use the `mmsr_SRResNet_pretrain.pth` as the pretrained model that can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing).
59 | 3. Modify the configuration file `options/train/train_RankSRGAN.json`
60 | 4. Run command:
61 | ```c++
62 | python train.py -opt options/train/train_RankSRGAN.yml
63 | ```
64 | or
65 |
66 | ```c++
67 | python train_niqe.py -opt options/train/train_RankSRGAN.yml
68 | ```
69 | Using the train.py can output the convergence curves with PSNR; Using the train_niqe.py can output the convergence curves with NIQE and PSNR.
70 |
71 | ## Acknowledgement
72 | - Part of this codes was done by [Yihao Liu](http://xpixel.group/2010/03/29/yihaoliu.html).
73 | - This codes are based on [BasicSR](https://github.com/xinntao/BasicSR).
74 |
--------------------------------------------------------------------------------
/codes/data/LQGT_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 | import lmdb
5 | import torch
6 | import torch.utils.data as data
7 | import data.util as util
8 |
9 | class LQGTDataset(data.Dataset):
10 | """
11 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
12 | If only GT images are provided, generate LQ images on-the-fly.
13 | """
14 |
15 | def __init__(self, opt):
16 | super(LQGTDataset, self).__init__()
17 | self.opt = opt
18 | self.data_type = self.opt['data_type']
19 | self.paths_LQ, self.paths_GT = None, None
20 | self.sizes_LQ, self.sizes_GT = None, None
21 | self.LQ_env, self.GT_env = None, None # environments for lmdb
22 |
23 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
24 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
25 | assert self.paths_GT, 'Error: GT path is empty.'
26 | if self.paths_LQ and self.paths_GT:
27 | assert len(self.paths_LQ) == len(
28 | self.paths_GT
29 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
30 | len(self.paths_LQ), len(self.paths_GT))
31 | self.random_scale_list = [1]
32 |
33 | def _init_lmdb(self):
34 | # https://github.com/chainer/chainermn/issues/129
35 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
36 | meminit=False)
37 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
38 | meminit=False)
39 |
40 | def __getitem__(self, index):
41 | if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
42 | self._init_lmdb()
43 | GT_path, LQ_path = None, None
44 | scale = self.opt['scale']
45 | GT_size = self.opt['GT_size']
46 |
47 | # get GT image
48 | GT_path = self.paths_GT[index]
49 | resolution = [int(s) for s in self.sizes_GT[index].split('_')
50 | ] if self.data_type == 'lmdb' else None
51 | img_GT = util.read_img(self.GT_env, GT_path, resolution)
52 | if self.opt['phase'] != 'train': # modcrop in the validation / test phase
53 | img_GT = util.modcrop(img_GT, scale)
54 | if self.opt['color']: # change color space if necessary
55 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
56 |
57 | # get LQ image
58 | if self.paths_LQ:
59 | LQ_path = self.paths_LQ[index]
60 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')
61 | ] if self.data_type == 'lmdb' else None
62 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
63 | else: # down-sampling on-the-fly
64 | # randomly scale during training
65 | if self.opt['phase'] == 'train':
66 | random_scale = random.choice(self.random_scale_list)
67 | H_s, W_s, _ = img_GT.shape
68 |
69 | def _mod(n, random_scale, scale, thres):
70 | rlt = int(n * random_scale)
71 | rlt = (rlt // scale) * scale
72 | return thres if rlt < thres else rlt
73 |
74 | H_s = _mod(H_s, random_scale, scale, GT_size)
75 | W_s = _mod(W_s, random_scale, scale, GT_size)
76 | img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
77 | if img_GT.ndim == 2:
78 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
79 |
80 | H, W, _ = img_GT.shape
81 | # using matlab imresize
82 | img_LQ = util.imresize_np(img_GT, 1 / scale, True)
83 | if img_LQ.ndim == 2:
84 | img_LQ = np.expand_dims(img_LQ, axis=2)
85 |
86 | if self.opt['phase'] == 'train':
87 | # if the image size is too small
88 | H, W, _ = img_GT.shape
89 | if H < GT_size or W < GT_size:
90 | img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
91 | # using matlab imresize
92 | img_LQ = util.imresize_np(img_GT, 1 / scale, True)
93 | if img_LQ.ndim == 2:
94 | img_LQ = np.expand_dims(img_LQ, axis=2)
95 |
96 | H, W, C = img_LQ.shape
97 | LQ_size = GT_size // scale
98 |
99 | # randomly crop
100 | rnd_h = random.randint(0, max(0, H - LQ_size))
101 | rnd_w = random.randint(0, max(0, W - LQ_size))
102 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
103 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
104 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
105 |
106 | # augmentation - flip, rotate
107 | img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
108 | self.opt['use_rot'])
109 |
110 | if self.opt['color']: # change color space if necessary
111 | img_LQ = util.channel_convert(C, self.opt['color'],
112 | [img_LQ])[0] # TODO during val no definition
113 |
114 | # BGR to RGB, HWC to CHW, numpy to tensor
115 | if img_GT.shape[2] == 3:
116 | img_GT = img_GT[:, :, [2, 1, 0]]
117 | img_LQ = img_LQ[:, :, [2, 1, 0]]
118 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
119 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
120 |
121 | if LQ_path is None:
122 | LQ_path = GT_path
123 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
124 |
125 | def __len__(self):
126 | return len(self.paths_GT)
127 |
--------------------------------------------------------------------------------
/codes/data/LQ_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import lmdb
3 | import torch
4 | import torch.utils.data as data
5 | import data.util as util
6 |
7 |
8 | class LQDataset(data.Dataset):
9 | '''Read LQ images only in the test phase.'''
10 |
11 | def __init__(self, opt):
12 | super(LQDataset, self).__init__()
13 | self.opt = opt
14 | self.data_type = self.opt['data_type']
15 | self.paths_LQ, self.paths_GT = None, None
16 | self.LQ_env = None # environment for lmdb
17 |
18 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
19 | assert self.paths_LQ, 'Error: LQ paths are empty.'
20 |
21 | def _init_lmdb(self):
22 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
23 | meminit=False)
24 |
25 | def __getitem__(self, index):
26 | if self.data_type == 'lmdb' and self.LQ_env is None:
27 | self._init_lmdb()
28 | LQ_path = None
29 |
30 | # get LQ image
31 | LQ_path = self.paths_LQ[index]
32 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')
33 | ] if self.data_type == 'lmdb' else None
34 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
35 | H, W, C = img_LQ.shape
36 |
37 | if self.opt['color']: # change color space if necessary
38 | img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0]
39 |
40 | # BGR to RGB, HWC to CHW, numpy to tensor
41 | if img_LQ.shape[2] == 3:
42 | img_LQ = img_LQ[:, :, [2, 1, 0]]
43 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
44 |
45 | return {'LQ': img_LQ, 'LQ_path': LQ_path}
46 |
47 | def __len__(self):
48 | return len(self.paths_LQ)
49 |
--------------------------------------------------------------------------------
/codes/data/README.md:
--------------------------------------------------------------------------------
1 |
2 | Dataloader
3 |
4 | - use opencv (`cv2`) to read and process images.
5 |
6 | - read from **image** files OR from **.lmdb** for fast IO speed.
7 | - How to create .lmdb file? Please see [`codes/scripts/create_lmdb.py`](../scripts/create_lmdb.py).
8 |
9 | - can downsample images using `matlab bicubic` function. However, the speed is a bit slow. Implemented in [`util.py`](util.py). More about [`matlab bicubic` function](https://github.com/xinntao/BasicSR/wiki/Matlab-bicubic-imresize).
10 |
11 |
12 | ## Contents
13 |
14 | - `LR_dataset`: only reads LR images in test phase where there is no GT images.
15 | - `LRHR_dataset`: reads LR and HR pairs from image folder or lmdb files. If only HR images are provided, downsample the images on-the-fly. Used in SR and SRGAN training and validation phase.
16 | - `Rank_IMIM_Pair_dataset`: reads rank data pairs from image folder.
17 | ## How To Prepare Data
18 | ### SR, SRGAN
19 | 1. Prepare the images. You can download **classical SR** datasets (including BSD200, T91, General100; Set5, Set14, urban100, BSD100, manga109; historical) from [Google Drive](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/18fJzAHIg8Zpkc-2seGRW4Q). DIV2K dataset can be downloaded from [DIV2K offical page](https://data.vision.ee.ethz.ch/cvl/DIV2K/), or from [Baidu Drive](https://pan.baidu.com/s/1LUj90_skqlVw4rjRVeEoiw).
20 |
21 | 1. For faster IO speed, you can make lmdb files for training dataset. Please see [`codes/scripts/create_lmdb.py`](../scripts/create_lmdb.py).
22 |
23 | 1. We use DIV2K dataset for training the SR and SRGAN models.
24 | 1. since DIV2K images are large, we first crop them to sub images using [`codes/scripts/extract_subimgs_single.py`](../scripts/extract_subimgs_single.py).
25 | 1. generate LR images using matlab with [`codes/scripts/generate_mod_LR_bic.m`](../scripts/generate_mod_LR_bic.m). If you already have LR images, you can skip this step. Please make sure the LR and HR folders have the same number of images.
26 | 1. generate .lmdb file if needed using [`codes/scripts/create_lmdb.py`](../scripts/create_lmdb.py).
27 | 1. modify configurations in `options/train/xxx.json` when training, e.g., `dataroot_HR`, `dataroot_LR`.
28 |
29 | ### data augmentation
30 |
31 | We use random crop, random flip/rotation, (random scale) for data augmentation.
32 |
33 | ### wiki
34 |
35 | [Color-conversion-in-SR](https://github.com/xinntao/BasicSR/wiki/Color-conversion-in-SR)
36 |
37 |
38 |
42 |
--------------------------------------------------------------------------------
/codes/data/Rank_IMIM_Pair_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as data
7 | import data.util as util
8 | from itertools import combinations
9 | from scipy.special import comb
10 |
11 |
12 | class RANK_IMIM_Pair_Dataset(data.Dataset):
13 | '''
14 | Read LR and HR image pair.
15 | If only HR image is provided, generate LR image on-the-fly.
16 | The pair is ensured by 'sorted' function, so please check the name convention.
17 | '''
18 |
19 | def name(self):
20 | return 'RANK_IMIM_Pair_Dataset'
21 |
22 | def __init__(self, opt, is_train):
23 | super(RANK_IMIM_Pair_Dataset, self).__init__()
24 | self.opt = opt
25 |
26 | self.is_train = is_train
27 |
28 | # read image list from lmdb or image files
29 |
30 | self.paths_img1, self.sizes_GT = util.get_image_paths(opt['data_type'], opt['dataroot_img1'])
31 | self.paths_img2, self.sizes_GT = util.get_image_paths(opt['data_type'], opt['dataroot_img2'])
32 | self.paths_img3, self.sizes_GT = util.get_image_paths(opt['data_type'], opt['dataroot_img3'])
33 |
34 | self.img_env1 = None
35 | self.img_env2 = None
36 | self.img_env3 = None
37 |
38 | self.label_path = opt['dataroot_label_file']
39 |
40 | # get image label scores
41 | self.label = {}
42 | f = open(self.label_path, 'r')
43 | for line in f.readlines():
44 | line = line.strip().split()
45 | self.label[line[0]] = line[1]
46 | f.close()
47 |
48 | assert self.paths_img1, 'Error: img1 paths are empty.'
49 |
50 | # self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5]
51 | self.random_scale_list = None
52 |
53 | def __getitem__(self, index):
54 |
55 | if self.is_train:
56 | # get img1 and img1 label score
57 | # choice = random.choice(['img1_img2','img1_img2','img1_img2','img1_img3','img2_img3']) #Oversampling for hard sample
58 | choice = random.choice(['img1_img2', 'img1_img3', 'img2_img3'])
59 |
60 | # print(choice)
61 |
62 | if choice == 'img1_img2':
63 | img1_path = self.paths_img1[index]
64 | img1 = util.read_img(self.img_env1, img1_path)
65 | img2_path = self.paths_img2[index]
66 | img2 = util.read_img(self.img_env2, img2_path)
67 | elif choice == 'img1_img3':
68 | img1_path = self.paths_img1[index]
69 | img1 = util.read_img(self.img_env1, img1_path)
70 | img2_path = self.paths_img3[index]
71 | img2 = util.read_img(self.img_env3, img2_path)
72 |
73 | elif choice == 'img2_img3':
74 | img1_path = self.paths_img2[index]
75 | img1 = util.read_img(self.img_env2, img1_path)
76 | img2_path = self.paths_img3[index]
77 | img2 = util.read_img(self.img_env3, img2_path)
78 |
79 |
80 | img1_name = img1_path.split('/')[-1]
81 | img1_score = np.array(float(self.label[img1_name]), dtype='float')
82 | img1_score = img1_score.reshape(1)
83 |
84 | img2_name = img2_path.split('/')[-1]
85 | img2_score = np.array(float(self.label[img2_name]), dtype='float')
86 | img2_score = img2_score.reshape(1)
87 |
88 | if img1.shape[2] == 3:
89 | img1 = img1[:, :, [2, 1, 0]]
90 | img1 = torch.from_numpy(np.ascontiguousarray(np.transpose(img1, (2, 0, 1)))).float()
91 | img1_score = torch.from_numpy(img1_score).float()
92 |
93 | if img2.shape[2] == 3:
94 | img2 = img2[:, :, [2, 1, 0]]
95 | img2 = torch.from_numpy(np.ascontiguousarray(np.transpose(img2, (2, 0, 1)))).float()
96 | img2_score = torch.from_numpy(img2_score).float()
97 |
98 | # print('img1:'+img1_name,' & ','img2:'+img2_name)
99 |
100 | else:
101 | # get img1
102 | img1_path = self.paths_img1[index]
103 | img1 = util.read_img(self.img_env1, img1_path)
104 |
105 | img1_name = img1_path.split('/')[-1]
106 | img1_score = np.array(float(self.label[img1_name]), dtype='float')
107 | img1_score = img1_score.reshape(1)
108 |
109 | if img1.shape[2] == 3:
110 | img1 = img1[:, :, [2, 1, 0]]
111 | img1 = torch.from_numpy(np.ascontiguousarray(np.transpose(img1, (2, 0, 1)))).float()
112 | img1_score = torch.from_numpy(img1_score).float()
113 | # print('img1:'+img1_name)
114 |
115 | # not useful
116 | img2_path = img1_path
117 | img2 = img1
118 | img2_score = img1_score
119 |
120 | # exit()
121 |
122 | return {'img1': img1, 'img2': img2, 'img1_path': img1_path, 'img2_path': img2_path, 'score1': img1_score,
123 | 'score2': img2_score}
124 |
125 | def __len__(self):
126 | return len(self.paths_img1)
127 |
--------------------------------------------------------------------------------
/codes/data/__init__.py:
--------------------------------------------------------------------------------
1 | """create dataset and dataloader"""
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 |
7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
8 | phase = dataset_opt['phase']
9 | if phase == 'train':
10 | if opt['dist']:
11 | world_size = torch.distributed.get_world_size()
12 | num_workers = dataset_opt['n_workers']
13 | assert dataset_opt['batch_size'] % world_size == 0
14 | batch_size = dataset_opt['batch_size'] // world_size
15 | shuffle = False
16 | else:
17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
18 | batch_size = dataset_opt['batch_size']
19 | shuffle = True
20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
21 | num_workers=num_workers, sampler=sampler, drop_last=True,
22 | pin_memory=False)
23 | else:
24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
25 | pin_memory=False)
26 |
27 | def create_dataset(dataset_opt, is_train = True):
28 | mode = dataset_opt['mode']
29 | # datasets for image restoration
30 | if mode == 'LQ':
31 | from data.LQ_dataset import LQDataset as D
32 | elif mode == 'LQGT':
33 | from data.LQGT_dataset import LQGTDataset as D
34 | elif mode == 'RANK_IMIM_Pair':
35 | from data.Rank_IMIM_Pair_dataset import RANK_IMIM_Pair_Dataset as D
36 | else:
37 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
38 | if 'RANK_IMIM_Pair' in mode:
39 | dataset = D(dataset_opt, is_train = is_train)
40 | else:
41 | dataset = D(dataset_opt)
42 | logger = logging.getLogger('base')
43 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
44 | dataset_opt['name']))
45 | return dataset
46 |
--------------------------------------------------------------------------------
/codes/data/__pycache__/LQGT_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/LQGT_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/LRHR_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/LRHR_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/LR_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/LR_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/Rank_IMIM_Pair_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/Rank_IMIM_Pair_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/data_sampler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/data_sampler.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/data/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from torch.utils.data.distributed.DistributedSampler
3 | Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the
4 | dataloader after each epoch
5 | """
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | import torch.distributed as dist
10 |
11 |
12 | class DistIterSampler(Sampler):
13 | """Sampler that restricts data loading to a subset of the dataset.
14 |
15 | It is especially useful in conjunction with
16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17 | process can pass a DistributedSampler instance as a DataLoader sampler,
18 | and load a subset of the original dataset that is exclusive to it.
19 |
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 |
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (optional): Number of processes participating in
26 | distributed training.
27 | rank (optional): Rank of the current process within num_replicas.
28 | """
29 |
30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31 | if num_replicas is None:
32 | if not dist.is_available():
33 | raise RuntimeError("Requires distributed package to be available")
34 | num_replicas = dist.get_world_size()
35 | if rank is None:
36 | if not dist.is_available():
37 | raise RuntimeError("Requires distributed package to be available")
38 | rank = dist.get_rank()
39 | self.dataset = dataset
40 | self.num_replicas = num_replicas
41 | self.rank = rank
42 | self.epoch = 0
43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
44 | self.total_size = self.num_samples * self.num_replicas
45 |
46 | def __iter__(self):
47 | # deterministically shuffle based on epoch
48 | g = torch.Generator()
49 | g.manual_seed(self.epoch)
50 | indices = torch.randperm(self.total_size, generator=g).tolist()
51 |
52 | dsize = len(self.dataset)
53 | indices = [v % dsize for v in indices]
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | assert len(indices) == self.num_samples
58 |
59 | return iter(indices)
60 |
61 | def __len__(self):
62 | return self.num_samples
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
--------------------------------------------------------------------------------
/codes/data_scripts/extract_subimages.py:
--------------------------------------------------------------------------------
1 | """A multi-thread tool to crop large images to sub-images for faster IO."""
2 | import os
3 | import os.path as osp
4 | import sys
5 | from multiprocessing import Pool
6 | import numpy as np
7 | import cv2
8 | from PIL import Image
9 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
10 | from utils.util import ProgressBar # noqa: E402
11 | import data.util as data_util # noqa: E402
12 |
13 |
14 | def main():
15 | mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
16 | opt = {}
17 | opt['n_thread'] = 20
18 | opt['compression_level'] = 3 # 3 is the default value in cv2
19 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
20 | # compression time. If read raw images during training, use 0 for faster IO speed.
21 | if mode == 'single':
22 | opt['input_folder'] = '../../datasets/DIV2K/DIV2K_train_HR'
23 | opt['save_folder'] = '../../datasets/DIV2K/DIV2K800_sub'
24 | opt['crop_sz'] = 480 # the size of each sub-image
25 | opt['step'] = 240 # step of the sliding crop window
26 | opt['thres_sz'] = 48 # size threshold
27 | extract_signle(opt)
28 | elif mode == 'pair':
29 | GT_folder = '../../datasets/DIV2K/DIV2K_train_HR'
30 | LR_folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
31 | save_GT_folder = '../../datasets/DIV2K/DIV2K800_sub'
32 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4'
33 | scale_ratio = 4
34 | crop_sz = 480 # the size of each sub-image (GT)
35 | step = 240 # step of the sliding crop window (GT)
36 | thres_sz = 48 # size threshold
37 | ########################################################################
38 | # check that all the GT and LR images have correct scale ratio
39 | img_GT_list = data_util._get_paths_from_images(GT_folder)
40 | img_LR_list = data_util._get_paths_from_images(LR_folder)
41 | assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
42 | for path_GT, path_LR in zip(img_GT_list, img_LR_list):
43 | img_GT = Image.open(path_GT)
44 | img_LR = Image.open(path_LR)
45 | w_GT, h_GT = img_GT.size
46 | w_LR, h_LR = img_LR.size
47 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
48 | w_GT, scale_ratio, w_LR, path_GT)
49 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
50 | w_GT, scale_ratio, w_LR, path_GT)
51 | # check crop size, step and threshold size
52 | assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
53 | scale_ratio)
54 | assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
55 | assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
56 | scale_ratio)
57 | print('process GT...')
58 | opt['input_folder'] = GT_folder
59 | opt['save_folder'] = save_GT_folder
60 | opt['crop_sz'] = crop_sz
61 | opt['step'] = step
62 | opt['thres_sz'] = thres_sz
63 | extract_signle(opt)
64 | print('process LR...')
65 | opt['input_folder'] = LR_folder
66 | opt['save_folder'] = save_LR_folder
67 | opt['crop_sz'] = crop_sz // scale_ratio
68 | opt['step'] = step // scale_ratio
69 | opt['thres_sz'] = thres_sz // scale_ratio
70 | extract_signle(opt)
71 | assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
72 | data_util._get_paths_from_images(
73 | save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
74 | else:
75 | raise ValueError('Wrong mode.')
76 |
77 |
78 | def extract_signle(opt):
79 | input_folder = opt['input_folder']
80 | save_folder = opt['save_folder']
81 | if not osp.exists(save_folder):
82 | os.makedirs(save_folder)
83 | print('mkdir [{:s}] ...'.format(save_folder))
84 | else:
85 | print('Folder [{:s}] already exists. Exit...'.format(save_folder))
86 | sys.exit(1)
87 | img_list = data_util._get_paths_from_images(input_folder)
88 |
89 | def update(arg):
90 | pbar.update(arg)
91 |
92 | pbar = ProgressBar(len(img_list))
93 |
94 | pool = Pool(opt['n_thread'])
95 | for path in img_list:
96 | pool.apply_async(worker, args=(path, opt), callback=update)
97 | pool.close()
98 | pool.join()
99 | print('All subprocesses done.')
100 |
101 |
102 | def worker(path, opt):
103 | crop_sz = opt['crop_sz']
104 | step = opt['step']
105 | thres_sz = opt['thres_sz']
106 | img_name = osp.basename(path)
107 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
108 |
109 | n_channels = len(img.shape)
110 | if n_channels == 2:
111 | h, w = img.shape
112 | elif n_channels == 3:
113 | h, w, c = img.shape
114 | else:
115 | raise ValueError('Wrong image shape - {}'.format(n_channels))
116 |
117 | h_space = np.arange(0, h - crop_sz + 1, step)
118 | if h - (h_space[-1] + crop_sz) > thres_sz:
119 | h_space = np.append(h_space, h - crop_sz)
120 | w_space = np.arange(0, w - crop_sz + 1, step)
121 | if w - (w_space[-1] + crop_sz) > thres_sz:
122 | w_space = np.append(w_space, w - crop_sz)
123 |
124 | index = 0
125 | for x in h_space:
126 | for y in w_space:
127 | index += 1
128 | if n_channels == 2:
129 | crop_img = img[x:x + crop_sz, y:y + crop_sz]
130 | else:
131 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
132 | crop_img = np.ascontiguousarray(crop_img)
133 | cv2.imwrite(
134 | osp.join(opt['save_folder'],
135 | img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
136 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
137 | return 'Processing {:s} ...'.format(img_name)
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_LR_Vimeo90K.m:
--------------------------------------------------------------------------------
1 | function generate_LR_Vimeo90K()
2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
3 |
4 | up_scale = 4;
5 | mod_scale = 4;
6 | idx = 0;
7 | filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
8 | for i = 1 : length(filepaths)
9 | [~,imname,ext] = fileparts(filepaths(i).name);
10 | folder_path = filepaths(i).folder;
11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
12 | if ~exist(save_LR_folder, 'dir')
13 | mkdir(save_LR_folder);
14 | end
15 | if isempty(imname)
16 | disp('Ignore . folder.');
17 | elseif strcmp(imname, '.')
18 | disp('Ignore .. folder.');
19 | else
20 | idx = idx + 1;
21 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
22 | fprintf(str_rlt);
23 | % read image
24 | img = imread(fullfile(folder_path, [imname, ext]));
25 | img = im2double(img);
26 | % modcrop
27 | img = modcrop(img, mod_scale);
28 | % LR
29 | im_LR = imresize(img, 1/up_scale, 'bicubic');
30 | if exist('save_LR_folder', 'var')
31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
32 | end
33 | end
34 | end
35 | end
36 |
37 | %% modcrop
38 | function img = modcrop(img, modulo)
39 | if size(img,3) == 1
40 | sz = size(img);
41 | sz = sz - mod(sz, modulo);
42 | img = img(1:sz(1), 1:sz(2));
43 | else
44 | tmpsz = size(img);
45 | sz = tmpsz(1:2);
46 | sz = sz - mod(sz, modulo);
47 | img = img(1:sz(1), 1:sz(2),:);
48 | end
49 | end
50 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_mod_LR_bic.m:
--------------------------------------------------------------------------------
1 | function generate_mod_LR_bic()
2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
3 |
4 | %% set parameters
5 | % comment the unnecessary line
6 | input_folder = '../../datasets/DIV2K/DIV2K800';
7 | % save_mod_folder = '';
8 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
9 | % save_bic_folder = '';
10 |
11 | up_scale = 4;
12 | mod_scale = 4;
13 |
14 | if exist('save_mod_folder', 'var')
15 | if exist(save_mod_folder, 'dir')
16 | disp(['It will cover ', save_mod_folder]);
17 | else
18 | mkdir(save_mod_folder);
19 | end
20 | end
21 | if exist('save_LR_folder', 'var')
22 | if exist(save_LR_folder, 'dir')
23 | disp(['It will cover ', save_LR_folder]);
24 | else
25 | mkdir(save_LR_folder);
26 | end
27 | end
28 | if exist('save_bic_folder', 'var')
29 | if exist(save_bic_folder, 'dir')
30 | disp(['It will cover ', save_bic_folder]);
31 | else
32 | mkdir(save_bic_folder);
33 | end
34 | end
35 |
36 | idx = 0;
37 | filepaths = dir(fullfile(input_folder,'*.*'));
38 | for i = 1 : length(filepaths)
39 | [paths,imname,ext] = fileparts(filepaths(i).name);
40 | if isempty(imname)
41 | disp('Ignore . folder.');
42 | elseif strcmp(imname, '.')
43 | disp('Ignore .. folder.');
44 | else
45 | idx = idx + 1;
46 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
47 | fprintf(str_rlt);
48 | % read image
49 | img = imread(fullfile(input_folder, [imname, ext]));
50 | img = im2double(img);
51 | % modcrop
52 | img = modcrop(img, mod_scale);
53 | if exist('save_mod_folder', 'var')
54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
55 | end
56 | % LR
57 | im_LR = imresize(img, 1/up_scale, 'bicubic');
58 | if exist('save_LR_folder', 'var')
59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
60 | end
61 | % Bicubic
62 | if exist('save_bic_folder', 'var')
63 | im_B = imresize(im_LR, up_scale, 'bicubic');
64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
65 | end
66 | end
67 | end
68 | end
69 |
70 | %% modcrop
71 | function img = modcrop(img, modulo)
72 | if size(img,3) == 1
73 | sz = size(img);
74 | sz = sz - mod(sz, modulo);
75 | img = img(1:sz(1), 1:sz(2));
76 | else
77 | tmpsz = size(img);
78 | sz = tmpsz(1:2);
79 | sz = sz - mod(sz, modulo);
80 | img = img(1:sz(1), 1:sz(2),:);
81 | end
82 | end
83 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_mod_LR_bic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import numpy as np
5 |
6 | try:
7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8 | from data.util import imresize_np
9 | except ImportError:
10 | pass
11 |
12 |
13 | def generate_mod_LR_bic():
14 | # set parameters
15 | up_scale = 4
16 | mod_scale = 4
17 | # set data dir
18 | sourcedir = '/data/datasets/img'
19 | savedir = '/data/datasets/mod'
20 |
21 | saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
22 | saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
23 | saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
24 |
25 | if not os.path.isdir(sourcedir):
26 | print('Error: No source data found')
27 | exit(0)
28 | if not os.path.isdir(savedir):
29 | os.mkdir(savedir)
30 |
31 | if not os.path.isdir(os.path.join(savedir, 'HR')):
32 | os.mkdir(os.path.join(savedir, 'HR'))
33 | if not os.path.isdir(os.path.join(savedir, 'LR')):
34 | os.mkdir(os.path.join(savedir, 'LR'))
35 | if not os.path.isdir(os.path.join(savedir, 'Bic')):
36 | os.mkdir(os.path.join(savedir, 'Bic'))
37 |
38 | if not os.path.isdir(saveHRpath):
39 | os.mkdir(saveHRpath)
40 | else:
41 | print('It will cover ' + str(saveHRpath))
42 |
43 | if not os.path.isdir(saveLRpath):
44 | os.mkdir(saveLRpath)
45 | else:
46 | print('It will cover ' + str(saveLRpath))
47 |
48 | if not os.path.isdir(saveBicpath):
49 | os.mkdir(saveBicpath)
50 | else:
51 | print('It will cover ' + str(saveBicpath))
52 |
53 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
54 | num_files = len(filepaths)
55 |
56 | # prepare data with augementation
57 | for i in range(num_files):
58 | filename = filepaths[i]
59 | print('No.{} -- Processing {}'.format(i, filename))
60 | # read image
61 | image = cv2.imread(os.path.join(sourcedir, filename))
62 |
63 | width = int(np.floor(image.shape[1] / mod_scale))
64 | height = int(np.floor(image.shape[0] / mod_scale))
65 | # modcrop
66 | if len(image.shape) == 3:
67 | image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
68 | else:
69 | image_HR = image[0:mod_scale * height, 0:mod_scale * width]
70 | # LR
71 | image_LR = imresize_np(image_HR, 1 / up_scale, True)
72 | # bic
73 | image_Bic = imresize_np(image_LR, up_scale, True)
74 |
75 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
76 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
77 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
78 |
79 |
80 | if __name__ == "__main__":
81 | generate_mod_LR_bic()
82 |
--------------------------------------------------------------------------------
/codes/data_scripts/prepare_DIV2K_x4_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | echo "Prepare DIV2K X4 datasets..."
4 | cd ../../datasets
5 | mkdir DIV2K
6 | cd DIV2K
7 |
8 | #### Step 1
9 | echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
10 | # GT
11 | FOLDER=DIV2K_train_HR
12 | FILE=DIV2K_train_HR.zip
13 | if [ ! -d "$FOLDER" ]; then
14 | if [ ! -f "$FILE" ]; then
15 | echo "Downloading $FILE..."
16 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
17 | fi
18 | unzip $FILE
19 | fi
20 | # LR
21 | FOLDER=DIV2K_train_LR_bicubic
22 | FILE=DIV2K_train_LR_bicubic_X4.zip
23 | if [ ! -d "$FOLDER" ]; then
24 | if [ ! -f "$FILE" ]; then
25 | echo "Downloading $FILE..."
26 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
27 | fi
28 | unzip $FILE
29 | fi
30 |
31 | #### Step 2
32 | echo "Step 2: Rename the LR images..."
33 | cd ../../codes/data_scripts
34 | python rename.py
35 |
36 | #### Step 4
37 | echo "Step 4: Crop to sub-images..."
38 | python extract_subimages.py
39 |
40 | #### Step 5
41 | echo "Step5: Create LMDB files..."
42 | python create_lmdb.py
43 |
--------------------------------------------------------------------------------
/codes/data_scripts/rename.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 |
5 | def main():
6 | folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
7 | DIV2K(folder)
8 | print('Finished.')
9 |
10 |
11 | def DIV2K(path):
12 | img_path_l = glob.glob(os.path.join(path, '*'))
13 | for img_path in img_path_l:
14 | new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
15 | os.rename(img_path, new_path)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
--------------------------------------------------------------------------------
/codes/data_scripts/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os.path as osp
3 | import math
4 | import torchvision.utils
5 |
6 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
7 | from data import create_dataloader, create_dataset # noqa: E402
8 | from utils import util # noqa: E402
9 |
10 |
11 | def main():
12 | dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
13 | opt = {}
14 | opt['dist'] = False
15 | opt['gpu_ids'] = [0]
16 | if dataset == 'REDS':
17 | opt['name'] = 'test_REDS'
18 | opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
19 | opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
20 | opt['mode'] = 'REDS'
21 | opt['N_frames'] = 5
22 | opt['phase'] = 'train'
23 | opt['use_shuffle'] = True
24 | opt['n_workers'] = 8
25 | opt['batch_size'] = 16
26 | opt['GT_size'] = 256
27 | opt['LQ_size'] = 64
28 | opt['scale'] = 4
29 | opt['use_flip'] = True
30 | opt['use_rot'] = True
31 | opt['interval_list'] = [1]
32 | opt['random_reverse'] = False
33 | opt['border_mode'] = False
34 | opt['cache_keys'] = None
35 | opt['data_type'] = 'lmdb' # img | lmdb | mc
36 | elif dataset == 'Vimeo90K':
37 | opt['name'] = 'test_Vimeo90K'
38 | opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
39 | opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
40 | opt['mode'] = 'Vimeo90K'
41 | opt['N_frames'] = 7
42 | opt['phase'] = 'train'
43 | opt['use_shuffle'] = True
44 | opt['n_workers'] = 8
45 | opt['batch_size'] = 16
46 | opt['GT_size'] = 256
47 | opt['LQ_size'] = 64
48 | opt['scale'] = 4
49 | opt['use_flip'] = True
50 | opt['use_rot'] = True
51 | opt['interval_list'] = [1]
52 | opt['random_reverse'] = False
53 | opt['border_mode'] = False
54 | opt['cache_keys'] = None
55 | opt['data_type'] = 'lmdb' # img | lmdb | mc
56 | elif dataset == 'DIV2K800_sub':
57 | opt['name'] = 'DIV2K800'
58 | opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
59 | opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
60 | opt['mode'] = 'LQGT'
61 | opt['phase'] = 'train'
62 | opt['use_shuffle'] = True
63 | opt['n_workers'] = 8
64 | opt['batch_size'] = 16
65 | opt['GT_size'] = 128
66 | opt['scale'] = 4
67 | opt['use_flip'] = True
68 | opt['use_rot'] = True
69 | opt['color'] = 'RGB'
70 | opt['data_type'] = 'lmdb' # img | lmdb
71 | else:
72 | raise ValueError('Please implement by yourself.')
73 |
74 | util.mkdir('tmp')
75 | train_set = create_dataset(opt)
76 | train_loader = create_dataloader(train_set, opt, opt, None)
77 | nrow = int(math.sqrt(opt['batch_size']))
78 | padding = 2 if opt['phase'] == 'train' else 0
79 |
80 | print('start...')
81 | for i, data in enumerate(train_loader):
82 | if i > 5:
83 | break
84 | print(i)
85 | if dataset == 'REDS' or dataset == 'Vimeo90K':
86 | LQs = data['LQs']
87 | else:
88 | LQ = data['LQ']
89 | GT = data['GT']
90 |
91 | if dataset == 'REDS' or dataset == 'Vimeo90K':
92 | for j in range(LQs.size(1)):
93 | torchvision.utils.save_image(LQs[:, j, :, :, :],
94 | 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
95 | padding=padding, normalize=False)
96 | else:
97 | torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
98 | padding=padding, normalize=False)
99 | torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
100 | normalize=False)
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/codes/metrics/calculate_PSNR_SSIM.m:
--------------------------------------------------------------------------------
1 | function calculate_PSNR_SSIM()
2 |
3 | % GT and SR folder
4 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
5 | folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
6 | scale = 4;
7 | suffix = ''; % suffix for SR images
8 | test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
9 | if test_Y
10 | fprintf('Tesing Y channel.\n');
11 | else
12 | fprintf('Tesing RGB channels.\n');
13 | end
14 | filepaths = dir(fullfile(folder_GT, '*.png'));
15 | PSNR_all = zeros(1, length(filepaths));
16 | SSIM_all = zeros(1, length(filepaths));
17 |
18 | for idx_im = 1:length(filepaths)
19 | im_name = filepaths(idx_im).name;
20 | im_GT = imread(fullfile(folder_GT, im_name));
21 | im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
22 |
23 | if test_Y % evaluate on Y channel in YCbCr color space
24 | if size(im_GT, 3) == 3
25 | im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
26 | im_GT_in = im_GT_YCbCr(:,:,1);
27 | im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
28 | im_SR_in = im_SR_YCbCr(:,:,1);
29 | else
30 | im_GT_in = im2double(im_GT);
31 | im_SR_in = im2double(im_SR);
32 | end
33 | else % evaluate on RGB channels
34 | im_GT_in = im2double(im_GT);
35 | im_SR_in = im2double(im_SR);
36 | end
37 |
38 | % calculate PSNR and SSIM
39 | PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
40 | SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
41 | fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
42 | end
43 |
44 | fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
45 | end
46 |
47 | function res = calculate_PSNR(GT, SR, border)
48 | % remove border
49 | GT = GT(border+1:end-border, border+1:end-border, :);
50 | SR = SR(border+1:end-border, border+1:end-border, :);
51 | % calculate PNSR (assume in [0,255])
52 | error = GT(:) - SR(:);
53 | mse = mean(error.^2);
54 | res = 10 * log10(255^2/mse);
55 | end
56 |
57 | function res = calculate_SSIM(GT, SR, border)
58 | GT = GT(border+1:end-border, border+1:end-border, :);
59 | SR = SR(border+1:end-border, border+1:end-border, :);
60 | % calculate SSIM
61 | mssim = zeros(1, size(SR, 3));
62 | for i = 1:size(SR,3)
63 | [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
64 | end
65 | res = mean(mssim);
66 | end
67 |
68 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
69 |
70 | %========================================================================
71 | %SSIM Index, Version 1.0
72 | %Copyright(c) 2003 Zhou Wang
73 | %All Rights Reserved.
74 | %
75 | %The author is with Howard Hughes Medical Institute, and Laboratory
76 | %for Computational Vision at Center for Neural Science and Courant
77 | %Institute of Mathematical Sciences, New York University.
78 | %
79 | %----------------------------------------------------------------------
80 | %Permission to use, copy, or modify this software and its documentation
81 | %for educational and research purposes only and without fee is hereby
82 | %granted, provided that this copyright notice and the original authors'
83 | %names appear on all copies and supporting documentation. This program
84 | %shall not be used, rewritten, or adapted as the basis of a commercial
85 | %software or hardware product without first obtaining permission of the
86 | %authors. The authors make no representations about the suitability of
87 | %this software for any purpose. It is provided "as is" without express
88 | %or implied warranty.
89 | %----------------------------------------------------------------------
90 | %
91 | %This is an implementation of the algorithm for calculating the
92 | %Structural SIMilarity (SSIM) index between two images. Please refer
93 | %to the following paper:
94 | %
95 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
96 | %quality assessment: From error measurement to structural similarity"
97 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
98 | %
99 | %Kindly report any suggestions or corrections to zhouwang@ieee.org
100 | %
101 | %----------------------------------------------------------------------
102 | %
103 | %Input : (1) img1: the first image being compared
104 | % (2) img2: the second image being compared
105 | % (3) K: constants in the SSIM index formula (see the above
106 | % reference). defualt value: K = [0.01 0.03]
107 | % (4) window: local window for statistics (see the above
108 | % reference). default widnow is Gaussian given by
109 | % window = fspecial('gaussian', 11, 1.5);
110 | % (5) L: dynamic range of the images. default: L = 255
111 | %
112 | %Output: (1) mssim: the mean SSIM index value between 2 images.
113 | % If one of the images being compared is regarded as
114 | % perfect quality, then mssim can be considered as the
115 | % quality measure of the other image.
116 | % If img1 = img2, then mssim = 1.
117 | % (2) ssim_map: the SSIM index map of the test image. The map
118 | % has a smaller size than the input images. The actual size:
119 | % size(img1) - size(window) + 1.
120 | %
121 | %Default Usage:
122 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
123 | %
124 | % [mssim ssim_map] = ssim_index(img1, img2);
125 | %
126 | %Advanced Usage:
127 | % User defined parameters. For example
128 | %
129 | % K = [0.05 0.05];
130 | % window = ones(8);
131 | % L = 100;
132 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
133 | %
134 | %See the results:
135 | %
136 | % mssim %Gives the mssim value
137 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
138 | %
139 | %========================================================================
140 |
141 |
142 | if (nargin < 2 || nargin > 5)
143 | ssim_index = -Inf;
144 | ssim_map = -Inf;
145 | return;
146 | end
147 |
148 | if (size(img1) ~= size(img2))
149 | ssim_index = -Inf;
150 | ssim_map = -Inf;
151 | return;
152 | end
153 |
154 | [M, N] = size(img1);
155 |
156 | if (nargin == 2)
157 | if ((M < 11) || (N < 11))
158 | ssim_index = -Inf;
159 | ssim_map = -Inf;
160 | return
161 | end
162 | window = fspecial('gaussian', 11, 1.5); %
163 | K(1) = 0.01; % default settings
164 | K(2) = 0.03; %
165 | L = 255; %
166 | end
167 |
168 | if (nargin == 3)
169 | if ((M < 11) || (N < 11))
170 | ssim_index = -Inf;
171 | ssim_map = -Inf;
172 | return
173 | end
174 | window = fspecial('gaussian', 11, 1.5);
175 | L = 255;
176 | if (length(K) == 2)
177 | if (K(1) < 0 || K(2) < 0)
178 | ssim_index = -Inf;
179 | ssim_map = -Inf;
180 | return;
181 | end
182 | else
183 | ssim_index = -Inf;
184 | ssim_map = -Inf;
185 | return;
186 | end
187 | end
188 |
189 | if (nargin == 4)
190 | [H, W] = size(window);
191 | if ((H*W) < 4 || (H > M) || (W > N))
192 | ssim_index = -Inf;
193 | ssim_map = -Inf;
194 | return
195 | end
196 | L = 255;
197 | if (length(K) == 2)
198 | if (K(1) < 0 || K(2) < 0)
199 | ssim_index = -Inf;
200 | ssim_map = -Inf;
201 | return;
202 | end
203 | else
204 | ssim_index = -Inf;
205 | ssim_map = -Inf;
206 | return;
207 | end
208 | end
209 |
210 | if (nargin == 5)
211 | [H, W] = size(window);
212 | if ((H*W) < 4 || (H > M) || (W > N))
213 | ssim_index = -Inf;
214 | ssim_map = -Inf;
215 | return
216 | end
217 | if (length(K) == 2)
218 | if (K(1) < 0 || K(2) < 0)
219 | ssim_index = -Inf;
220 | ssim_map = -Inf;
221 | return;
222 | end
223 | else
224 | ssim_index = -Inf;
225 | ssim_map = -Inf;
226 | return;
227 | end
228 | end
229 |
230 | C1 = (K(1)*L)^2;
231 | C2 = (K(2)*L)^2;
232 | window = window/sum(sum(window));
233 | img1 = double(img1);
234 | img2 = double(img2);
235 |
236 | mu1 = filter2(window, img1, 'valid');
237 | mu2 = filter2(window, img2, 'valid');
238 | mu1_sq = mu1.*mu1;
239 | mu2_sq = mu2.*mu2;
240 | mu1_mu2 = mu1.*mu2;
241 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
242 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
243 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
244 |
245 | if (C1 > 0 && C2 > 0)
246 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
247 | else
248 | numerator1 = 2*mu1_mu2 + C1;
249 | numerator2 = 2*sigma12 + C2;
250 | denominator1 = mu1_sq + mu2_sq + C1;
251 | denominator2 = sigma1_sq + sigma2_sq + C2;
252 | ssim_map = ones(size(mu1));
253 | index = (denominator1.*denominator2 > 0);
254 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
255 | index = (denominator1 ~= 0) & (denominator2 == 0);
256 | ssim_map(index) = numerator1(index)./denominator1(index);
257 | end
258 |
259 | mssim = mean2(ssim_map);
260 |
261 | end
262 |
--------------------------------------------------------------------------------
/codes/metrics/calculate_PSNR_SSIM.py:
--------------------------------------------------------------------------------
1 | '''
2 | calculate the PSNR and SSIM.
3 | same as MATLAB's results
4 | '''
5 | import os
6 | import math
7 | import numpy as np
8 | import cv2
9 | import glob
10 |
11 |
12 | def main():
13 | # Configurations
14 |
15 | # GT - Ground-truth;
16 | # Gen: Generated / Restored / Recovered images
17 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
18 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
19 |
20 | crop_border = 4
21 | suffix = '' # suffix for Gen images
22 | test_Y = False # True: test Y channel only; False: test RGB channels
23 |
24 | PSNR_all = []
25 | SSIM_all = []
26 | img_list = sorted(glob.glob(folder_GT + '/*'))
27 |
28 | if test_Y:
29 | print('Testing Y channel.')
30 | else:
31 | print('Testing RGB channels.')
32 |
33 | for i, img_path in enumerate(img_list):
34 | base_name = os.path.splitext(os.path.basename(img_path))[0]
35 | im_GT = cv2.imread(img_path) / 255.
36 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
37 |
38 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
39 | im_GT_in = bgr2ycbcr(im_GT)
40 | im_Gen_in = bgr2ycbcr(im_Gen)
41 | else:
42 | im_GT_in = im_GT
43 | im_Gen_in = im_Gen
44 |
45 | # crop borders
46 | if im_GT_in.ndim == 3:
47 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
48 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
49 | elif im_GT_in.ndim == 2:
50 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
51 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
52 | else:
53 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
54 |
55 | # calculate PSNR and SSIM
56 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
57 |
58 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
59 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
60 | i + 1, base_name, PSNR, SSIM))
61 | PSNR_all.append(PSNR)
62 | SSIM_all.append(SSIM)
63 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
64 | sum(PSNR_all) / len(PSNR_all),
65 | sum(SSIM_all) / len(SSIM_all)))
66 |
67 |
68 | def calculate_psnr(img1, img2):
69 | # img1 and img2 have range [0, 255]
70 | img1 = img1.astype(np.float64)
71 | img2 = img2.astype(np.float64)
72 | mse = np.mean((img1 - img2)**2)
73 | if mse == 0:
74 | return float('inf')
75 | return 20 * math.log10(255.0 / math.sqrt(mse))
76 |
77 |
78 | def ssim(img1, img2):
79 | C1 = (0.01 * 255)**2
80 | C2 = (0.03 * 255)**2
81 |
82 | img1 = img1.astype(np.float64)
83 | img2 = img2.astype(np.float64)
84 | kernel = cv2.getGaussianKernel(11, 1.5)
85 | window = np.outer(kernel, kernel.transpose())
86 |
87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
89 | mu1_sq = mu1**2
90 | mu2_sq = mu2**2
91 | mu1_mu2 = mu1 * mu2
92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
95 |
96 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
97 | (sigma1_sq + sigma2_sq + C2))
98 | return ssim_map.mean()
99 |
100 |
101 | def calculate_ssim(img1, img2):
102 | '''calculate SSIM
103 | the same outputs as MATLAB's
104 | img1, img2: [0, 255]
105 | '''
106 | if not img1.shape == img2.shape:
107 | raise ValueError('Input images must have the same dimensions.')
108 | if img1.ndim == 2:
109 | return ssim(img1, img2)
110 | elif img1.ndim == 3:
111 | if img1.shape[2] == 3:
112 | ssims = []
113 | for i in range(3):
114 | ssims.append(ssim(img1, img2))
115 | return np.array(ssims).mean()
116 | elif img1.shape[2] == 1:
117 | return ssim(np.squeeze(img1), np.squeeze(img2))
118 | else:
119 | raise ValueError('Wrong input image dimensions.')
120 |
121 |
122 | def bgr2ycbcr(img, only_y=True):
123 | '''same as matlab rgb2ycbcr
124 | only_y: only return Y channel
125 | Input:
126 | uint8, [0, 255]
127 | float, [0, 1]
128 | '''
129 | in_img_type = img.dtype
130 | img.astype(np.float32)
131 | if in_img_type != np.uint8:
132 | img *= 255.
133 | # convert
134 | if only_y:
135 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
136 | else:
137 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
138 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
139 | if in_img_type == np.uint8:
140 | rlt = rlt.round()
141 | else:
142 | rlt /= 255.
143 | return rlt.astype(in_img_type)
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
148 |
--------------------------------------------------------------------------------
/codes/models/Ranker_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import logging
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.optim import lr_scheduler
8 | from torch.nn.parallel import DataParallel, DistributedDataParallel
9 |
10 | import models.networks as networks
11 | from .base_model import BaseModel
12 |
13 | logger = logging.getLogger('base')
14 |
15 |
16 | class Ranker_Model(BaseModel):
17 | def name(self):
18 | return 'Ranker_Model'
19 |
20 | def __init__(self, opt):
21 | super(Ranker_Model, self).__init__(opt)
22 |
23 | if opt['dist']:
24 | self.rank = torch.distributed.get_rank()
25 | else:
26 | self.rank = -1 # non dist training
27 | train_opt = opt['train']
28 |
29 | # define networks and load pretrained models
30 | self.netR = networks.define_R(opt).to(self.device)
31 | if opt['dist']:
32 | self.netR = DistributedDataParallel(self.netR, device_ids=[torch.cuda.current_device()])
33 | else:
34 | self.netR = DataParallel(self.netR)
35 | self.load()
36 |
37 | if self.is_train:
38 | self.netR.train()
39 |
40 | # loss
41 | self.RankLoss = nn.MarginRankingLoss(margin=0.5)
42 | self.RankLoss.to(self.device)
43 | self.L2Loss = nn.L1Loss()
44 | self.L2Loss.to(self.device)
45 | # optimizers
46 | self.optimizers = []
47 | wd_R = train_opt['weight_decay_R'] if train_opt['weight_decay_R'] else 0
48 | optim_params = []
49 | for k, v in self.netR.named_parameters(): # can optimize for a part of the model
50 | if v.requires_grad:
51 | optim_params.append(v)
52 | else:
53 | print('WARNING: params [%s] will not optimize.' % k)
54 | self.optimizer_R = torch.optim.Adam(optim_params, lr=train_opt['lr_R'], weight_decay=wd_R)
55 | print('Weight_decay:%f' % wd_R)
56 | self.optimizers.append(self.optimizer_R)
57 |
58 | # schedulers
59 | self.schedulers = []
60 | if train_opt['lr_scheme'] == 'MultiStepLR':
61 | for optimizer in self.optimizers:
62 | self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
63 | train_opt['lr_steps'], train_opt['lr_gamma']))
64 | else:
65 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
66 |
67 | self.log_dict = OrderedDict()
68 |
69 | print('---------- Model initialized ------------------')
70 | self.print_network()
71 | print('-----------------------------------------------')
72 |
73 | def feed_data(self, data, need_img2=True):
74 | # input img1
75 | self.input_img1 = data['img1'].to(self.device)
76 |
77 | # label score1
78 | self.label_score1 = data['score1'].to(self.device)
79 |
80 | if need_img2:
81 | # input img2
82 | self.input_img2 = data['img2'].to(self.device)
83 |
84 | # label score2
85 | self.label_score2 = data['score2'].to(self.device)
86 |
87 | # rank label
88 | self.label = self.label_score1 >= self.label_score2 # get a ByteTensor
89 | # transfer into FloatTensor
90 | self.label = self.label.float()
91 | self.label = (self.label - 0.5) * 2
92 |
93 |
94 | def optimize_parameters(self, step):
95 | self.optimizer_R.zero_grad()
96 | self.predict_score1 = self.netR(self.input_img1)
97 | self.predict_score2 = self.netR(self.input_img2)
98 |
99 |
100 | self.predict_score1 = torch.clamp(self.predict_score1, min=-5, max=5)
101 | self.predict_score2 = torch.clamp(self.predict_score2, min=-5, max=5)
102 |
103 | l_rank = self.RankLoss(self.predict_score1, self.predict_score2, self.label)
104 |
105 | l_rank.backward()
106 | self.optimizer_R.step()
107 |
108 | # set log
109 | self.log_dict['l_rank'] = l_rank.item()
110 |
111 | def test(self):
112 | self.netR.eval()
113 | self.predict_score1 = self.netR(self.input_img1)
114 | self.netR.train()
115 |
116 | def get_current_log(self):
117 | return self.log_dict
118 |
119 | def get_current_visuals(self, need_HR=True):
120 | out_dict = OrderedDict() # ............................
121 | out_dict['predict_score1'] = self.predict_score1.data[0].float().cpu()
122 |
123 | return out_dict
124 |
125 | def print_network(self):
126 | s, n = self.get_network_description(self.netR)
127 | if isinstance(self.netR, nn.DataParallel):
128 | net_struc_str = '{} - {}'.format(self.netR.__class__.__name__,
129 | self.netR.module.__class__.__name__)
130 | else:
131 | net_struc_str = '{}'.format(self.netR.__class__.__name__)
132 | logger.info('Network R structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
133 | logger.info(s)
134 |
135 | def load(self):
136 |
137 | load_path_R = self.opt['path']['pretrain_model_R']
138 | if load_path_R is not None:
139 | logger.info('Loading pretrained model for R [{:s}] ...'.format(load_path_R))
140 | self.load_network(load_path_R, self.netR)
141 | def save(self, iter_step):
142 | self.save_network(self.netR, 'R', iter_step)
143 |
--------------------------------------------------------------------------------
/codes/models/SR_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 | import models.networks as networks
8 | import models.lr_scheduler as lr_scheduler
9 | from .base_model import BaseModel
10 | from models.loss import CharbonnierLoss
11 |
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class SRModel(BaseModel):
16 | def __init__(self, opt):
17 | super(SRModel, self).__init__(opt)
18 |
19 | if opt['dist']:
20 | self.rank = torch.distributed.get_rank()
21 | else:
22 | self.rank = -1 # non dist training
23 | train_opt = opt['train']
24 |
25 | # define network and load pretrained models
26 | self.netG = networks.define_G(opt).to(self.device)
27 | if opt['dist']:
28 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
29 | else:
30 | self.netG = DataParallel(self.netG)
31 | # print network
32 | self.print_network()
33 | self.load()
34 |
35 | if self.is_train:
36 | self.netG.train()
37 |
38 | # loss
39 | loss_type = train_opt['pixel_criterion']
40 | if loss_type == 'l1':
41 | self.cri_pix = nn.L1Loss().to(self.device)
42 | elif loss_type == 'l2':
43 | self.cri_pix = nn.MSELoss().to(self.device)
44 | elif loss_type == 'cb':
45 | self.cri_pix = CharbonnierLoss().to(self.device)
46 | else:
47 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
48 | self.l_pix_w = train_opt['pixel_weight']
49 |
50 | # optimizers
51 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
52 | optim_params = []
53 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
54 | if v.requires_grad:
55 | optim_params.append(v)
56 | else:
57 | if self.rank <= 0:
58 | logger.warning('Params [{:s}] will not optimize.'.format(k))
59 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
60 | weight_decay=wd_G,
61 | betas=(train_opt['beta1'], train_opt['beta2']))
62 | self.optimizers.append(self.optimizer_G)
63 |
64 | # schedulers
65 | if train_opt['lr_scheme'] == 'MultiStepLR':
66 | for optimizer in self.optimizers:
67 | self.schedulers.append(
68 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
69 | restarts=train_opt['restarts'],
70 | weights=train_opt['restart_weights'],
71 | gamma=train_opt['lr_gamma'],
72 | clear_state=train_opt['clear_state']))
73 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
74 | for optimizer in self.optimizers:
75 | self.schedulers.append(
76 | lr_scheduler.CosineAnnealingLR_Restart(
77 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
78 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
79 | else:
80 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
81 |
82 | self.log_dict = OrderedDict()
83 |
84 | def feed_data(self, data, need_GT=True):
85 | self.var_L = data['LQ'].to(self.device) # LQ
86 | if need_GT:
87 | self.real_H = data['GT'].to(self.device) # GT
88 |
89 | def optimize_parameters(self, step):
90 | self.optimizer_G.zero_grad()
91 | self.fake_H = self.netG(self.var_L)
92 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
93 | l_pix.backward()
94 | self.optimizer_G.step()
95 |
96 | # set log
97 | self.log_dict['l_pix'] = l_pix.item()
98 |
99 | def test(self):
100 | self.netG.eval()
101 | with torch.no_grad():
102 | self.fake_H = self.netG(self.var_L)
103 | self.netG.train()
104 |
105 | def test_x8(self):
106 | # from https://github.com/thstkdgus35/EDSR-PyTorch
107 | self.netG.eval()
108 |
109 | def _transform(v, op):
110 | # if self.precision != 'single': v = v.float()
111 | v2np = v.data.cpu().numpy()
112 | if op == 'v':
113 | tfnp = v2np[:, :, :, ::-1].copy()
114 | elif op == 'h':
115 | tfnp = v2np[:, :, ::-1, :].copy()
116 | elif op == 't':
117 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
118 |
119 | ret = torch.Tensor(tfnp).to(self.device)
120 | # if self.precision == 'half': ret = ret.half()
121 |
122 | return ret
123 |
124 | lr_list = [self.var_L]
125 | for tf in 'v', 'h', 't':
126 | lr_list.extend([_transform(t, tf) for t in lr_list])
127 | with torch.no_grad():
128 | sr_list = [self.netG(aug) for aug in lr_list]
129 | for i in range(len(sr_list)):
130 | if i > 3:
131 | sr_list[i] = _transform(sr_list[i], 't')
132 | if i % 4 > 1:
133 | sr_list[i] = _transform(sr_list[i], 'h')
134 | if (i % 4) % 2 == 1:
135 | sr_list[i] = _transform(sr_list[i], 'v')
136 |
137 | output_cat = torch.cat(sr_list, dim=0)
138 | self.fake_H = output_cat.mean(dim=0, keepdim=True)
139 | self.netG.train()
140 |
141 | def get_current_log(self):
142 | return self.log_dict
143 |
144 | def get_current_visuals(self, need_GT=True):
145 | out_dict = OrderedDict()
146 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
147 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
148 | if need_GT:
149 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
150 | return out_dict
151 |
152 | def print_network(self):
153 | s, n = self.get_network_description(self.netG)
154 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
155 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
156 | self.netG.module.__class__.__name__)
157 | else:
158 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
159 | if self.rank <= 0:
160 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
161 | logger.info(s)
162 |
163 | def load(self):
164 | load_path_G = self.opt['path']['pretrain_model_G']
165 | if load_path_G is not None:
166 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
167 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
168 |
169 | def save(self, iter_label):
170 | self.save_network(self.netG, 'G', iter_label)
171 |
--------------------------------------------------------------------------------
/codes/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | model = opt['model']
7 | # image restoration
8 | if model == 'sr': # PSNR-oriented super resolution
9 | from .SR_model import SRModel as M
10 | elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
11 | from .SRGAN_model import SRGANModel as M
12 | elif model == 'ranksrgan': # GAN-based super resolution
13 | from .RankSRGAN_model import SRGANModel as M
14 | # Ranker
15 | elif model == 'rank':
16 | from .Ranker_model import Ranker_Model as M
17 | else:
18 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
19 | m = M(opt)
20 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
21 | return m
22 |
--------------------------------------------------------------------------------
/codes/models/__pycache__/RankSRGAN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/RankSRGAN.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/RankSRGAN_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/RankSRGAN_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/Ranker_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/Ranker_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/SRGAN_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/SRGAN_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/SRGAN_rank_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/SRGAN_rank_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/SR_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/SR_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/base_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/base_model.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/lr_scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/lr_scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/RRDBNet_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import models.archs.arch_util as arch_util
6 |
7 |
8 | class ResidualDenseBlock_5C(nn.Module):
9 | def __init__(self, nf=64, gc=32, bias=True):
10 | super(ResidualDenseBlock_5C, self).__init__()
11 | # gc: growth channel, i.e. intermediate channels
12 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
13 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
14 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
15 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
16 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
17 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
18 |
19 | # initialization
20 | arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
21 | 0.1)
22 |
23 | def forward(self, x):
24 | x1 = self.lrelu(self.conv1(x))
25 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
26 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
27 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
28 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
29 | return x5 * 0.2 + x
30 |
31 |
32 | class RRDB(nn.Module):
33 | '''Residual in Residual Dense Block'''
34 |
35 | def __init__(self, nf, gc=32):
36 | super(RRDB, self).__init__()
37 | self.RDB1 = ResidualDenseBlock_5C(nf, gc)
38 | self.RDB2 = ResidualDenseBlock_5C(nf, gc)
39 | self.RDB3 = ResidualDenseBlock_5C(nf, gc)
40 |
41 | def forward(self, x):
42 | out = self.RDB1(x)
43 | out = self.RDB2(out)
44 | out = self.RDB3(out)
45 | return out * 0.2 + x
46 |
47 |
48 | class RRDBNet(nn.Module):
49 | def __init__(self, in_nc, out_nc, nf, nb, gc=32):
50 | super(RRDBNet, self).__init__()
51 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
52 |
53 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
54 | self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
55 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
56 | #### upsampling
57 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
58 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
59 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
60 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
61 |
62 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
63 |
64 | def forward(self, x):
65 | fea = self.conv_first(x)
66 | trunk = self.trunk_conv(self.RRDB_trunk(fea))
67 | fea = fea + trunk
68 |
69 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
70 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
71 | out = self.conv_last(self.lrelu(self.HRconv(fea)))
72 |
73 | return out
74 |
--------------------------------------------------------------------------------
/codes/models/archs/RankSRGAN_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 | import models.archs.arch_util as arch_util
4 |
5 | ####################
6 | # Generator for RankSRGAN
7 | ####################
8 | class SRResNet(nn.Module):
9 |
10 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
11 | super(SRResNet, self).__init__()
12 | self.upscale = upscale
13 |
14 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
15 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
16 | self.recon_trunk = arch_util.make_layer(basic_block, nb)
17 | self.LRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
18 |
19 | # upsampling
20 | if self.upscale == 2:
21 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
22 | self.pixel_shuffle = nn.PixelShuffle(2)
23 | elif self.upscale == 3:
24 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
25 | self.pixel_shuffle = nn.PixelShuffle(3)
26 | elif self.upscale == 4:
27 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
28 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
29 | self.pixel_shuffle = nn.PixelShuffle(2)
30 |
31 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
32 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
33 |
34 | # activation function
35 | self.relu = nn.ReLU(inplace=True)
36 |
37 | # initialization
38 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
39 | 0.1)
40 | if self.upscale == 4:
41 | arch_util.initialize_weights(self.upconv2, 0.1)
42 |
43 | def forward(self, x):
44 | fea = self.conv_first(x)
45 | out = self.recon_trunk(fea)
46 | out = self.LRconv(out)
47 |
48 | if self.upscale == 4:
49 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea)))
50 | out = self.relu(self.pixel_shuffle(self.upconv2(out)))
51 | elif self.upscale == 3 or self.upscale == 2:
52 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea)))
53 |
54 | out = self.conv_last(self.relu(self.HRconv(out)))
55 |
56 | return out
57 |
58 | ####################
59 | # Discriminator with patchsize 296 for RankSRGAN
60 | ####################
61 | class Discriminator_VGG_296(nn.Module):
62 | def __init__(self, in_nc, nf):
63 | super(Discriminator_VGG_296, self).__init__()
64 | # [64, 128, 128]
65 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
66 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
67 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
68 | # [64, 64, 64]
69 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
70 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
71 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
72 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
73 | # [128, 32, 32]
74 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
75 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
76 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
77 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
78 | # [256, 16, 16]
79 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
80 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
81 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
82 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
83 | # [512, 8, 8]
84 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
85 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
86 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
87 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
88 |
89 | self.linear1 = nn.Linear(512 * 9 * 9, 100)
90 | self.linear2 = nn.Linear(100, 1)
91 |
92 | # activation function
93 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
94 |
95 | def forward(self, x):
96 | fea = self.lrelu(self.conv0_0(x))
97 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
98 |
99 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
100 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
101 |
102 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
103 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
104 |
105 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
106 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
107 |
108 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
109 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
110 |
111 | fea = fea.view(fea.size(0), -1)
112 | fea = self.lrelu(self.linear1(fea))
113 | out = self.linear2(fea)
114 | return out
115 |
116 |
117 | ####################
118 | # Ranker
119 | ####################
120 |
121 | class Ranker_VGG12_296(nn.Module):
122 | def __init__(self, in_nc, nf):
123 | super(Ranker_VGG12_296, self).__init__()
124 | # [64, 128, 128]
125 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
126 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=True)
127 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
128 | # [64, 64, 64]
129 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=True)
130 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
131 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=True)
132 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
133 | # [128, 32, 32]
134 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=True)
135 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
136 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=True)
137 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
138 | # [256, 16, 16]
139 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=True)
140 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
141 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=True)
142 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
143 | # [512, 8, 8]
144 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=True)
145 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
146 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=True)
147 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
148 |
149 | # classifier
150 | self.classifier = nn.Sequential(
151 | nn.Linear(512, 100),
152 | nn.LeakyReLU(0.2, True),
153 | nn.Linear(100, 1)
154 | )
155 |
156 | # activation function
157 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
158 |
159 | def forward(self, x):
160 | fea = self.lrelu(self.conv0_0(x))
161 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
162 |
163 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
164 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
165 |
166 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
167 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
168 |
169 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
170 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
171 |
172 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
173 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
174 |
175 | fea = nn.AvgPool2d(fea.size()[2])(fea)
176 | fea = fea.view(fea.size(0), -1)
177 | out = self.classifier(fea)
178 | return out
179 |
180 |
--------------------------------------------------------------------------------
/codes/models/archs/SRResNet_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import models.archs.arch_util as arch_util
5 |
6 |
7 | class MSRResNet(nn.Module):
8 | ''' modified SRResNet'''
9 |
10 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
11 | super(MSRResNet, self).__init__()
12 | self.upscale = upscale
13 |
14 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
15 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
16 | self.recon_trunk = arch_util.make_layer(basic_block, nb)
17 |
18 | # upsampling
19 | if self.upscale == 2:
20 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
21 | self.pixel_shuffle = nn.PixelShuffle(2)
22 | elif self.upscale == 3:
23 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
24 | self.pixel_shuffle = nn.PixelShuffle(3)
25 | elif self.upscale == 4:
26 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
27 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
28 | self.pixel_shuffle = nn.PixelShuffle(2)
29 |
30 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
31 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
32 |
33 | # activation function
34 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
35 |
36 | # initialization
37 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
38 | 0.1)
39 | if self.upscale == 4:
40 | arch_util.initialize_weights(self.upconv2, 0.1)
41 |
42 | def forward(self, x):
43 | fea = self.lrelu(self.conv_first(x))
44 | out = self.recon_trunk(fea)
45 |
46 | if self.upscale == 4:
47 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
48 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
49 | elif self.upscale == 3 or self.upscale == 2:
50 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
51 |
52 | out = self.conv_last(self.lrelu(self.HRconv(out)))
53 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
54 | out += base
55 | return out
56 |
--------------------------------------------------------------------------------
/codes/models/archs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__init__.py
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/RRDBNet_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/RankSRGAN_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/RankSRGAN_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/SRResNet_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/arch_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/arch_util.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/models/archs/__pycache__/discriminator_vgg_arch.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/models/archs/arch_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 |
27 | def make_layer(block, n_layers):
28 | layers = []
29 | for _ in range(n_layers):
30 | layers.append(block())
31 | return nn.Sequential(*layers)
32 |
33 |
34 | class ResidualBlock_noBN(nn.Module):
35 | '''Residual block w/o BN
36 | ---Conv-ReLU-Conv-+-
37 | |________________|
38 | '''
39 |
40 | def __init__(self, nf=64):
41 | super(ResidualBlock_noBN, self).__init__()
42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
44 |
45 | # initialization
46 | initialize_weights([self.conv1, self.conv2], 0.1)
47 |
48 | def forward(self, x):
49 | identity = x
50 | out = F.relu(self.conv1(x), inplace=True)
51 | out = self.conv2(out)
52 | return identity + out
53 |
54 |
--------------------------------------------------------------------------------
/codes/models/archs/discriminator_vgg_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | class Discriminator_VGG_128(nn.Module):
7 | def __init__(self, in_nc, nf):
8 | super(Discriminator_VGG_128, self).__init__()
9 | # [64, 128, 128]
10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
13 | # [64, 64, 64]
14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
18 | # [128, 32, 32]
19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
23 | # [256, 16, 16]
24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
28 | # [512, 8, 8]
29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
33 |
34 | self.linear1 = nn.Linear(512 * 4 * 4, 100)
35 | self.linear2 = nn.Linear(100, 1)
36 |
37 | # activation function
38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
39 |
40 | def forward(self, x):
41 | fea = self.lrelu(self.conv0_0(x))
42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
43 |
44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
46 |
47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
49 |
50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
52 |
53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
55 |
56 | fea = fea.view(fea.size(0), -1)
57 | fea = self.lrelu(self.linear1(fea))
58 | out = self.linear2(fea)
59 | return out
60 |
61 |
62 | class VGGFeatureExtractor(nn.Module):
63 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
64 | device=torch.device('cpu')):
65 | super(VGGFeatureExtractor, self).__init__()
66 | self.use_input_norm = use_input_norm
67 | if use_bn:
68 | model = torchvision.models.vgg19_bn(pretrained=True)
69 | else:
70 | model = torchvision.models.vgg19(pretrained=True)
71 | if self.use_input_norm:
72 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
73 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
75 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
76 | self.register_buffer('mean', mean)
77 | self.register_buffer('std', std)
78 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
79 | # No need to BP to variable
80 | for k, v in self.features.named_parameters():
81 | v.requires_grad = False
82 |
83 | def forward(self, x):
84 | # Assume input range is [0, 1]
85 | if self.use_input_norm:
86 | x = (x - self.mean) / self.std
87 | output = self.features(x)
88 | return output
89 |
--------------------------------------------------------------------------------
/codes/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
12 | self.is_train = opt['is_train']
13 | self.schedulers = []
14 | self.optimizers = []
15 |
16 | def feed_data(self, data):
17 | pass
18 |
19 | def optimize_parameters(self):
20 | pass
21 |
22 | def get_current_visuals(self):
23 | pass
24 |
25 | def get_current_losses(self):
26 | pass
27 |
28 | def print_network(self):
29 | pass
30 |
31 | def save(self, label):
32 | pass
33 |
34 | def load(self):
35 | pass
36 |
37 | def _set_lr(self, lr_groups_l):
38 | """Set learning rate for warmup
39 | lr_groups_l: list for lr_groups. each for a optimizer"""
40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
41 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
42 | param_group['lr'] = lr
43 |
44 | def _get_init_lr(self):
45 | """Get the initial lr, which is set by the scheduler"""
46 | init_lr_groups_l = []
47 | for optimizer in self.optimizers:
48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
49 | return init_lr_groups_l
50 |
51 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
52 | for scheduler in self.schedulers:
53 | scheduler.step()
54 | # set up warm-up learning rate
55 | if cur_iter < warmup_iter:
56 | # get initial lr for each group
57 | init_lr_g_l = self._get_init_lr()
58 | # modify warming-up learning rates
59 | warm_up_lr_l = []
60 | for init_lr_g in init_lr_g_l:
61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
62 | # set learning rate
63 | self._set_lr(warm_up_lr_l)
64 |
65 | def get_current_learning_rate(self):
66 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
67 |
68 | def get_network_description(self, network):
69 | """Get the string and total parameters of the network"""
70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
71 | network = network.module
72 | return str(network), sum(map(lambda x: x.numel(), network.parameters()))
73 |
74 | def save_network(self, network, network_label, iter_label):
75 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
76 | save_path = os.path.join(self.opt['path']['models'], save_filename)
77 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
78 | network = network.module
79 | state_dict = network.state_dict()
80 | for key, param in state_dict.items():
81 | state_dict[key] = param.cpu()
82 | torch.save(state_dict, save_path)
83 |
84 | def load_network(self, load_path, network, strict=True):
85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
86 | network = network.module
87 | load_net = torch.load(load_path)
88 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
89 | for k, v in load_net.items():
90 | if k.startswith('module.'):
91 | load_net_clean[k[7:]] = v
92 | else:
93 | load_net_clean[k] = v
94 | network.load_state_dict(load_net_clean, strict=strict)
95 |
96 | def save_training_state(self, epoch, iter_step):
97 | """Save training state during training, which will be used for resuming"""
98 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
99 | for s in self.schedulers:
100 | state['schedulers'].append(s.state_dict())
101 | for o in self.optimizers:
102 | state['optimizers'].append(o.state_dict())
103 | save_filename = '{}.state'.format(iter_step)
104 | save_path = os.path.join(self.opt['path']['training_state'], save_filename)
105 | torch.save(state, save_path)
106 |
107 | def resume_training(self, resume_state):
108 | """Resume the optimizers and schedulers for training"""
109 | resume_optimizers = resume_state['optimizers']
110 | resume_schedulers = resume_state['schedulers']
111 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
112 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
113 | for i, o in enumerate(resume_optimizers):
114 | self.optimizers[i].load_state_dict(o)
115 | for i, s in enumerate(resume_schedulers):
116 | self.schedulers[i].load_state_dict(s)
117 |
--------------------------------------------------------------------------------
/codes/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CharbonnierLoss(nn.Module):
6 | """Charbonnier Loss (L1)"""
7 |
8 | def __init__(self, eps=1e-6):
9 | super(CharbonnierLoss, self).__init__()
10 | self.eps = eps
11 |
12 | def forward(self, x, y):
13 | diff = x - y
14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps))
15 | return loss
16 |
17 |
18 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
19 | class GANLoss(nn.Module):
20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
21 | super(GANLoss, self).__init__()
22 | self.gan_type = gan_type.lower()
23 | self.real_label_val = real_label_val
24 | self.fake_label_val = fake_label_val
25 |
26 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
27 | self.loss = nn.BCEWithLogitsLoss()
28 | elif self.gan_type == 'lsgan':
29 | self.loss = nn.MSELoss()
30 | elif self.gan_type == 'wgan-gp':
31 |
32 | def wgan_loss(input, target):
33 | # target is boolean
34 | return -1 * input.mean() if target else input.mean()
35 |
36 | self.loss = wgan_loss
37 | else:
38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
39 |
40 | def get_target_label(self, input, target_is_real):
41 | if self.gan_type == 'wgan-gp':
42 | return target_is_real
43 | if target_is_real:
44 | return torch.empty_like(input).fill_(self.real_label_val)
45 | else:
46 | return torch.empty_like(input).fill_(self.fake_label_val)
47 |
48 | def forward(self, input, target_is_real):
49 | target_label = self.get_target_label(input, target_is_real)
50 | loss = self.loss(input, target_label)
51 | return loss
52 |
53 |
54 | class GradientPenaltyLoss(nn.Module):
55 | def __init__(self, device=torch.device('cpu')):
56 | super(GradientPenaltyLoss, self).__init__()
57 | self.register_buffer('grad_outputs', torch.Tensor())
58 | self.grad_outputs = self.grad_outputs.to(device)
59 |
60 | def get_grad_outputs(self, input):
61 | if self.grad_outputs.size() != input.size():
62 | self.grad_outputs.resize_(input.size()).fill_(1.0)
63 | return self.grad_outputs
64 |
65 | def forward(self, interp, interp_crit):
66 | grad_outputs = self.get_grad_outputs(interp_crit)
67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
68 | grad_outputs=grad_outputs, create_graph=True,
69 | retain_graph=True, only_inputs=True)[0]
70 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
71 | grad_interp_norm = grad_interp.norm(2, dim=1)
72 |
73 | loss = ((grad_interp_norm - 1)**2).mean()
74 | return loss
75 |
--------------------------------------------------------------------------------
/codes/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restarts = [v + 1 for v in self.restarts]
16 | self.restart_weights = weights if weights else [1]
17 | assert len(self.restarts) == len(
18 | self.restart_weights), 'restarts and their weights do not match.'
19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
20 |
21 | def get_lr(self):
22 | if self.last_epoch in self.restarts:
23 | if self.clear_state:
24 | self.optimizer.state = defaultdict(dict)
25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
27 | if self.last_epoch not in self.milestones:
28 | return [group['lr'] for group in self.optimizer.param_groups]
29 | return [
30 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
31 | for group in self.optimizer.param_groups
32 | ]
33 |
34 |
35 | class CosineAnnealingLR_Restart(_LRScheduler):
36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
37 | self.T_period = T_period
38 | self.T_max = self.T_period[0] # current T period
39 | self.eta_min = eta_min
40 | self.restarts = restarts if restarts else [0]
41 | self.restarts = [v + 1 for v in self.restarts]
42 | self.restart_weights = weights if weights else [1]
43 | self.last_restart = 0
44 | assert len(self.restarts) == len(
45 | self.restart_weights), 'restarts and their weights do not match.'
46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
47 |
48 | def get_lr(self):
49 | if self.last_epoch == 0:
50 | return self.base_lrs
51 | elif self.last_epoch in self.restarts:
52 | self.last_restart = self.last_epoch
53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
57 | return [
58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
60 | ]
61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
63 | (group['lr'] - self.eta_min) + self.eta_min
64 | for group in self.optimizer.param_groups]
65 |
66 |
67 | if __name__ == "__main__":
68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
69 | betas=(0.9, 0.99))
70 | ##############################
71 | # MultiStepLR_Restart
72 | ##############################
73 | ## Original
74 | lr_steps = [200000, 400000, 600000, 800000]
75 | restarts = None
76 | restart_weights = None
77 |
78 | ## two
79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
80 | restarts = [500000]
81 | restart_weights = [1]
82 |
83 | ## four
84 | lr_steps = [
85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
87 | ]
88 | restarts = [250000, 500000, 750000]
89 | restart_weights = [1, 1, 1]
90 |
91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
92 | clear_state=False)
93 |
94 | ##############################
95 | # Cosine Annealing Restart
96 | ##############################
97 | ## two
98 | T_period = [500000, 500000]
99 | restarts = [500000]
100 | restart_weights = [1]
101 |
102 | ## four
103 | T_period = [250000, 250000, 250000, 250000]
104 | restarts = [250000, 500000, 750000]
105 | restart_weights = [1, 1, 1]
106 |
107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
108 | weights=restart_weights)
109 |
110 | ##############################
111 | # Draw figure
112 | ##############################
113 | N_iter = 1000000
114 | lr_l = list(range(N_iter))
115 | for i in range(N_iter):
116 | scheduler.step()
117 | current_lr = optimizer.param_groups[0]['lr']
118 | lr_l[i] = current_lr
119 |
120 | import matplotlib as mpl
121 | from matplotlib import pyplot as plt
122 | import matplotlib.ticker as mtick
123 | mpl.style.use('default')
124 | import seaborn
125 | seaborn.set(style='whitegrid')
126 | seaborn.set_context('paper')
127 |
128 | plt.figure(1)
129 | plt.subplot(111)
130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
131 | plt.title('Title', fontsize=16, color='k')
132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
133 | legend = plt.legend(loc='upper right', shadow=False)
134 | ax = plt.gca()
135 | labels = ax.get_xticks().tolist()
136 | for k, v in enumerate(labels):
137 | labels[k] = str(int(v / 1000)) + 'K'
138 | ax.set_xticklabels(labels)
139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
140 |
141 | ax.set_ylabel('Learning rate')
142 | ax.set_xlabel('Iteration')
143 | fig = plt.gcf()
144 | plt.show()
145 |
--------------------------------------------------------------------------------
/codes/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import models.archs.SRResNet_arch as SRResNet_arch
3 | import models.archs.discriminator_vgg_arch as SRGAN_arch
4 | import models.archs.RRDBNet_arch as RRDBNet_arch
5 | import models.archs.RankSRGAN_arch as RankSRGAN_arch
6 |
7 |
8 | # Generator
9 | def define_G(opt):
10 | opt_net = opt['network_G']
11 | which_model = opt_net['which_model_G']
12 |
13 | # image restoration
14 | if which_model == 'MSRResNet':
15 | netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
16 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
17 | elif which_model == 'RRDBNet':
18 | netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
19 | nf=opt_net['nf'], nb=opt_net['nb'])
20 | elif which_model == 'SRResNet':
21 | netG = RankSRGAN_arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
22 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
23 | else:
24 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
25 |
26 | return netG
27 |
28 |
29 | # Discriminator
30 | def define_D(opt):
31 | opt_net = opt['network_D']
32 | which_model = opt_net['which_model_D']
33 |
34 | if which_model == 'discriminator_vgg_128':
35 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
36 | elif which_model == 'discriminator_vgg_296':
37 | netD = RankSRGAN_arch.Discriminator_VGG_296(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
38 | else:
39 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
40 | return netD
41 |
42 | # Define network used for perceptual loss
43 | def define_F(opt, use_bn=False):
44 | gpu_ids = opt['gpu_ids']
45 | device = torch.device('cuda' if gpu_ids else 'cpu')
46 | # PyTorch pretrained VGG19-54, before ReLU.
47 | if use_bn:
48 | feature_layer = 49
49 | else:
50 | feature_layer = 34
51 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
52 | use_input_norm=True, device=device)
53 | netF.eval() # No need to train
54 | return netF
55 |
56 | # Define network used for rank-content loss
57 | def define_R(opt):
58 | opt_net = opt['network_R']
59 | which_model = opt_net['which_model_R']
60 |
61 | if which_model == 'Ranker_VGG12':
62 | netR = RankSRGAN_arch.Ranker_VGG12_296(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
63 | else:
64 | raise NotImplementedError('Ranker model [{:s}] is not recognized'.format(which_model))
65 |
66 | return netR
67 |
--------------------------------------------------------------------------------
/codes/options/README.md:
--------------------------------------------------------------------------------
1 | # Configurations
2 | - Use **json** files to configure options.
3 | - Convert the json file to python dict.
4 | - Support `//` comments and use `null` for `None`.
5 |
6 | ## Table
7 | Click for detailed explanations for each json file.
8 |
9 | 1. [RankSRGAN_NIQE.json](#RankSRGAN_NIQE_json)
10 | 2. [Ranker.json](#Ranker_json)
11 |
12 | ## RankSRGAN_NIQE_json
13 | ```c++
14 | {
15 | "name": "RankSRGANx4_NIQE"
16 | ,"model":"ranksrgan" // use tensorboard_logger
17 | ,"scale": 4
18 | ,"gpu_ids": [2] // specify GPUs, actually it sets the `CUDA_VISIBLE_DEVICES`
19 |
20 | ,"datasets": { // configure the training and validation datasets
21 | "train": { // training dataset configurations
22 | "name": "DIV2K"
23 | ,"mode": "LRHR"
24 | ,"dataroot_HR": "/home/wlzhang/BasicSR12/data/DIV2K800_sub.lmdb" // HR data root
25 | ,"dataroot_LR": "/home/wlzhang/BasicSR12/data/DIV2K800_sub_bicLRx4.lmdb" // LR data root
26 | ,"subset_file": null
27 | ,"use_shuffle": true
28 | ,"n_workers": 8 // number of data load workers
29 | ,"batch_size": 8
30 | ,"HR_size": 296 // 128 for SRGAN | 296 for RankSRGAN, cropped HR patch size
31 | ,"use_flip": true
32 | ,"use_rot": true
33 | , "random_flip": false // whether use horizontal and vertical flips
34 | , "random_scale": false // whether use rotations: 90, 190, 270 degrees
35 | }
36 | , "val": { // validation dataset configurations
37 | "name": "val_PIRM"
38 | ,"mode": "LRHR"
39 | ,"dataroot_HR": "/home/wlzhang/BasicSR12/data/val/PIRMtestHR"
40 | ,"dataroot_LR": "/home/wlzhang/BasicSR12/data/val/PIRMtest"
41 | }
42 | }
43 |
44 | ,"path": {
45 | "root": "/home/wlzhang/RankSRGAN", // root path
46 | // "resume_state": "../experiments/RankSRGANx4_NIQE/training_state/152000.state", // Resume the training from 152000 iteration
47 | "pretrain_model_G": "/home/wlzhang/RankSRGAN/experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth", // G network pretrain model
48 | "pretrain_model_R": "/home/wlzhang/RankSRGAN/experiments/pretrained_models/Ranker_NIQE.pth", // R network pretrain model
49 |
50 | "experiments_root": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE",
51 | "models": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE/models",
52 | "log": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE",
53 | "val_images": "/home/wlzhang/RankSRGAN/experiments/RankSRGANx4_NIQE/val_images"
54 | }
55 |
56 | ,"network_G": { // configurations for the network G
57 | "which_model_G": "sr_resnet"
58 | ,"norm_type": null // null | "batch", norm type
59 | ,"mode": "CNA" // Convolution mode: CNA for Conv-Norm_Activation
60 | ,"nf": 64 // number of features for each layer
61 | ,"nb": 16 // number of blocks
62 | ,"in_nc": 3 // input channels
63 | ,"out_nc": 3 // output channels
64 | ,"group": 1
65 | }
66 | ,"network_D": { // configurations for the network D
67 | "which_model_D": "discriminator_vgg_128"
68 | ,"norm_type": "batch"
69 | ,"act_type": "leakyrelu"
70 | ,"mode": "CNA"
71 | ,"nf": 64
72 | ,"in_nc": 3
73 | },
74 | "network_R": {
75 | "which_model_R": "Ranker_VGG12",
76 | "norm_type": "batch",
77 | "act_type": "leakyrelu",
78 | "mode": "CNA",
79 | "nf": 64,
80 | "in_nc": 3
81 | },
82 | "train": { // training strategies
83 | "lr_G": 0.0001, // initialized learning rate for G
84 | "train_D": 1,
85 | "weight_decay_G": 0,
86 | "beta1_G": 0.9,
87 | "lr_D": 0.0001, // initialized learning rate for D
88 | "weight_decay_D": 0,
89 | "beta1_D": 0.9,
90 | "lr_scheme": "MultiStepLR", // learning rate decay scheme
91 | "lr_steps": [
92 | 50000,
93 | 100000,
94 | 200000,
95 | 300000
96 | ],
97 | "lr_gamma": 0.5,
98 | "pixel_criterion": "l1", // "l1" | "l2", pixel criterion
99 | "pixel_weight": 0,
100 | "feature_criterion": "l1", // perceptual criterion (VGG loss)
101 | "feature_weight": 1,
102 | "gan_type": "vanilla", // GAN type
103 | "gan_weight": 0.005,
104 | "D_update_ratio": 1,
105 | "D_init_iters": 0,
106 | "R_weight": 0.03, // Ranker-content loss
107 | "R_bias": 0,
108 | "manual_seed": 0,
109 | "niter": 500000.0, // total training iteration
110 | "val_freq": 5000 // validation frequency
111 | },
112 | "logger": { // logger configurations
113 | "print_freq": 200
114 | ,"save_checkpoint_freq": 5000
115 | },
116 | "timestamp": "180804-004247",
117 | "is_train": true,
118 | "fine_tune": false
119 | }
120 |
121 | ```
122 | ## Ranker_json
123 |
124 | ```c++
125 | {
126 | "name": "Ranker_NIQE" //
127 | ,"use_tb_logger": true // use tensorboard_logger
128 | ,"model":"rank"
129 | ,"scale": 4
130 | ,"gpu_ids": [2,5] // specify GPUs, actually it sets the `CUDA_VISIBLE_DEVICES`
131 | ,"datasets": { // configure the training and validation rank datasets
132 | "train": { // training dataset configurations
133 | "name": "DF2K_train_rankdataset"
134 | ,"mode": "RANK_IMIM_Pair"
135 | ,"dataroot_HR": null
136 | ,"dataroot_LR":null
137 | ,"dataroot_img1": "/home/wlzhang/data/rankdataset/DF2K_train_patch_esrgan/" // Rankdataset: Perceptual level1 data root
138 | ,"dataroot_img2": "/home/wlzhang/data/rankdataset/DF2K_train_patch_srgan/" // Rankdataset: Perceptual level2 data root
139 | ,"dataroot_img3": "/home/wlzhang/data/rankdataset/DF2K_train_patch_srres/" // Rankdataset: Perceptual level3 data root
140 | ,"dataroot_label_file": "/home/wlzhang/data/rankdataset/DF2K_train_patch_label.txt" // Rankdataset: Perceptual rank label root
141 | ,"subset_file": null
142 | ,"use_shuffle": true
143 | ,"n_workers": 8 // number of data load workers
144 | ,"batch_size": 32
145 | ,"HR_size": 128
146 | ,"use_flip": true
147 | ,"use_rot": true
148 | }
149 | , "val": {
150 | "name": "DF2K_valid_rankdataset" // validation dataset configurations
151 | ,"mode": "RANK_IMIM_Pair"
152 | ,"dataroot_HR": null
153 | ,"dataroot_LR":null
154 | ,"dataroot_img1": "/home/wlzhang/data/rankdataset/DF2K_test_patch_all/"
155 | ,"dataroot_label_file": "/home/wlzhang/data/rankdataset/DF2K_test_patch_label.txt"
156 | }
157 | }
158 |
159 | ,"path": { // root path
160 | "root": "/home/wlzhang/RankSRGAN",
161 | "experiments_root": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE",
162 | "models": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE/models",
163 | "log": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE",
164 | "val_images": "/home/wlzhang/RankSRGAN/experiments/Ranker_NIQE/val_images"
165 | }
166 |
167 | ,"network_G": {
168 | "which_model_G": "sr_resnet"
169 | ,"norm_type": null
170 | ,"mode": "CNA"
171 | ,"nf": 64
172 | ,"nb": 16
173 | ,"in_nc": 3
174 | ,"out_nc": 3
175 | ,"group": 1
176 | }
177 | ,"network_R": { // configurations for the network Ranker
178 | "which_model_R": "Ranker_VGG12"
179 | ,"norm_type": "batch" // null | "batch", norm type
180 | ,"act_type": "leakyrelu"
181 | ,"mode": "CNA"
182 | ,"nf": 64
183 | ,"nb": 16
184 | ,"in_nc": 3
185 | ,"out_nc": 3
186 | ,"in_nc": 3
187 | }
188 |
189 | ,"train": { // training strategies
190 | "lr_R": 1e-3 // initialized learning rate for R
191 | ,"weight_decay_R": 1e-4
192 | ,"beta1_G": 0.9
193 | ,"lr_D": 1e-4
194 | ,"weight_decay_D": 0
195 | ,"beta1_D": 0.9
196 | ,"lr_scheme": "MultiStepLR"
197 | ,"lr_steps": [100000, 200000] // learning rate decay scheme
198 |
199 | ,"lr_gamma": 0.5
200 |
201 | // ,"pixel_criterion": "l1"
202 | // ,"pixel_weight": 1
203 | // ,"feature_criterion": "l1"
204 | // ,"feature_weight": 1
205 | // ,"gan_type": "vanilla"
206 | // ,"gan_weight": 5e-3
207 |
208 | ,"D_update_ratio": 1
209 | ,"D_init_iters": 0
210 |
211 | ,"manual_seed": 0
212 | ,"niter": 400000 // total training iteration
213 | ,"val_freq": 5000 // validation frequency
214 | }
215 |
216 | ,"logger": { // logger configurations
217 | "print_freq": 200
218 | ,"save_checkpoint_freq": 5000
219 | }
220 | }
221 | '''
222 |
--------------------------------------------------------------------------------
/codes/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/options/__init__.py
--------------------------------------------------------------------------------
/codes/options/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/options/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/options/__pycache__/options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/options/__pycache__/options.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 | Loader, Dumper = OrderedYaml()
7 |
8 |
9 | def parse(opt_path, is_train=True):
10 | with open(opt_path, mode='r') as f:
11 | opt = yaml.load(f, Loader=Loader)
12 | # export CUDA_VISIBLE_DEVICES
13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
16 |
17 | opt['is_train'] = is_train
18 | if opt['distortion'] == 'sr':
19 | scale = opt['scale']
20 |
21 | # datasets
22 | for phase, dataset in opt['datasets'].items():
23 | phase = phase.split('_')[0]
24 | dataset['phase'] = phase
25 | if opt['distortion'] == 'sr':
26 | dataset['scale'] = scale
27 | is_lmdb = False
28 | if dataset.get('dataroot_GT', None) is not None:
29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
30 | if dataset['dataroot_GT'].endswith('lmdb'):
31 | is_lmdb = True
32 | if dataset.get('dataroot_LQ', None) is not None:
33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
34 | if dataset['dataroot_LQ'].endswith('lmdb'):
35 | is_lmdb = True
36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
37 | if dataset['mode'].endswith('mc'): # for memcached
38 | dataset['data_type'] = 'mc'
39 | dataset['mode'] = dataset['mode'].replace('_mc', '')
40 |
41 | # path
42 | for key, path in opt['path'].items():
43 | if path and key in opt['path'] and key != 'strict_load':
44 | opt['path'][key] = osp.expanduser(path)
45 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
46 | if is_train:
47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
48 | opt['path']['experiments_root'] = experiments_root
49 | opt['path']['models'] = osp.join(experiments_root, 'models')
50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
51 | opt['path']['log'] = experiments_root
52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
53 |
54 | # change some options for debug mode
55 | if 'debug' in opt['name']:
56 | opt['train']['val_freq'] = 8
57 | opt['logger']['print_freq'] = 1
58 | opt['logger']['save_checkpoint_freq'] = 8
59 | else: # test
60 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
61 | opt['path']['results_root'] = results_root
62 | opt['path']['log'] = results_root
63 |
64 | # network
65 | if opt['distortion'] == 'sr':
66 | opt['network_G']['scale'] = scale
67 |
68 | return opt
69 |
70 |
71 | def dict2str(opt, indent_l=1):
72 | '''dict to string for logger'''
73 | msg = ''
74 | for k, v in opt.items():
75 | if isinstance(v, dict):
76 | msg += ' ' * (indent_l * 2) + k + ':[\n'
77 | msg += dict2str(v, indent_l + 1)
78 | msg += ' ' * (indent_l * 2) + ']\n'
79 | else:
80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
81 | return msg
82 |
83 |
84 | class NoneDict(dict):
85 | def __missing__(self, key):
86 | return None
87 |
88 |
89 | # convert to NoneDict, which return None for missing key.
90 | def dict_to_nonedict(opt):
91 | if isinstance(opt, dict):
92 | new_opt = dict()
93 | for key, sub_opt in opt.items():
94 | new_opt[key] = dict_to_nonedict(sub_opt)
95 | return NoneDict(**new_opt)
96 | elif isinstance(opt, list):
97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
98 | else:
99 | return opt
100 |
101 |
102 | def check_resume(opt, resume_iter):
103 | '''Check resume states and pretrain_model paths'''
104 | logger = logging.getLogger('base')
105 | if opt['path']['resume_state']:
106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
107 | 'pretrain_model_D', None) is not None:
108 | logger.warning('pretrain_model path will be ignored when resuming training.')
109 |
110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
111 | '{}_G.pth'.format(resume_iter))
112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
113 | if 'gan' in opt['model']:
114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
115 | '{}_D.pth'.format(resume_iter))
116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
117 |
--------------------------------------------------------------------------------
/codes/options/test/test_RankSRGAN.yml:
--------------------------------------------------------------------------------
1 | name: RankSRGANx4
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 4
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | gpu_ids: [3]
8 |
9 | datasets:
10 | test_1: # the 1st test dataset
11 | name: set14
12 | mode: LQGT
13 | dataroot_GT: /home/wlzhang/BasicSR12/data/val/Set14_mod
14 | dataroot_LQ:
15 | test_2: # the 2st test dataset
16 | name: PIRMtest
17 | mode: LQGT
18 | dataroot_GT: /home/wlzhang/RankSRGAN/data/val/PIRMtestHR
19 | dataroot_LQ: /home/wlzhang/RankSRGAN/data/val/PIRMtest
20 |
21 | #### network structures
22 | network_G:
23 | which_model_G: SRResNet # SRResNet for RankSRGAN
24 | in_nc: 3
25 | out_nc: 3
26 | nf: 64
27 | nb: 16
28 | upscale: 4
29 |
30 | #### path
31 | # Download pretrained models from https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing
32 | path:
33 | pretrain_model_G: ../experiments/pretrained_models/mmsr_RankSRGAN_NIQE.pth
34 |
--------------------------------------------------------------------------------
/codes/options/train/train_RankSRGAN.yml:
--------------------------------------------------------------------------------
1 | # Not exactly the same as SRGAN in
2 | # With 16 Residual blocks w/o BN
3 |
4 | #### general settings
5 | name: 002_RankSRGANx4_SRResNetx4_DIV2K
6 | use_tb_logger: true
7 | model: ranksrgan
8 | distortion: sr
9 | scale: 4
10 | gpu_ids: [0,1]
11 |
12 | #### datasets
13 | datasets:
14 | train:
15 | name: DIV2K
16 | mode: LQGT
17 | dataroot_GT: /home/wlzhang/data/DIV2K/DIV2K_train_HR
18 | dataroot_LQ:
19 |
20 | use_shuffle: true
21 | n_workers: 6 # per GPU
22 | batch_size: 8
23 | GT_size: 296
24 | use_flip: true
25 | use_rot: true
26 | color: RGB
27 | val:
28 | name: val_Pirm_test100
29 | mode: LQGT
30 | dataroot_GT: /home/wlzhang/BasicSR12/data/val/PIRMtestHR
31 | dataroot_LQ: /home/wlzhang/BasicSR12/data/val/PIRMtest
32 |
33 | #### network structures
34 | network_G:
35 | which_model_G: SRResNet # SRResNet for RankSRGAN
36 | in_nc: 3
37 | out_nc: 3
38 | nf: 64
39 | nb: 16
40 | upscale: 4
41 | network_D:
42 | which_model_D: discriminator_vgg_296
43 | in_nc: 3
44 | nf: 64
45 |
46 | network_R:
47 | which_model_R: Ranker_VGG12
48 | in_nc: 3
49 | nf: 64
50 |
51 | #### path
52 | # Download pretrained models from https://drive.google.com/drive/folders/1_KhEc_zBRW7iLeEJITU3i923DC6wv51T?usp=sharing
53 | path:
54 | pretrain_model_G: ../experiments/pretrained_models/mmsr_SRResNet_pretrain.pth
55 | pretrain_model_R: ../experiments/pretrained_models/mmsr_Ranker_NIQE.pth
56 |
57 | strict_load: true
58 | resume_state: ~
59 |
60 | #### training settings: learning rate scheme, loss
61 | train:
62 | lr_G: !!float 1e-4
63 | weight_decay_G: 0
64 | beta1_G: 0.9
65 | beta2_G: 0.99
66 | lr_D: !!float 1e-4
67 | weight_decay_D: 0
68 | beta1_D: 0.9
69 | beta2_D: 0.99
70 | lr_scheme: MultiStepLR
71 |
72 | niter: 500000
73 | warmup_iter: -1 # no warm up
74 | lr_steps: [50000, 100000, 200000, 300000]
75 | lr_gamma: 0.5
76 |
77 | pixel_criterion: l1
78 | pixel_weight: 0
79 | feature_criterion: l1
80 | feature_weight: 1
81 | R_weight: !!float 3e-2 # rank-content loss
82 | R_bias: 0
83 | gan_type: gan # gan | ragan
84 | gan_weight: !!float 5e-3
85 |
86 | D_update_ratio: 1
87 | D_init_iters: 0
88 |
89 | manual_seed: 10
90 | val_freq: !!float 2e3
91 |
92 | #### logger
93 | logger:
94 | print_freq: 200
95 | save_checkpoint_freq: !!float 2e3
96 |
--------------------------------------------------------------------------------
/codes/options/train/train_Ranker.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: 001_VGG_Ranker_DF2K
3 | use_tb_logger: true
4 | model: rank
5 | distortion: sr
6 | scale: 4
7 | gpu_ids: [0,1]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: Rankdataset
13 | mode: RANK_IMIM_Pair
14 | dataroot_img1: /home/wlzhang/data/rankdataset/DF2K_train_patch_esrgan/
15 | dataroot_img2: /home/wlzhang/data/rankdataset/DF2K_train_patch_srgan/
16 | dataroot_img3: /home/wlzhang/data/rankdataset/DF2K_train_patch_srres/
17 | dataroot_label_file: /home/wlzhang/data/rankdataset/DF2K_train_patch_label.txt # Rankdataset: Perceptual rank label root
18 |
19 | use_shuffle: true
20 | n_workers: 6 # per GPU
21 | batch_size: 16
22 | GT_size: 128
23 | use_flip: true
24 | use_rot: true
25 | color: RGB
26 | val:
27 | name: DF2K_valid_rankdataset
28 | mode: RANK_IMIM_Pair
29 | dataroot_img1: /home/wlzhang/data/rankdataset/DF2K_test_patch_all/
30 | dataroot_label_file: /home/wlzhang/data/rankdataset/DF2K_test_patch_label.txt
31 |
32 | #### network structures
33 | network_G:
34 | which_model_G: RRDBNet
35 | in_nc: 3
36 | out_nc: 3
37 | nf: 64
38 | nb: 23
39 | network_R: # configurations for the network Ranker
40 | which_model_R: Ranker_VGG12
41 | in_nc: 3
42 | nf: 64
43 |
44 | #### path
45 | path:
46 | strict_load: true
47 | resume_state: ~
48 |
49 | #### training settings: learning rate scheme, loss
50 | train:
51 | lr_R: !!float 1e-4
52 | weight_decay_R: !!float 1e-4
53 | beta1_G: 0.9
54 | lr_scheme: MultiStepLR
55 |
56 | niter: 400000
57 | warmup_iter: -1 # no warm up
58 | lr_steps: [100000, 200000]
59 | lr_gamma: 0.5
60 |
61 | manual_seed: 10
62 | val_freq: !!float 5e3
63 |
64 | #### logger
65 | logger:
66 | print_freq: 200
67 | save_checkpoint_freq: !!float 5e3
68 |
--------------------------------------------------------------------------------
/codes/options/train/train_SRGAN.yml:
--------------------------------------------------------------------------------
1 | # Not exactly the same as SRGAN in
2 | # With 16 Residual blocks w/o BN
3 |
4 | #### general settings
5 | name: 003_SRGANx4_MSRResNetx4Ini_DIV2K
6 | use_tb_logger: true
7 | model: srgan
8 | distortion: sr
9 | scale: 4
10 | gpu_ids: [1]
11 |
12 | #### datasets
13 | datasets:
14 | train:
15 | name: DIV2K
16 | mode: LQGT
17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
19 |
20 | use_shuffle: true
21 | n_workers: 6 # per GPU
22 | batch_size: 16
23 | GT_size: 128
24 | use_flip: true
25 | use_rot: true
26 | color: RGB
27 | val:
28 | name: val_set14
29 | mode: LQGT
30 | dataroot_GT: ../datasets/val_set14/Set14
31 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
32 |
33 | #### network structures
34 | network_G:
35 | which_model_G: MSRResNet
36 | in_nc: 3
37 | out_nc: 3
38 | nf: 64
39 | nb: 16
40 | upscale: 4
41 | network_D:
42 | which_model_D: discriminator_vgg_128
43 | in_nc: 3
44 | nf: 64
45 |
46 | #### path
47 | path:
48 | pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
49 | strict_load: true
50 | resume_state: ~
51 |
52 | #### training settings: learning rate scheme, loss
53 | train:
54 | lr_G: !!float 1e-4
55 | weight_decay_G: 0
56 | beta1_G: 0.9
57 | beta2_G: 0.99
58 | lr_D: !!float 1e-4
59 | weight_decay_D: 0
60 | beta1_D: 0.9
61 | beta2_D: 0.99
62 | lr_scheme: MultiStepLR
63 |
64 | niter: 400000
65 | warmup_iter: -1 # no warm up
66 | lr_steps: [50000, 100000, 200000, 300000]
67 | lr_gamma: 0.5
68 |
69 | pixel_criterion: l1
70 | pixel_weight: !!float 1e-2
71 | feature_criterion: l1
72 | feature_weight: 1
73 | gan_type: gan # gan | ragan
74 | gan_weight: !!float 5e-3
75 |
76 | D_update_ratio: 1
77 | D_init_iters: 0
78 |
79 | manual_seed: 10
80 | val_freq: !!float 5e3
81 |
82 | #### logger
83 | logger:
84 | print_freq: 100
85 | save_checkpoint_freq: !!float 5e3
86 |
--------------------------------------------------------------------------------
/codes/options/train/train_SRResNet.yml:
--------------------------------------------------------------------------------
1 | # Not exactly the same as SRResNet in
2 | # With 16 Residual blocks w/o BN
3 |
4 | #### general settings
5 | name: 004_MSRResNetx4_scratch_DIV2K
6 | use_tb_logger: true
7 | model: sr
8 | distortion: sr
9 | scale: 4
10 | gpu_ids: [0]
11 |
12 | #### datasets
13 | datasets:
14 | train:
15 | name: DIV2K
16 | mode: LQGT
17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
19 |
20 | use_shuffle: true
21 | n_workers: 6 # per GPU
22 | batch_size: 16
23 | GT_size: 128
24 | use_flip: true
25 | use_rot: true
26 | color: RGB
27 | val:
28 | name: val_set5
29 | mode: LQGT
30 | dataroot_GT: ../datasets/val_set5/Set5
31 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
32 |
33 | #### network structures
34 | network_G:
35 | which_model_G: MSRResNet
36 | in_nc: 3
37 | out_nc: 3
38 | nf: 64
39 | nb: 16
40 | upscale: 4
41 |
42 | #### path
43 | path:
44 | pretrain_model_G: ~
45 | strict_load: true
46 | resume_state: ~
47 |
48 | #### training settings: learning rate scheme, loss
49 | train:
50 | lr_G: !!float 2e-4
51 | lr_scheme: CosineAnnealingLR_Restart
52 | beta1: 0.9
53 | beta2: 0.99
54 | niter: 1000000
55 | warmup_iter: -1 # no warm up
56 | T_period: [250000, 250000, 250000, 250000]
57 | restarts: [250000, 500000, 750000]
58 | restart_weights: [1, 1, 1]
59 | eta_min: !!float 1e-7
60 |
61 | pixel_criterion: l1
62 | pixel_weight: 1.0
63 |
64 | manual_seed: 10
65 | val_freq: !!float 5e3
66 |
67 | #### logger
68 | logger:
69 | print_freq: 100
70 | save_checkpoint_freq: !!float 5e3
71 |
--------------------------------------------------------------------------------
/codes/run_scripts.sh:
--------------------------------------------------------------------------------
1 | # image SR training
2 |
3 | python train.py -opt options/train/train_RankSRGAN.yml # Validation with PSNR
4 | python train_niqe.py -opt options/train/train_RankSRGAN.yml # Validation with PSNR and NIQE
5 |
6 | # Ranker training
7 |
8 | python train_rank.py -opt options/train/train_Ranker.yml #
9 |
--------------------------------------------------------------------------------
/codes/scripts/README.md:
--------------------------------------------------------------------------------
1 | # Scripts
2 | We provide some useful scripts here.
3 |
4 | ## List
5 |
6 | | Name | Description |
7 | |:---:|:---:|
8 | | back projection | `Matlab` codes for back projection |
9 |
--------------------------------------------------------------------------------
/codes/scripts/__pycache__/arch_util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/scripts/__pycache__/arch_util.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/scripts/__pycache__/block.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/scripts/__pycache__/block.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/scripts/arch_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 |
27 | def make_layer(block, n_layers):
28 | layers = []
29 | for _ in range(n_layers):
30 | layers.append(block())
31 | return nn.Sequential(*layers)
32 |
33 |
34 | class ResidualBlock_noBN(nn.Module):
35 | '''Residual block w/o BN
36 | ---Conv-ReLU-Conv-+-
37 | |________________|
38 | '''
39 |
40 | def __init__(self, nf=64):
41 | super(ResidualBlock_noBN, self).__init__()
42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
44 |
45 | # initialization
46 | initialize_weights([self.conv1, self.conv2], 0.1)
47 |
48 | def forward(self, x):
49 | identity = x
50 | out = F.relu(self.conv1(x), inplace=True)
51 | out = self.conv2(out)
52 | return identity + out
53 |
54 |
55 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
56 | """Warp an image or feature map with optical flow
57 | Args:
58 | x (Tensor): size (N, C, H, W)
59 | flow (Tensor): size (N, H, W, 2), normal value
60 | interp_mode (str): 'nearest' or 'bilinear'
61 | padding_mode (str): 'zeros' or 'border' or 'reflection'
62 |
63 | Returns:
64 | Tensor: warped image or feature map
65 | """
66 | assert x.size()[-2:] == flow.size()[1:3]
67 | B, C, H, W = x.size()
68 | # mesh grid
69 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
70 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
71 | grid.requires_grad = False
72 | grid = grid.type_as(x)
73 | vgrid = grid + flow
74 | # scale grid to [-1,1]
75 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
76 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
77 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
78 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
79 | return output
80 |
--------------------------------------------------------------------------------
/codes/scripts/back_projection/backprojection.m:
--------------------------------------------------------------------------------
1 | function [im_h] = backprojection(im_h, im_l, maxIter)
2 |
3 | [row_l, col_l,~] = size(im_l);
4 | [row_h, col_h,~] = size(im_h);
5 |
6 | p = fspecial('gaussian', 5, 1);
7 | p = p.^2;
8 | p = p./sum(p(:));
9 |
10 | im_l = double(im_l);
11 | im_h = double(im_h);
12 |
13 | for ii = 1:maxIter
14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
15 | im_diff = im_l - im_l_s;
16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
20 | end
21 |
--------------------------------------------------------------------------------
/codes/scripts/back_projection/main_bp.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 |
3 | LR_folder = './LR'; % LR
4 | preout_folder = './results'; % pre output
5 | save_folder = './results_20bp';
6 | filepaths = dir(fullfile(preout_folder, '*.png'));
7 | max_iter = 20;
8 |
9 | if ~ exist(save_folder, 'dir')
10 | mkdir(save_folder);
11 | end
12 |
13 | for idx_im = 1:length(filepaths)
14 | fprintf([num2str(idx_im) '\n']);
15 | im_name = filepaths(idx_im).name;
16 | im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17 | im_out = im2double(imread(fullfile(preout_folder, im_name)));
18 | %tic
19 | im_out = backprojection(im_out, im_LR, max_iter);
20 | %toc
21 | imwrite(im_out, fullfile(save_folder, im_name));
22 | end
23 |
--------------------------------------------------------------------------------
/codes/scripts/back_projection/main_reverse_filter.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 |
3 | LR_folder = './LR'; % LR
4 | preout_folder = './results'; % pre output
5 | save_folder = './results_20if';
6 | filepaths = dir(fullfile(preout_folder, '*.png'));
7 | max_iter = 20;
8 |
9 | if ~ exist(save_folder, 'dir')
10 | mkdir(save_folder);
11 | end
12 |
13 | for idx_im = 1:length(filepaths)
14 | fprintf([num2str(idx_im) '\n']);
15 | im_name = filepaths(idx_im).name;
16 | im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17 | im_out = im2double(imread(fullfile(preout_folder, im_name)));
18 | J = imresize(im_LR,4,'bicubic');
19 | %tic
20 | for m = 1:max_iter
21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
22 | end
23 | %toc
24 | imwrite(im_out, fullfile(save_folder, im_name));
25 | end
26 |
--------------------------------------------------------------------------------
/codes/scripts/calparameters.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchsummary import summary
5 | from collections import OrderedDict
6 | import torch
7 | import torch.nn as nn
8 | from . import block as B
9 |
10 | class Discriminator_VGG_128(nn.Module):
11 | def __init__(self, in_nc = 3, base_nf = 64, norm_type='batch', act_type='leakyrelu', mode='CNA'):
12 | super(Discriminator_VGG_128, self).__init__()
13 | # features
14 | # hxw, c
15 | # 128, 64
16 |
17 | conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
18 | mode=mode)
19 | conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
20 | act_type=act_type, mode=mode)
21 | # 64, 64
22 | conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
23 | act_type=act_type, mode=mode)
24 | conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
25 | act_type=act_type, mode=mode)
26 | # 32, 128
27 | conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
28 | act_type=act_type, mode=mode)
29 | conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
30 | act_type=act_type, mode=mode)
31 | # 16, 256
32 | conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
33 | act_type=act_type, mode=mode)
34 | conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
35 | act_type=act_type, mode=mode)
36 | # 8, 512
37 | conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
38 | act_type=act_type, mode=mode)
39 | conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
40 | act_type=act_type, mode=mode)
41 | # 4, 512
42 | self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
43 | conv9)
44 |
45 | # classifier
46 | self.classifier = nn.Sequential(
47 | nn.Linear(512 * 9 * 9, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1)) # patch400 12, patch 256 8 200 6 296 9
48 |
49 | def forward(self, x):
50 | x = self.features(x)
51 | x = x.view(x.size(0), -1)
52 | x = self.classifier(x)
53 | return x
54 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
55 | model = Discriminator_VGG_128().to(device)
56 |
57 | summary(model, (1, 28, 28))
--------------------------------------------------------------------------------
/codes/scripts/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os.path
3 | import glob
4 | import pickle
5 | import lmdb
6 | import cv2
7 |
8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9 | from utils.progress_bar import ProgressBar
10 |
11 | # configurations
12 | img_folder = '/sdd/BasicSR_datasets/DIV2K800/DIV2K800/*' # glob matching pattern
13 | lmdb_save_path = '/sdd/BasicSR_datasets/DIV2K800/DIV2K800.lmdb' # must end with .lmdb
14 |
15 | img_list = sorted(glob.glob(img_folder))
16 | dataset = []
17 | data_size = 0
18 |
19 | print('Read images...')
20 | pbar = ProgressBar(len(img_list))
21 | for i, v in enumerate(img_list):
22 | pbar.update('Read {}'.format(v))
23 | img = cv2.imread(v, cv2.IMREAD_UNCHANGED)
24 | dataset.append(img)
25 | data_size += img.nbytes
26 | env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
27 | print('Finish reading {} images.\nWrite lmdb...'.format(len(img_list)))
28 |
29 | pbar = ProgressBar(len(img_list))
30 | with env.begin(write=True) as txn: # txn is a Transaction object
31 | for i, v in enumerate(img_list):
32 | pbar.update('Write {}'.format(v))
33 | base_name = os.path.splitext(os.path.basename(v))[0]
34 | key = base_name.encode('ascii')
35 | data = dataset[i]
36 | if dataset[i].ndim == 2:
37 | H, W = dataset[i].shape
38 | C = 1
39 | else:
40 | H, W, C = dataset[i].shape
41 | meta_key = (base_name + '.meta').encode('ascii')
42 | meta = '{:d}, {:d}, {:d}'.format(H, W, C)
43 | # The encode is only essential in Python 3
44 | txn.put(key, data)
45 | txn.put(meta_key, meta.encode('ascii'))
46 | print('Finish writing lmdb.')
47 |
48 | # create keys cache
49 | keys_cache_file = os.path.join(lmdb_save_path, '_keys_cache.p')
50 | env = lmdb.open(lmdb_save_path, readonly=True, lock=False, readahead=False, meminit=False)
51 | with env.begin(write=False) as txn:
52 | print('Create lmdb keys cache: {}'.format(keys_cache_file))
53 | keys = [key.decode('ascii') for key, _ in txn.cursor()]
54 | pickle.dump(keys, open(keys_cache_file, "wb"))
55 | print('Finish creating lmdb keys cache.')
56 |
--------------------------------------------------------------------------------
/codes/scripts/extract_subimgs_single.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | from multiprocessing import Pool
4 | import time
5 | import numpy as np
6 | import cv2
7 |
8 |
9 | def main():
10 | GT_dir = '/media/sdc/wlzhang/data/DIV2K_train_HR'
11 | save_GT_dir = '/media/sdc/wlzhang/data/DIV2K800_sub2'
12 | n_thread = 20
13 |
14 | print('Parent process %s.' % os.getpid())
15 | start = time.time()
16 |
17 | p = Pool(n_thread)
18 | # read all files to a list
19 | all_files = []
20 | for root, _, fnames in sorted(os.walk(GT_dir)):
21 | full_path = [os.path.join(root, x) for x in fnames]
22 | all_files.extend(full_path)
23 | # cut into subtasks
24 | def chunkify(lst, n): # for non-continuous chunks
25 | return [lst[i::n] for i in range(n)]
26 |
27 | sub_lists = chunkify(all_files, n_thread)
28 | # call workers
29 | for i in range(n_thread):
30 | p.apply_async(worker, args=(sub_lists[i], save_GT_dir))
31 | print('Waiting for all subprocesses done...')
32 | p.close()
33 | p.join()
34 | end = time.time()
35 | print('All subprocesses done. Using time {} sec.'.format(end - start))
36 |
37 |
38 | def worker(GT_paths, save_GT_dir):
39 | crop_sz = 480
40 | step = 240
41 | thres_sz = 48
42 |
43 | for GT_path in GT_paths:
44 | base_name = os.path.basename(GT_path)
45 | print(base_name, os.getpid())
46 | img_GT = cv2.imread(GT_path, cv2.IMREAD_UNCHANGED)
47 |
48 | n_channels = len(img_GT.shape)
49 | if n_channels == 2:
50 | h, w = img_GT.shape
51 | elif n_channels == 3:
52 | h, w, c = img_GT.shape
53 | else:
54 | raise ValueError('Wrong image shape - {}'.format(n_channels))
55 |
56 | h_space = np.arange(0, h - crop_sz + 1, step)
57 | if h - (h_space[-1] + crop_sz) > thres_sz:
58 | h_space = np.append(h_space, h - crop_sz)
59 | w_space = np.arange(0, w - crop_sz + 1, step)
60 | if w - (w_space[-1] + crop_sz) > thres_sz:
61 | w_space = np.append(w_space, w - crop_sz)
62 | index = 0
63 | for x in h_space:
64 | for y in w_space:
65 | index += 1
66 | if n_channels == 2:
67 | crop_img = img_GT[x:x + crop_sz, y:y + crop_sz]
68 | else:
69 | crop_img = img_GT[x:x + crop_sz, y:y + crop_sz, :]
70 |
71 | crop_img = np.ascontiguousarray(crop_img)
72 | index_str = '{:03d}'.format(index)
73 | # var = np.var(crop_img / 255)
74 | # if var > 0.008:
75 | # print(index_str, var)
76 | cv2.imwrite(os.path.join(save_GT_dir, base_name.replace('.png', \
77 | '_s'+index_str+'.png')), crop_img, [cv2.IMWRITE_PNG_COMPRESSION, 0])
78 |
79 |
80 | if __name__ == '__main__':
81 | main()
82 |
--------------------------------------------------------------------------------
/codes/scripts/generate_mod_LR_bic.m:
--------------------------------------------------------------------------------
1 | function generate_mod_LR_bic()
2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
3 |
4 | %% set parameters
5 | % comment the unnecessary line
6 | input_folder = '/sdd/wlzhang/data/DF2K_HR';
7 | %save_mod_folder = '/home/wlzhang/BasicSR/data_samples/urban100_mod';
8 | save_LR_folder = '/sdd/wlzhang/data/DF2K800_bicLRx4';
9 | % save_bic_folder = '';
10 |
11 | up_scale = 4;
12 | mod_scale = 4;
13 |
14 | if exist('save_mod_folder', 'var')
15 | if exist(save_mod_folder, 'dir')
16 | disp(['It will cover ', save_mod_folder]);
17 | else
18 | mkdir(save_mod_folder);
19 | end
20 | end
21 | if exist('save_LR_folder', 'var')
22 | if exist(save_LR_folder, 'dir')
23 | disp(['It will cover ', save_LR_folder]);
24 | else
25 | mkdir(save_LR_folder);
26 | end
27 | end
28 | if exist('save_bic_folder', 'var')
29 | if exist(save_bic_folder, 'dir')
30 | disp(['It will cover ', save_bic_folder]);
31 | else
32 | mkdir(save_bic_folder);
33 | end
34 | end
35 |
36 | idx = 0;
37 | filepaths = dir(fullfile(input_folder,'*.*'));
38 | for i = 1 : length(filepaths)
39 | [paths,imname,ext] = fileparts(filepaths(i).name);
40 | if isempty(imname)
41 | disp('Ignore . folder.');
42 | elseif strcmp(imname, '.')
43 | disp('Ignore .. folder.');
44 | else
45 | idx = idx + 1;
46 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
47 | fprintf(str_rlt);
48 | % read image
49 | img = imread(fullfile(input_folder, [imname, ext]));
50 | img = im2double(img);
51 | % modcrop
52 | img = modcrop(img, mod_scale);
53 | if exist('save_mod_folder', 'var')
54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
55 | end
56 |
57 | %sigma = 1.8;
58 | %image_kernel = fspecial('gaussian',21, sigma); % 7x7 Gaussian kernel k_d with width 1.6
59 | %img = imfilter(img,double(image_kernel),'replicate'); % blur
60 |
61 | % LR
62 | im_LR = imresize(img, 1/up_scale, 'bicubic');
63 | if exist('save_LR_folder', 'var')
64 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
65 | end
66 | % Bicubic
67 | if exist('save_bic_folder', 'var')
68 | im_B = imresize(im_LR, up_scale, 'bicubic');
69 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
70 | end
71 | end
72 | end
73 | end
74 |
75 | %% modcrop
76 | function img = modcrop(img, modulo)
77 | if size(img,3) == 1
78 | sz = size(img);
79 | sz = sz - mod(sz, modulo);
80 | img = img(1:sz(1), 1:sz(2));
81 | else
82 | tmpsz = size(img);
83 | sz = tmpsz(1:2);
84 | sz = sz - mod(sz, modulo);
85 | img = img(1:sz(1), 1:sz(2),:);
86 | end
87 | end
88 |
--------------------------------------------------------------------------------
/codes/scripts/transfer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchsummary import summary
5 | from collections import OrderedDict
6 | import torch
7 | import torch.nn as nn
8 | import block as B
9 | import os
10 | import math
11 | import functools
12 | import arch_util as arch_util
13 |
14 | class SRResNet(nn.Module):
15 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4, norm_type= None , act_type='relu', \
16 | mode='CNA', res_scale=1, upsample_mode='pixelshuffle'):
17 | super(SRResNet, self).__init__()
18 | n_upscale = int(math.log(upscale, 2))
19 | if upscale == 3:
20 | n_upscale = 1
21 |
22 | fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
23 | resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\
24 | mode=mode, res_scale=res_scale) for _ in range(nb)]
25 | LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
26 |
27 | if upsample_mode == 'upconv':
28 | upsample_block = B.upconv_blcok
29 | elif upsample_mode == 'pixelshuffle':
30 | upsample_block = B.pixelshuffle_block
31 | else:
32 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
33 | if upscale == 3:
34 | upsampler = upsample_block(nf, nf, 3, act_type=act_type)
35 | else:
36 | upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
37 | HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
38 | HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
39 |
40 | self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\
41 | *upsampler, HR_conv0, HR_conv1)
42 |
43 | def forward(self, x):
44 | x = self.model(x)
45 | return x
46 | class mmsrSRResNet(nn.Module):
47 | ''' modified SRResNet'''
48 |
49 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
50 | super(mmsrSRResNet, self).__init__()
51 | self.upscale = upscale
52 |
53 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
54 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
55 | self.recon_trunk = arch_util.make_layer(basic_block, nb)
56 | self.LRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
57 |
58 | # upsampling
59 | if self.upscale == 2:
60 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
61 | self.pixel_shuffle = nn.PixelShuffle(2)
62 | elif self.upscale == 3:
63 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
64 | self.pixel_shuffle = nn.PixelShuffle(3)
65 | elif self.upscale == 4:
66 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
67 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
68 | self.pixel_shuffle = nn.PixelShuffle(2)
69 |
70 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
71 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
72 |
73 | # activation function
74 | self.relu = nn.ReLU(inplace=True)
75 | # self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
76 |
77 | # initialization
78 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
79 | 0.1)
80 | if self.upscale == 4:
81 | arch_util.initialize_weights(self.upconv2, 0.1)
82 |
83 | def forward(self, x):
84 | fea = self.conv_first(x)
85 | out = self.recon_trunk(fea)
86 | out = self.LRconv(out)
87 |
88 | if self.upscale == 4:
89 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea)))
90 | out = self.relu(self.pixel_shuffle(self.upconv2(out)))
91 | elif self.upscale == 3 or self.upscale == 2:
92 | out = self.relu(self.pixel_shuffle(self.upconv1(out+fea)))
93 |
94 | out = self.conv_last(self.relu(self.HRconv(out)))
95 |
96 | return out
97 |
98 |
99 | def transfer_network(load_path, network,ordereddict, strict=True):
100 | load_net = torch.load(load_path)
101 | load_net_dict = OrderedDict() # remove unnecessary 'module.'
102 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
103 | load_model_key = []
104 | for k, v in load_net.items():
105 | load_net_dict[k] = v
106 | load_model_key.append(k)
107 |
108 | i = 0
109 | for param_tensor in model2.state_dict():
110 | load_net_clean[param_tensor] = load_net_dict[load_model_key[i]]
111 | print('-------')
112 | print(param_tensor)
113 | print(load_model_key[i])
114 | i=i+1
115 | print(i)
116 |
117 | torch.save(load_net_clean, '/home/wlzhang/mmsr/experiments/pretrained_models/mmsr_SRResNet_pretrain.pth')
118 | network.load_state_dict(load_net_clean, strict=strict)
119 |
120 | net_old = SRResNet()
121 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
122 | model = net_old.to(device)
123 |
124 | # print("Model's state_dict:")
125 | # for param_tensor in model.state_dict():
126 | # print(param_tensor, "\t", model.state_dict()[param_tensor].size())
127 |
128 | net_new = mmsrSRResNet()
129 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
130 | model2 = net_new.to(device)
131 |
132 | print("Model2's state_dict:")
133 | ordereddict = []
134 | for param_tensor in model2.state_dict():
135 | ordereddict.append(param_tensor)
136 | # print(param_tensor, "\t", model2.state_dict()[param_tensor].size())
137 |
138 | # print("key state_dict:")
139 | # print(ordereddict)
140 |
141 | transfer_network('/home/wlzhang/mmsr/experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth', net_new, ordereddict)
142 |
143 |
144 | # print("key:")
145 | # print(ordereddict)
146 | # summary(model, (3, 296, 296))
--------------------------------------------------------------------------------
/codes/scripts/transfer_params_MSRResNet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 | import torch
4 | try:
5 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
6 | import models.archs.SRResNet_arch as SRResNet_arch
7 | except ImportError:
8 | pass
9 |
10 | pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
11 | crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
12 | crt_net = crt_model.state_dict()
13 |
14 | for k, v in crt_net.items():
15 | if k in pretrained_net and 'upconv1' not in k:
16 | crt_net[k] = pretrained_net[k]
17 | print('replace ... ', k)
18 |
19 | # x4 -> x3
20 | crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
21 | crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
22 | crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
23 | crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
24 | crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
25 | crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
26 |
27 | torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
28 |
--------------------------------------------------------------------------------
/codes/test.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import logging
3 | import time
4 | import argparse
5 | from collections import OrderedDict
6 |
7 | import options.options as option
8 | import utils.util as util
9 | from data.util import bgr2ycbcr
10 | from data import create_dataset, create_dataloader
11 | from models import create_model
12 |
13 | #### options
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
16 | opt = option.parse(parser.parse_args().opt, is_train=False)
17 | opt = option.dict_to_nonedict(opt)
18 |
19 | util.mkdirs(
20 | (path for key, path in opt['path'].items()
21 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
22 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
23 | screen=True, tofile=True)
24 | logger = logging.getLogger('base')
25 | logger.info(option.dict2str(opt))
26 |
27 | #### Create test dataset and dataloader
28 | test_loaders = []
29 | for phase, dataset_opt in sorted(opt['datasets'].items()):
30 | test_set = create_dataset(dataset_opt)
31 | test_loader = create_dataloader(test_set, dataset_opt)
32 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
33 | test_loaders.append(test_loader)
34 |
35 | model = create_model(opt)
36 | for test_loader in test_loaders:
37 | test_set_name = test_loader.dataset.opt['name']
38 | logger.info('\nTesting [{:s}]...'.format(test_set_name))
39 | test_start_time = time.time()
40 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
41 | util.mkdir(dataset_dir)
42 |
43 | test_results = OrderedDict()
44 | test_results['psnr'] = []
45 | test_results['ssim'] = []
46 | test_results['psnr_y'] = []
47 | test_results['ssim_y'] = []
48 |
49 | for data in test_loader:
50 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
51 | model.feed_data(data, need_GT=need_GT)
52 | img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0]
53 | img_name = osp.splitext(osp.basename(img_path))[0]
54 |
55 | model.test()
56 | visuals = model.get_current_visuals(need_GT=need_GT)
57 |
58 | sr_img = util.tensor2img(visuals['rlt']) # uint8
59 |
60 | # save images
61 | suffix = opt['suffix']
62 | if suffix:
63 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
64 | else:
65 | save_img_path = osp.join(dataset_dir, img_name + '.png')
66 | util.save_img(sr_img, save_img_path)
67 |
68 | # calculate PSNR and SSIM
69 | if need_GT:
70 | gt_img = util.tensor2img(visuals['GT'])
71 | sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale'])
72 | psnr = util.calculate_psnr(sr_img, gt_img)
73 | ssim = util.calculate_ssim(sr_img, gt_img)
74 | test_results['psnr'].append(psnr)
75 | test_results['ssim'].append(ssim)
76 |
77 | if gt_img.shape[2] == 3: # RGB image
78 | sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
79 | gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)
80 |
81 | psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
82 | ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
83 | test_results['psnr_y'].append(psnr_y)
84 | test_results['ssim_y'].append(ssim_y)
85 | logger.info(
86 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
87 | format(img_name, psnr, ssim, psnr_y, ssim_y))
88 | else:
89 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
90 | else:
91 | logger.info(img_name)
92 |
93 | if need_GT: # metrics
94 | # Average PSNR/SSIM results
95 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
96 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
97 | logger.info(
98 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format(
99 | test_set_name, ave_psnr, ave_ssim))
100 | if test_results['psnr_y'] and test_results['ssim_y']:
101 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
102 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
103 | logger.info(
104 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'.
105 | format(ave_psnr_y, ave_ssim_y))
106 |
--------------------------------------------------------------------------------
/codes/utils/README.md:
--------------------------------------------------------------------------------
1 | # Utils
2 |
3 | ## Tensorboard Logger (tb_logger)
4 |
5 | [tensorboard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) is a nice visualization tool for visualizing/comparing training loss, validation PSNR and etc.
6 |
7 | You can turn it on/off in json option file with the key: `use_tb_logger`.
8 |
9 | ### Install
10 | 1. `pip install tensorflow` - Maybe it is the easiest way to install tensorboard, though we will install tensorflow at the same time.
11 | 1. `pip install tensorboard_logger` - install [tensorboard_logger](https://github.com/TeamHG-Memex/tensorboard_logger)
12 |
13 | ### Run
14 | 1. In terminal: `tensorboard --logdir xxx/xxx`.
15 | 1. Open TensorBoard UI at http://localhost:6006 in your browser
16 |
--------------------------------------------------------------------------------
/codes/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__init__.py
--------------------------------------------------------------------------------
/codes/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/rank_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__pycache__/rank_test.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utils/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/calc_NIQE.m:
--------------------------------------------------------------------------------
1 | function NIQE = calc_scores(input_image_path,shave_width)
2 |
3 | %% Loading model
4 | load modelparameters.mat
5 | blocksizerow = 96;
6 | blocksizecol = 96;
7 | blockrowoverlap = 0;
8 | blockcoloverlap = 0;
9 | %% Calculating scores
10 | NIQE = [];
11 | input_image = convert_shave_image(imread(input_image_path),shave_width);
12 |
13 | % Calculating scores
14 | NIQE = computequality(input_image,blocksizerow,blocksizecol,...
15 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam);
16 |
17 | end
18 |
19 |
20 |
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/computefeature.m:
--------------------------------------------------------------------------------
1 | function feat = computefeature(structdis)
2 |
3 | % Input - MSCn coefficients
4 | % Output - Compute the 18 dimensional feature vector
5 |
6 | feat = [];
7 |
8 |
9 |
10 | [alpha betal betar] = estimateaggdparam(structdis(:));
11 |
12 | feat = [feat;alpha;(betal+betar)/2];
13 |
14 | shifts = [ 0 1;1 0 ;1 1;1 -1];
15 |
16 | for itr_shift =1:4
17 |
18 | shifted_structdis = circshift(structdis,shifts(itr_shift,:));
19 | pair = structdis(:).*shifted_structdis(:);
20 | [alpha betal betar] = estimateaggdparam(pair);
21 | meanparam = (betar-betal)*(gamma(2/alpha)/gamma(1/alpha));
22 | feat = [feat;alpha;meanparam;betal;betar];
23 |
24 | end
25 |
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/computemean.m:
--------------------------------------------------------------------------------
1 | function val = computemean(patch)
2 |
3 | val = mean2(patch);
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/computequality.m:
--------------------------------------------------------------------------------
1 | function quality = computequality(im,blocksizerow,blocksizecol,...
2 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam)
3 |
4 | % Input
5 | % im - Image whose quality needs to be computed
6 | % blocksizerow - Height of the blocks in to which image is divided
7 | % blocksizecol - Width of the blocks in to which image is divided
8 | % blockrowoverlap - Amount of vertical overlap between blocks
9 | % blockcoloverlap - Amount of horizontal overlap between blocks
10 | % mu_prisparam - mean of multivariate Gaussian model
11 | % cov_prisparam - covariance of multivariate Gaussian model
12 |
13 | % For good performance, it is advisable to use make the multivariate Gaussian model
14 | % using same size patches as the distorted image is divided in to
15 |
16 | % Output
17 | %quality - Quality of the input distorted image
18 |
19 | % Example call
20 | %quality = computequality(im,96,96,0,0,mu_prisparam,cov_prisparam)
21 |
22 | % ---------------------------------------------------------------
23 | %Number of features
24 | % 18 features at each scale
25 | featnum = 18;
26 | %----------------------------------------------------------------
27 | %Compute features
28 | if(size(im,3)==3)
29 | im = rgb2gray(im);
30 | end
31 | im = double(im);
32 | [row col] = size(im);
33 | block_rownum = floor(row/blocksizerow);
34 | block_colnum = floor(col/blocksizecol);
35 |
36 | im = im(1:block_rownum*blocksizerow,1:block_colnum*blocksizecol);
37 | [row col] = size(im);
38 | block_rownum = floor(row/blocksizerow);
39 | block_colnum = floor(col/blocksizecol);
40 | im = im(1:block_rownum*blocksizerow, ...
41 | 1:block_colnum*blocksizecol);
42 | window = fspecial('gaussian',7,7/6);
43 | window = window/sum(sum(window));
44 | scalenum = 2;
45 | warning('off')
46 |
47 | feat = [];
48 |
49 |
50 | for itr_scale = 1:scalenum
51 |
52 |
53 | mu = imfilter(im,window,'replicate');
54 | mu_sq = mu.*mu;
55 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq));
56 | structdis = (im-mu)./(sigma+1);
57 |
58 |
59 |
60 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ...
61 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ...
62 | @computefeature);
63 | feat_scale = reshape(feat_scale,[featnum ....
64 | size(feat_scale,1)*size(feat_scale,2)/featnum]);
65 | feat_scale = feat_scale';
66 |
67 |
68 | if(itr_scale == 1)
69 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ...
70 | [blockrowoverlap blockcoloverlap],@computemean);
71 | sharpness = sharpness(:);
72 | end
73 |
74 |
75 | feat = [feat feat_scale];
76 |
77 | im =imresize(im,0.5);
78 |
79 | end
80 |
81 |
82 | % Fit a MVG model to distorted patch features
83 | distparam = feat;
84 | mu_distparam = nanmean(distparam);
85 | cov_distparam = nancov(distparam);
86 |
87 | % Compute quality
88 | invcov_param = pinv((cov_prisparam+cov_distparam)/2);
89 | quality = sqrt((mu_prisparam-mu_distparam)* ...
90 | invcov_param*(mu_prisparam-mu_distparam)');
91 |
92 |
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/convert_shave_image.m:
--------------------------------------------------------------------------------
1 | function shaved = convert_shave_image(input_image,shave_width)
2 |
3 | % Converting to y channel only
4 | image_ychannel = rgb2ycbcr(input_image);
5 | image_ychannel = image_ychannel(:,:,1);
6 |
7 | % Shaving image
8 | shaved = image_ychannel(1+shave_width:end-shave_width,...
9 | 1+shave_width:end-shave_width);
10 |
11 | end
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/estimateaggdparam.m:
--------------------------------------------------------------------------------
1 | function [alpha betal betar] = estimateaggdparam(vec)
2 |
3 |
4 | gam = 0.2:0.001:10;
5 | r_gam = ((gamma(2./gam)).^2)./(gamma(1./gam).*gamma(3./gam));
6 |
7 |
8 | leftstd = sqrt(mean((vec(vec<0)).^2));
9 | rightstd = sqrt(mean((vec(vec>0)).^2));
10 |
11 | gammahat = leftstd/rightstd;
12 | rhat = (mean(abs(vec)))^2/mean((vec).^2);
13 | rhatnorm = (rhat*(gammahat^3 +1)*(gammahat+1))/((gammahat^2 +1)^2);
14 | [min_difference, array_position] = min((r_gam - rhatnorm).^2);
15 | alpha = gam(array_position);
16 |
17 | betal = leftstd *sqrt(gamma(1/alpha)/gamma(3/alpha));
18 | betar = rightstd*sqrt(gamma(1/alpha)/gamma(3/alpha));
19 |
20 |
21 |
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/estimatemodelparam.m:
--------------------------------------------------------------------------------
1 | function [mu_prisparam cov_prisparam] = estimatemodelparam(folderpath,...
2 | blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,sh_th)
3 |
4 | % Input
5 | % folderpath - Folder containing the pristine images
6 | % blocksizerow - Height of the blocks in to which image is divided
7 | % blocksizecol - Width of the blocks in to which image is divided
8 | % blockrowoverlap - Amount of vertical overlap between blocks
9 | % blockcoloverlap - Amount of horizontal overlap between blocks
10 | % sh_th - The sharpness threshold level
11 | %Output
12 | %mu_prisparam - mean of multivariate Gaussian model
13 | %cov_prisparam - covariance of multivariate Gaussian model
14 |
15 | % Example call
16 |
17 | %[mu_prisparam cov_prisparam] = estimatemodelparam('pristine',96,96,0,0,0.75);
18 |
19 |
20 | %----------------------------------------------------------------
21 | % Find the names of images in the folder
22 | current = pwd;
23 | cd(sprintf('%s',folderpath))
24 | names = ls;
25 | names = names(3:end,:);
26 | cd(current)
27 | % ---------------------------------------------------------------
28 | %Number of features
29 | % 18 features at each scale
30 | featnum = 18;
31 | % ---------------------------------------------------------------
32 | % Make the directory for storing the features
33 | mkdir(sprintf('local_risquee_prisfeatures'))
34 | % ---------------------------------------------------------------
35 | % Compute pristine image features
36 | for itr = 1:size(names,1)
37 | itr
38 | im = imread(sprintf('%s\\%s',folderpath,names(itr,:)));
39 | if(size(im,3)==3)
40 | im = rgb2gray(im);
41 | end
42 | im = double(im);
43 | [row col] = size(im);
44 | block_rownum = floor(row/blocksizerow);
45 | block_colnum = floor(col/blocksizecol);
46 | im = im(1:block_rownum*blocksizerow, ...
47 | 1:block_colnum*blocksizecol);
48 | window = fspecial('gaussian',7,7/6);
49 | window = window/sum(sum(window));
50 | scalenum = 2;
51 | warning('off')
52 |
53 | feat = [];
54 |
55 |
56 | for itr_scale = 1:scalenum
57 |
58 |
59 | mu = imfilter(im,window,'replicate');
60 | mu_sq = mu.*mu;
61 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq));
62 | structdis = (im-mu)./(sigma+1);
63 |
64 |
65 |
66 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ...
67 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ...
68 | @computefeature);
69 | feat_scale = reshape(feat_scale,[featnum ....
70 | size(feat_scale,1)*size(feat_scale,2)/featnum]);
71 | feat_scale = feat_scale';
72 |
73 |
74 | if(itr_scale == 1)
75 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ...
76 | [blockrowoverlap blockcoloverlap],@computemean);
77 | sharpness = sharpness(:);
78 | end
79 |
80 |
81 | feat = [feat feat_scale];
82 |
83 | im =imresize(im,0.5);
84 |
85 | end
86 |
87 | save(sprintf('local_risquee_prisfeatures\\prisfeatures_local%d.mat',...
88 | itr),'feat','sharpness');
89 | end
90 |
91 |
92 |
93 | %----------------------------------------------
94 | % Load pristine image features
95 | prisparam = [];
96 | current = pwd;
97 | cd(sprintf('%s','local_risquee_prisfeatures'))
98 | names = ls;
99 | names = names(3:end,:);
100 | cd(current)
101 | for itr = 1:size(names,1)
102 | % Load the features and select the only features
103 | load(sprintf('local_risquee_prisfeatures\\%s',strtrim(names(itr,:))));
104 | IX = find(sharpness(:) >sh_th*max(sharpness(:)));
105 | feat = feat(IX,:);
106 | prisparam = [prisparam; feat];
107 |
108 | end
109 | %----------------------------------------------
110 | % Compute model parameters
111 | mu_prisparam = nanmean(prisparam);
112 | cov_prisparam = nancov(prisparam);
113 | %----------------------------------------------
114 | % Save features in the mat file
115 | save('modelparameters_new.mat','mu_prisparam','cov_prisparam');
116 | %----------------------------------------------
117 |
--------------------------------------------------------------------------------
/codes/utils/perceptualmetric/modelparameters.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/codes/utils/perceptualmetric/modelparameters.mat
--------------------------------------------------------------------------------
/codes/utils/rank_test.py:
--------------------------------------------------------------------------------
1 |
2 | def rank_pair_test(predict_file ,label_file):
3 | predict_score = {}
4 | label_score = {}
5 | f1 = open(predict_file ,'r')
6 | f2 = open(label_file ,'r')
7 |
8 | for line in f1.readlines():
9 | line = line.strip().split()
10 | img_name = line[0]
11 | img_score = line[1]
12 | predict_score[img_name] = float(img_score)
13 |
14 | for line in f2.readlines():
15 | line = line.strip().split()
16 | img_name = line[0]
17 | img_score = line[1]
18 | label_score[img_name] = float(img_score)
19 |
20 | keys_list = list(predict_score.keys())
21 | keys_list.sort()
22 |
23 | cursor = keys_list[0].split('_')[0]
24 | class_num = 0
25 | for key in keys_list:
26 | if cursor == key.split('_')[0]:
27 | class_num += 1
28 | else:
29 | break
30 | count = 0
31 | positive = 0
32 | for idx in range(0 ,len(keys_list) ,class_num):
33 | for i in range(idx ,idx +class_num):
34 | for j in range( i +1 ,idx +class_num):
35 |
36 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1
37 |
38 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1
39 |
40 | count += 1
41 | if real_rank == predict_rank:
42 | positive += 1
43 |
44 | # print('%d/%d ' %(positive ,count))
45 | accuracy = positive /count
46 | # print('Aligned Pair Accuracy: %f ' %accuracy)
47 |
48 | count1 = 1
49 | count2 = 1
50 | positive1 = 0
51 | positive2 = 0
52 |
53 | for idx in range(0 ,len(keys_list) ,class_num):
54 |
55 | i = idx
56 | j = i+ 1
57 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1
58 |
59 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1
60 |
61 | count += 1
62 | if real_rank == 1:
63 | count1 += 1
64 | if real_rank == predict_rank:
65 | positive1 += 1
66 | if real_rank == -1:
67 | count2 += 1
68 | if real_rank == predict_rank:
69 | positive2 += 1
70 |
71 | # print('%d/%d' % (positive1, count1))
72 | accuracy_esrganbig1 = positive1 / count1
73 | # print('accuracy_esrganbig: %f' % accuracy_esrganbig1)
74 | #
75 | # print('%d/%d' % (positive2, count2))
76 | accuracy_srganbig1 = positive2 / count2
77 | # print('accuracy_srganbig: %f' % accuracy_srganbig1)
78 | count1 = 1
79 | count2 = 1
80 | positive1 = 0
81 | positive2 = 0
82 | for idx in range(0, len(keys_list), class_num):
83 |
84 | i = idx
85 | j = i + 2
86 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1
87 |
88 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1
89 |
90 | count += 1
91 | if real_rank == 1:
92 | count1 += 1
93 | if real_rank == predict_rank:
94 | positive1 += 1
95 | if real_rank == -1:
96 | count2 += 1
97 | if real_rank == predict_rank:
98 | positive2 += 1
99 |
100 | # print('%d/%d' % (positive1, count1))
101 | # accuracy_esrganbig = positive1 / count1
102 | # print('accuracy2: %f' % accuracy_esrganbig)
103 | #
104 | # print('%d/%d' % (positive2, count2))
105 | # accuracy_srganbig = positive2 / count2
106 | # print('accuracy2: %f' % accuracy_srganbig)
107 |
108 | count1 = 1
109 | count2 = 1
110 | positive1 = 0
111 | positive2 = 0
112 |
113 | for idx in range(0, len(keys_list), class_num):
114 |
115 | i = idx + 1
116 | j = i + 1
117 | real_rank = 1 if label_score[keys_list[i]] >= label_score[keys_list[j]] else -1
118 | predict_rank = 1 if predict_score[keys_list[i]] >= predict_score[keys_list[j]] else -1
119 |
120 | count += 1
121 | if real_rank == 1:
122 | count1 += 1
123 | if real_rank == predict_rank:
124 | positive1 += 1
125 | if real_rank == -1:
126 | count2 += 1
127 | if real_rank == predict_rank:
128 | positive2 += 1
129 |
130 | # print('%d/%d' % (positive1, count1))
131 | # accuracy_esrganbig = positive1 / count1
132 | # print('accuracy3: %f' % accuracy_esrganbig)
133 |
134 | # print('%d/%d' % (positive2, count2))
135 | # accuracy_srganbig = positive2 / count2
136 | # print('accuracy3: %f' % accuracy_srganbig)
137 |
138 | return accuracy, accuracy_esrganbig1, accuracy_srganbig1
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/README.md:
--------------------------------------------------------------------------------
1 | ## How To Prepare Rank Dataset
2 | ### Prepare perceptual data
3 | 1. Prepare Three levels SR Models. You can download the [SRResNet (SRResNet_bicx4_in3nf64nb16.pth),
4 | SRGAN (SRGAN.pth), ESRGAN (ESRGAN_SuperSR.pth)] from
5 | [Google Drive](https://drive.google.com/drive/folders/16DkwrBa4cIqAoTbGU_bKMYoATcXC4IT6?usp=sharing)
6 | or [Baidu Drive](https://pan.baidu.com/s/1HFZokeAWne9oUkmJBnGr-A). You could place them in [`./experiments/pretrained_models/`](../../master/experiments/pretrained_models/).
7 |
8 | 2. Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](https://github.com/LimBee/NTIRE2017)
9 | from [Google Drive](https://drive.google.com/drive/folders/1B-uaxvV9qeuQ-t7MFiN1oEdA6dKnj2vW?usp=sharing) or
10 | [Baidu Drive](https://pan.baidu.com/s/1CFIML6KfQVYGZSNFrhMXmA)
11 | 3. Generate Three level images using 'How to test' with [`codes/options/test/test_RankSRGAN.yml`](../../master/codes/options/test/test_RankSRGAN.yml)
12 | ### Generate rank dataset
13 | 4. **Training dataset:** Use [`./datasets/generate_rankdataset/generate_rankdataset.m`](../../master/datasets/generate_rankdataset/generate_rankdataset.m)
14 | to generate three level training patchs.
15 | 5. **Validation dataset:** Use [`./datasets/generate_rankdataset/move_valid.py`](../../master/datasets/generate_rankdataset/move_valid.py)
16 | to generate three level patchs.
17 | 6. **Rank label:** Use [`./datasets/generate_rankdataset/generate_train_ranklabel.m`](../../master/datasets/generate_rankdataset/generate_train_ranklabel.m)
18 | to generate Training Rank label (NIQE).
19 | Use [`./datasets/generate_rankdataset/generate_valid_ranklabel.m`](../../master/datasets/generate_rankdataset/generate_valid_ranklabel.m)
20 | to generate Validation Rank label (NIQE).
21 |
22 |
23 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/generate_rankdataset.m:
--------------------------------------------------------------------------------
1 |
2 | Level1_folder = '/home/wlzhang/data/DIV2K/DIV2K_train_ESRGAN/'
3 | Level2_folder = '/home/wlzhang/data/DIV2K/DIV2K_train_srgan/'
4 | Level3_folder = '/home/wlzhang/data/DIV2K/DIV2K_train_srres/'
5 |
6 | Level1_filepaths = dir(fullfile(Level1_folder,'*.png'));
7 | Level2_filepaths = dir(fullfile(Level2_folder,'*.png'));
8 | Level3_filepaths = dir(fullfile(Level3_folder,'*.png'));
9 |
10 | Level1_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_esrgan';
11 | Level2_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srgan';
12 | Level3_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srres';
13 |
14 | mkdir(Level3_patchsave_path)
15 | mkdir(Level2_patchsave_path)
16 | mkdir(Level1_patchsave_path)
17 |
18 | addpath utils
19 |
20 | patch_sz = 296;
21 | stride = 148; %300 148 200
22 | blocksizerow = 96;
23 | blocksizecol = 96;
24 | blockrowoverlap = 0;
25 | blockcoloverlap = 0;
26 |
27 | total_count = 0;
28 | selected_count = 0;
29 | display_flag = 0;
30 | save_count = 0;
31 | count_class = 0;
32 | fprintf('-------- Strating -----------');
33 | for k = 1 : length(Level1_filepaths)
34 | tic;
35 | fprintf('Processing img: %s\n',Level1_filepaths(k).name);
36 | fprintf('Processing srganimg: %s\n',Level2_filepaths(k).name);
37 |
38 | level1_img = imread(fullfile(Level1_folder,Level1_filepaths(k).name));
39 | level2_img = imread(fullfile(Level2_folder,Level2_filepaths(k).name));
40 | %srres_img = imread(fullfile(Level3_folder,srres_filepaths(k).name));
41 |
42 | if display_flag == 1
43 | subplot(1,2,1);
44 | imshow(level1_img);
45 | end
46 | img_height = size(level1_img,1);
47 | img_width = size(level1_img,2);
48 |
49 | h_num = ceil((img_height-patch_sz)/stride);
50 | w_num = ceil((img_width-patch_sz)/stride);
51 | count = 0;
52 | patch_list = zeros(patch_sz,patch_sz,3,1);
53 |
54 | % esrgan_patch_list1 = zeros(patch_sz,patch_sz,3,200);
55 | % srres_patch_list1 = zeros(patch_sz,patch_sz,3,200);
56 | %srgan_patch_list1 = zeros(patch_sz,patch_sz,3,200);
57 |
58 | location_list = zeros(1,2,15);
59 | i = 0;
60 | for h = 1:stride:img_height-patch_sz
61 | for w = 1:stride:img_width-patch_sz
62 | count = count +1;
63 | total_count = total_count + 1;
64 | level1_img_patch = level1_img(h:h+patch_sz-1,w:w+patch_sz-1,:);
65 | if(size(level1_img_patch,3)==3)
66 | im = rgb2gray(level1_img_patch);
67 | end
68 | im = double(im);
69 |
70 | sharpness = compute_sharpness(im,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap);
71 |
72 | if sharpness >= 40
73 | i = i+1;
74 | if display_flag == 1
75 | subplot(1,2,1);
76 | rectangle('Position',[w,h,patch_sz,patch_sz],'edgecolor','r');
77 | end
78 | selected_count = selected_count + 1;
79 | level1_patch_list(:,:,:,i) = level1_img_patch;
80 | location_list(:,:,i) = [h,w];
81 |
82 | else
83 | if display_flag == 1
84 | subplot(1,2,1);
85 | rectangle('Position',[w,h,patch_sz,patch_sz],'edgecolor','k');
86 | end
87 | end
88 |
89 |
90 | end
91 | end
92 | fprintf('patch numbel:%d ',i);
93 | count1 = 0;
94 | count2 = 0;
95 | count3 = 0;
96 | for j = 1:i
97 |
98 | level1_save_patch = uint8(level1_patch_list(:,:,:,j));
99 | level1_NIQE = calc_NIQE(level1_save_patch);
100 |
101 | patch_h = location_list(1,1,j);
102 | patch_w = location_list(1,2,j);
103 | Level3_img_patch=get_sigle_patch(Level3_folder,Level1_filepaths(k).name,...
104 | Level3_patchsave_path,...
105 | [num2str(save_count) '_srres.png'],...
106 | patch_h,patch_w,patch_sz);
107 | level2_img_patch=get_sigle_patch(Level2_folder,Level1_filepaths(k).name,...
108 | Level2_patchsave_path,...
109 | [num2str(save_count) '_srgan.png'],...
110 | patch_h,patch_w,patch_sz);
111 | level2_NIQE = calc_NIQE(level2_img_patch);
112 | Level3_NIQE = calc_NIQE(Level3_img_patch);
113 |
114 | if abs(level2_NIQE - level1_NIQE) > 0.1
115 | count1 = count1+1;
116 | level1_patch_list1(:,:,:,count1) = level1_save_patch;
117 | level2_patch_list1(:,:,:,count1) = level2_img_patch;
118 | Level3_patch_list1(:,:,:,count1) = Level3_img_patch;
119 | end
120 |
121 | end
122 | fprintf('distance good:%d ',count1);
123 | if count1 < 200 %200
124 | for idx= 1:count1
125 | save_count = save_count + 1;
126 |
127 | save_name = [num2str(save_count) '_esrgan.png'];
128 | level1_patch = uint8(level1_patch_list1(:,:,:,idx));
129 | imwrite(level1_patch,fullfile(Level1_patchsave_path,save_name));
130 |
131 | save_name = [num2str(save_count) '_srgan.png'];
132 | level2_patch = uint8(level2_patch_list1(:,:,:,idx));
133 | imwrite(level2_patch,fullfile(Level2_patchsave_path,save_name));
134 |
135 | save_name = [num2str(save_count) '_srres.png'];
136 | Level3_patch = uint8(Level3_patch_list1(:,:,:,idx));
137 | imwrite(Level3_patch,fullfile(Level3_patchsave_path,save_name));
138 |
139 | end
140 | else
141 | rand_order = randperm(count1);
142 | for idx= 1:200
143 | save_count = save_count + 1;
144 |
145 | save_name = [num2str(save_count) '_esrgan.png'];
146 | level1_patch = uint8(level1_patch_list1(:,:,:,rand_order(idx)));
147 | imwrite(level1_patch,fullfile(Level1_patchsave_path,save_name));
148 |
149 | save_name = [num2str(save_count) '_srgan.png'];
150 | level2_patch = uint8(level2_patch_list1(:,:,:,rand_order(idx)));
151 | imwrite(level2_patch,fullfile(Level2_patchsave_path,save_name));
152 |
153 | save_name = [num2str(save_count) '_srres.png'];
154 | Level3_patch = uint8(Level3_patch_list1(:,:,:,rand_order(idx)));
155 | imwrite(Level3_patch,fullfile(Level3_patchsave_path,save_name));
156 |
157 | end
158 | end
159 | fprintf('Current save patches:%d\n',save_count);
160 |
161 | end
162 | toc;
163 | fprintf('Total image patch: %d\n',total_count);
164 | fprintf('Selected image patch: %d\n',selected_count);
165 | fprintf('Generated image patch: %d\n',save_count);
166 |
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/generate_train_ranklabel.m:
--------------------------------------------------------------------------------
1 |
2 | %% set path
3 | start = clock;
4 | data_path = '../../data/Rank_dataset_test/';
5 |
6 | level1_path = [data_path,'DF2K_train_patch_esrgan'];
7 | level2_path = [data_path,'DF2K_train_patch_srgan'];
8 | level3_path = [data_path,'DF2K_train_patch_srres'];
9 |
10 | ranklabel_path = [data_path,'/DF2K_train_NIQE.txt'];
11 |
12 | level1_dir = fullfile(pwd,level1_path);
13 | level2_dir = fullfile(pwd,level2_path);
14 | level3_dir = fullfile(pwd,level3_path);
15 |
16 | % Number of pixels to shave off image borders when calcualting scores
17 | shave_width = 4;
18 |
19 | % Set verbose option
20 | verbose = true;
21 | %% Calculate scores and save
22 | addpath utils
23 | addpath(genpath(fullfile(pwd,'utils')));
24 |
25 | %% Reading file list
26 | level1_file_list = dir([level1_dir,'/*.png']);
27 | level2_file_list = dir([level2_path,'/*.png']);
28 | level3_file_list = dir([level3_path,'/*.png']);
29 |
30 | im_num = length(level1_file_list)
31 | %fprintf(' %f\n',im_num);
32 |
33 | %% Calculating scores
34 | txtfp = fopen(ranklabel_path,'w');
35 | tic;
36 | pp = parpool('local',28);
37 | pp.IdleTimeout = 9800
38 | disp('Already initialized'); %Strating
39 | fprintf('-------- Strating -----------');
40 | parfor ii=(1:im_num)
41 | [scoresname,scoresniqe] = parcal_niqe(ii,level3_dir,level3_file_list,im_num)
42 | level3_name{ii} = scoresname;
43 | level3_niqe(ii) = scoresniqe;
44 |
45 | end
46 | parfor ii=(1:im_num)
47 | [scoresname,scoresniqe] = parcal_niqe(ii,level2_dir,level2_file_list,im_num)
48 | level2_name{ii} = scoresname;
49 | level2_niqe(ii) = scoresniqe;
50 | end
51 | parfor ii=(1:im_num)
52 | [scoresname,scoresniqe] = parcal_niqe(ii,level1_dir,level1_file_list,im_num)
53 | level1_name{ii} = scoresname;
54 | level1_niqe(ii) = scoresniqe;
55 | end
56 |
57 | toc;
58 | delete(pp)
59 | txtfp = fopen(ranklabel_path,'w');
60 | for ii=(1:im_num)
61 | fprintf(txtfp,level3_name{ii});
62 | fprintf(txtfp,' %f\n',level3_niqe(ii));
63 | fprintf(txtfp,level2_name{ii});
64 | fprintf(txtfp,' %f\n',level2_niqe(ii));
65 | fprintf(txtfp,level1_name{ii});
66 | fprintf(txtfp,' %f\n',level1_niqe(ii));
67 | end
68 |
69 | fclose(txtfp);
70 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/generate_valid_ranklabel.m:
--------------------------------------------------------------------------------
1 |
2 | %% set path
3 | start = clock;
4 | data_path = '../../data/Rank_dataset_test/';
5 |
6 | level1_path = [data_path,'DF2K_valid_patch_esrgan'];
7 | level2_path = [data_path,'DF2K_valid_patch_srgan'];
8 | level3_path = [data_path,'DF2K_valid_patch_srres'];
9 |
10 | ranklabel_path = [data_path,'/DF2K_train_NIQE.txt'];
11 |
12 | level1_dir = fullfile(pwd,level1_path);
13 | level2_dir = fullfile(pwd,level2_path);
14 | level3_dir = fullfile(pwd,level3_path);
15 |
16 | % Number of pixels to shave off image borders when calcualting scores
17 | shave_width = 4;
18 |
19 | % Set verbose option
20 | verbose = true;
21 | %% Calculate scores and save
22 | addpath utils
23 | addpath(genpath(fullfile(pwd,'utils')));
24 |
25 | %% Reading file list
26 | level1_file_list = dir([level1_dir,'/*.png']);
27 | level2_file_list = dir([level2_path,'/*.png']);
28 | level3_file_list = dir([level3_path,'/*.png']);
29 |
30 | im_num = length(level1_file_list)
31 | %fprintf(' %f\n',im_num);
32 |
33 | %% Calculating scores
34 | txtfp = fopen(ranklabel_path,'w');
35 | tic;
36 | pp = parpool('local',28);
37 | pp.IdleTimeout = 9800
38 | disp('Already initialized'); %Strating
39 | fprintf('-------- Strating -----------');
40 | parfor ii=(1:im_num)
41 | [scoresname,scoresniqe] = parcal_niqe(ii,level3_dir,level3_file_list,im_num)
42 | level3_name{ii} = scoresname;
43 | level3_niqe(ii) = scoresniqe;
44 |
45 | end
46 | parfor ii=(1:im_num)
47 | [scoresname,scoresniqe] = parcal_niqe(ii,level2_dir,level2_file_list,im_num)
48 | level2_name{ii} = scoresname;
49 | level2_niqe(ii) = scoresniqe;
50 | end
51 | parfor ii=(1:im_num)
52 | [scoresname,scoresniqe] = parcal_niqe(ii,level1_dir,level1_file_list,im_num)
53 | level1_name{ii} = scoresname;
54 | level1_niqe(ii) = scoresniqe;
55 | end
56 |
57 | toc;
58 | delete(pp)
59 | txtfp = fopen(ranklabel_path,'w');
60 | for ii=(1:im_num)
61 | fprintf(txtfp,level3_name{ii});
62 | fprintf(txtfp,' %f\n',level3_niqe(ii));
63 | fprintf(txtfp,level2_name{ii});
64 | fprintf(txtfp,' %f\n',level2_niqe(ii));
65 | fprintf(txtfp,level1_name{ii});
66 | fprintf(txtfp,' %f\n',level1_niqe(ii));
67 | end
68 |
69 | fclose(txtfp);
70 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/move_valid.py:
--------------------------------------------------------------------------------
1 | import os, random, shutil
2 | Level1_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_esrgan/';
3 | Level2_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srgan/';
4 | Level3_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_train_patch_srres/';
5 |
6 | Level1_valid_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_valid_patch_esrgan/';
7 | Level2_valid_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_valid_patch_srgan/';
8 | Level3_valid_patchsave_path = '/home/wlzhang/RankSRGAN/data/Rank_dataset_test/DF2K_valid_patch_srres/';
9 |
10 | if not os.path.exists(Level1_valid_patchsave_path):
11 | os.makedirs(Level1_valid_patchsave_path)
12 | else:
13 | print('exists')
14 |
15 | if not os.path.exists(Level2_valid_patchsave_path):
16 | os.makedirs(Level2_valid_patchsave_path)
17 | else:
18 | print('exists')
19 |
20 | if not os.path.exists(Level3_valid_patchsave_path):
21 | os.makedirs(Level3_valid_patchsave_path)
22 | else:
23 | print('exists')
24 |
25 | pathDir = os.listdir(Level1_patchsave_path) #取图片的原始路径
26 | filenumber=len(pathDir)
27 | rate=0.1
28 | picknumber=int(filenumber*rate)
29 |
30 | sample = random.sample(pathDir, picknumber)
31 |
32 |
33 | for name in sample:
34 |
35 | name = "".join(name)
36 | name = name.split('_')
37 | print(name[0])
38 |
39 | shutil.move(Level1_patchsave_path+name[0]+'_esrgan.png', Level1_valid_patchsave_path)
40 | shutil.move(Level2_patchsave_path+name[0]+'_srgan.png', Level2_valid_patchsave_path)
41 | shutil.move(Level3_patchsave_path+name[0]+'_srres.png', Level3_valid_patchsave_path)
42 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/calc_NIQE.m:
--------------------------------------------------------------------------------
1 | function NIQE = calc_NIQE(input_image)
2 |
3 | addpath(genpath(fullfile(pwd,'utils')));
4 |
5 | %% Loading model
6 | load modelparameters.mat
7 | blocksizerow = 96;
8 | blocksizecol = 96;
9 | blockrowoverlap = 0;
10 | blockcoloverlap = 0;
11 |
12 | %% Reading file list
13 | %scores = struct([]);
14 |
15 |
16 | % Calculating scores
17 | NIQE = computequality(input_image,blocksizerow,blocksizecol,...
18 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam);
19 | % perceptual_score = ([scores(ii).NIQE] + (10 - [scores(ii).Ma])) / 2;
20 | % perceptual_score = scores(ii).NIQE;
21 | % fprintf([' perceptual scores is: ',num2str(perceptual_score),' ',scores(ii).name,'']);
22 |
23 | end
24 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/compute_sharpness.m:
--------------------------------------------------------------------------------
1 | function sharpness = compute_sharpness(im,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap)
2 |
3 | window = fspecial('gaussian',7,7/6);
4 | window = window/sum(sum(window));
5 |
6 | sigma = sqrt(imfilter((im-imfilter(im,window,'replicate')).*(im-imfilter(im,window,'replicate')),window,'replicate'));
7 | paper_sharpness = blkproc(sigma,[blocksizerow blocksizecol],[blockrowoverlap blockcoloverlap],@computemean);
8 | paper_sharpness = paper_sharpness(:);
9 | sharpness = sum(paper_sharpness);
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | end
20 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/computemean.m:
--------------------------------------------------------------------------------
1 | function val = computemean(patch)
2 |
3 | val = mean2(patch);
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/convert_shave_image.m:
--------------------------------------------------------------------------------
1 | function shaved = convert_shave_image(input_image,shave_width)
2 |
3 | % Converting to y channel only
4 | image_ychannel = rgb2ycbcr(input_image);
5 | image_ychannel = image_ychannel(:,:,1);
6 |
7 | % Shaving image
8 | shaved = image_ychannel(1+shave_width:end-shave_width,...
9 | 1+shave_width:end-shave_width);
10 |
11 | end
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/get_sigle_patch.m:
--------------------------------------------------------------------------------
1 | function img_patch=get_sigle_patch(file_path,file_name,...
2 | save_dir,save_name,h,w,patch_sz)
3 |
4 | img = imread(fullfile(file_path,file_name));
5 | img_patch = img(h:h+patch_sz-1,w:w+patch_sz-1,:);
6 |
7 |
8 | end
9 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/niqe_release/computefeature.m:
--------------------------------------------------------------------------------
1 | function feat = computefeature(structdis)
2 |
3 | % Input - MSCn coefficients
4 | % Output - Compute the 18 dimensional feature vector
5 |
6 | feat = [];
7 |
8 |
9 |
10 | [alpha betal betar] = estimateaggdparam(structdis(:));
11 |
12 | feat = [feat;alpha;(betal+betar)/2];
13 |
14 | shifts = [ 0 1;1 0 ;1 1;1 -1];
15 |
16 | for itr_shift =1:4
17 |
18 | shifted_structdis = circshift(structdis,shifts(itr_shift,:));
19 | pair = structdis(:).*shifted_structdis(:);
20 | [alpha betal betar] = estimateaggdparam(pair);
21 | meanparam = (betar-betal)*(gamma(2/alpha)/gamma(1/alpha));
22 | feat = [feat;alpha;meanparam;betal;betar];
23 |
24 | end
25 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/niqe_release/computemean.m:
--------------------------------------------------------------------------------
1 | function val = computemean(patch)
2 |
3 | val = mean2(patch);
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/niqe_release/computequality.m:
--------------------------------------------------------------------------------
1 | function quality = computequality(im,blocksizerow,blocksizecol,...
2 | blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam)
3 |
4 | % Input
5 | % im - Image whose quality needs to be computed
6 | % blocksizerow - Height of the blocks in to which image is divided
7 | % blocksizecol - Width of the blocks in to which image is divided
8 | % blockrowoverlap - Amount of vertical overlap between blocks
9 | % blockcoloverlap - Amount of horizontal overlap between blocks
10 | % mu_prisparam - mean of multivariate Gaussian model
11 | % cov_prisparam - covariance of multivariate Gaussian model
12 |
13 | % For good performance, it is advisable to use make the multivariate Gaussian model
14 | % using same size patches as the distorted image is divided in to
15 |
16 | % Output
17 | %quality - Quality of the input distorted image
18 |
19 | % Example call
20 | %quality = computequality(im,96,96,0,0,mu_prisparam,cov_prisparam)
21 |
22 | % ---------------------------------------------------------------
23 | %Number of features
24 | % 18 features at each scale
25 | featnum = 18;
26 | %----------------------------------------------------------------
27 | %Compute features
28 | if(size(im,3)==3)
29 | im = rgb2gray(im);
30 | end
31 | im = double(im);
32 | [row col] = size(im);
33 | block_rownum = floor(row/blocksizerow);
34 | block_colnum = floor(col/blocksizecol);
35 |
36 | im = im(1:block_rownum*blocksizerow,1:block_colnum*blocksizecol);
37 | [row col] = size(im);
38 | block_rownum = floor(row/blocksizerow);
39 | block_colnum = floor(col/blocksizecol);
40 | im = im(1:block_rownum*blocksizerow, ...
41 | 1:block_colnum*blocksizecol);
42 | window = fspecial('gaussian',7,7/6);
43 | window = window/sum(sum(window));
44 | scalenum = 2;
45 | warning('off')
46 |
47 | feat = [];
48 |
49 |
50 | for itr_scale = 1:scalenum
51 |
52 |
53 | mu = imfilter(im,window,'replicate');
54 | mu_sq = mu.*mu;
55 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq));
56 | structdis = (im-mu)./(sigma+1);
57 |
58 |
59 |
60 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ...
61 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ...
62 | @computefeature);
63 | feat_scale = reshape(feat_scale,[featnum ....
64 | size(feat_scale,1)*size(feat_scale,2)/featnum]);
65 | feat_scale = feat_scale';
66 |
67 |
68 | if(itr_scale == 1)
69 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ...
70 | [blockrowoverlap blockcoloverlap],@computemean);
71 | sharpness = sharpness(:);
72 | end
73 |
74 |
75 | feat = [feat feat_scale];
76 |
77 | im =imresize(im,0.5);
78 |
79 | end
80 |
81 |
82 | % Fit a MVG model to distorted patch features
83 | distparam = feat;
84 | mu_distparam = nanmean(distparam);
85 | cov_distparam = nancov(distparam);
86 |
87 | % Compute quality
88 | invcov_param = pinv((cov_prisparam+cov_distparam)/2);
89 | quality = sqrt((mu_prisparam-mu_distparam)* ...
90 | invcov_param*(mu_prisparam-mu_distparam)');
91 |
92 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/niqe_release/estimateaggdparam.m:
--------------------------------------------------------------------------------
1 | function [alpha betal betar] = estimateaggdparam(vec)
2 |
3 |
4 | gam = 0.2:0.001:10;
5 | r_gam = ((gamma(2./gam)).^2)./(gamma(1./gam).*gamma(3./gam));
6 |
7 |
8 | leftstd = sqrt(mean((vec(vec<0)).^2));
9 | rightstd = sqrt(mean((vec(vec>0)).^2));
10 |
11 | gammahat = leftstd/rightstd;
12 | rhat = (mean(abs(vec)))^2/mean((vec).^2);
13 | rhatnorm = (rhat*(gammahat^3 +1)*(gammahat+1))/((gammahat^2 +1)^2);
14 | [min_difference, array_position] = min((r_gam - rhatnorm).^2);
15 | alpha = gam(array_position);
16 |
17 | betal = leftstd *sqrt(gamma(1/alpha)/gamma(3/alpha));
18 | betar = rightstd*sqrt(gamma(1/alpha)/gamma(3/alpha));
19 |
20 |
21 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/niqe_release/estimatemodelparam.m:
--------------------------------------------------------------------------------
1 | function [mu_prisparam cov_prisparam] = estimatemodelparam(folderpath,...
2 | blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,sh_th)
3 |
4 | % Input
5 | % folderpath - Folder containing the pristine images
6 | % blocksizerow - Height of the blocks in to which image is divided
7 | % blocksizecol - Width of the blocks in to which image is divided
8 | % blockrowoverlap - Amount of vertical overlap between blocks
9 | % blockcoloverlap - Amount of horizontal overlap between blocks
10 | % sh_th - The sharpness threshold level
11 | %Output
12 | %mu_prisparam - mean of multivariate Gaussian model
13 | %cov_prisparam - covariance of multivariate Gaussian model
14 |
15 | % Example call
16 |
17 | %[mu_prisparam cov_prisparam] = estimatemodelparam('pristine',96,96,0,0,0.75);
18 |
19 |
20 | %----------------------------------------------------------------
21 | % Find the names of images in the folder
22 | current = pwd;
23 | cd(sprintf('%s',folderpath))
24 | names = ls;
25 | names = names(3:end,:);
26 | cd(current)
27 | % ---------------------------------------------------------------
28 | %Number of features
29 | % 18 features at each scale
30 | featnum = 18;
31 | % ---------------------------------------------------------------
32 | % Make the directory for storing the features
33 | mkdir(sprintf('local_risquee_prisfeatures'))
34 | % ---------------------------------------------------------------
35 | % Compute pristine image features
36 | for itr = 1:size(names,1)
37 | itr
38 | im = imread(sprintf('%s\\%s',folderpath,names(itr,:)));
39 | if(size(im,3)==3)
40 | im = rgb2gray(im);
41 | end
42 | im = double(im);
43 | [row col] = size(im);
44 | block_rownum = floor(row/blocksizerow);
45 | block_colnum = floor(col/blocksizecol);
46 | im = im(1:block_rownum*blocksizerow, ...
47 | 1:block_colnum*blocksizecol);
48 | window = fspecial('gaussian',7,7/6);
49 | window = window/sum(sum(window));
50 | scalenum = 2;
51 | warning('off')
52 |
53 | feat = [];
54 |
55 |
56 | for itr_scale = 1:scalenum
57 |
58 |
59 | mu = imfilter(im,window,'replicate');
60 | mu_sq = mu.*mu;
61 | sigma = sqrt(abs(imfilter(im.*im,window,'replicate') - mu_sq));
62 | structdis = (im-mu)./(sigma+1);
63 |
64 |
65 |
66 | feat_scale = blkproc(structdis,[blocksizerow/itr_scale blocksizecol/itr_scale], ...
67 | [blockrowoverlap/itr_scale blockcoloverlap/itr_scale], ...
68 | @computefeature);
69 | feat_scale = reshape(feat_scale,[featnum ....
70 | size(feat_scale,1)*size(feat_scale,2)/featnum]);
71 | feat_scale = feat_scale';
72 |
73 |
74 | if(itr_scale == 1)
75 | sharpness = blkproc(sigma,[blocksizerow blocksizecol], ...
76 | [blockrowoverlap blockcoloverlap],@computemean);
77 | sharpness = sharpness(:);
78 | end
79 |
80 |
81 | feat = [feat feat_scale];
82 |
83 | im =imresize(im,0.5);
84 |
85 | end
86 |
87 | save(sprintf('local_risquee_prisfeatures\\prisfeatures_local%d.mat',...
88 | itr),'feat','sharpness');
89 | end
90 |
91 |
92 |
93 | %----------------------------------------------
94 | % Load pristine image features
95 | prisparam = [];
96 | current = pwd;
97 | cd(sprintf('%s','local_risquee_prisfeatures'))
98 | names = ls;
99 | names = names(3:end,:);
100 | cd(current)
101 | for itr = 1:size(names,1)
102 | % Load the features and select the only features
103 | load(sprintf('local_risquee_prisfeatures\\%s',strtrim(names(itr,:))));
104 | IX = find(sharpness(:) >sh_th*max(sharpness(:)));
105 | feat = feat(IX,:);
106 | prisparam = [prisparam; feat];
107 |
108 | end
109 | %----------------------------------------------
110 | % Compute model parameters
111 | mu_prisparam = nanmean(prisparam);
112 | cov_prisparam = nancov(prisparam);
113 | %----------------------------------------------
114 | % Save features in the mat file
115 | save('modelparameters_new.mat','mu_prisparam','cov_prisparam');
116 | %----------------------------------------------
117 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/niqe_release/modelparameters.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/datasets/generate_rankdataset/utils/niqe_release/modelparameters.mat
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/niqe_release/readme.txt:
--------------------------------------------------------------------------------
1 | NIQE Software release.
2 |
3 | =======================================================================
4 | -----------COPYRIGHT NOTICE STARTS WITH THIS LINE------------
5 | Copyright (c) 2011 The University of Texas at Austin
6 | All rights reserved.
7 |
8 | Permission is hereby granted, without written agreement and without license or royalty fees, to use, copy,
9 | modify, and distribute this code (the source files) and its documentation for
10 | any purpose, provided that the copyright notice in its entirety appear in all copies of this code, and the
11 | original source of this code, Laboratory for Image and Video Engineering (LIVE, http://live.ece.utexas.edu)
12 | and Center for Perceptual Systems (CPS, http://www.cps.utexas.edu) at the University of Texas at Austin (UT Austin,
13 | http://www.utexas.edu), is acknowledged in any publication that reports research using this code. The research
14 | is to be cited in the bibliography as:
15 |
16 | 1) A. Mittal, R. Soundararajan and A. C. Bovik, "NIQE Software Release",
17 | URL: http://live.ece.utexas.edu/research/quality/niqe.zip, 2012.
18 |
19 | 2) A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", submitted to IEEE Signal Processing Letters, 2012.
20 |
21 | IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT AUSTIN BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL,
22 | OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF THIS DATABASE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF TEXAS
23 | AT AUSTIN HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 | THE UNIVERSITY OF TEXAS AT AUSTIN SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
26 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE DATABASE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS,
27 | AND THE UNIVERSITY OF TEXAS AT AUSTIN HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
28 |
29 | -----------COPYRIGHT NOTICE ENDS WITH THIS LINE------------%
30 |
31 | Author : Anish Mittal
32 | Version : 1.0
33 |
34 | The authors are with the Laboratory for Image and Video Engineering
35 | (LIVE), Department of Electrical and Computer Engineering, The
36 | University of Texas at Austin, Austin, TX.
37 |
38 | Kindly report any suggestions or corrections to mittal.anish@gmail.com
39 |
40 | =======================================================================
41 |
42 | This is a demonstration of the Naturalness Image Quality Evaluator(NIQE) index. The algorithm is described in:
43 |
44 | A. Mittal, R. Soundararajan and A. C. Bovik, "Making a Completely Blind Image Quality Analyzer", submitted to IEEE Signal Processing Letters, 2012.
45 |
46 | You can change this program as you like and use it anywhere, but please
47 | refer to its original source (cite our paper and our web page at
48 | http://live.ece.utexas.edu/research/quality/niqe_release.zip).
49 |
50 | =======================================================================
51 | Running on Matlab
52 |
53 | Input : A test image loaded in an array
54 |
55 | Output: A quality score of the image. Higher value represents a lower quality.
56 |
57 | Usage:
58 |
59 | 1. Load the image, for example
60 |
61 | image = imread('testimage1.bmp');
62 |
63 | 2. Load the parameters of pristine multivariate Gaussian model.
64 |
65 |
66 | load modelparameters.mat;
67 |
68 |
69 | The images used for making the current model may be viewed at http://live.ece.utexas.edu/research/quality/pristinedata.zip
70 |
71 |
72 | 3. Initialize different parameters
73 |
74 | Height of the block
75 | blocksizerow = 96;
76 | Width of the block
77 | blocksizecol = 96;
78 | Verical overlap between blocks
79 | blocksizerow = 0;
80 | Horizontal overlap between blocks
81 | blocksizecol = 0;
82 |
83 | For good performance, it is advisable to divide the distorted image in to same size patched as used for the construction of multivariate Gaussian model.
84 |
85 | 3. Call this function to calculate the quality score:
86 |
87 |
88 | qualityscore = computequality(im,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,mu_prisparam,cov_prisparam)
89 |
90 | Sample execution is also shown through example.m
91 |
92 |
93 | =======================================================================
94 |
95 | MATLAB files: (provided with release): example.m, computefeature.m, computemean.m, computequality.m, estimateaggdparam.m and estimatemodelparam.m
96 |
97 | Image Files: image1.bmp, image2.bmp, image3.bmp and image4.bmp
98 |
99 | Dependencies: Mat file: modelparameters.mat provided with release
100 |
101 | =======================================================================
102 |
103 | Note on training:
104 | This release version of NIQE was trained on 125 pristine images with patch size set to 96X96 and sharpness threshold of 0.75.
105 |
106 | Training the model
107 |
108 | If the user wants to retrain the model using different set of pristine image or set the patch sizes to different values, he/she can do so
109 | use the following function. The images used for making the current model may be viewed at http://live.ece.utexas.edu/research/quality/pristinedata.zip
110 |
111 | Folder containing the pristine images
112 | folderpath = 'pristine'
113 | Height of the block
114 | blocksizerow = 96;
115 | Width of the block
116 | blocksizecol = 96;
117 | Verical overlap between blocks
118 | blocksizerow = 0;
119 | Horizontal overlap between blocks
120 | blocksizecol = 0;
121 | The sharpness threshold level
122 | sh_th = 0.75;
123 |
124 |
125 | [mu_prisparam cov_prisparam] = estimatemodelparam(folderpath,blocksizerow,blocksizecol,blockrowoverlap,blockcoloverlap,sh_th)
126 | =======================================================================
127 |
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/parcal_niqe.m:
--------------------------------------------------------------------------------
1 | function [scoresname,scoresNIQE] = parcal_niqemse(ii,input_dir,file_list,im_num)
2 |
3 | load modelparameters.mat
4 | fprintf(['\nCalculating scores for image ',num2str(ii),' / ',num2str(im_num)]);
5 | % Reading and converting images
6 | input_image_path = fullfile(input_dir,file_list(ii).name);
7 | input_image = convert_shave_image(imread(input_image_path),4);
8 |
9 | % Calculating scores
10 | scoresname = file_list(ii).name;
11 |
12 | scoresNIQE = computequality(input_image,96,96,...
13 | 0,0,mu_prisparam,cov_prisparam);
14 |
15 | fprintf([' perceptual scores is: ',num2str(scoresNIQE),' ',scoresname,'']);
16 | end
--------------------------------------------------------------------------------
/datasets/generate_rankdataset/utils/save_patch_img.m:
--------------------------------------------------------------------------------
1 | function img_patch=save_patch_img(file_path,file_name,...
2 | save_dir,save_name,h,w,patch_sz)
3 |
4 | img = imread(fullfile(file_path,file_name));
5 | img_patch = img(h:h+patch_sz-1,w:w+patch_sz-1,:);
6 |
7 | imwrite(img_patch,fullfile(save_dir,save_name));
8 |
9 | end
10 |
--------------------------------------------------------------------------------
/experiments/pretrained_models/readme.md:
--------------------------------------------------------------------------------
1 | ## Place pretrained models here.
2 |
3 | ### Pretrained models
4 |
5 | 1. `SRResNet_bicx4_in3nf64nb16.pth`: the well-trained SRResNet model in PSNR orientation.
6 | 2. `SRGAN.pth`: the pretrained model SRGAN implemented by [BasicSR](https://github.com/xinntao/BasicSR).
7 | 3. `ESRGAN_SuperSR.pth`: the pretrained model [ESRGAN_SuperSR](https://github.com/xinntao/ESRGAN).
8 |
9 | ### Three pretrained Ranker models :
10 | 1. `Ranker_NIQE.pth`: the well-trained Ranker with **NIQE** metric.
11 | 2. `Ranker_Ma.pth`: the well-trained Ranker with **Ma** metric.
12 | 3. `Ranker_PI.pth`: the well-trained Ranker with **PI** metric.
13 |
14 | ### RankSRGAN models
15 | 1. `RankSRGAN_NIQE.pth`: the RankSRGAN in **NIQE** orientation.
16 | 2. `RankSRGAN_Ma.pth`: the RankSRGAN in **Ma** orientation.
17 | 3. `RankSRGAN_PI.pth`: the RankSRGAN in **PI** orientation.
18 |
19 |
20 |
21 | *Note that* the pretrained models are trained under the `MATLAB bicubic` kernel.
22 | If the downsampled kernel is different from that, the results may have artifacts.
23 |
--------------------------------------------------------------------------------
/experiments/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figures/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/figures/method.png
--------------------------------------------------------------------------------
/figures/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figures/visual_results1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XPixelGroup/RankSRGAN/b313c24a25c9844d1d0c7ea8fd1e35da00ad8975/figures/visual_results1.png
--------------------------------------------------------------------------------
/results/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------