├── README.md ├── checkpoint_x2 ├── checkpoint ├── model.ckpt.data-00000-of-00001 ├── model.ckpt.index └── model.ckpt.meta ├── checkpoint_x3 ├── checkpoint ├── model.ckpt.data-00000-of-00001 ├── model.ckpt.index └── model.ckpt.meta ├── checkpoint_x4 ├── checkpoint ├── model.ckpt.data-00000-of-00001 ├── model.ckpt.index └── model.ckpt.meta ├── config.py ├── model.py ├── test.py ├── train_SR.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # IDN-tensorflow 2 | [[Original Caffe version]](https://github.com/Zheng222/IDN-Caffe) 3 | ## Testing 4 | * Install Tensorflow 1.11, Matlab R2017a 5 | * Download [Test datasets](https://drive.google.com/open?id=1_K6mchwDGOQMIXuBIGrlDA4EAYgbtdmU) 6 | * Modify `config.py` (if you want to test x3 model on Set14, `config.TEST.model_path = 'checkpoint_x3/model.ckpt'` `config.TEST.dataset = 'Set14'`) and `test.py` (`scale = 3`). 7 | * Run testing: 8 | ```bash 9 | python test.py 10 | ``` 11 | 12 | ## Training 13 | * Download [Training dataset](https://drive.google.com/open?id=12hOYsMa8t1ErKj6PZA352icsx9mz1TwB) 14 | * Modify `config.py` (if you want to train x4 model, `config.TRAIN.hr_img_path = '/path/to/DIV2K_train_HR/'` `config.TRAIN.checkpoint_dir = 'checkpoint_x4/'` `config.VALID.hr_img_path = '/path/to/DIV2K_valid_HR/'` `config.VALID.lr_img_path = '/path/to/DIV2K_valid_LR_x4/'`) and `train_SR.py` (`scale = 4`) 15 | * Run training: 16 | ```bash 17 | python train_SR.py 18 | ``` 19 | ## Note 20 | This TensorFlow version is trained with DIV2K training dataset on RGB channels. Additionally, We modify the upsample layer to subpixel convolution (the original version is transposed convolution). 21 | 22 | ## Results 23 | [Test_results](https://drive.google.com/open?id=1saFhGV8t2ytzRLHE2CaFc4H_UkvJo9KS) 24 | 25 | The following PSNR/SSIMs are evaluated on Matlab R2017a and the code can be referred to [Evaluate_PSNR_SSIM.m](https://github.com/yulunzhang/RCAN/blob/master/RCAN_TestCode/Evaluate_PSNR_SSIM.m). 26 | 27 | | Training dataset | Scale | Set5 | Set14 | B100 | Urban100 | 28 | |:---:|:---:|:---:|:---:|:---:|:---:| 29 | | 291 | ×2 | 37.83 / 0.9600 | 33.30 / 0.9148|32.08 / 0.8985|31.27 / 0.9196| 30 | | DIV2K | ×2 | 37.85 / 0.9598 | 33.58 / 0.9178|32.11 / 0.8989|31.95 / 0.9266| 31 | | 291 | ×3 | 34.11 / 0.9253 | 29.99 / 0.8354|28.95 / 0.8013|27.42 / 0.8359| 32 | | DIV2K | ×3 | 34.24 / 0.9260 | 30.27 / 0.8408|29.03 / 0.8038|27.99 / 0.8489| 33 | | 291 | ×4 | 31.82 / 0.8903 | 28.25 / 0.7730|27.41 / 0.7297|25.41 / 0.7632| 34 | | DIV2K | ×4 | 31.99 / 0.8928 | 28.52 / 0.7794|27.52 / 0.7339|25.92 / 0.7801| 35 | 36 | ## Model Parameters 37 | | Scale| Model size | 38 | |:---:|:---:| 39 | | ×2 | **579,276** | 40 | | ×3 | **587,931** | 41 | | ×4 | **600,048** | 42 | ## Citation 43 | 44 | If you find IDN useful in your research, please consider citing: 45 | 46 | ``` 47 | @inproceedings{Hui-IDN-2018, 48 | title={Fast and Accurate Single Image Super-Resolution via Information Distillation Network}, 49 | author={Hui, Zheng and Wang, Xiumei and Gao, Xinbo}, 50 | booktitle={CVPR}, 51 | pages = {723--731}, 52 | year={2018} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /checkpoint_x2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /checkpoint_x2/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x2/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint_x2/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x2/model.ckpt.index -------------------------------------------------------------------------------- /checkpoint_x2/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x2/model.ckpt.meta -------------------------------------------------------------------------------- /checkpoint_x3/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /checkpoint_x3/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x3/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint_x3/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x3/model.ckpt.index -------------------------------------------------------------------------------- /checkpoint_x3/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x3/model.ckpt.meta -------------------------------------------------------------------------------- /checkpoint_x4/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /checkpoint_x4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint_x4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x4/model.ckpt.index -------------------------------------------------------------------------------- /checkpoint_x4/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x4/model.ckpt.meta -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | config = edict() 4 | config.TRAIN = edict() 5 | 6 | ## Adam 7 | config.TRAIN.batch_size = 16 8 | config.TRAIN.lr_init = 2e-4 9 | 10 | ## Generator 11 | config.TRAIN.n_epoch = 10000 12 | config.TRAIN.lr_decay = 0.5 13 | config.TRAIN.decay_every = 2000 14 | 15 | ## training dataset location 16 | config.TRAIN.hr_img_path = '/data/DIV2K_train_HR/' 17 | config.TRAIN.checkpoint_dir = 'checkpoint_x2/' 18 | 19 | config.VALID = edict() 20 | config.VALID.hr_img_path = '/data/DIV2K_valid_HR/' 21 | config.VALID.lr_img_path = '/data/DIV2K_valid_LR_x2/' 22 | ## test 23 | config.TEST = edict() 24 | config.TEST.model_path = 'checkpoint_x2/model.ckpt' 25 | config.TEST.save_path = 'results' 26 | config.TEST.dataset = 'Set5' # Set5 | Set14 | B100 | Urban100 | manga109 | DIV2K_val 27 | 28 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def IDN(t_image, t_image_bicubic, scale, reuse=False): 4 | t_image_bicubic = tf.identity(t_image_bicubic) 5 | with tf.variable_scope("IDN", reuse=reuse): 6 | conv1 = tf.layers.conv2d(t_image, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name='conv1') 7 | conv2 = tf.layers.conv2d(conv1, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name='conv2') 8 | n = conv2 9 | for i in range(4): 10 | n = distillation(n, name='distill/%i' % i) 11 | output = upsample(n, scale=scale,features=64, name=str(scale)) + t_image_bicubic 12 | return output 13 | 14 | def distillation(x, name=''): 15 | tmp = tf.layers.conv2d(x, 48, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv1') 16 | tmp = GroupConv2d(tmp, act=lrelu, name=name+'/conv2') 17 | tmp = tf.layers.conv2d(tmp, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv3') 18 | tmp1, tmp2 = tf.split(axis=3, num_or_size_splits=[16, 48], value=tmp) 19 | tmp2 = tf.layers.conv2d(tmp2, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv4') 20 | tmp2 = GroupConv2d(tmp2, n_filter=48, act=lrelu, name=name+'/conv5') 21 | tmp2 = tf.layers.conv2d(tmp2, 80, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv6') 22 | output = tf.concat(axis=3, values=[x, tmp1]) + tmp2 23 | output = tf.layers.conv2d(output, 64, (1, 1), (1, 1), padding='same', activation=lrelu, name=name+'/conv7') 24 | return output 25 | 26 | 27 | def lrelu(x, alpha=0.05): 28 | return tf.maximum(alpha * x, x) 29 | 30 | 31 | def _phase_shift(I, r): 32 | return tf.depth_to_space(I, r) 33 | 34 | 35 | def PS(X, r, color=False): 36 | if color: 37 | Xc = tf.split(X, 3, 3) # tf.split(value, num_or_size_splits, axis=0) 38 | X = tf.concat([_phase_shift(x, r) for x in Xc], 3) 39 | else: 40 | X = _phase_shift(X, r) 41 | return X 42 | 43 | def upsample(x, scale=4, features=32, name=None): 44 | with tf.variable_scope(name): 45 | x = tf.layers.conv2d(x, features, 3, padding='same') 46 | ps_features = 3 * (scale ** 2) 47 | x = tf.layers.conv2d(x, ps_features, 3, padding='same') 48 | x = PS(x, scale, color=True) 49 | return x 50 | 51 | def GroupConv2d(x, n_filter=32, filter_size=(3, 3), strides=(1, 1), n_group=4, act=None, padding='SAME', name=None): 52 | groupConv = lambda i, k: tf.nn.conv2d(i, k, strides=[1, strides[0], strides[1], 1], padding=padding) 53 | channels = int(x.get_shape()[-1]) 54 | with tf.variable_scope(name): 55 | We = tf.get_variable( 56 | name='W', shape=[filter_size[0], filter_size[1], channels / n_group, n_filter], trainable=True 57 | ) 58 | 59 | if n_group == 1: 60 | outputs = groupConv(x, We) 61 | else: 62 | inputGroups = tf.split(axis=3, num_or_size_splits=n_group, value=x) 63 | weightsGroups = tf.split(axis=3, num_or_size_splits=n_group, value=We) 64 | convGroups = [groupConv(i, k) for i, k in zip(inputGroups, weightsGroups)] 65 | 66 | outputs = tf.concat(axis=3, values=convGroups) 67 | 68 | b = tf.get_variable( 69 | name='b', shape=n_filter, trainable=True 70 | ) 71 | 72 | outputs = tf.nn.bias_add(outputs, b, name='bias_add') 73 | 74 | if act: 75 | outputs = lrelu(outputs) 76 | return outputs -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from config import config 2 | import numpy as np 3 | from scipy import misc 4 | import os 5 | import tensorflow as tf 6 | import glob 7 | from model import IDN 8 | import utils 9 | import skimage.color as sc 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 12 | dataset = config.TEST.dataset 13 | model_path = config.TEST.model_path 14 | saved_path = config.TEST.save_path 15 | scale = 2 # 2 | 3 | 4 16 | rgb = False 17 | 18 | def main(): 19 | ## data 20 | print('Loading data...') 21 | test_hr_path = os.path.join('data/', dataset) 22 | if dataset == 'Set5': 23 | ext = '*.bmp' 24 | else: 25 | ext = '*.png' 26 | hr_paths = sorted(glob.glob(os.path.join(test_hr_path, ext))) 27 | 28 | ## model 29 | print('Loading model...') 30 | tensor_lr = tf.placeholder('float32', [1, None, None, 3], name='tensor_lr') 31 | tensor_b = tf.placeholder('float32', [1, None, None, 3], name='tensor_b') 32 | 33 | tensor_sr = IDN(tensor_lr, tensor_b, scale) 34 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 35 | sess.run(tf.global_variables_initializer()) 36 | saver = tf.train.Saver() 37 | saver.restore(sess, model_path) 38 | 39 | ## result 40 | save_path = os.path.join(saved_path, dataset+'/x'+str(scale)) 41 | if not os.path.exists(save_path): 42 | os.makedirs(save_path) 43 | 44 | psnr_score = 0 45 | for i, _ in enumerate(hr_paths): 46 | print('processing image %d' % (i+1)) 47 | img_hr = utils.modcrop(misc.imread(hr_paths[i]), scale) 48 | img_lr = utils.downsample_fn(img_hr, scale=scale) 49 | img_b = utils.upsample_fn(img_lr, scale=scale) 50 | [lr, b] = utils.datatype([img_lr, img_b]) 51 | lr = lr[np.newaxis, :, :, :] 52 | b = b[np.newaxis, :, :, :] 53 | [sr] = sess.run([tensor_sr], {tensor_lr: lr, tensor_b: b}) 54 | sr = utils.quantize(np.squeeze(sr)) 55 | img_sr = utils.shave(sr, scale) 56 | img_hr = utils.shave(img_hr, scale) 57 | if not rgb: 58 | img_pre = utils.quantize(sc.rgb2ycbcr(img_sr)[:, :, 0]) 59 | img_label = utils.quantize(sc.rgb2ycbcr(img_hr)[:, :, 0]) 60 | else: 61 | img_pre = img_sr 62 | img_label = img_hr 63 | psnr_score += utils.compute_psnr(img_pre, img_label) 64 | misc.imsave(os.path.join(save_path, os.path.basename(hr_paths[i])), sr) 65 | 66 | print('Average PSNR: %.4f' % (psnr_score / len(hr_paths))) 67 | print('Finish') 68 | 69 | if __name__ == '__main__': 70 | main() -------------------------------------------------------------------------------- /train_SR.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorlayer as tl 3 | from model import IDN 4 | from config import config 5 | import numpy as np 6 | import os 7 | from tensorboardX import SummaryWriter 8 | import utils 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 11 | batch_size = config.TRAIN.batch_size 12 | lr_init = config.TRAIN.lr_init 13 | n_epoch = config.TRAIN.n_epoch 14 | lr_decay = config.TRAIN.lr_decay 15 | decay_every = config.TRAIN.decay_every 16 | checkpoint_dir = config.TRAIN.checkpoint_dir 17 | hr_image_path = config.TRAIN.hr_img_path 18 | scale = 2 # 2 | 3 | 4 19 | eval_every = 10 20 | 21 | with tf.variable_scope(tf.get_variable_scope()): 22 | ## create folders to save trained model 23 | tl.files.exists_or_mkdir(checkpoint_dir) 24 | 25 | ## pre-load data 26 | train_hr_npy = os.path.join(hr_image_path, 'train_hr.npy') 27 | valid_hr_npy = os.path.join(config.VALID.hr_img_path, 'valid_hr.npy') 28 | valid_lr_npy = os.path.join(config.VALID.lr_img_path, 'valid_lr_x{}.npy'.format(scale)) 29 | 30 | if os.path.exists(train_hr_npy) and os.path.exists(valid_hr_npy) and os.path.exists(valid_lr_npy): 31 | print('Loading data...') 32 | train_hr_imgs = np.load(train_hr_npy) 33 | valid_hr_imgs = np.load(valid_hr_npy) 34 | valid_lr_imgs = np.load(valid_lr_npy) 35 | else: 36 | print('Creating data binary...') 37 | train_hr_imgs_list = sorted(tl.files.load_file_list(path=hr_image_path, regx='.*.png', printable=False)) 38 | valid_hr_imgs_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) 39 | 40 | train_hr_imgs = np.array(tl.visualize.read_images(train_hr_imgs_list, path=hr_image_path, n_threads=16)) 41 | valid_hr_imgs = np.array(tl.visualize.read_images(valid_hr_imgs_list, path=config.VALID.hr_img_path, n_threads=16)) 42 | valid_lr_imgs = tl.prepro.threading_data(valid_hr_imgs, fn=utils.downsample_fn, scale=scale) 43 | 44 | np.save(train_hr_npy, train_hr_imgs) 45 | np.save(valid_hr_npy, valid_hr_imgs) 46 | np.save(valid_lr_npy, valid_lr_imgs) 47 | 48 | ## define model 49 | tensor_lr = tf.placeholder('float32', [None, None, None, 3], name='tensor_lr') 50 | tensor_b = tf.placeholder('float32', [None, None, None, 3], name='tensor_b') 51 | tensor_hr = tf.placeholder('float32', [None, None, None, 3], name='tensor_hr') 52 | 53 | print('Loading model...') 54 | tensor_sr = IDN(tensor_lr, tensor_b, scale) 55 | 56 | ## calculate the number of parameters 57 | total_parameters = 0 58 | for variable in tf.trainable_variables(): 59 | variable_parameters = 1 60 | for dim in variable.get_shape(): 61 | variable_parameters *= dim.value 62 | total_parameters += variable_parameters 63 | print("Total number of trainable parameters: %d" % total_parameters) 64 | 65 | 66 | ## define loss functions 67 | mae_loss = tf.reduce_mean(tf.losses.absolute_difference(tensor_sr, tensor_hr)) 68 | 69 | ## PSNR and SSIM (Evaluation) 70 | PSNR = tf.image.psnr(tensor_sr, tensor_hr, max_val=255) 71 | SSIM = tf.image.ssim_multiscale(tensor_sr, tensor_hr, max_val=255) 72 | 73 | ## create the optimization 74 | g_vars = [v for v in tf.global_variables() if v.name.startswith("IDN")] 75 | 76 | with tf.variable_scope("learning_rate"): 77 | lr_value = tf.Variable(lr_init, trainable=False) 78 | g_optim = tf.train.AdamOptimizer(learning_rate=lr_value).minimize(mae_loss, var_list=g_vars) 79 | 80 | ## restore model 81 | 82 | config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 83 | config.gpu_options.allow_growth = True 84 | sess = tf.Session(config=config) 85 | saver = tf.train.Saver() 86 | sess.run(tf.global_variables_initializer()) 87 | 88 | # load from checkpoint if exist 89 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 90 | if ckpt: 91 | print('loaded ' + ckpt.model_checkpoint_path) 92 | saver.restore(sess, ckpt.model_checkpoint_path) 93 | 94 | 95 | ## Tensorboard 96 | writer = SummaryWriter(os.path.join(checkpoint_dir, 'result')) 97 | tf.summary.FileWriter(os.path.join(checkpoint_dir, 'graph'), sess.graph) 98 | best_psnr, best_epoch = 0, 0 99 | 100 | ## training 101 | print("Training network...") 102 | for epoch in range(0, n_epoch + 1): 103 | # update learning rate 104 | if epoch != 0 and (epoch % decay_every == 0): 105 | new_lr_decay = lr_decay ** (epoch // decay_every) 106 | sess.run(tf.assign(lr_value, lr_init * new_lr_decay)) 107 | log = " ** new learning rate: %f" % (lr_init * new_lr_decay) 108 | print(log) 109 | elif epoch == 0: 110 | sess.run(tf.assign(lr_value, lr_init)) 111 | log = " ** init lr: %f decay_every_init: %d, lr_decay: %f" % (lr_init, decay_every, lr_decay) 112 | print(log) 113 | 114 | index = np.random.permutation(len(train_hr_imgs)) 115 | num_batches = len(train_hr_imgs) // batch_size 116 | 117 | total_losses = np.zeros(1) 118 | for i in range(num_batches): 119 | hr = tl.prepro.threading_data(train_hr_imgs[index[i * batch_size: (i+1) * batch_size]], fn=utils.crop_sub_imgs_fn, is_random=True) 120 | lr = tl.prepro.threading_data(hr, fn=utils.downsample_fn, scale=scale) 121 | b = tl.prepro.threading_data(lr, fn=utils.upsample_fn, scale=scale) 122 | [lr, hr, b] = utils.datatype([lr, hr, b]) 123 | ## update G 124 | error_mae, _ = sess.run([mae_loss, g_optim], {tensor_lr: lr, tensor_hr: hr, tensor_b: b}) 125 | total_losses += error_mae 126 | 127 | avg_loss = total_losses / num_batches 128 | log = "[*] Epoch: [%2d/%2d] mae: %.6f" % \ 129 | (epoch, n_epoch, avg_loss[0]) 130 | print(log) 131 | 132 | writer.add_scalar('mae_loss', avg_loss[0], epoch) 133 | 134 | ## validating 135 | if (epoch != 0 and epoch % eval_every == 0): 136 | print('Validating...') 137 | val_psnr = 0 138 | val_ssim = 0 139 | for i in range(len(valid_hr_imgs)): 140 | hr = valid_hr_imgs[i] 141 | lr = valid_lr_imgs[i] 142 | b = utils.upsample_fn(lr, scale=scale) 143 | [lr, hr, b] = utils.datatype([lr, hr, b]) 144 | 145 | hr_expand = np.expand_dims(hr, axis=0) 146 | lr_expand = np.expand_dims(lr, axis=0) 147 | b_expand = np.expand_dims(b, axis=0) 148 | 149 | psnr, ssim, sr_expand = sess.run([PSNR, SSIM, tensor_sr], {tensor_lr: lr_expand, tensor_hr: hr_expand, tensor_b: b_expand}) 150 | sr = np.squeeze(sr_expand) 151 | utils.update_tensorboard(epoch, writer, i, lr, sr, hr) 152 | val_psnr += psnr 153 | val_ssim += ssim 154 | 155 | val_psnr = val_psnr / len(valid_hr_imgs) 156 | val_ssim = val_ssim / len(valid_hr_imgs) 157 | if val_psnr > best_psnr: 158 | best_psnr = val_psnr 159 | best_epoch = epoch 160 | print('Saving new best model, Epoch = %d, PSNR = %.4f' % (best_epoch, best_psnr)) 161 | 162 | ## save model 163 | saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt')) 164 | writer.add_scalar('Validate PSNR', val_psnr, epoch) 165 | writer.add_scalar('Validate SSIM', val_ssim, epoch) 166 | 167 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from tensorlayer.prepro import imresize, crop, rotation, flip_axis 2 | import random 3 | from scipy import misc 4 | import numpy as np 5 | from skimage.measure import compare_psnr as psnr 6 | from skimage.measure import compare_ssim as ssim 7 | 8 | def get_imgs_fn(file_name, path): 9 | return misc.imread(path + file_name, mode='RGB') 10 | 11 | def augment(x, hflip=True, rot=True): 12 | hflip = hflip and random.random() < 0.5 13 | vflip = rot and random.random() < 0.5 14 | rot90 = rot and random.random() < 0.5 15 | 16 | def _augment(img): 17 | if hflip: 18 | img = flip_axis(img, axis=1) 19 | if vflip: 20 | img = flip_axis(img, axis=0) 21 | if rot90: 22 | img = rotation(img, rg=90) 23 | return img 24 | return _augment(x) 25 | 26 | def crop_sub_imgs_fn(x, is_random=True): 27 | x = crop(x, wrg=192, hrg=192, is_random=is_random) 28 | x = augment(x) 29 | return x 30 | 31 | def downsample_fn(x, scale=4): 32 | h, w = x.shape[0:2] 33 | hs, ws = h // scale, w // scale 34 | 35 | x = imresize(x, size=[hs, ws], interp='bicubic') 36 | return x 37 | 38 | def upsample_fn(x, scale=4): 39 | h, w = x.shape[0:2] 40 | newh, neww = h * scale, w * scale 41 | x = imresize(x, size=(newh, neww), interp='bicubic') 42 | return x 43 | 44 | def datatype(x): 45 | for i in range(len(x)): 46 | x[i] = x[i].astype(np.float32) 47 | return x 48 | 49 | def datarange(x): 50 | for i in range(len(x)): 51 | x[i] = x[i] / 255. 52 | return x 53 | 54 | def transpose(xs): 55 | for i in range(len(xs)): 56 | xs[i] = xs[i].transpose(2, 0, 1) 57 | return xs 58 | 59 | def update_tensorboard(epoch, tb, img_idx, lr, sr, hr): # tb--> tensorboard 60 | [lr, sr, hr] = transpose([lr, sr, hr]) 61 | [lr, sr, hr] = datarange([lr, sr, hr]) # for visualizing correctly 62 | 63 | if epoch == 20: 64 | tb.add_image(str(img_idx) + '_LR', lr, 0) 65 | tb.add_image(str(img_idx) + '_HR', hr, 0) 66 | tb.add_image(str(img_idx) + '_SR', np.clip(sr, 0, 1), epoch) 67 | 68 | def modcrop(im, modulo): 69 | sz = im.shape 70 | h = np.int32(sz[0] / modulo) * modulo 71 | w = np.int32(sz[1] / modulo) * modulo 72 | ims = im[0:h, 0:w, ...] 73 | return ims 74 | 75 | def shave(im, border): 76 | im = im[border:-border, border:-border, ...] 77 | return im 78 | 79 | def compute_psnr(im1,im2): 80 | p=psnr(im1,im2) 81 | return p 82 | 83 | def compute_ssim(im1,im2): 84 | s=ssim(im1,im2,K1=0.01,K2=0.03,gaussian_weights=True,sigma=1.5,use_sample_covariance=False, multichannel=False) 85 | return s 86 | 87 | def quantize(img): 88 | return np.uint8(img.clip(0, 255)) --------------------------------------------------------------------------------