├── KOALAnet_framework.png ├── README.md ├── koalanet.py ├── main.py ├── ops.py └── utils.py /KOALAnet_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hjSim/KOALAnet/0a8453f969b7670d8c6785967e62701b49be05fb/KOALAnet_framework.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KOALAnet 2 | **This is the official repository of "KOALAnet: Blind Super-Resolution using Kernel-Oriented Adaptive Local Adjustment", CVPR 2021,** by [Soo Ye Kim](https://sites.google.com/view/sooyekim)\*, [Hyeonjun Sim](https://sites.google.com/view/hjsim)\* and Munchurl Kim. (\* equal contribution) 3 | 4 | We provide the training and test code along with the trained weights and the test dataset. 5 | If you find this repository useful, please consider citing our paper [[arXiv](https://arxiv.org/abs/2012.08103)]. 6 | 7 | Please watch our **[5 min presentation](https://youtu.be/j9WX5CkdF5w)** on YouTube. 8 | 9 | ![framework](/KOALAnet_framework.png) 10 | 11 | ### Reference 12 | > Soo Ye Kim*, Hyeonjun Sim*, and Munchurl Kim, "KOALAnet: Blind Super-Resolution using Kernel-Oriented Adaptive Local Adjustment", CVPR, 2021. (* *equal contribution*) 13 | > 14 | **BibTeX** 15 | ```bibtex 16 | @inproceedings{kim2021koalanet, 17 | title={KOALAnet: Blind Super-Resolution using Kernel-Oriented Adaptive Local Adjustment}, 18 | author={Kim, Soo Ye and Sim, Hyeonjun and Kim, Munchurl}, 19 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 20 | pages={10611--10620}, 21 | year={2021} 22 | } 23 | ``` 24 | ### Requirements 25 | Our code is implemented using TensorFlow and was tested under the following setting: 26 | * Python 3.6 27 | * TensorFlow 1.13 28 | * CUDA 10.0 29 | * cuDNN7.4.1 30 | * NVIDIA TITAN RTX 31 | * Windows 10 32 | 33 | ## Test Code 34 | 1. Download the files below and place them in ****: 35 | * Source code (main.py, koalanet.py, ops.py and utils.py) 36 | * Test dataset is [here](https://www.dropbox.com/sh/zkwia1ndleokeex/AAClDJY5sUDVWRLgSfi1sL3ka?dl=0). 37 | * Trained weights are [here](https://www.dropbox.com/sh/m0e2wezc2nv3z22/AAAaA-b1BGohioe4_EHzE_oIa?dl=0). 38 | 2. Set arguments defined in main.py and run main 39 | * Set ```--phase 'test'``` and provide the input and label paths to ```--test_data_path``` and ```--test_label_path``` and checkpoint path to ```--test_ckpt_path```. 40 | * Example: 41 | ``` 42 | python main.py --phase 'test' --test_data_path './testset/Set5/LR/X4/imgs' --test_label_path './testset/Set5/HR' --test_ckpt_path './pretrained_ckpt' 43 | ``` 44 | 3. Result images will be saved in **/results/imgs_test**. 45 | 46 | ### Notes 47 | * Set ```--factor``` to ```2``` or ```4``` depending on your desired upscaling factor. 48 | * If you're using the provided testset, 6 datasets each with 2 scaling factors can be used. To try these out, set ```--test_data_path``` to ```'./testset/[dataset]/LR/X[factor]/imgs'``` and ```--test_label_path``` to ```'./testset/[dataset]/HR'```, where: 49 | * ```[dataset]: Set5, Set14, BSD100, Urban100, Manga109 or DIV2K``` 50 | * ```[factor]: 2 or 4``` 51 | * If you want to test our model on your own data, set ```--test_data_path``` and ```--test_label_path``` to your desired path. 52 | * If no ground truth HR images are available, set ```--eval False``` (defaults to True) to only save images without computing PSNR. 53 | * If you're getting memory issues due to large input image sizes and limited memory, try setting ```--test_patch 2, 2 or 4, 4 ...``` (defaults to ```1, 1```). This option divides the input into an nxn grid, performs SR on each patch and stitches them back into a full image. Inference time measurements would be inaccurate in this case. 54 | * When testing with your trained version, set ```--test_ckpt_path``` accordingly, to where you've stored the weights. 55 | 56 | ## Training Code 57 | 1. Download the files below and place them in ****: 58 | * Source code (main.py, koalanet.py, ops.py and utils.py) 59 | * A training dataset (e.g. [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)). 60 | 2. 3-stage training: 61 | * Pretrain the downsampling network with ```python main.py --phase 'train' --training_stage 1``` 62 | * Pretrain the upsampling network with ```python main.py --phase 'train' --training_stage 2``` 63 | * Joint training of both networks with ```python main.py --phase 'train' --training_stage 3``` 64 | * Set ```--training_data_path``` and ```--training_label_path``` to the directory containing training and validation data. For example, 65 | ``` 66 | python main.py --phase 'train' --training_stage 3 --training_data_path './dataset/DIV2K/train/DIV2K_train_HR' --validation_data_path './dataset/DIV2K/val/DIV2K_valid_HR' 67 | ``` 68 | 3. Checkpoints will be saved in **/ckpt**. 69 | 4. Monitoring training: 70 | * Intermediate results will be available in **/results/imgs_train**. 71 | * A text log file and TensorBoard logs will be saved in **/logs**. 72 | 73 | ### Notes 74 | * Set ```--factor``` to ```2``` or ```4``` depending on your desired upscaling factor. 75 | * If ```--tensorboard True``` (defaults to True), tensorboard logs will be saved. 76 | * Model settings (gaussian kernel size, local filter size in the downsampling and upsampling networks, etc) and hyperparameters (number of epochs, batch size, patch size, learning rate, etc) are defined as arguments. Default values are what we used for the paper. Please refer to main.py for details. 77 | 78 | ## Test Dataset 79 | In blind SR, not a lot of benchmark datasets are available yet. We release the [random anisotropic Gaussian testset](https://www.dropbox.com/sh/zkwia1ndleokeex/AAClDJY5sUDVWRLgSfi1sL3ka?dl=0) we used in our paper, consisting of six datasets (Set5, Set14, BSD100, Urban100, Manga109 and DIV2K) and two scale factors (2 and 4). We hope that the community will use them for future research in SR. 80 | 81 | **Disclaimer:** The ```degradation_kernels``` folder contains images of degradation kernels used for generating the corresponding LR image, *scaled and upsampled for better visualization*. They should only be used as visual reference. 82 | 83 | ## Contact 84 | Please contact us via any of the following emails: sooyekim@kaist.ac.kr, flhy5836@kaist.ac.kr or leave a note in the issues tab. 85 | 86 | 87 | -------------------------------------------------------------------------------- /koalanet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from datetime import datetime 3 | import time 4 | from tensorflow.data.experimental import shuffle_and_repeat, unbatch 5 | 6 | from utils import * 7 | from ops import * 8 | 9 | 10 | class KOALAnet: 11 | def __init__(self, args): 12 | self.phase = args.phase 13 | self.factor = args.factor 14 | 15 | """ Training Settings """ 16 | self.training_stage = args.training_stage 17 | self.tensorboard = args.tensorboard 18 | 19 | """ Testing Settings """ 20 | self.eval = args.eval 21 | self.test_data_path = args.test_data_path 22 | self.test_label_path = args.test_label_path 23 | self.test_ckpt_path = args.test_ckpt_path 24 | self.test_patch = args.test_patch 25 | 26 | """ Model Settings """ 27 | self.channels = args.channels 28 | self.bicubic_size = args.bicubic_size 29 | self.gaussian_size = args.gaussian_size 30 | self.down_kernel = args.down_kernel 31 | self.up_kernel = args.up_kernel 32 | self.anti_aliasing = args.anti_aliasing 33 | 34 | """ Hyperparameters """ 35 | self.max_epoch = args.max_epoch 36 | self.batch_size = args.batch_size 37 | self.val_batch_size = args.val_batch_size 38 | self.patch_size = args.patch_size 39 | self.val_patch_size = args.val_patch_size 40 | self.lr = args.lr 41 | self.lr_type = args.lr_type 42 | self.lr_stair_decay_points = args.lr_stair_decay_points 43 | self.lr_stair_decay_factor = args.lr_stair_decay_factor 44 | self.lr_linear_decay_point = args.lr_linear_decay_point 45 | self.n_display = args.n_display 46 | 47 | if self.training_stage == 1: 48 | self.model_name = 'downsampling_network' 49 | elif self.training_stage == 2: 50 | self.model_name = 'upsampling_network_baseline' 51 | elif self.training_stage == 3: 52 | self.model_name = 'upsampling_network' 53 | 54 | """ Directories """ 55 | self.ckpt_dir = os.path.join('ckpt', self.model_dir) 56 | self.result_dir = os.path.join('results') 57 | check_folder(self.ckpt_dir) 58 | check_folder(self.result_dir) 59 | 60 | """ Model Init """ 61 | config = tf.ConfigProto() 62 | config.gpu_options.allow_growth = True 63 | self.sess = tf.Session(config=config) 64 | 65 | """ Print Model """ 66 | print('Model arguments, [{:s}]'.format((str(datetime.now())[:-7]))) 67 | for arg in vars(args): 68 | print('# {} : {}'.format(arg, getattr(args, arg))) 69 | print("\n") 70 | 71 | def upsampling_network_baseline(self, input_LR, factor, kernel, channels=3, reuse=False, scope='SISR_DUF'): 72 | with tf.variable_scope(scope, reuse=reuse): 73 | ch = 64 74 | n_res = 12 75 | net = conv2d(input_LR, ch, 3) 76 | for res in range(n_res): 77 | net = res_block(net, ch, 3, scope='Residual_block0_' + str(res + 1)) 78 | net = tf.nn.relu(net) 79 | # upsampling kernel branch 80 | k2d = tf.nn.relu(conv2d(net, ch * 2, 3)) 81 | k2d = conv2d(k2d, kernel * kernel * factor * factor, 3) 82 | # rgb residual image branch 83 | rgb = tf.nn.relu(conv2d(net, ch * 2, 3)) 84 | rgb = tf.depth_to_space(rgb, 2) 85 | if factor == 4: 86 | rgb = tf.nn.relu(conv2d(rgb, ch, 3)) 87 | rgb = tf.depth_to_space(rgb, 2) 88 | rgb = conv2d(rgb, channels, 3) 89 | # local filtering and upsampling 90 | output_k2d = local_conv_us(input_LR, k2d, factor, channels, kernel) 91 | output = output_k2d + rgb 92 | return output 93 | 94 | def upsampling_network(self, input_LR, k2d_ds, factor, kernel, channels=3, reuse=False, scope='SISR_DUF'): 95 | with tf.variable_scope(scope, reuse=reuse): 96 | ch = 64 97 | n_res = 12 98 | skip_idx = np.arange(0, 5, 1) 99 | # extract degradation kernel features 100 | k = cr_block(k2d_ds, 3, ch, 3, 'kernel_condition') 101 | net = conv2d(input_LR, ch, 3) 102 | filter_p_list = [] 103 | for res in range(n_res): 104 | if res in skip_idx: 105 | net, filter_p = koala(net, k, ch, ch, conv_k_sz=3, lc_k_sz=7, scope_res='Residual_block0_' + str(res + 1), scope='KOALA_module/%d' % (res+1)) 106 | filter_p_list.append(filter_p) 107 | else: 108 | net = res_block(net, ch, 3, scope='Residual_block0_' + str(res + 1)) 109 | net = tf.nn.relu(net) 110 | # upsampling kernel branch 111 | k2d = tf.nn.relu(conv2d(net, ch * 2, 3)) 112 | k2d = conv2d(k2d, kernel * kernel * factor * factor, 3) 113 | # rgb residual image branch 114 | rgb = tf.nn.relu(conv2d(net, ch * 2, 3)) 115 | rgb = tf.depth_to_space(rgb, 2) 116 | if factor == 4: 117 | rgb = tf.nn.relu(conv2d(rgb, ch, 3)) 118 | rgb = tf.depth_to_space(rgb, 2) 119 | rgb = conv2d(rgb, channels, 3) 120 | # local filtering and upsampling 121 | output_k2d = local_conv_us(input_LR, k2d, factor, channels, kernel) 122 | output = output_k2d + rgb 123 | return output, filter_p_list[-1] 124 | 125 | def downsampling_network(self, input_LR, kernel, reuse=False, scope='SISR'): 126 | with tf.variable_scope(scope, reuse=reuse): 127 | ch = 64 128 | skip = dict() 129 | # encoder 130 | n, skip[0] = enc_level_res(input_LR, ch, scope='enc_block_res/0') 131 | n, skip[1] = enc_level_res(n, ch*2, scope='enc_block_res/1') 132 | # bottleneck 133 | n = bottleneck_res(n, ch*4) 134 | # decoder 135 | n = dec_level_res(n, skip[1], ch*2, scope='dec_block_res/0') 136 | n = dec_level_res(n, skip[0], ch, scope='dec_block_res/1') 137 | # downsampling kernel branch 138 | n = tf.nn.relu(conv2d(n, ch, 3)) 139 | k2d = conv2d(n, kernel * kernel, 3) 140 | return k2d 141 | 142 | def build_model(self, args): 143 | data = SISRData(args) 144 | if self.phase == 'train': 145 | """ Directories """ 146 | self.log_dir = os.path.join('logs', self.model_dir) 147 | self.img_dir = os.path.join(self.result_dir, 'imgs_train', self.model_dir) 148 | check_folder(self.log_dir) 149 | check_folder(self.img_dir) 150 | 151 | self.updates_per_epoch = int(data.num_train / self.batch_size) 152 | print("Update per epoch : ", self.updates_per_epoch) 153 | 154 | """ Training Data Generation """ 155 | train_folder_path = tf.data.Dataset.from_tensor_slices(data.list_train).apply(shuffle_and_repeat(len(data.list_train))) 156 | train_data = train_folder_path.map(data.image_processing, num_parallel_calls=4) 157 | train_data = train_data.apply(unbatch()).shuffle(data.Qsize*50).batch(data.batch_size).prefetch(1) 158 | train_data_iterator = train_data.make_one_shot_iterator() 159 | 160 | # self.train_hr : [B, H, W, C], self.gaussian_kernel : [B, gaussian_size, gaussian_size, 1], data.bicubic_kernel : [1, bicubic_size, bicubic_size, B] 161 | self.train_hr, self.gaussian_kernel = train_data_iterator.get_next() 162 | self.ds_kernel = get_ds_kernel(data.bicubic_kernel, self.gaussian_kernel) 163 | self.train_lr = get_ds_input(self.train_hr, self.ds_kernel, self.channels, self.batch_size, data.pad_left, data.pad_right, self.factor) 164 | self.train_lr = tf.math.round((self.train_lr+1.0)/2.0*255.0) 165 | self.train_lr = tf.cast(self.train_lr, tf.float32)/255.0 * 2.0 - 1.0 166 | print("#### Degraded train_lr is quantized.") 167 | 168 | # set placeholders for validation 169 | self.val_hr = tf.placeholder(tf.float32, (self.val_batch_size, self.val_patch_size * self.factor, self.val_patch_size * self.factor, self.channels)) 170 | self.val_base_k = tf.placeholder(tf.float32, (1, self.bicubic_size, self.bicubic_size, self.val_batch_size)) 171 | self.val_rand_k = tf.placeholder(tf.float32, (self.val_batch_size, self.gaussian_size, self.gaussian_size, 1)) 172 | self.ds_kernel_val = get_ds_kernel(self.val_base_k, self.val_rand_k) 173 | self.val_lr = get_ds_input(self.val_hr, self.ds_kernel_val, self.channels, self.val_batch_size, data.pad_left, data.pad_right, self.factor) 174 | self.val_lr = tf.math.round((self.val_lr+1.0)/2.0*255.0) 175 | self.val_lr = tf.cast(self.val_lr, tf.float32)/255.0 * 2.0 - 1.0 176 | print("#### Degraded val_lr is quantized.") 177 | self.list_val = data.list_val 178 | print("Training patch size : ", self.train_lr.get_shape()) 179 | 180 | """ Define Model """ 181 | if self.training_stage == 1: 182 | self.k2d_ds = self.downsampling_network(self.train_lr, self.down_kernel, reuse=False, scope='SISR_DDF') 183 | self.k2d_ds_val = self.downsampling_network(self.val_lr, self.down_kernel, reuse=True, scope='SISR_DDF') 184 | # reconstructed LR images 185 | self.output_ds_hr = local_conv_ds(self.train_hr, self.k2d_ds, self.factor, self.channels, self.down_kernel) 186 | self.output_ds_hr_val = local_conv_ds(self.val_hr, self.k2d_ds_val, self.factor, self.channels, self.down_kernel) 187 | elif self.training_stage == 2: 188 | # reconstructed HR images 189 | self.output = self.upsampling_network_baseline(self.train_lr, self.factor, self.up_kernel, self.channels, reuse=False, scope='SISR_DUF') 190 | self.output_val = self.upsampling_network_baseline(self.val_lr, self.factor, self.up_kernel, self.channels, reuse=True, scope='SISR_DUF') 191 | elif self.training_stage == 3: 192 | self.k2d_ds = self.downsampling_network(self.train_lr, self.down_kernel, reuse=False, scope='SISR_DDF') 193 | self.k2d_ds_val = self.downsampling_network(self.val_lr, self.down_kernel, reuse=True, scope='SISR_DDF') 194 | # reconstructed LR images 195 | self.output_ds_hr = local_conv_ds(self.train_hr, self.k2d_ds, self.factor, self.channels, self.down_kernel) 196 | self.output_ds_hr_val = local_conv_ds(self.val_hr, self.k2d_ds_val, self.factor, self.channels, self.down_kernel) 197 | # reconstructed HR images 198 | self.output, self.filter_p = self.upsampling_network(self.train_lr, self.k2d_ds, self.factor, self.up_kernel, self.channels, reuse=False, scope='SISR_DUF') 199 | self.output_val, _ = self.upsampling_network(self.val_lr, self.k2d_ds_val, self.factor, self.up_kernel, self.channels, reuse=True, scope='SISR_DUF') 200 | 201 | """ Define Losses """ 202 | if self.training_stage == 1: 203 | # training 204 | self.rec_loss_ds_hr = l1_loss(self.train_lr, self.output_ds_hr) 205 | self.k2d_ds = kernel_normalize(self.k2d_ds, self.down_kernel) 206 | k2d_mean = tf.reduce_mean(self.k2d_ds, axis=[1, 2], keepdims=True) 207 | self.kernel_loss = l1_loss(k2d_mean, get_1d_kernel(self.ds_kernel, self.batch_size)) 208 | self.total_loss = self.rec_loss_ds_hr + self.kernel_loss 209 | # validation 210 | self.val_rec_loss_ds_hr = l1_loss(self.val_lr, self.output_ds_hr_val) 211 | self.k2d_ds_val = kernel_normalize(self.k2d_ds_val, self.down_kernel) 212 | k2d_mean_val = tf.reduce_mean(self.k2d_ds_val, axis=[1, 2], keepdims=True) 213 | self.val_kernel_loss = l1_loss(k2d_mean_val, get_1d_kernel(self.ds_kernel_val, self.val_batch_size)) 214 | self.val_total_loss = self.val_rec_loss_ds_hr + self.val_kernel_loss 215 | self.val_PSNR = tf.reduce_mean(tf.image.psnr((self.val_lr + 1) / 2, (self.output_ds_hr_val + 1) / 2, max_val=1.0)) 216 | 217 | elif self.training_stage == 2: 218 | # training 219 | self.rec_loss = l1_loss(self.train_hr, self.output) 220 | self.total_loss = self.rec_loss 221 | # validation 222 | self.val_rec_loss = l1_loss(self.val_hr, self.output_val) 223 | self.val_total_loss = self.val_rec_loss 224 | self.val_PSNR = tf.reduce_mean(tf.image.psnr((self.val_hr + 1) / 2, (self.output_val + 1) / 2, max_val=1.0)) 225 | 226 | elif self.training_stage == 3: 227 | # training 228 | self.rec_loss = l1_loss(self.train_hr, self.output) 229 | self.rec_loss_ds_hr = l1_loss(self.train_lr, self.output_ds_hr) 230 | self.k2d_ds = kernel_normalize(self.k2d_ds, self.down_kernel) 231 | k2d_mean = tf.reduce_mean(self.k2d_ds, axis=[1, 2], keepdims=True) 232 | self.kernel_loss = l1_loss(k2d_mean, get_1d_kernel(self.ds_kernel, self.batch_size)) 233 | self.total_loss = self.rec_loss + self.rec_loss_ds_hr + self.kernel_loss 234 | # validation 235 | self.val_rec_loss = l1_loss(self.val_hr, self.output_val) 236 | self.val_rec_loss_ds_hr = l1_loss(self.val_lr, self.output_ds_hr_val) 237 | self.k2d_ds_val = kernel_normalize(self.k2d_ds_val, self.down_kernel) 238 | k2d_mean_val = tf.reduce_mean(self.k2d_ds_val, axis=[1, 2], keepdims=True) 239 | self.val_kernel_loss = l1_loss(k2d_mean_val, get_1d_kernel(self.ds_kernel_val, self.val_batch_size)) 240 | self.val_total_loss = self.val_rec_loss + self.val_rec_loss_ds_hr + self.val_kernel_loss 241 | self.val_PSNR = tf.reduce_mean(tf.image.psnr((self.val_hr + 1) / 2, (self.output_val + 1) / 2, max_val=1.0)) 242 | 243 | """ Visualization """ 244 | # visualization of GT degradation kernel 245 | self.ds_kernel_vis = tf.transpose(self.ds_kernel, (3, 1, 2, 0)) # [B, bicubic_size, bicubic_size, 1] 246 | kernel_min = tf.reduce_min(self.ds_kernel_vis, axis=(1, 2), keepdims=True) 247 | kernel_max = tf.reduce_max(self.ds_kernel_vis, axis=(1, 2), keepdims=True) 248 | self.scale_vis = (self.patch_size*self.factor)//self.bicubic_size 249 | self.ds_kernel_vis = local_conv_vis_ds(self.ds_kernel_vis, kernel_min, kernel_max, 3, self.scale_vis) 250 | 251 | # visualization of estimated degradation kernel 252 | if self.training_stage in [1, 3]: 253 | self.k2d_ds_vis = tf.reshape(k2d_mean, [self.batch_size, self.down_kernel, self.down_kernel, 1]) # [B, down_kernel, down_kernel, 1] 254 | self.k2d_ds_vis = local_conv_vis_ds(self.k2d_ds_vis, kernel_min, kernel_max, 3, self.scale_vis) 255 | 256 | # visualization of local filters in KOALA modules 257 | if self.training_stage == 3: 258 | self.filter_p = tf.reduce_mean(self.filter_p, axis=(1, 2)) 259 | self.filter_p = tf.reshape(self.filter_p, [self.batch_size, 7, 7, 1]) 260 | self.filter_p = local_conv_vis_ds(self.filter_p, None, None, 6, 10) 261 | 262 | """ Learning Rate Schedule """ 263 | global_step = tf.Variable(initial_value=0, trainable=False) 264 | if self.lr_type == "stair_decay": 265 | self.lr_decay_boundary = [y * (self.updates_per_epoch) for y in self.lr_stair_decay_points] 266 | self.lr_decay_value = [self.lr * (self.lr_stair_decay_factor ** y) for y in range(len(self.lr_stair_decay_points) + 1)] 267 | self.reduced_lr = tf.train.piecewise_constant(global_step, self.lr_decay_boundary, self.lr_decay_value) 268 | print("lr_type: stair_decay") 269 | elif self.lr_type == "linear_decay": 270 | self.reduced_lr = tf.placeholder(tf.float32, name='learning_rate') 271 | print("lr_type: linear_decay") 272 | else: # no decay 273 | self.reduced_lr = tf.convert_to_tensor(self.lr) 274 | print("lr_type: no decay") 275 | 276 | """ Optimizer """ 277 | srnet_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="SISR") 278 | # print("\nTrainable Parameters:") 279 | # for param in srnet_params: 280 | # print(param.name) 281 | self.optimizer = tf.train.AdamOptimizer(self.reduced_lr).minimize(self.total_loss, global_step=global_step, var_list=srnet_params) 282 | 283 | """" TensorBoard Summary """ 284 | if self.tensorboard: 285 | # loss summary 286 | total_loss_sum = tf.summary.scalar("val_total_loss", self.val_total_loss) 287 | train_PSNR_sum = tf.summary.scalar("val_PSNR", self.val_PSNR) 288 | self.total_summary_loss = tf.summary.merge([total_loss_sum, train_PSNR_sum]) 289 | # image summary 290 | lr_sum = tf.summary.image("LR", self.val_lr, max_outputs=self.val_batch_size) 291 | hr_sum = tf.summary.image("HR", self.val_hr, max_outputs=self.val_batch_size) 292 | # kernel summary 293 | self.ds_kernel_val_vis = tf.transpose(self.ds_kernel_val, [3, 1, 2, 0]) # [B, bicubic_size, bicubic_size, 1] 294 | self.ds_kernel_val_vis = local_conv_vis_ds(self.ds_kernel_val_vis, None, None, 3, self.scale_vis) 295 | ds_kernel_sum = tf.summary.image("Degradation Kernel (GT)", self.ds_kernel_val_vis, max_outputs=self.val_batch_size) 296 | self.total_summary_img = tf.summary.merge([ds_kernel_sum, lr_sum, hr_sum]) 297 | # result summary 298 | if self.training_stage in [1, 3]: 299 | self.k2d_ds_val_vis = tf.reshape(k2d_mean_val, [self.val_batch_size, self.down_kernel, self.down_kernel, 1]) 300 | self.k2d_ds_val_vis = local_conv_vis_ds(self.k2d_ds_val_vis, None, None, 3, self.scale_vis) 301 | k2d_ds_sum = tf.summary.image("Degradation Kernel (Predicted)", self.k2d_ds_val_vis, max_outputs=self.val_batch_size) 302 | output_sum_ds_hr = tf.summary.image("LR (Predicted)", self.output_ds_hr_val, max_outputs=self.val_batch_size) 303 | self.total_summary_img = tf.summary.merge([self.total_summary_img, k2d_ds_sum, output_sum_ds_hr]) 304 | if self.training_stage in [2, 3]: 305 | output_sum = tf.summary.image("SR (Predicted)", self.output_val, max_outputs=self.val_batch_size) 306 | self.total_summary_img = tf.summary.merge([self.total_summary_img, output_sum]) 307 | 308 | elif self.phase == 'test': 309 | assert self.training_stage == 3, "training_stage should be 3" 310 | 311 | """ Directories """ 312 | self.test_img_dir = os.path.join(self.result_dir, 'imgs_test', self.model_dir) 313 | check_folder(self.test_img_dir) 314 | 315 | """ Set Data Paths """ 316 | self.list_test_lr = data.list_test_lr # test_data_path (LR) 317 | if self.eval: 318 | self.list_test_hr = data.list_test_hr # test_label_path (HR) 319 | 320 | """ Set Placeholders """ 321 | self.test_lr = tf.placeholder(tf.float32, (1, None, None, self.channels)) 322 | self.test_hr = tf.placeholder(tf.float32, (1, None, None, self.channels)) 323 | 324 | """ Define Model """ 325 | self.k2d_ds_test = self.downsampling_network(self.test_lr, self.down_kernel, reuse=False, scope='SISR_DDF') 326 | self.output_test, _ = self.upsampling_network(self.test_lr, self.k2d_ds_test, self.factor, self.up_kernel, self.channels, reuse=False, scope='SISR_DUF') 327 | 328 | self.sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())) 329 | 330 | def train(self, args): 331 | if self.tensorboard: 332 | self.writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) 333 | saver = tf.train.Saver(max_to_keep=3) 334 | """ Restore Checkpoint """ 335 | ckpt = tf.train.get_checkpoint_state(self.ckpt_dir) 336 | if ckpt and ckpt.model_checkpoint_path: 337 | # print("####################### print tensors from checkpoints########") 338 | # print_tensors_in_checkpoint_file(ckpt.model_checkpoint_path,'',False,True) 339 | saver.restore(self.sess, ckpt.model_checkpoint_path) 340 | start_epoch = int(ckpt.model_checkpoint_path.split('-')[1]) 341 | print("!!!!!!!!!!!!!! Restored iteration : {}".format(start_epoch)) 342 | else: 343 | print("!!!!!!!!!!!!!! Learning from scratch") 344 | start_epoch = 1 345 | 346 | # load pre-trained model for downsampling_network and upsampling_network_baseline 347 | if self.training_stage == 3: 348 | print(" [*] Loading pre-trained downsampling_network model...") 349 | saver_ds = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='SISR_DDF')) 350 | ckpt_ds = tf.train.get_checkpoint_state('./ckpt/downsampling_network_x{}'.format(self.factor)) 351 | assert ckpt_ds is not None, " [!] No pretrained downsampling network - stage 1 training is needed!" 352 | if ckpt_ds.model_checkpoint_path: 353 | saver_ds.restore(self.sess, ckpt_ds.model_checkpoint_path) 354 | print(" [*] Restored {}".format(ckpt_ds.model_checkpoint_path)) 355 | 356 | print(" [*] Loading pre-trained upsampling_network_baseline model...") 357 | saver_us = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='SISR_DUF/conv2d') + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='SISR_DUF/Residual_block0')) 358 | ckpt_us = tf.train.get_checkpoint_state('./ckpt/upsampling_network_baseline_x{}'.format(self.factor)) 359 | assert ckpt_us is not None, " [!] No pretrained upsampling network - stage 2 training is needed!" 360 | if ckpt_us.model_checkpoint_path: 361 | saver_us.restore(self.sess, ckpt_us.model_checkpoint_path) 362 | print(" [*] Restored {}".format(ckpt_us.model_checkpoint_path)) 363 | 364 | # write logs 365 | with open(self.log_dir+'/'+self.model_dir+'.txt', 'a') as log: 366 | log.write('----- Model parameters -----\n') 367 | log.write('[{:s}] \n'.format((str(datetime.now())[:-7]))) 368 | for arg in vars(args): 369 | log.write('{} : {}\n'.format(arg, getattr(args, arg))) 370 | log.write('\n\nepoch\tl1_loss\tPSNR\n') 371 | 372 | reduced_lr = self.lr 373 | feed_dict = {} 374 | for epoch in range(start_epoch, self.max_epoch+1): 375 | if self.lr_type == 'linear_decay': 376 | if epoch > self.lr_linear_decay_point: 377 | reduced_lr = self.lr * (1 - (epoch-float(self.lr_linear_decay_point))/(self.max_epoch-float(self.lr_linear_decay_point))) 378 | feed_dict = {self.reduced_lr: reduced_lr} 379 | 380 | """ Training """ 381 | rec_loss = 0.0 382 | for i in range(self.updates_per_epoch): 383 | if self.training_stage == 1: 384 | rec_loss_temp, _, lr_per_epoch = self.sess.run([self.rec_loss_ds_hr, self.optimizer, self.reduced_lr], feed_dict) 385 | elif self.training_stage in [2, 3]: 386 | rec_loss_temp, _, lr_per_epoch = self.sess.run([self.rec_loss, self.optimizer, self.reduced_lr], feed_dict) 387 | rec_loss += rec_loss_temp 388 | print('{:s}\t\tEpoch: [{}/{}], lr : {:.8}'.format((str(datetime.now())[:-7]), epoch, self.max_epoch, lr_per_epoch)) 389 | 390 | """ Validation """ 391 | val_hr_batch = np.empty((self.val_batch_size, self.factor*self.val_patch_size, self.factor*self.val_patch_size, self.channels)) 392 | val_bicubic_k = np.expand_dims(np.expand_dims(get_bicubic_kernel(self.bicubic_size, self.anti_aliasing, self.factor), axis=0), axis=3) 393 | val_bicubic_k = np.tile(val_bicubic_k, (1, 1, 1, self.val_batch_size)) 394 | val_gaussian_k = np.empty((self.val_batch_size, self.gaussian_size, self.gaussian_size, 1)) 395 | if epoch % 5 == 0: 396 | self.generate_sampled_image(int(epoch)) 397 | val_cnt = 0 398 | psnr = 0.0 399 | num_val = 20 400 | while val_cnt < num_val: 401 | for b in range(self.val_batch_size): 402 | val_gaussian_k[b, :, :, 0] = random_anisotropic_gaussian_kernel_seed(val_cnt+b, self.gaussian_size) 403 | val_hr = read_img_trim(self.list_val[val_cnt + b], factor=self.factor) 404 | # crop index 405 | _, h, w, _ = val_hr.shape 406 | h_idx = int(np.floor(h / 2) - np.floor(self.val_patch_size / 2 * self.factor)) 407 | w_idx = int(np.floor(w / 2) - np.floor(self.val_patch_size / 2 * self.factor)) 408 | # crop center 409 | val_hr = val_hr[:, h_idx:h_idx + self.factor * self.val_patch_size, w_idx:w_idx + self.factor * self.val_patch_size, :] 410 | # store as batch 411 | val_hr_batch[b, :, :, :] = val_hr 412 | 413 | if self.tensorboard: 414 | if (val_cnt == 0) & (epoch % 50 == 0): # add summary_img 415 | summary_loss, summary_img, val_PSNR = self.sess.run([self.total_summary_loss, self.total_summary_img, self.val_PSNR], 416 | {self.val_base_k: val_bicubic_k, self.val_rand_k: val_gaussian_k, self.val_hr: val_hr_batch}) 417 | self.writer.add_summary(summary_loss, epoch) 418 | self.writer.add_summary(summary_img, epoch) 419 | else: # only summary_loss (for speed) 420 | summary_loss, val_PSNR = self.sess.run([self.total_summary_loss, self.val_PSNR], 421 | {self.val_base_k: val_bicubic_k, self.val_rand_k: val_gaussian_k, self.val_hr: val_hr_batch}) 422 | self.writer.add_summary(summary_loss, epoch) 423 | else: 424 | val_PSNR = self.sess.run(self.val_PSNR, {self.val_base_k: val_bicubic_k, self.val_rand_k: val_gaussian_k, self.val_hr: val_hr_batch}) 425 | psnr += val_PSNR 426 | val_cnt += self.val_batch_size 427 | psnr /= (num_val / self.val_batch_size) 428 | rec_loss = rec_loss / (self.updates_per_epoch) 429 | print('Validation: Recon loss {:.8}, PSNR {:.4} dB'.format(rec_loss, psnr)) 430 | with open(self.log_dir+'/'+self.model_dir+'.txt', 'a') as log: 431 | log.write('{}\t{:.4}\t{:.4}\n'.format(epoch, rec_loss, psnr)) 432 | 433 | # save network weights 434 | if epoch % 10 == 0: 435 | print('Saving the model...') 436 | saver.save(self.sess, os.path.join(self.ckpt_dir, self.model_name), epoch) 437 | 438 | def test(self): 439 | assert self.training_stage == 3, "training_stage should be 3" 440 | # saver to save model 441 | self.saver = tf.train.Saver() 442 | # restore checkpoint 443 | ckpt = tf.train.get_checkpoint_state(os.path.join(self.test_ckpt_path, self.model_dir)) 444 | if ckpt and ckpt.model_checkpoint_path: 445 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 446 | print("!!!!!!!!!!!!!! Restored from {}".format(ckpt.model_checkpoint_path)) 447 | 448 | """" Test """ 449 | avg_inf_time = 0.0 450 | avg_test_PSNR = 0.0 451 | patch_boundary = 0 452 | for test_cnt in range(len(self.list_test_lr)): 453 | test_lr = read_img_trim(self.list_test_lr[test_cnt], factor=4*self.test_patch[0]) 454 | test_lr = check_gray(test_lr) 455 | if self.eval: 456 | test_hr = read_img_trim(self.list_test_hr[test_cnt], factor=self.factor*4*self.test_patch[0]) 457 | test_hr = check_gray(test_hr) 458 | _, h, w, c = test_lr.shape 459 | output_test = np.zeros((1, h*self.factor, w*self.factor, c)) 460 | inf_time = 0.0 461 | # test image divided into test_patch[0]*test_patch[1] to fit memory (default: 1x1) 462 | for p in range(self.test_patch[0] * self.test_patch[1]): 463 | pH = p // self.test_patch[1] 464 | pW = p % self.test_patch[1] 465 | sH = h // self.test_patch[0] 466 | sW = w // self.test_patch[1] 467 | # process data considering patch boundary 468 | H_low_ind, H_high_ind, W_low_ind, W_high_ind = get_HW_boundary(patch_boundary, h, w, pH, sH, pW, sW) 469 | test_lr_p = test_lr[:, H_low_ind: H_high_ind, W_low_ind: W_high_ind, :] 470 | st = time.time() 471 | output_test_p = self.sess.run([self.output_test], feed_dict={self.test_lr: test_lr_p}) 472 | inf_time_p = time.time() - st 473 | inf_time += inf_time_p 474 | output_test_p = trim_patch_boundary(output_test_p, patch_boundary, h, w, pH, sH, pW, sW, self.factor) 475 | output_test[:, pH * sH * self.factor: (pH + 1) * sH * self.factor, pW * sW * self.factor: (pW + 1) * sW * self.factor, :] = output_test_p 476 | avg_inf_time += inf_time 477 | # compute PSNR and print results 478 | if self.eval: 479 | test_PSNR = compute_y_psnr(output_test, test_hr) 480 | avg_test_PSNR += test_PSNR 481 | print(" [%4d/%4d]-th images, time: %4.4f(seconds), test_PSNR: %2.2f[dB] " 482 | % (int(test_cnt+1), len(self.list_test_lr), inf_time, test_PSNR)) 483 | else: 484 | print(" [%4d/%4d]-th images, time: %4.4f(seconds) " 485 | % (int(test_cnt + 1), len(self.list_test_lr), inf_time)) 486 | # save predicted SR images 487 | save_path = os.path.join(self.test_img_dir, os.path.basename(self.list_test_lr[test_cnt])) 488 | save_img(output_test, save_path) 489 | 490 | if self.eval: 491 | avg_test_PSNR /= float(len(self.list_test_lr)) 492 | print("######### Average Test PSNR: %.8f[dB] #########" % avg_test_PSNR) 493 | avg_inf_time /= float(len(self.list_test_lr)) 494 | print("######### Average Inference Time: %.8f[s] #########" % avg_inf_time) 495 | 496 | 497 | @property 498 | def model_dir(self): 499 | return "{}_x{}".format(self.model_name, self.factor) 500 | 501 | def generate_sampled_image(self, epoch): 502 | patch_size = self.patch_size 503 | if self.training_stage == 1: 504 | grid = patch_size 505 | else: 506 | grid = int(patch_size*self.factor) 507 | 508 | n = min(self.n_display, self.batch_size) 509 | if self.training_stage == 1: 510 | train_lr, output_ds_hr, ds_kernel_vis, k2d_ds_vis = self.sess.run([self.train_lr, self.output_ds_hr, self.ds_kernel_vis, self.k2d_ds_vis]) 511 | combined_img = np.zeros((n*grid, 4*grid, 3)) 512 | elif self.training_stage == 2: 513 | train_lr, train_hr, output, ds_kernel_vis = self.sess.run([self.train_lr, self.train_hr, self.output, self.ds_kernel_vis]) 514 | combined_img = np.zeros((n*grid, 4*grid, 3)) 515 | elif self.training_stage == 3: 516 | train_lr, train_hr, output, ds_kernel_vis, k2d_ds_vis, filter_p = self.sess.run([self.train_lr, self.train_hr, self.output, self.ds_kernel_vis, self.k2d_ds_vis, self.filter_p]) 517 | combined_img = np.zeros((n*grid, 6*grid, 3)) 518 | 519 | for i in range(0,n): 520 | if self.training_stage == 1: 521 | combined_img[i*grid:(i+1)*grid, 0*grid:1*grid] = imresize(ds_kernel_vis[i, :], output_shape=(grid,grid)) 522 | combined_img[i*grid:(i+1)*grid, 1*grid:2*grid] = imresize(k2d_ds_vis[i, :], output_shape=(grid,grid)) 523 | combined_img[i*grid:(i+1)*grid, 2*grid:3*grid] = train_lr[i, :] 524 | combined_img[i*grid:(i+1)*grid, 3*grid:4*grid] = output_ds_hr[i, :] 525 | elif self.training_stage == 2: 526 | combined_img[i*grid:(i+1)*grid, 0*grid:1*grid] = imresize(ds_kernel_vis[i, :], output_shape=(grid,grid)) 527 | combined_img[i*grid:(i+1)*grid, 1*grid:2*grid] = imresize(train_lr[i, :], self.factor) 528 | combined_img[i*grid:(i+1)*grid, 2*grid:3*grid] = output[i, :] 529 | combined_img[i*grid:(i+1)*grid, 3*grid:4*grid] = train_hr[i, :] 530 | elif self.training_stage == 3: 531 | combined_img[i*grid:(i+1)*grid, 0*grid:1*grid] = imresize(ds_kernel_vis[i, :], output_shape=(grid,grid)) 532 | combined_img[i*grid:(i+1)*grid, 1*grid:2*grid] = imresize(k2d_ds_vis[i, :], output_shape=(grid,grid)) 533 | combined_img[i*grid:(i+1)*grid, 2*grid:3*grid] = imresize(filter_p[i, :], output_shape=(grid,grid)) 534 | combined_img[i*grid:(i+1)*grid, 3*grid:4*grid] = imresize(train_lr[i, :], self.factor) 535 | combined_img[i*grid:(i+1)*grid, 4*grid:5*grid] = output[i, :] 536 | combined_img[i*grid:(i+1)*grid, 5*grid:6*grid] = train_hr[i, :] 537 | 538 | combined_img = np.clip(combined_img, -1.0, 1.0) 539 | 540 | combined_img = Image.fromarray(((np.squeeze(combined_img) + 1.0) / 2.0 * 255).astype(np.uint8)) 541 | combined_img.save(os.path.join(self.img_dir, 'img_'+'{:05d}'.format(epoch)+'.jpg')) 542 | print("!!!!!!!!!!! Output image saved !!!!!!!!!!!! (check ./{})".format(self.img_dir)) 543 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | 4 | from koalanet import KOALAnet 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="SISR") 9 | 10 | parser.add_argument('--phase', type=str, default='test', choices=['train', 'test']) 11 | parser.add_argument('--factor', type=int, default=4, help='scale factor') 12 | 13 | """ Training Settings """ 14 | parser.add_argument('--training_stage', type=int, default=3, choices=[1, 2, 3], help='Set stage for the 3-stage training strategy.') 15 | parser.add_argument('--tensorboard', type=bool, default=True, help='If set to True, tensorboard summaries are created') 16 | parser.add_argument('--training_data_path', type=str, default='./dataset/DIV2K/train/DIV2K_train_HR', help='training_dataset path') 17 | parser.add_argument('--validation_data_path', type=str, default='./dataset/DIV2K/val/DIV2K_valid_HR', help='validation_dataset path') 18 | 19 | """ Testing Settings """ 20 | parser.add_argument('--eval', type=bool, default=True, help='If set to True, evaluation is performed with HR images during the testing phase') 21 | parser.add_argument('--test_data_path', type=str, default='./testset/Set5/LR/X4/imgs', help='test dataset path') 22 | parser.add_argument('--test_label_path', type=str, default='./testset/Set5/HR', help='test dataset label path for eval') 23 | parser.add_argument('--test_ckpt_path', type=str, default='./pretrained_ckpt', help='checkpoint path with trained weights') 24 | parser.add_argument('--test_patch', type=int, nargs='+', default=[1, 1], help='input image can be divide into an nxn grid of smaller patches in the test phase to fit memory') 25 | 26 | """ Model Settings """ 27 | parser.add_argument('--channels', type=int, default=3, help='img channels') 28 | parser.add_argument('--bicubic_size', type=int, default=20, help='size of bicubic kernel - should be an even number; we recommend at least 4*factor; only 4 centered values are meaningful and other (bicubic_size-4) values are all zeros.') 29 | parser.add_argument('--gaussian_size', type=int, default=15, help='size of anisotropic gaussian kernel - should be an odd number') 30 | parser.add_argument('--down_kernel', type=int, default=20, help='downsampling kernel size in the downsampling network') 31 | parser.add_argument('--up_kernel', type=int, default=5, help='upsampling kernel size in the upsampling network') 32 | parser.add_argument('--anti_aliasing', type=bool, default=False, help='Matlab anti-aliasing') 33 | 34 | """ Hyperparameters """ 35 | parser.add_argument('--max_epoch', type=int, default=2000, help='number of total epochs') 36 | parser.add_argument('--batch_size', type=int, default=8, help='batch size for training') 37 | parser.add_argument('--val_batch_size', type=int, default=4, help='batch size for validation') 38 | parser.add_argument('--patch_size', type=int, default=64, help='training patch size') 39 | parser.add_argument('--val_patch_size', type=int, default=100, help='validation patch size') 40 | parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate') 41 | parser.add_argument('--lr_type', type=str, default='stair_decay', choices=['stair_decay', 'linear_decay', 'no_decay']) 42 | parser.add_argument('--lr_stair_decay_points', type=int, nargs='+', help='stair_decay - Epochs where lr is decayed', default=[1600, 1800]) 43 | parser.add_argument('--lr_stair_decay_factor', type=float, default=0.1, help='stair_decay - lr decreasing factor') 44 | parser.add_argument('--lr_linear_decay_point', type=int, default=100, help='linear decay - Epoch to start lr decay') 45 | parser.add_argument('--Qsize', type=int, default=50, help='number of random crop patches from a image') 46 | parser.add_argument('--n_display', type=int, default=4, help='number images to display - Should be less than or equal to batch_size') 47 | return parser.parse_args() 48 | 49 | 50 | def main(): 51 | args = parse_args() 52 | # set model class 53 | model = KOALAnet(args) 54 | # build model 55 | model.build_model(args) 56 | 57 | # train 58 | if args.phase == 'train': 59 | print("Training phase starts!!!") 60 | model.train(args) 61 | # test 62 | elif args.phase == 'test': 63 | print("Testing phase starts!!!") 64 | model.test() 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | ################################################################################## 5 | # Network Blocks 6 | ################################################################################## 7 | 8 | def res_block(x, out_ch, k_sz, scope='Residual_block'): 9 | # residual block 10 | with tf.variable_scope(scope): 11 | n = conv2d(tf.nn.relu(x), out_ch, k_sz) 12 | n = conv2d(tf.nn.relu(n), out_ch, k_sz) 13 | return x + n 14 | 15 | 16 | def enc_level_res(x, out_ch, pool_factor=2, scope='enc_block_res'): 17 | # encoder level with resblocks 18 | with tf.variable_scope(scope): 19 | n = conv2d(x, out_ch, 3) 20 | n = res_block(n, out_ch, 3, 'res_block/0') 21 | n = tf.nn.relu(res_block(n, out_ch, 3, 'res_block/1')) 22 | skip = n 23 | n = max_pool(n, pool_factor) 24 | return n, skip 25 | 26 | 27 | def bottleneck_res(x, out_ch, scope='bottleneck_res'): 28 | # bottleneck using resblock 29 | with tf.variable_scope(scope): 30 | n = conv2d(x, out_ch, 3) 31 | n = tf.nn.relu(res_block(n, out_ch, 3)) 32 | return n 33 | 34 | 35 | def dec_level_res(x, skip, out_ch, stride=2, scope='dec_block_res'): 36 | # decoder level with resblocks 37 | with tf.variable_scope(scope): 38 | n = deconv2d(x, out_ch, 4, stride) 39 | n = tf.concat([n, skip], 3) 40 | n = conv2d(n, out_ch, 3) 41 | n = res_block(n, out_ch, 3, 'res_block/0') 42 | n = tf.nn.relu(res_block(n, out_ch, 3, 'res_block/1')) 43 | return n 44 | 45 | 46 | def koala(x, kernel, feat_ch, ker_ch, conv_k_sz, lc_k_sz, scope_res='Residual_block0_', scope='KOALA_module'): 47 | # kernel-oriented adaptive local adjustment (KOALA) module 48 | with tf.variable_scope(scope_res): 49 | n = conv2d(tf.nn.relu(x), feat_ch, conv_k_sz) 50 | n = conv2d(tf.nn.relu(n), feat_ch, conv_k_sz) 51 | with tf.variable_scope(scope): 52 | # multiplicative parameters 53 | mul_p = conv2d(kernel, feat_ch, conv_k_sz) 54 | mul_p = conv2d(tf.nn.relu(mul_p), feat_ch, conv_k_sz) 55 | # local filtering parameters 56 | filter_p = conv2d(kernel, ker_ch, k_sz=1) # 1x1 conv 57 | filter_p = conv2d(tf.nn.relu(filter_p), lc_k_sz*lc_k_sz, k_sz=1) # 1x1 conv 58 | # spatially-variant feature filtering 59 | n = tf.multiply(n, mul_p) 60 | n = local_conv_feat(n, filter_p, feat_ch, lc_k_sz) 61 | n = x+n 62 | return n, filter_p 63 | 64 | 65 | def cr_block(x, num_blocks, out_ch, k_sz, scope='condition'): 66 | # conv-relu stack 67 | with tf.variable_scope(scope): 68 | n = tf.nn.relu(conv2d(x, out_ch, k_sz)) 69 | for i in range(num_blocks-1): 70 | n = tf.nn.relu(conv2d(n, out_ch, k_sz)) 71 | return n 72 | 73 | 74 | ################################################################################## 75 | # Layers 76 | ################################################################################## 77 | 78 | def conv2d(x, out_ch, k_sz, stride=1): 79 | # convolution layer 80 | init = tf.contrib.layers.xavier_initializer(uniform=False) 81 | n = tf.layers.conv2d(x, out_ch, k_sz, stride, 'same', kernel_initializer=init) 82 | return n 83 | 84 | 85 | def deconv2d(x, out_ch, k_sz=4, stride=2): 86 | # deconvolution layer 87 | init = tf.contrib.layers.xavier_initializer(uniform=False) 88 | n = tf.layers.conv2d_transpose(x, out_ch, k_sz, stride, 'same', kernel_initializer=init) 89 | return n 90 | 91 | 92 | def max_pool(x, pool_factor): 93 | # max pooling layer 94 | n = tf.nn.max_pool(x, [1, pool_factor, pool_factor, 1], [1, pool_factor, pool_factor, 1], 'SAME') 95 | return n 96 | 97 | 98 | def local_conv_us(img, kernel_2d, factor, num_ch, k_sz): 99 | # local filtering operation for upsampling network 100 | # img: [B, H, W, num_ch] 101 | # kernel_2d: [B, H, W, k_sz*k_sz*factor*factor] 102 | 103 | # [B, H, W, k*k*c] 104 | img = tf.image.extract_image_patches(img, ksizes=(1, k_sz, k_sz, 1), strides=(1, 1, 1, 1), rates=(1, 1, 1, 1), padding="SAME") 105 | img = tf.split(img, k_sz*k_sz, axis=-1) # k*k of [B, H, W, c] 106 | img = tf.stack(img, axis=3) # [B, H, W, k*k, c] 107 | img = tf.tile(img, [1, 1, 1, 1, factor*factor]) # [B, H, W, k*k, f*f*c] 108 | 109 | kernel_2d = tf.split(kernel_2d, k_sz*k_sz, axis=-1) # k*k of [B, H, W, f*f] 110 | kernel_2d = tf.stack(kernel_2d, axis=3) # [B, H, W, k*k, f*f] 111 | kernel_2d = kernel_normalize(kernel_2d, k_sz) 112 | 113 | kernel_2d = tf.expand_dims(kernel_2d, -1) # [B, H, W, k*k, f*f, 1] 114 | kernel_2d = tf.tile(kernel_2d, [1, 1, 1, 1, 1, num_ch]) # [B, H, W, k*k, f*f, c] 115 | kernel_2d = tf.unstack(kernel_2d, axis=4) # f*f of [B, H, W, k*k, c] 116 | kernel_2d = tf.concat(kernel_2d, axis=4) # [B, H, W, k*k, f*f*c] 117 | 118 | result = tf.multiply(img, kernel_2d) # element-wise multiplication, resulting in [B, H, W, k*k, f*f*c] 119 | result = tf.reduce_sum(result, axis=3) # [B, H, W, f*f*c] 120 | result = tf.depth_to_space(result, factor) # [B, f*H, f*W, c] 121 | 122 | return result 123 | 124 | 125 | def local_conv_ds(img, kernel_2d, factor, num_ch, k_sz): 126 | # local filtering operation for downsampling network 127 | # img: [B, H, W, num_ch] 128 | # kernel_2d: [B, H, W, kernel*kernel] 129 | 130 | # [B, H, W, k*k*c] 131 | img = tf.image.extract_image_patches(img, ksizes=(1, k_sz, k_sz, 1), strides=(1, factor, factor, 1), rates=(1, 1, 1, 1), padding="SAME") 132 | img = tf.split(img, k_sz * k_sz, axis=-1) # k*k of [B, H, W, c] 133 | img = tf.stack(img, axis=3) # [B, H, W, k*k, c] 134 | 135 | kernel_2d = kernel_normalize(kernel_2d, k_sz) 136 | kernel_2d = tf.expand_dims(kernel_2d, -1) # [B, H, W, k*k, 1] 137 | kernel_2d = tf.tile(kernel_2d, [1, 1, 1, 1, num_ch]) # [B, H, W, k*k, c] 138 | 139 | result = tf.multiply(img, kernel_2d) # element-wise multiplication, resulting in [B, H, W, k*k, c] 140 | result = tf.reduce_sum(result, axis=3) # [B, H, W, c] 141 | 142 | return result 143 | 144 | 145 | def local_conv_feat(img, kernel_2d, num_ch, k_sz): 146 | # local filtering operation for features 147 | # img: [B, H, W, num_ch] 148 | # kernel_2d: [B, H, W, kernel*kernel] 149 | 150 | # [B, H, W, k*k*c] 151 | img = tf.image.extract_image_patches(img, ksizes=(1, k_sz, k_sz, 1), strides=(1, 1, 1, 1), rates=(1, 1, 1, 1), padding="SAME") 152 | img = tf.split(img, k_sz * k_sz, axis=-1) # k*k of [B, H, W, c] 153 | img = tf.stack(img, axis=3) # [B, H, W, k*k, c] 154 | 155 | kernel_2d = kernel_normalize(kernel_2d, k_sz) 156 | kernel_2d = tf.expand_dims(kernel_2d, -1) # [B, H, W, k*k, 1] 157 | kernel_2d = tf.tile(kernel_2d, [1, 1, 1, 1, num_ch]) # [B, H, W, k*k, c] 158 | 159 | result = tf.multiply(img, kernel_2d) # element-wise multiplication, resulting in [B, H, W, k*k, c] 160 | result = tf.reduce_sum(result, axis=3) # [B, H, W, c] 161 | 162 | return result 163 | 164 | 165 | def kernel_normalize(kernel_2d, k_sz): 166 | kernel_2d = kernel_2d - tf.reduce_mean(kernel_2d, axis=3, keepdims=True) 167 | kernel_2d = kernel_2d + 1.0 / (k_sz ** 2) 168 | return kernel_2d 169 | 170 | 171 | ################################################################################## 172 | # Loss Function 173 | ################################################################################## 174 | 175 | def l1_loss(x, y): 176 | loss = tf.reduce_mean(tf.abs(x - y)) 177 | return loss 178 | 179 | 180 | ################################################################################## 181 | # Degradation 182 | ################################################################################## 183 | 184 | def get_ds_kernel(base_kernel, rand_kernel): 185 | # convolve base_kernel with rand_kernel 186 | # base kernel: bicubic, random kernel: anisotropic gaussian 187 | rand_kernel = tf.transpose(rand_kernel, [1, 2, 0, 3]) # [gaussian_size, gaussian_size, B, 1] 188 | ds_kernel = tf.nn.depthwise_conv2d(base_kernel, filter=rand_kernel, strides=[1, 1, 1, 1], padding='SAME') # [1, bicubic_size, bicubic_size, B] 189 | return ds_kernel 190 | 191 | 192 | def get_ds_input(hr, ds_kernel, num_ch, batch_size, pad_l, pad_r, factor): 193 | # convolve HR image with the downsampling kernel to obtain input LR 194 | ds_kernel = tf.squeeze(ds_kernel, 0) # [bicubic_size, bicubic_size, B] 195 | ds_kernel = tf.expand_dims(ds_kernel, 3) # [bicubic_size, bicubic_size, B, 1] 196 | ds_kernel = tf.tile(ds_kernel, [1, 1, 1, num_ch]) # [bicubic_size, bicubic_size, B, channels] 197 | ds_kernel = tf.unstack(ds_kernel, batch_size, axis=2) # B*[bicubic_size, bicubic_size, channels] 198 | ds_kernel = tf.concat(ds_kernel, axis=2) # [bicubic_size, bicubic_size, B*channels] 199 | ds_kernel = tf.expand_dims(ds_kernel, 3) # [bicubic_size, bicubic_size, B*channels, 1] 200 | 201 | lr = tf.unstack(hr, batch_size, axis=0) # B*[H, W, C] 202 | lr = tf.concat(lr, axis=2) # [H, W, B*C] 203 | lr = tf.expand_dims(lr, 0) # [1, H, W, B*C] 204 | lr = tf.pad(lr, [[0, 0], [pad_l, pad_r], [pad_l, pad_r], [0, 0]], 'symmetric') 205 | lr = tf.nn.depthwise_conv2d(lr, filter=ds_kernel, strides=[1, factor, factor, 1], padding='VALID') # [1, H, W, B*C] 206 | lr = tf.split(lr, batch_size, axis=3) # B*[1, H, W, C] 207 | lr = tf.concat(lr, axis=0) # [B, H, W, C] 208 | return lr 209 | 210 | 211 | def get_1d_kernel(flattened_kernel, batch_size): 212 | # flattened_kernel : [1, k_sz, k_sz, B] 213 | # kernel_1d : [B, 1, 1, k_sz*k_sz] 214 | kernel_1d = tf.transpose(flattened_kernel, [3,1,2,0]) # [B, k_sz, k_sz, 1] 215 | kernel_1d = tf.reshape(kernel_1d, [batch_size,1,1,-1]) # [B, 1, 1, k_sz*k_sz] 216 | return kernel_1d 217 | 218 | 219 | ################################################################################## 220 | # Visualization 221 | ################################################################################## 222 | 223 | def local_conv_vis_ds(kernel_2d, kernel_min=None, kernel_max=None, padding=0, scale=1): 224 | if kernel_min is None: 225 | kernel_min = tf.reduce_min(kernel_2d, axis=(1, 2), keepdims=True) 226 | if kernel_max is None: 227 | kernel_max = tf.reduce_max(kernel_2d, axis=(1, 2), keepdims=True) 228 | if padding != 0: 229 | kernel_2d = tf.pad(kernel_2d, [[0, 0], [padding, padding], [padding, padding], [0, 0]], 'constant') 230 | kernel_2d = 2.0*(kernel_2d-kernel_min)/(kernel_max-kernel_min)-1.0 231 | if scale != 1: 232 | kernel_2d = nearest_neighbor(kernel_2d, scale) 233 | kernel_2d = tf.concat((kernel_2d, kernel_2d, kernel_2d), axis=3) 234 | return kernel_2d 235 | 236 | 237 | def nearest_neighbor(x, factor): 238 | y = tf.tile(x, [1, 1, 1, factor*factor]) 239 | y = tf.depth_to_space(y, factor) 240 | return y 241 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import numpy as np 5 | from math import ceil 6 | from random import random, randint 7 | from PIL import Image 8 | from skimage import color 9 | import tensorflow as tf 10 | from tensorflow.python.util import deprecation 11 | deprecation._PRINT_DEPRECATION_WARNINGS = False 12 | 13 | 14 | ################################################################################## 15 | # Training Data Processing 16 | ################################################################################## 17 | 18 | class SISRData: 19 | 20 | def __init__(self, args): 21 | self.factor = args.factor 22 | self.bicubic_size = args.bicubic_size 23 | self.gaussian_size = args.gaussian_size 24 | self.anti_aliasing = args.anti_aliasing 25 | self.channels = args.channels 26 | self.training_data_path = args.training_data_path 27 | self.validation_data_path = args.validation_data_path 28 | self.test_data_path = args.test_data_path 29 | self.test_label_path = args.test_label_path 30 | self.batch_size = args.batch_size 31 | self.patch_size = args.patch_size 32 | self.Qsize = args.Qsize 33 | 34 | path_train = os.path.join(self.training_data_path, '*') 35 | path_val = os.path.join(self.validation_data_path, '*') 36 | path_test_lr = os.path.join(self.test_data_path, '*') 37 | path_test_hr = os.path.join(self.test_label_path, '*') 38 | print("###### path_train ", path_train) 39 | print("###### path_val ", path_val) 40 | print("###### path_test_lr ", path_test_lr) 41 | print("###### path_test_hr ", path_test_hr) 42 | self.list_train = sorted(glob.glob(path_train)) 43 | self.list_val = sorted(glob.glob(path_val)) 44 | self.list_test_lr = sorted(glob.glob(path_test_lr)) 45 | self.list_test_hr = sorted(glob.glob(path_test_hr)) 46 | self.num_train = len(self.list_train) 47 | 48 | print('Load all files list') 49 | print("# training imgs : {} \n".format(self.num_train)) 50 | 51 | # bicubic kernel to be convolved by anisotropic gaussian 52 | self.bicubic_kernel = get_bicubic_kernel(self.bicubic_size, anti_aliasing=self.anti_aliasing, factor=self.factor) 53 | self.bicubic_kernel = tf.constant(self.bicubic_kernel, dtype=tf.float32, shape=(1, self.bicubic_size, self.bicubic_size, 1)) 54 | self.bicubic_kernel = tf.tile(self.bicubic_kernel, [1, 1, 1, self.batch_size]) 55 | self.pad_left = (self.bicubic_size - self.factor) // 2 56 | self.pad_right = self.pad_left 57 | 58 | def image_processing(self, img_path): 59 | y, gaussian_kernel = tf.py_func(self.image_processing_py, [img_path], [tf.float32, tf.float32]) 60 | y.set_shape((self.Qsize, self.factor * self.patch_size, self.factor * self.patch_size, self.channels)) 61 | gaussian_kernel.set_shape((self.Qsize, self.gaussian_size, self.gaussian_size, 1)) 62 | return y, gaussian_kernel 63 | 64 | def image_processing_py(self, img_path): 65 | img_hr = Image.open(img_path) 66 | 67 | width, height = img_hr.size 68 | patches_hr = np.zeros((self.Qsize, self.factor * self.patch_size, self.factor * self.patch_size, self.channels), dtype=np.float32) 69 | patches_gaussian_kernel = np.zeros((self.Qsize, self.gaussian_size, self.gaussian_size), dtype=np.float32) 70 | for patch in range(self.Qsize): 71 | w = int(random() * (width - self.patch_size * self.factor)) 72 | h = int(random() * (height - self.patch_size * self.factor)) 73 | patches_hr[patch] = np.array( 74 | img_hr.crop((w, h, w + self.patch_size * self.factor, h + self.patch_size * self.factor)), 'float32') 75 | patches_gaussian_kernel[patch] = random_anisotropic_gaussian_kernel(width=self.gaussian_size) 76 | 77 | if random() > 0.5: # horizontal flip 78 | patches_hr = np.flip(patches_hr, axis=2) 79 | 80 | rot = randint(0, 3) 81 | patches_hr = np.rot90(patches_hr, rot, (1, 2)) 82 | 83 | patches_hr = (patches_hr / 255.0) * 2 - 1 # normalize to [-1,1] 84 | patches_gaussian_kernel = np.expand_dims(patches_gaussian_kernel, -1) 85 | 86 | return patches_hr, patches_gaussian_kernel 87 | 88 | 89 | ################################################################################## 90 | # Degradation 91 | ################################################################################## 92 | 93 | def get_bicubic_kernel(bicubic_size, anti_aliasing=False, factor=1): 94 | # set correct factor if anti_aliasing=True 95 | # assert self.bicubic_size % 2 == 0, "bicubic_size should be an even number" 96 | cubic_input = np.arange(-bicubic_size // 2 + 1, bicubic_size // 2 + 1) - 0.5 97 | if anti_aliasing: 98 | bicubic_kernel = cubic32(cubic_input / float(factor)) 99 | else: 100 | bicubic_kernel = cubic32(cubic_input) 101 | bicubic_kernel = bicubic_kernel / np.sum(bicubic_kernel) 102 | bicubic_kernel = np.outer(bicubic_kernel, bicubic_kernel.T) 103 | return bicubic_kernel 104 | 105 | 106 | def cubic32(x): 107 | x = np.array(x).astype(np.float32) 108 | absx = np.absolute(x) 109 | absx2 = np.multiply(absx, absx) 110 | absx3 = np.multiply(absx2, absx) 111 | f = np.multiply(1.5*absx3 - 2.5*absx2 + 1, absx <= 1) + np.multiply(-0.5*absx3 + 2.5*absx2 - 4*absx + 2, (1 < absx) & (absx <= 2)) 112 | return f 113 | 114 | 115 | def inv_covariance_matrix(sig_x, sig_y, theta): 116 | # sig_x : x-direction standard deviation 117 | # sig_x : y-direction standard deviation 118 | # theta : rotation angle 119 | D_inv = np.array([[1/(sig_x ** 2), 0.], [0., 1/(sig_y ** 2)]]) # inverse of diagonal matrix D 120 | U = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) # eigenvector matrix 121 | inv_cov = np.dot(U, np.dot(D_inv, U.T)) # inverse of covariance matrix 122 | return inv_cov 123 | 124 | 125 | def anisotropic_gaussian_kernel(width, inv_cov): 126 | # width : kernel size of anisotropic gaussian filter 127 | ax = np.arange(-width // 2 + 1., width // 2 + 1.) 128 | # avoid shift 129 | if width % 2 == 0: 130 | ax = ax - 0.5 131 | xx, yy = np.meshgrid(ax, ax) 132 | xy = np.stack([xx, yy], axis=2) 133 | # pdf of bivariate gaussian distribution with the covariance matrix 134 | kernel = np.exp(-0.5 * np.sum(np.dot(xy, inv_cov) * xy, 2)) 135 | kernel = kernel / np.sum(kernel) 136 | return kernel 137 | 138 | 139 | def random_anisotropic_gaussian_kernel(width=15, sig_min=0.2, sig_max=4.0): 140 | # width : kernel size of anisotropic gaussian filter 141 | # sig_min : minimum of standard deviation 142 | # sig_max : maximum of standard deviation 143 | sig_x = np.random.random() * (sig_max - sig_min) + sig_min 144 | sig_y = np.random.random() * (sig_max - sig_min) + sig_min 145 | theta = np.random.random() * 3.141/2. 146 | inv_cov = inv_covariance_matrix(sig_x, sig_y, theta) 147 | kernel = anisotropic_gaussian_kernel(width, inv_cov) 148 | kernel = kernel.astype(np.float32) 149 | return kernel 150 | 151 | 152 | def random_anisotropic_gaussian_kernel_seed(s, width=15, sig_min=0.2, sig_max=4.0): 153 | # width : kernel size of anisotropic gaussian filter 154 | # sig_min : minimum of standard deviation 155 | # sig_max : maximum of standard deviation 156 | # s as seed 157 | np.random.seed(3 * s) 158 | sig_x = np.random.random() * (sig_max - sig_min) + sig_min 159 | np.random.seed(3 * s + 1) 160 | sig_y = np.random.random() * (sig_max - sig_min) + sig_min 161 | np.random.seed(3 * s + 2) 162 | theta = np.random.random() * 3.141/2. 163 | inv_cov = inv_covariance_matrix(sig_x, sig_y, theta) 164 | kernel = anisotropic_gaussian_kernel(width, inv_cov) 165 | kernel = kernel.astype(np.float32) 166 | return kernel 167 | 168 | 169 | ################################################################################## 170 | # Image I/O 171 | ################################################################################## 172 | 173 | def read_img_trim(img_path, factor): 174 | # read and trim image so that it is divisible by factor 175 | img = np.array(Image.open(img_path), 'float32') 176 | if len(img.shape) == 3: 177 | h, w, _ = img.shape 178 | h = h - np.remainder(h, factor) 179 | w = w - np.remainder(w, factor) 180 | img = np.expand_dims(img[:h, :w, :], axis=0) 181 | else: 182 | h, w = img.shape 183 | h = h - np.remainder(h, factor) 184 | w = w - np.remainder(w, factor) 185 | img = np.expand_dims(img[:h, :w], axis=0) 186 | img = (img / 255.0) * 2.0 - 1.0 187 | return img 188 | 189 | 190 | def save_img(img, img_path): 191 | img = np.squeeze(img) 192 | img = np.clip((img + 1.) / 2. * 255., 0, 255).round() 193 | img = Image.fromarray(img.astype('uint8')) 194 | img.save(img_path) 195 | 196 | 197 | ################################################################################## 198 | # Image Processing 199 | ################################################################################## 200 | 201 | def get_HW_boundary(patch_boundary, h, w, pH, sH, pW, sW): 202 | # get boundary indices for patch-wise processing 203 | H_low_ind = max(pH * sH - patch_boundary, 0) 204 | H_high_ind = min((pH + 1) * sH + patch_boundary, h) 205 | W_low_ind = max(pW * sW - patch_boundary, 0) 206 | W_high_ind = min((pW + 1) * sW + patch_boundary, w) 207 | 208 | return H_low_ind, H_high_ind, W_low_ind, W_high_ind 209 | 210 | 211 | def trim_patch_boundary(img, patch_boundary, h, w, pH, sH, pW, sW, sf): 212 | # trim boundaries for patch-wise processing 213 | if patch_boundary == 0: 214 | img = img 215 | else: 216 | if pH * sH < patch_boundary: 217 | img = img 218 | else: 219 | img = img[:, patch_boundary*sf:, :, :] 220 | if (pH + 1) * sH + patch_boundary > h: 221 | img = img 222 | else: 223 | img = img[:, :-patch_boundary*sf, :, :] 224 | if pW * sW < patch_boundary: 225 | img = img 226 | else: 227 | img = img[:, :, patch_boundary*sf:, :] 228 | if (pW + 1) * sW + patch_boundary > w: 229 | img = img 230 | else: 231 | img = img[:, :, :-patch_boundary*sf, :] 232 | 233 | return img 234 | 235 | 236 | ################################################################################## 237 | # Resize functions from https://github.com/fatheral/matlab_imresize 238 | ################################################################################## 239 | 240 | def deriveSizeFromScale(img_shape, scale): 241 | output_shape = [] 242 | for k in range(2): 243 | output_shape.append(int(ceil(scale[k] * img_shape[k]))) 244 | return output_shape 245 | 246 | 247 | def deriveScaleFromSize(img_shape_in, img_shape_out): 248 | scale = [] 249 | for k in range(2): 250 | scale.append(1.0 * img_shape_out[k] / img_shape_in[k]) 251 | return scale 252 | 253 | 254 | def cubic(x): 255 | x = np.array(x).astype(np.float64) 256 | absx = np.absolute(x) 257 | absx2 = np.multiply(absx, absx) 258 | absx3 = np.multiply(absx2, absx) 259 | f = np.multiply(1.5*absx3 - 2.5*absx2 + 1, absx <= 1) + np.multiply(-0.5*absx3 + 2.5*absx2 - 4*absx + 2, (1 < absx) & (absx <= 2)) 260 | return f 261 | 262 | 263 | def contributions(in_length, out_length, scale, kernel, k_width): 264 | # compute weights and indices from kernel function 265 | if scale < 1: 266 | h = lambda x: scale * kernel(scale * x) 267 | kernel_width = 1.0 * k_width / scale 268 | else: 269 | h = kernel 270 | kernel_width = k_width 271 | x = np.arange(1, out_length+1).astype(np.float64) 272 | u = x / scale + 0.5 * (1 - 1 / scale) 273 | left = np.floor(u - kernel_width / 2) 274 | P = int(ceil(kernel_width)) + 2 275 | ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 276 | indices = ind.astype(np.int32) 277 | weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 278 | weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1)) 279 | aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32) 280 | indices = aux[np.mod(indices, aux.size)] 281 | ind2store = np.nonzero(np.any(weights, axis=0)) 282 | weights = weights[:, ind2store] 283 | indices = indices[:, ind2store] 284 | return weights, indices 285 | 286 | 287 | def imresizevec(inimg, weights, indices, dim): 288 | wshape = weights.shape 289 | if dim == 0: 290 | weights = weights.reshape((wshape[0], wshape[2], 1, 1)) 291 | outimg = np.sum(weights*((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1) 292 | elif dim == 1: 293 | weights = weights.reshape((1, wshape[0], wshape[2], 1)) 294 | outimg = np.sum(weights*((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2) 295 | if inimg.dtype == np.uint8: 296 | outimg = np.clip(outimg, 0, 255) 297 | return np.around(outimg).astype(np.uint8) 298 | else: 299 | return outimg 300 | 301 | 302 | def imresize(I, scalar_scale=None, output_shape=None): 303 | kernel = cubic 304 | kernel_width = 4.0 305 | # Fill scale and output_size 306 | if scalar_scale is not None: 307 | scalar_scale = float(scalar_scale) 308 | scale = [scalar_scale, scalar_scale] 309 | output_size = deriveSizeFromScale(I.shape, scale) 310 | elif output_shape is not None: 311 | scale = deriveScaleFromSize(I.shape, output_shape) 312 | output_size = list(output_shape) 313 | else: 314 | print('Error: scalar_scale OR output_shape should be defined!') 315 | return 316 | scale_np = np.array(scale) 317 | order = np.argsort(scale_np) 318 | weights = [] 319 | indices = [] 320 | for k in range(2): 321 | w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width) 322 | weights.append(w) 323 | indices.append(ind) 324 | B = np.copy(I) 325 | flag2D = False 326 | if B.ndim == 2: 327 | B = np.expand_dims(B, axis=2) 328 | flag2D = True 329 | for k in range(2): 330 | dim = order[k] 331 | B = imresizevec(B, weights[dim], indices[dim], dim) 332 | if flag2D: 333 | B = np.squeeze(B, axis=2) 334 | return B 335 | 336 | 337 | def convertDouble2Byte(I): 338 | B = np.clip(I, 0.0, 1.0) 339 | B = 255*B 340 | return np.around(B).astype(np.uint8) 341 | 342 | 343 | ################################################################################## 344 | # Miscellaneous 345 | ################################################################################## 346 | 347 | def compute_psnr(img_gt, img_out, peak): 348 | mse = np.mean(np.square(img_gt - img_out)) 349 | psnr = 10 * np.log10(peak*peak / mse) 350 | return psnr 351 | 352 | 353 | def compute_y_psnr(img_gt_rgb, img_out_rgb): 354 | # images must be in range [-1, 1] float or double 355 | peak = 255 356 | img_gt_rgb = np.squeeze(img_gt_rgb) 357 | img_out_rgb = np.squeeze(img_out_rgb) 358 | img_gt_rgb = np.clip((img_gt_rgb + 1.) / 2. * 255., 0, 255).round() 359 | img_out_rgb = np.clip((img_out_rgb + 1.) / 2. * 255., 0, 255).round() 360 | 361 | img_gt_yuv = color.rgb2ycbcr(img_gt_rgb.astype('uint8')) 362 | img_out_yuv = color.rgb2ycbcr(img_out_rgb.astype('uint8')) 363 | img_gt_yuv = np.clip(img_gt_yuv[:, :, 0], 0, 255).round() 364 | img_out_yuv = np.clip(img_out_yuv[:, :, 0], 0, 255).round() 365 | psnr = compute_psnr(img_gt_yuv, img_out_yuv, peak) 366 | return psnr 367 | 368 | 369 | def check_folder(log_dir): 370 | if not os.path.exists(log_dir): 371 | os.makedirs(log_dir) 372 | return log_dir 373 | 374 | 375 | def check_gray(img): 376 | if len(img.shape) == 3: 377 | img = np.expand_dims(img, axis=3) 378 | img = np.tile(img, (1, 1, 1, 3)) 379 | return img 380 | --------------------------------------------------------------------------------