├── README.md
├── checkpoint_x2
├── checkpoint
├── model.ckpt.data-00000-of-00001
├── model.ckpt.index
└── model.ckpt.meta
├── checkpoint_x3
├── checkpoint
├── model.ckpt.data-00000-of-00001
├── model.ckpt.index
└── model.ckpt.meta
├── checkpoint_x4
├── checkpoint
├── model.ckpt.data-00000-of-00001
├── model.ckpt.index
└── model.ckpt.meta
├── config.py
├── model.py
├── test.py
├── train_SR.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # IDN-tensorflow
2 | [[Original Caffe version]](https://github.com/Zheng222/IDN-Caffe)
3 | ## Testing
4 | * Install Tensorflow 1.11, Matlab R2017a
5 | * Download [Test datasets](https://drive.google.com/open?id=1_K6mchwDGOQMIXuBIGrlDA4EAYgbtdmU)
6 | * Modify `config.py` (if you want to test x3 model on Set14, `config.TEST.model_path = 'checkpoint_x3/model.ckpt'` `config.TEST.dataset = 'Set14'`) and `test.py` (`scale = 3`).
7 | * Run testing:
8 | ```bash
9 | python test.py
10 | ```
11 |
12 | ## Training
13 | * Download [Training dataset](https://drive.google.com/open?id=12hOYsMa8t1ErKj6PZA352icsx9mz1TwB)
14 | * Modify `config.py` (if you want to train x4 model, `config.TRAIN.hr_img_path = '/path/to/DIV2K_train_HR/'` `config.TRAIN.checkpoint_dir = 'checkpoint_x4/'` `config.VALID.hr_img_path = '/path/to/DIV2K_valid_HR/'` `config.VALID.lr_img_path = '/path/to/DIV2K_valid_LR_x4/'`) and `train_SR.py` (`scale = 4`)
15 | * Run training:
16 | ```bash
17 | python train_SR.py
18 | ```
19 | ## Note
20 | This TensorFlow version is trained with DIV2K training dataset on RGB channels. Additionally, We modify the upsample layer to subpixel convolution (the original version is transposed convolution).
21 |
22 | ## Results
23 | [Test_results](https://drive.google.com/open?id=1saFhGV8t2ytzRLHE2CaFc4H_UkvJo9KS)
24 |
25 | The following PSNR/SSIMs are evaluated on Matlab R2017a and the code can be referred to [Evaluate_PSNR_SSIM.m](https://github.com/yulunzhang/RCAN/blob/master/RCAN_TestCode/Evaluate_PSNR_SSIM.m).
26 |
27 | | Training dataset | Scale | Set5 | Set14 | B100 | Urban100 |
28 | |:---:|:---:|:---:|:---:|:---:|:---:|
29 | | 291 | ×2 | 37.83 / 0.9600 | 33.30 / 0.9148|32.08 / 0.8985|31.27 / 0.9196|
30 | | DIV2K | ×2 | 37.85 / 0.9598 | 33.58 / 0.9178|32.11 / 0.8989|31.95 / 0.9266|
31 | | 291 | ×3 | 34.11 / 0.9253 | 29.99 / 0.8354|28.95 / 0.8013|27.42 / 0.8359|
32 | | DIV2K | ×3 | 34.24 / 0.9260 | 30.27 / 0.8408|29.03 / 0.8038|27.99 / 0.8489|
33 | | 291 | ×4 | 31.82 / 0.8903 | 28.25 / 0.7730|27.41 / 0.7297|25.41 / 0.7632|
34 | | DIV2K | ×4 | 31.99 / 0.8928 | 28.52 / 0.7794|27.52 / 0.7339|25.92 / 0.7801|
35 |
36 | ## Model Parameters
37 | | Scale| Model size |
38 | |:---:|:---:|
39 | | ×2 | **579,276** |
40 | | ×3 | **587,931** |
41 | | ×4 | **600,048** |
42 | ## Citation
43 |
44 | If you find IDN useful in your research, please consider citing:
45 |
46 | ```
47 | @inproceedings{Hui-IDN-2018,
48 | title={Fast and Accurate Single Image Super-Resolution via Information Distillation Network},
49 | author={Hui, Zheng and Wang, Xiumei and Gao, Xinbo},
50 | booktitle={CVPR},
51 | pages = {723--731},
52 | year={2018}
53 | }
54 | ```
55 |
--------------------------------------------------------------------------------
/checkpoint_x2/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt"
2 | all_model_checkpoint_paths: "model.ckpt"
3 |
--------------------------------------------------------------------------------
/checkpoint_x2/model.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x2/model.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/checkpoint_x2/model.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x2/model.ckpt.index
--------------------------------------------------------------------------------
/checkpoint_x2/model.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x2/model.ckpt.meta
--------------------------------------------------------------------------------
/checkpoint_x3/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt"
2 | all_model_checkpoint_paths: "model.ckpt"
3 |
--------------------------------------------------------------------------------
/checkpoint_x3/model.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x3/model.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/checkpoint_x3/model.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x3/model.ckpt.index
--------------------------------------------------------------------------------
/checkpoint_x3/model.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x3/model.ckpt.meta
--------------------------------------------------------------------------------
/checkpoint_x4/checkpoint:
--------------------------------------------------------------------------------
1 | model_checkpoint_path: "model.ckpt"
2 | all_model_checkpoint_paths: "model.ckpt"
3 |
--------------------------------------------------------------------------------
/checkpoint_x4/model.ckpt.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x4/model.ckpt.data-00000-of-00001
--------------------------------------------------------------------------------
/checkpoint_x4/model.ckpt.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x4/model.ckpt.index
--------------------------------------------------------------------------------
/checkpoint_x4/model.ckpt.meta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng222/IDN-tensorflow/2a3012360a172ef6b8c8fc99caf7d4ad83850c10/checkpoint_x4/model.ckpt.meta
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | config = edict()
4 | config.TRAIN = edict()
5 |
6 | ## Adam
7 | config.TRAIN.batch_size = 16
8 | config.TRAIN.lr_init = 2e-4
9 |
10 | ## Generator
11 | config.TRAIN.n_epoch = 10000
12 | config.TRAIN.lr_decay = 0.5
13 | config.TRAIN.decay_every = 2000
14 |
15 | ## training dataset location
16 | config.TRAIN.hr_img_path = '/data/DIV2K_train_HR/'
17 | config.TRAIN.checkpoint_dir = 'checkpoint_x2/'
18 |
19 | config.VALID = edict()
20 | config.VALID.hr_img_path = '/data/DIV2K_valid_HR/'
21 | config.VALID.lr_img_path = '/data/DIV2K_valid_LR_x2/'
22 | ## test
23 | config.TEST = edict()
24 | config.TEST.model_path = 'checkpoint_x2/model.ckpt'
25 | config.TEST.save_path = 'results'
26 | config.TEST.dataset = 'Set5' # Set5 | Set14 | B100 | Urban100 | manga109 | DIV2K_val
27 |
28 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def IDN(t_image, t_image_bicubic, scale, reuse=False):
4 | t_image_bicubic = tf.identity(t_image_bicubic)
5 | with tf.variable_scope("IDN", reuse=reuse):
6 | conv1 = tf.layers.conv2d(t_image, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name='conv1')
7 | conv2 = tf.layers.conv2d(conv1, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name='conv2')
8 | n = conv2
9 | for i in range(4):
10 | n = distillation(n, name='distill/%i' % i)
11 | output = upsample(n, scale=scale,features=64, name=str(scale)) + t_image_bicubic
12 | return output
13 |
14 | def distillation(x, name=''):
15 | tmp = tf.layers.conv2d(x, 48, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv1')
16 | tmp = GroupConv2d(tmp, act=lrelu, name=name+'/conv2')
17 | tmp = tf.layers.conv2d(tmp, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv3')
18 | tmp1, tmp2 = tf.split(axis=3, num_or_size_splits=[16, 48], value=tmp)
19 | tmp2 = tf.layers.conv2d(tmp2, 64, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv4')
20 | tmp2 = GroupConv2d(tmp2, n_filter=48, act=lrelu, name=name+'/conv5')
21 | tmp2 = tf.layers.conv2d(tmp2, 80, (3, 3), (1, 1), padding='same', activation=lrelu, name=name+'/conv6')
22 | output = tf.concat(axis=3, values=[x, tmp1]) + tmp2
23 | output = tf.layers.conv2d(output, 64, (1, 1), (1, 1), padding='same', activation=lrelu, name=name+'/conv7')
24 | return output
25 |
26 |
27 | def lrelu(x, alpha=0.05):
28 | return tf.maximum(alpha * x, x)
29 |
30 |
31 | def _phase_shift(I, r):
32 | return tf.depth_to_space(I, r)
33 |
34 |
35 | def PS(X, r, color=False):
36 | if color:
37 | Xc = tf.split(X, 3, 3) # tf.split(value, num_or_size_splits, axis=0)
38 | X = tf.concat([_phase_shift(x, r) for x in Xc], 3)
39 | else:
40 | X = _phase_shift(X, r)
41 | return X
42 |
43 | def upsample(x, scale=4, features=32, name=None):
44 | with tf.variable_scope(name):
45 | x = tf.layers.conv2d(x, features, 3, padding='same')
46 | ps_features = 3 * (scale ** 2)
47 | x = tf.layers.conv2d(x, ps_features, 3, padding='same')
48 | x = PS(x, scale, color=True)
49 | return x
50 |
51 | def GroupConv2d(x, n_filter=32, filter_size=(3, 3), strides=(1, 1), n_group=4, act=None, padding='SAME', name=None):
52 | groupConv = lambda i, k: tf.nn.conv2d(i, k, strides=[1, strides[0], strides[1], 1], padding=padding)
53 | channels = int(x.get_shape()[-1])
54 | with tf.variable_scope(name):
55 | We = tf.get_variable(
56 | name='W', shape=[filter_size[0], filter_size[1], channels / n_group, n_filter], trainable=True
57 | )
58 |
59 | if n_group == 1:
60 | outputs = groupConv(x, We)
61 | else:
62 | inputGroups = tf.split(axis=3, num_or_size_splits=n_group, value=x)
63 | weightsGroups = tf.split(axis=3, num_or_size_splits=n_group, value=We)
64 | convGroups = [groupConv(i, k) for i, k in zip(inputGroups, weightsGroups)]
65 |
66 | outputs = tf.concat(axis=3, values=convGroups)
67 |
68 | b = tf.get_variable(
69 | name='b', shape=n_filter, trainable=True
70 | )
71 |
72 | outputs = tf.nn.bias_add(outputs, b, name='bias_add')
73 |
74 | if act:
75 | outputs = lrelu(outputs)
76 | return outputs
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from config import config
2 | import numpy as np
3 | from scipy import misc
4 | import os
5 | import tensorflow as tf
6 | import glob
7 | from model import IDN
8 | import utils
9 | import skimage.color as sc
10 |
11 | os.environ["CUDA_VISIBLE_DEVICES"] = '1'
12 | dataset = config.TEST.dataset
13 | model_path = config.TEST.model_path
14 | saved_path = config.TEST.save_path
15 | scale = 2 # 2 | 3 | 4
16 | rgb = False
17 |
18 | def main():
19 | ## data
20 | print('Loading data...')
21 | test_hr_path = os.path.join('data/', dataset)
22 | if dataset == 'Set5':
23 | ext = '*.bmp'
24 | else:
25 | ext = '*.png'
26 | hr_paths = sorted(glob.glob(os.path.join(test_hr_path, ext)))
27 |
28 | ## model
29 | print('Loading model...')
30 | tensor_lr = tf.placeholder('float32', [1, None, None, 3], name='tensor_lr')
31 | tensor_b = tf.placeholder('float32', [1, None, None, 3], name='tensor_b')
32 |
33 | tensor_sr = IDN(tensor_lr, tensor_b, scale)
34 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
35 | sess.run(tf.global_variables_initializer())
36 | saver = tf.train.Saver()
37 | saver.restore(sess, model_path)
38 |
39 | ## result
40 | save_path = os.path.join(saved_path, dataset+'/x'+str(scale))
41 | if not os.path.exists(save_path):
42 | os.makedirs(save_path)
43 |
44 | psnr_score = 0
45 | for i, _ in enumerate(hr_paths):
46 | print('processing image %d' % (i+1))
47 | img_hr = utils.modcrop(misc.imread(hr_paths[i]), scale)
48 | img_lr = utils.downsample_fn(img_hr, scale=scale)
49 | img_b = utils.upsample_fn(img_lr, scale=scale)
50 | [lr, b] = utils.datatype([img_lr, img_b])
51 | lr = lr[np.newaxis, :, :, :]
52 | b = b[np.newaxis, :, :, :]
53 | [sr] = sess.run([tensor_sr], {tensor_lr: lr, tensor_b: b})
54 | sr = utils.quantize(np.squeeze(sr))
55 | img_sr = utils.shave(sr, scale)
56 | img_hr = utils.shave(img_hr, scale)
57 | if not rgb:
58 | img_pre = utils.quantize(sc.rgb2ycbcr(img_sr)[:, :, 0])
59 | img_label = utils.quantize(sc.rgb2ycbcr(img_hr)[:, :, 0])
60 | else:
61 | img_pre = img_sr
62 | img_label = img_hr
63 | psnr_score += utils.compute_psnr(img_pre, img_label)
64 | misc.imsave(os.path.join(save_path, os.path.basename(hr_paths[i])), sr)
65 |
66 | print('Average PSNR: %.4f' % (psnr_score / len(hr_paths)))
67 | print('Finish')
68 |
69 | if __name__ == '__main__':
70 | main()
--------------------------------------------------------------------------------
/train_SR.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorlayer as tl
3 | from model import IDN
4 | from config import config
5 | import numpy as np
6 | import os
7 | from tensorboardX import SummaryWriter
8 | import utils
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
11 | batch_size = config.TRAIN.batch_size
12 | lr_init = config.TRAIN.lr_init
13 | n_epoch = config.TRAIN.n_epoch
14 | lr_decay = config.TRAIN.lr_decay
15 | decay_every = config.TRAIN.decay_every
16 | checkpoint_dir = config.TRAIN.checkpoint_dir
17 | hr_image_path = config.TRAIN.hr_img_path
18 | scale = 2 # 2 | 3 | 4
19 | eval_every = 10
20 |
21 | with tf.variable_scope(tf.get_variable_scope()):
22 | ## create folders to save trained model
23 | tl.files.exists_or_mkdir(checkpoint_dir)
24 |
25 | ## pre-load data
26 | train_hr_npy = os.path.join(hr_image_path, 'train_hr.npy')
27 | valid_hr_npy = os.path.join(config.VALID.hr_img_path, 'valid_hr.npy')
28 | valid_lr_npy = os.path.join(config.VALID.lr_img_path, 'valid_lr_x{}.npy'.format(scale))
29 |
30 | if os.path.exists(train_hr_npy) and os.path.exists(valid_hr_npy) and os.path.exists(valid_lr_npy):
31 | print('Loading data...')
32 | train_hr_imgs = np.load(train_hr_npy)
33 | valid_hr_imgs = np.load(valid_hr_npy)
34 | valid_lr_imgs = np.load(valid_lr_npy)
35 | else:
36 | print('Creating data binary...')
37 | train_hr_imgs_list = sorted(tl.files.load_file_list(path=hr_image_path, regx='.*.png', printable=False))
38 | valid_hr_imgs_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
39 |
40 | train_hr_imgs = np.array(tl.visualize.read_images(train_hr_imgs_list, path=hr_image_path, n_threads=16))
41 | valid_hr_imgs = np.array(tl.visualize.read_images(valid_hr_imgs_list, path=config.VALID.hr_img_path, n_threads=16))
42 | valid_lr_imgs = tl.prepro.threading_data(valid_hr_imgs, fn=utils.downsample_fn, scale=scale)
43 |
44 | np.save(train_hr_npy, train_hr_imgs)
45 | np.save(valid_hr_npy, valid_hr_imgs)
46 | np.save(valid_lr_npy, valid_lr_imgs)
47 |
48 | ## define model
49 | tensor_lr = tf.placeholder('float32', [None, None, None, 3], name='tensor_lr')
50 | tensor_b = tf.placeholder('float32', [None, None, None, 3], name='tensor_b')
51 | tensor_hr = tf.placeholder('float32', [None, None, None, 3], name='tensor_hr')
52 |
53 | print('Loading model...')
54 | tensor_sr = IDN(tensor_lr, tensor_b, scale)
55 |
56 | ## calculate the number of parameters
57 | total_parameters = 0
58 | for variable in tf.trainable_variables():
59 | variable_parameters = 1
60 | for dim in variable.get_shape():
61 | variable_parameters *= dim.value
62 | total_parameters += variable_parameters
63 | print("Total number of trainable parameters: %d" % total_parameters)
64 |
65 |
66 | ## define loss functions
67 | mae_loss = tf.reduce_mean(tf.losses.absolute_difference(tensor_sr, tensor_hr))
68 |
69 | ## PSNR and SSIM (Evaluation)
70 | PSNR = tf.image.psnr(tensor_sr, tensor_hr, max_val=255)
71 | SSIM = tf.image.ssim_multiscale(tensor_sr, tensor_hr, max_val=255)
72 |
73 | ## create the optimization
74 | g_vars = [v for v in tf.global_variables() if v.name.startswith("IDN")]
75 |
76 | with tf.variable_scope("learning_rate"):
77 | lr_value = tf.Variable(lr_init, trainable=False)
78 | g_optim = tf.train.AdamOptimizer(learning_rate=lr_value).minimize(mae_loss, var_list=g_vars)
79 |
80 | ## restore model
81 |
82 | config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
83 | config.gpu_options.allow_growth = True
84 | sess = tf.Session(config=config)
85 | saver = tf.train.Saver()
86 | sess.run(tf.global_variables_initializer())
87 |
88 | # load from checkpoint if exist
89 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
90 | if ckpt:
91 | print('loaded ' + ckpt.model_checkpoint_path)
92 | saver.restore(sess, ckpt.model_checkpoint_path)
93 |
94 |
95 | ## Tensorboard
96 | writer = SummaryWriter(os.path.join(checkpoint_dir, 'result'))
97 | tf.summary.FileWriter(os.path.join(checkpoint_dir, 'graph'), sess.graph)
98 | best_psnr, best_epoch = 0, 0
99 |
100 | ## training
101 | print("Training network...")
102 | for epoch in range(0, n_epoch + 1):
103 | # update learning rate
104 | if epoch != 0 and (epoch % decay_every == 0):
105 | new_lr_decay = lr_decay ** (epoch // decay_every)
106 | sess.run(tf.assign(lr_value, lr_init * new_lr_decay))
107 | log = " ** new learning rate: %f" % (lr_init * new_lr_decay)
108 | print(log)
109 | elif epoch == 0:
110 | sess.run(tf.assign(lr_value, lr_init))
111 | log = " ** init lr: %f decay_every_init: %d, lr_decay: %f" % (lr_init, decay_every, lr_decay)
112 | print(log)
113 |
114 | index = np.random.permutation(len(train_hr_imgs))
115 | num_batches = len(train_hr_imgs) // batch_size
116 |
117 | total_losses = np.zeros(1)
118 | for i in range(num_batches):
119 | hr = tl.prepro.threading_data(train_hr_imgs[index[i * batch_size: (i+1) * batch_size]], fn=utils.crop_sub_imgs_fn, is_random=True)
120 | lr = tl.prepro.threading_data(hr, fn=utils.downsample_fn, scale=scale)
121 | b = tl.prepro.threading_data(lr, fn=utils.upsample_fn, scale=scale)
122 | [lr, hr, b] = utils.datatype([lr, hr, b])
123 | ## update G
124 | error_mae, _ = sess.run([mae_loss, g_optim], {tensor_lr: lr, tensor_hr: hr, tensor_b: b})
125 | total_losses += error_mae
126 |
127 | avg_loss = total_losses / num_batches
128 | log = "[*] Epoch: [%2d/%2d] mae: %.6f" % \
129 | (epoch, n_epoch, avg_loss[0])
130 | print(log)
131 |
132 | writer.add_scalar('mae_loss', avg_loss[0], epoch)
133 |
134 | ## validating
135 | if (epoch != 0 and epoch % eval_every == 0):
136 | print('Validating...')
137 | val_psnr = 0
138 | val_ssim = 0
139 | for i in range(len(valid_hr_imgs)):
140 | hr = valid_hr_imgs[i]
141 | lr = valid_lr_imgs[i]
142 | b = utils.upsample_fn(lr, scale=scale)
143 | [lr, hr, b] = utils.datatype([lr, hr, b])
144 |
145 | hr_expand = np.expand_dims(hr, axis=0)
146 | lr_expand = np.expand_dims(lr, axis=0)
147 | b_expand = np.expand_dims(b, axis=0)
148 |
149 | psnr, ssim, sr_expand = sess.run([PSNR, SSIM, tensor_sr], {tensor_lr: lr_expand, tensor_hr: hr_expand, tensor_b: b_expand})
150 | sr = np.squeeze(sr_expand)
151 | utils.update_tensorboard(epoch, writer, i, lr, sr, hr)
152 | val_psnr += psnr
153 | val_ssim += ssim
154 |
155 | val_psnr = val_psnr / len(valid_hr_imgs)
156 | val_ssim = val_ssim / len(valid_hr_imgs)
157 | if val_psnr > best_psnr:
158 | best_psnr = val_psnr
159 | best_epoch = epoch
160 | print('Saving new best model, Epoch = %d, PSNR = %.4f' % (best_epoch, best_psnr))
161 |
162 | ## save model
163 | saver.save(sess, os.path.join(checkpoint_dir, 'model.ckpt'))
164 | writer.add_scalar('Validate PSNR', val_psnr, epoch)
165 | writer.add_scalar('Validate SSIM', val_ssim, epoch)
166 |
167 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from tensorlayer.prepro import imresize, crop, rotation, flip_axis
2 | import random
3 | from scipy import misc
4 | import numpy as np
5 | from skimage.measure import compare_psnr as psnr
6 | from skimage.measure import compare_ssim as ssim
7 |
8 | def get_imgs_fn(file_name, path):
9 | return misc.imread(path + file_name, mode='RGB')
10 |
11 | def augment(x, hflip=True, rot=True):
12 | hflip = hflip and random.random() < 0.5
13 | vflip = rot and random.random() < 0.5
14 | rot90 = rot and random.random() < 0.5
15 |
16 | def _augment(img):
17 | if hflip:
18 | img = flip_axis(img, axis=1)
19 | if vflip:
20 | img = flip_axis(img, axis=0)
21 | if rot90:
22 | img = rotation(img, rg=90)
23 | return img
24 | return _augment(x)
25 |
26 | def crop_sub_imgs_fn(x, is_random=True):
27 | x = crop(x, wrg=192, hrg=192, is_random=is_random)
28 | x = augment(x)
29 | return x
30 |
31 | def downsample_fn(x, scale=4):
32 | h, w = x.shape[0:2]
33 | hs, ws = h // scale, w // scale
34 |
35 | x = imresize(x, size=[hs, ws], interp='bicubic')
36 | return x
37 |
38 | def upsample_fn(x, scale=4):
39 | h, w = x.shape[0:2]
40 | newh, neww = h * scale, w * scale
41 | x = imresize(x, size=(newh, neww), interp='bicubic')
42 | return x
43 |
44 | def datatype(x):
45 | for i in range(len(x)):
46 | x[i] = x[i].astype(np.float32)
47 | return x
48 |
49 | def datarange(x):
50 | for i in range(len(x)):
51 | x[i] = x[i] / 255.
52 | return x
53 |
54 | def transpose(xs):
55 | for i in range(len(xs)):
56 | xs[i] = xs[i].transpose(2, 0, 1)
57 | return xs
58 |
59 | def update_tensorboard(epoch, tb, img_idx, lr, sr, hr): # tb--> tensorboard
60 | [lr, sr, hr] = transpose([lr, sr, hr])
61 | [lr, sr, hr] = datarange([lr, sr, hr]) # for visualizing correctly
62 |
63 | if epoch == 20:
64 | tb.add_image(str(img_idx) + '_LR', lr, 0)
65 | tb.add_image(str(img_idx) + '_HR', hr, 0)
66 | tb.add_image(str(img_idx) + '_SR', np.clip(sr, 0, 1), epoch)
67 |
68 | def modcrop(im, modulo):
69 | sz = im.shape
70 | h = np.int32(sz[0] / modulo) * modulo
71 | w = np.int32(sz[1] / modulo) * modulo
72 | ims = im[0:h, 0:w, ...]
73 | return ims
74 |
75 | def shave(im, border):
76 | im = im[border:-border, border:-border, ...]
77 | return im
78 |
79 | def compute_psnr(im1,im2):
80 | p=psnr(im1,im2)
81 | return p
82 |
83 | def compute_ssim(im1,im2):
84 | s=ssim(im1,im2,K1=0.01,K2=0.03,gaussian_weights=True,sigma=1.5,use_sample_covariance=False, multichannel=False)
85 | return s
86 |
87 | def quantize(img):
88 | return np.uint8(img.clip(0, 255))
--------------------------------------------------------------------------------