├── README.md ├── image └── example1.jpg ├── main_train.py ├── network.py ├── ops.py ├── pre_train1.py ├── pre_train2.py ├── pretrain_generator.py ├── test.py ├── train_module.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Real-World Super Resolution with Cycle Generative Adversarial Network and Domain Discriminator(DDGAN) 2 | This is a project of CVPR2020 workshop paper "Unsupervised Real-World Super Resolution with Cycle Generative Adversarial Network and Domain Discriminator", which achieved 5th place in NTIRE2020 Real World Super Resolution Challenge Track 1. 3 | 4 | This code is based on tensorflow implementation of ESRGAN made by hiram64(github.com/hiram64/ESRGAN-tensorflow). Thank you! 5 | 6 | ![image1](./image/example1.jpg) 7 | 8 | 9 | ## Dependencies 10 | Python==3.5.2 11 | Numpy==1.17.2 12 | Scipy==1.2.0 13 | OpenCV==3.4.4.19 14 | Tensorflow-gpu==1.12.0 15 | 16 | 17 | ## Train 18 | ### Stage 1-1 19 | > python pre_train1.py 20 | 21 | ### Stage 1-2 22 | > python pre_train2.py 23 | 24 | ### Stage 2 25 | > python main_train.py 26 | 27 | 28 | ## Evaluate 29 | ### Track 1 30 | > python test.py --data_dir ./data/track1/ --checkpoint_dir ./checkpoint_track1/ --test_result_dir ./test_result_track1 31 | ### Track 2 32 | > python test.py --data_dir ./data/track2/ --checkpoint_dir ./checkpoint_track2/ --test_result_dir ./test_result_track2 33 | 34 | ## Pretrained model 35 | 36 | training with NTIRE2020 real world super-resolution challenge track 1 dataset 37 | 38 | https://www.dropbox.com/s/vv68nsj36tiig4l/checkpoint_track1.zip?dl=0 39 | 40 | training with NTIRE2020 real world super-resolution challenge track 2 dataset 41 | 42 | https://www.dropbox.com/s/xnjzlff0tezk75f/checkpoint_track2.zip?dl=0 43 | -------------------------------------------------------------------------------- /image/example1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-KIM/unsupervised-super-resolution-domain-discriminator/5e8c26988b2ebf311c97d8501e572b5384b48ccf/image/example1.jpg -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import argparse 7 | import tensorflow as tf 8 | import numpy as np 9 | from glob import glob 10 | from train_module import Network, Loss, Optimizer 11 | from utils import create_dirs, log, save_image, generate_batch, generate_testset 12 | from ops import load_vgg19_weight 13 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" 15 | 16 | parser = argparse.ArgumentParser(description='') 17 | 18 | # directory 19 | parser.add_argument('--data_dir', dest='data_dir', default='/sdc1/NTIRE2020/task1/', help='path of the dataset') 20 | parser.add_argument('--train_result_dir', dest='train_result_dir', default='/sdc1/NTIRE2020/result/train_result', help='output directory during training') 21 | parser.add_argument('--train_LR_result_dir', dest='train_LR_result_dir', default='/sdc1/NTIRE2020/result/train_LR_result', help='output directory during training') 22 | parser.add_argument('--valid_result_dir', dest='valid_result_dir', default='/sdc1/NTIRE2020/result/valid_result', help='output directory during training') 23 | parser.add_argument('--valid_LR_result_dir', dest='valid_LR_result_dir', default='/sdc1/NTIRE2020/result/valid_LR_result', help='output directory during training') 24 | 25 | # About Data 26 | parser.add_argument('--crop', dest='crop', default=True, help='patch width') 27 | parser.add_argument('--crop_size', dest='crop_size', type=int, default=56, help='patch height') 28 | parser.add_argument('--stride', dest='stride', type=int, default=56, help='patch stride') 29 | parser.add_argument('--data_augmentation', dest='data_augmentation', default=True, help='') 30 | 31 | # About Network 32 | parser.add_argument('--scale_SR', dest='scale_SR', default=4, help='the scale of super-resolution') 33 | parser.add_argument('--num_repeat_RRDB', dest='num_repeat_RRDB', type=int, default=15, help='the number of RRDB blocks') 34 | parser.add_argument('--residual_scaling', dest='residual_scaling', type=float, default=0.2, help='residual scaling parameter') 35 | parser.add_argument('--initialization_random_seed', dest='initialization_random_seed', default=111, help='random_seed') 36 | parser.add_argument('--perceptual_loss', dest='perceptual_loss', default='VGG19', help='the part of loss function. "VGG19" or "pixel-wise"') 37 | parser.add_argument('--gan_loss_type', dest='gan_loss_type', default='MaGAN', help='the type of GAN loss functions. "RaGAN" or "GAN"') 38 | 39 | # About training 40 | parser.add_argument('--num_iter', dest='num_iter', type=int, default=50000, help='The number of iterations') 41 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=2, help='Mini-batch size') 42 | parser.add_argument('--channel', dest='channel', type=int, default=3, help='Number of input/output image channel') 43 | parser.add_argument('--pretrain_generator', dest='pretrain_generator', type=bool, default=False, help='Whether to pretrain generator') 44 | parser.add_argument('--pretrain_learning_rate', dest='pretrain_learning_rate', type=float, default=2e-4, help='learning rate for pretrain') 45 | parser.add_argument('--pretrain_lr_decay_step', dest='pretrain_lr_decay_step', type=float, default=100000, help='decay by every n iteration') 46 | parser.add_argument('--learning_rate', dest='learning_rate', type=float, default=1e-4, help='learning rate') 47 | parser.add_argument('--weight_initialize_scale', dest='weight_initialize_scale', type=float, default=0.1, help='scale to multiply after MSRA initialization') 48 | parser.add_argument('--HR_image_size', dest='HR_image_size', type=int, default=128, help='Image width and height of LR image. This is should be 1/4 of HR_image_size exactly') 49 | parser.add_argument('--LR_image_size', dest='LR_image_size', type=int, default=32, help='Image width and height of LR image.') 50 | parser.add_argument('--epsilon', dest='epsilon', type=float, default=1e-12, help='used in loss function') 51 | parser.add_argument('--gan_loss_coeff', dest='gan_loss_coeff', type=float, default=1.0, help='used in perceptual loss') 52 | parser.add_argument('--content_loss_coeff', dest='content_loss_coeff', type=float, default=0.01, help='used in content loss') 53 | 54 | # About log 55 | parser.add_argument('--logging', dest='logging', type=bool, default=True, help='whether to record training log') 56 | parser.add_argument('--train_sample_save_freq', dest='train_sample_save_freq', type=int, default=100, help='save samples during training every n iteration') 57 | parser.add_argument('--train_ckpt_save_freq', dest='train_ckpt_save_freq', type=int, default=100, help='save checkpoint during training every n iteration') 58 | parser.add_argument('--train_summary_save_freq', dest='train_summary_save_freq', type=int, default=200, help='save summary during training every n iteration') 59 | parser.add_argument('--pre_train_checkpoint_dir1', dest='pre_train_checkpoint_dir1', type=str, default='./pre_train_checkpoint', help='pre-train checkpoint directory') 60 | parser.add_argument('--pre_train_checkpoint_dir2', dest='pre_train_checkpoint_dir2', type=str, default='./pre_train_checkpoint2', help='pre-train checkpoint directory') 61 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', type=str, default='./checkpoint', help='checkpoint directory') 62 | parser.add_argument('--logdir', dest='logdir', type=str, default='./log3', help='log directory') 63 | 64 | # GPU setting 65 | parser.add_argument('--gpu_dev_num', dest='gpu_dev_num', type=str, default='0,1,2', help='Which GPU to use for multi-GPUs') 66 | 67 | # ----- my args ----- 68 | parser.add_argument('--image_batch_size', dest='image_batch_size', type=int, default=64, help='Mini-batch size') 69 | parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='Total Epochs') 70 | 71 | args = parser.parse_args() 72 | 73 | def set_logger(args): 74 | """set logger for training recording""" 75 | if args.logging: 76 | logfile = '{0}/training_logfile_{1}.log'.format(args.logdir, datetime.now().strftime("%Y%m%d_%H%M%S")) 77 | formatter = '%(levelname)s:%(asctime)s:%(message)s' 78 | logging.basicConfig(level=logging.INFO, filename=logfile, format=formatter, datefmt='%Y-%m-%d %I:%M:%S') 79 | return True 80 | else: 81 | print('No logging is set') 82 | return False 83 | 84 | 85 | def main(): 86 | # make dirs 87 | target_dirs = [args.checkpoint_dir, args.logdir, 88 | args.train_result_dir, args.train_LR_result_dir, 89 | args.valid_result_dir, args.valid_LR_result_dir] 90 | create_dirs(target_dirs) 91 | 92 | # set logger 93 | logflag = set_logger(args) 94 | log(logflag, 'Training script start', 'info') 95 | 96 | NLR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 97 | name='NLR_input') 98 | CLR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 99 | name='CLR_input') 100 | NHR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 101 | name='NHR_input') 102 | CHR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 103 | name='CHR_input') 104 | 105 | # build Generator and Discriminator 106 | network = Network(args, NLR_data=NLR_data,CLR_data=CLR_data, NHR_data=NHR_data,CHR_data=CHR_data) 107 | CLR_C1, NLR_C1, CLR_C2, CHR_C3, NLR_C3, CHR_C4, CLR_I1, CHR_I1, CHR_I2 = network.train_generator() 108 | D_out, Y_out, C_out = network.train_discriminator(CLR_C1, CHR_C3) 109 | 110 | # build loss function 111 | loss = Loss() 112 | gen_loss, g_gen_loss, dis_loss, Y_dis_loss, C_dis_loss = loss.gan_loss(args, NLR_data, CLR_data, NHR_data, CHR_data, 113 | CLR_C1, NLR_C1, CLR_C2, CHR_C3, NLR_C3, CHR_C4, CLR_I1, CHR_I1, CHR_I2, 114 | D_out, Y_out, C_out) 115 | 116 | # define optimizers 117 | global_iter = tf.Variable(0, trainable=False) 118 | dis_var, dis_optimizer, gen_var, gen_optimizer, Y_dis_optimizer, C_dis_optimizer = Optimizer().gan_optimizer( 119 | args, global_iter, dis_loss, gen_loss, Y_dis_loss, C_dis_loss) 120 | 121 | # build summary writer 122 | tr_summary = tf.summary.merge(loss.add_summary_writer()) 123 | fetches = {'dis_optimizer': dis_optimizer,'Y_dis_optimizer': Y_dis_optimizer, 124 | 'C_dis_optimizer': C_dis_optimizer, 'gen_optimizer': gen_optimizer, 125 | 'dis_loss': dis_loss,'Y_dis_loss': Y_dis_loss,'C_dis_loss': C_dis_loss, 126 | 'gen_loss': gen_loss, 'g_gen_loss' : g_gen_loss, 127 | 'CHR_out': CHR_C3, 'CLR_out' : CLR_C1, 'summary' : tr_summary} 128 | """ 129 | fetches = {'dis_optimizer': dis_optimizer,'dis_optimizerLR': dis_optimizerLR, 'gen_optimizer': gen_optimizer, 130 | 'dis_loss': dis_loss, 'gen_loss': gen_loss, 131 | 'CHR_out': CHR_C3, 'CLR_out' : CLR_C1, 'summary' : tr_summary} 132 | """ 133 | 134 | gc.collect() 135 | 136 | config = tf.ConfigProto( 137 | gpu_options=tf.GPUOptions( 138 | allow_growth=True, 139 | visible_device_list=args.gpu_dev_num 140 | ) 141 | ) 142 | 143 | # Start Session 144 | with tf.Session(config=config) as sess: 145 | log(logflag, 'Training ESRGAN starts', 'info') 146 | 147 | sess.run(tf.global_variables_initializer()) 148 | sess.run(global_iter.initializer) 149 | 150 | writer = tf.summary.FileWriter(args.logdir, graph=sess.graph) 151 | 152 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator/Generator1')\ 153 | + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator/Generator2')\ 154 | + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='discriminator') 155 | 156 | pre_saver = tf.train.Saver(var_list=var_list) 157 | pre_saver.restore(sess, tf.train.latest_checkpoint(args.pre_train_checkpoint_dir1)) 158 | 159 | var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='generator/SR') 160 | 161 | pre_saver = tf.train.Saver(var_list=var_list) 162 | pre_saver.restore(sess, tf.train.latest_checkpoint(args.pre_train_checkpoint_dir2)) 163 | 164 | if args.perceptual_loss == 'VGG19': 165 | sess.run(load_vgg19_weight(args)) 166 | 167 | saver = tf.train.Saver(max_to_keep=10) 168 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir)) 169 | 170 | _datapathNLR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/source/train_HR_aug/x4/', '*.png')))) 171 | _datapathNHR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/source/train_HR_aug/x4/', '*.png')))) 172 | _datapathCLR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/target/train_LR_aug/x4/', '*.png')))) 173 | _datapathCHR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/target/train_HR_aug/x4/', '*.png')))) 174 | idxLR = np.random.permutation(len(_datapathNLR)) 175 | datapathNLR = _datapathNLR[idxLR] 176 | datapathNHR = _datapathNHR[idxLR] 177 | idxHR = np.random.permutation(len(_datapathCHR)) 178 | datapathCLR = _datapathCLR[idxHR] 179 | datapathCHR = _datapathCHR[idxHR] 180 | 181 | epoch = 0 182 | counterNLR = 0 183 | counterCLR = 0 184 | 185 | g_loss = 0.0 186 | d_loss = 0.0 187 | steps = 0 188 | psnr_max = 20 189 | while True: 190 | if counterNLR >= len(_datapathNLR): 191 | log(logflag, 'Train Epoch: {0} g.loss : {1} d.loss : {2}'.format( 192 | epoch, g_loss / steps, d_loss / steps), 'info') 193 | idx = np.random.permutation(len(_datapathNLR)) 194 | datapathNLR = _datapathNLR[idx] 195 | datapathNHR = _datapathNHR[idx] 196 | counterNLR = 0 197 | g_loss = 0.0 198 | d_loss = 0.0 199 | steps = 0 200 | epoch += 1 201 | if epoch == 200: 202 | break 203 | if counterCLR >= len(_datapathCHR): 204 | idx = np.random.permutation(len(_datapathCHR)) 205 | datapathCHR = _datapathCHR[idx] 206 | datapathCLR = _datapathCLR[idx] 207 | counterCLR = 0 208 | 209 | dataNLR, dataNHR, dataCLR, dataCHR = generate_batch(datapathNLR[counterNLR:counterNLR + args.image_batch_size], 210 | datapathNHR[counterNLR:counterNLR + args.image_batch_size], 211 | datapathCLR[counterCLR:counterCLR + args.image_batch_size], 212 | datapathCHR[counterCLR:counterCLR + args.image_batch_size], 213 | args) 214 | 215 | counterNLR += args.image_batch_size 216 | counterCLR += args.image_batch_size 217 | 218 | for iteration in range(0, dataCLR.shape[0], args.batch_size): 219 | 220 | _CHR_data = dataCHR[iteration:iteration + args.batch_size] 221 | _CLR_data = dataCLR[iteration:iteration + args.batch_size] 222 | _NLR_data = dataNLR[iteration:iteration + args.batch_size] 223 | _NHR_data = dataNHR[iteration:iteration + args.batch_size] 224 | 225 | feed_dict = { 226 | CHR_data: _CHR_data, 227 | CLR_data : _CLR_data, 228 | NLR_data: _NLR_data, 229 | NHR_data: _NHR_data, 230 | } 231 | # update weights 232 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 233 | current_iter = tf.train.global_step(sess, global_iter) 234 | 235 | g_loss += result['gen_loss'] 236 | d_loss += result['dis_loss'] 237 | steps += 1 238 | 239 | # save summary every n iter 240 | if current_iter % args.train_summary_save_freq == 0: 241 | writer.add_summary(result['summary'], global_step=current_iter) 242 | 243 | # save samples every n iter 244 | if current_iter % 100 == 0 : 245 | log(logflag, 246 | 'Mymodel iteration : {0}, gen_loss : {1}, dis_loss : {2},' 247 | ' Y_dis_loss {3} C_dis_loss {4} g_gen_loss {5}'.format(current_iter, 248 | result['gen_loss'], 249 | result['dis_loss'], 250 | result['Y_dis_loss'], 251 | result['C_dis_loss'], 252 | result['g_gen_loss']), 253 | 'info') 254 | if current_iter % args.train_sample_save_freq == 0: 255 | validpathLR = np.sort( 256 | np.asarray(glob(os.path.join(args.data_dir + '/validation/x/', '*.png')))) 257 | validpathHR = np.sort( 258 | np.asarray(glob(os.path.join(args.data_dir + '/validation/y/', '*.png')))) 259 | psnr_avg = 0.0 260 | for valid_ii in range(100) : 261 | #valid_i = np.random.randint(100) 262 | validLR, validHR = generate_testset(validpathLR[valid_ii], 263 | validpathHR[valid_ii], 264 | args) 265 | 266 | validLR = np.transpose(validLR[:,:,:,np.newaxis],(3,0,1,2)) 267 | validHR = np.transpose(validHR[:,:,:,np.newaxis],(3,0,1,2)) 268 | valid_out, valid_out_LR = sess.run([CHR_C3, CLR_C1], 269 | feed_dict = {NLR_data : validLR, 270 | NHR_data: validHR, 271 | CLR_data : validLR, 272 | CHR_data : validHR}) 273 | 274 | validLR = np.rot90(validLR, 1, axes=(1,2)) 275 | validHR = np.rot90(validHR, 1, axes=(1,2)) 276 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 277 | feed_dict = {NLR_data : validLR, 278 | NHR_data: validHR, 279 | CLR_data : validLR, 280 | CHR_data : validHR}) 281 | valid_out += np.rot90(valid_out90, 3, axes=(1,2)) 282 | valid_out_LR += np.rot90(valid_out_LR90, 3, axes=(1,2)) 283 | 284 | validLR = np.rot90(validLR, 1, axes=(1,2)) 285 | validHR = np.rot90(validHR, 1, axes=(1,2)) 286 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 287 | feed_dict = {NLR_data : validLR, 288 | NHR_data: validHR, 289 | CLR_data : validLR, 290 | CHR_data : validHR}) 291 | valid_out += np.rot90(valid_out90, 2, axes=(1,2)) 292 | valid_out_LR += np.rot90(valid_out_LR90, 2, axes=(1,2)) 293 | 294 | validLR = np.rot90(validLR, 1, axes=(1,2)) 295 | validHR = np.rot90(validHR, 1, axes=(1,2)) 296 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 297 | feed_dict = {NLR_data : validLR, 298 | NHR_data: validHR, 299 | CLR_data : validLR, 300 | CHR_data : validHR}) 301 | valid_out += np.rot90(valid_out90, 1, axes=(1,2)) 302 | valid_out_LR += np.rot90(valid_out_LR90, 1, axes=(1,2)) 303 | 304 | valid_out /=4. 305 | valid_out_LR /=4. 306 | 307 | from utils import de_normalize_image 308 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 309 | F = de_normalize_image(validHR) / 255. 310 | G = de_normalize_image(valid_out) / 255. 311 | 312 | E = F-G 313 | N = np.size(E) 314 | PSNR = 10*np.log10( N / np.sum(E ** 2)) 315 | print(PSNR) 316 | psnr_avg += PSNR / 100 317 | if valid_ii < 5 : 318 | save_image(args, valid_out, 'valid', current_iter+valid_ii, save_max_num=5) 319 | save_image(args, valid_out_LR, 'valid_LR', current_iter + valid_ii, save_max_num=5) 320 | if psnr_avg > psnr_max : 321 | print("max psnr : %f" %psnr_avg) 322 | psnr_max = psnr_avg 323 | saver.save(sess, os.path.join('./best_checkpoint', 'gen'), global_step=current_iter) 324 | #save_image(args, result['gen_HR'], 'My_train', current_iter, save_max_num=5) 325 | #save_image(args, result['gen_out_LR'], 'My_train_LR', current_iter, save_max_num=5) 326 | # save checkpoints 327 | if current_iter % args.train_ckpt_save_freq == 0: 328 | saver.save(sess, os.path.join(args.checkpoint_dir, 'gen'), global_step=current_iter) 329 | 330 | writer.close() 331 | log(logflag, 'Training ESRGAN end', 'info') 332 | log(logflag, 'Training script end', 'info') 333 | 334 | 335 | if __name__ == '__main__': 336 | main() -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import batch_instance_norm, instance_norm 3 | class Generator(object) : 4 | def __init__(self, args) : 5 | self.channel= args.channel 6 | self.n_filter = 64 7 | self.inc_filter = 32 8 | self.num_repeat_RRDB = args.num_repeat_RRDB 9 | self.residual_scaling = args.residual_scaling 10 | self.init_kernel = tf.initializers.he_normal(seed=args.initialization_random_seed) 11 | 12 | def _conv_RRDB(self, x, out_channel, num=None, activate=True) : 13 | with tf.variable_scope('block{0}'.format(num)) : 14 | x =tf.layers.conv2d(x, out_channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, name='conv') 15 | if activate : 16 | #x = instance_norm(x, name='RRDB_IN' + str(num)) #0218 pre-train2 학습 17 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 18 | return x 19 | 20 | def _denseBlock(self, x, num=None) : 21 | with tf.variable_scope('DenseBlock_sub{0}'.format(num)) : 22 | x1 = self._conv_RRDB(x, self.inc_filter, 0) 23 | x2 = self._conv_RRDB(tf.concat([x, x1], axis=3), self.inc_filter, 1) 24 | x3 = self._conv_RRDB(tf.concat([x, x1, x2], axis=3), self.inc_filter, 2) 25 | x4 = self._conv_RRDB(tf.concat([x, x1, x2, x3], axis=3), self.inc_filter, 3) 26 | x5 = self._conv_RRDB(tf.concat([x, x1, x2, x3, x4], axis=3), self.n_filter, 4, activate=False) 27 | return x5 * self.residual_scaling 28 | 29 | def _RRDB(self, x, num=None) : 30 | with tf.variable_scope('RRDB_sub{0}'.format(num)) : 31 | x_branch = tf.identity(x) 32 | 33 | x_branch += self._denseBlock(x_branch, 0) 34 | x_branch += self._denseBlock(x_branch, 1) 35 | x_branch += self._denseBlock(x_branch, 2) 36 | return x + x_branch * self.residual_scaling 37 | 38 | def _upsampling_layer(self, x, num=None) : 39 | x = tf.layers.conv2d_transpose(x, self.n_filter, 3, 2, padding='same', name='upsample_{0}'.format(num)) 40 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 41 | 42 | return x 43 | 44 | def _upsampling_layer_v2(self, x, num=None) : 45 | x = tf.image.resize_images(x, [tf.shape(x)[1]*2, tf.shape(x)[2]*2], align_corners=True) 46 | x = tf.pad(x, [[0,0],[1,1],[1,1],[0,0]], "REFLECT") 47 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='valid', name='upsample_{0}'.format(num)) 48 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 49 | 50 | return x 51 | 52 | def _upsampling_layer_v3(self, x, num=None) : 53 | x = tf.layers.conv2d(x, 64, 3, 1, padding='same', name='upsamplef_{0}'.format(num)) 54 | x = tf.nn.depth_to_space(x, 2, name='pixel_shuffle_{0}'.format(num)) 55 | x = tf.layers.conv2d(x, 64, 3, 1, padding='same', name='upsampleb_{0}'.format(num)) 56 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 57 | 58 | return x 59 | 60 | def _upsampling_layer_v4(self, x, num=None) : 61 | x_b = tf.identity(x) 62 | x_b = tf.image.resize_images(x_b, [tf.shape(x_b)[1]*2, tf.shape(x_b)[2]*2], align_corners=True) 63 | x_b = tf.layers.conv2d(x_b, self.n_filter, 3, 1, padding='same', name='upsampleb_{0}'.format(num)) 64 | x_b = tf.nn.leaky_relu(x_b, alpha=0.2, name='leakyReLU') 65 | 66 | x = tf.layers.conv2d_transpose(x, self.n_filter, 3, 2, padding='same', name='upsamplet_{0}'.format(num)) 67 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 68 | 69 | return 0.1 * x + x_b 70 | 71 | def style_pool(self, x, name) : 72 | with tf.variable_scope(name) : 73 | x_avg, x_std = tf.nn.moments(x, axes=[1,2]) 74 | x_std = tf.abs(tf.sqrt(x_std + 1e-12)) 75 | x_avg = tf.expand_dims(x_avg, axis=2) 76 | x_std = tf.expand_dims(x_std, axis=2) 77 | x_feature = tf.concat((x_avg, x_std), axis=2) 78 | _, width, channel = x_feature.get_shape().as_list() 79 | x_weight = tf.get_variable("W", shape=[width, channel], initializer=self.init_kernel) 80 | x_feature = tf.reduce_sum(x_feature * x_weight, axis=2) 81 | x_feature = tf.expand_dims(x_feature, axis=1) 82 | x_feature = tf.expand_dims(x_feature, axis=2) 83 | #x_feature = instance_norm(x_feature) 84 | x_feature = tf.nn.sigmoid(x_feature) 85 | 86 | x = x * x_feature 87 | return x 88 | 89 | def _mask_layer(self, x, i) : 90 | x_mask = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel) 91 | #x_mask = tf.layers.BatchNormalization(name='batch_norm_0')(x_mask) 92 | x_mask = tf.nn.leaky_relu(x_mask, alpha=0.2) 93 | #x_mask = self.style_pool(x_mask, "style0_{0}".format(i)) 94 | 95 | x_mask = tf.layers.conv2d(x_mask, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel) 96 | #x_mask = tf.layers.BatchNormalization(name='batch_norm_1')(x_mask) 97 | x_mask = tf.nn.leaky_relu(x_mask, alpha=0.2) 98 | #x_mask = self.style_pool(x_mask, "style1_{0}".format(i)) 99 | x_mask = x + 0.2 * x_mask 100 | return x_mask 101 | 102 | def build_G1(self, NLR) : 103 | with tf.variable_scope('Generator1') : 104 | with tf.variable_scope('first_conv') : 105 | x = tf.layers.conv2d(NLR, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 106 | name='conv0') 107 | x = tf.nn.leaky_relu(x, alpha=0.2) 108 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 109 | name='conv1') 110 | x = tf.nn.leaky_relu(x, alpha=0.2) 111 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 112 | name='conv2') 113 | x = tf.nn.leaky_relu(x, alpha=0.2) 114 | 115 | with tf.variable_scope('mask_conv') : 116 | x = self._mask_layer(x, 0) 117 | x = self._mask_layer(x, 1) 118 | x = self._mask_layer(x, 2) 119 | x = self._mask_layer(x, 3) 120 | with tf.variable_scope('last_conv') : 121 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 122 | name='conv0') 123 | x = tf.nn.leaky_relu(x, alpha=0.2) 124 | CLR = tf.layers.conv2d(x, self.channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, 125 | name='conv1') 126 | return CLR 127 | 128 | def build_G2(self, CLR) : 129 | with tf.variable_scope('Generator2') : 130 | with tf.variable_scope('first_conv') : 131 | x = tf.layers.conv2d(CLR, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 132 | name='conv0') 133 | x = tf.nn.leaky_relu(x, alpha=0.2) 134 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 135 | name='conv1') 136 | x = tf.nn.leaky_relu(x, alpha=0.2) 137 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 138 | name='conv2') 139 | x = tf.nn.leaky_relu(x, alpha=0.2) 140 | 141 | with tf.variable_scope('mask_conv') : 142 | x = self._mask_layer(x, 0) 143 | x = self._mask_layer(x, 1) 144 | x = self._mask_layer(x, 2) 145 | x = self._mask_layer(x, 3) 146 | with tf.variable_scope('last_conv') : 147 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 148 | name='conv0') 149 | x = tf.nn.leaky_relu(x, alpha=0.2) 150 | NLR = tf.layers.conv2d(x, self.channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, 151 | name='conv1') 152 | return NLR 153 | 154 | def build_G3(self, CHR) : 155 | with tf.variable_scope('Generator3') : 156 | with tf.variable_scope('first_conv') : 157 | x = tf.layers.conv2d(CHR, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 158 | name='conv0') 159 | x = tf.nn.leaky_relu(x, alpha=0.2) 160 | x = tf.layers.conv2d(x, self.n_filter, 3, 2, padding='same', kernel_initializer=self.init_kernel, 161 | name='conv1') 162 | x = tf.nn.leaky_relu(x, alpha=0.2) 163 | x = tf.layers.conv2d(x, self.n_filter, 3, 2, padding='same', kernel_initializer=self.init_kernel, 164 | name='conv2') 165 | x = tf.nn.leaky_relu(x, alpha=0.2) 166 | with tf.variable_scope('mask_conv') : 167 | x = self._mask_layer(x, 0) 168 | x = self._mask_layer(x, 1) 169 | x = self._mask_layer(x, 2) 170 | x = self._mask_layer(x, 3) 171 | with tf.variable_scope('last_conv') : 172 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 173 | name='conv0') 174 | x = tf.nn.leaky_relu(x, alpha=0.2) 175 | NLR = tf.layers.conv2d(x, self.channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, 176 | name='conv1') 177 | return NLR 178 | 179 | def build_SR(self, CLR) : 180 | with tf.variable_scope('SR') : 181 | with tf.variable_scope('first_conv') : 182 | x = tf.layers.conv2d(CLR, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 183 | name='conv') 184 | x = tf.nn.leaky_relu(x, alpha=0.2) 185 | 186 | with tf.variable_scope('RRDB'): 187 | x_branch = tf.identity(x) 188 | x_branch = tf.layers.conv2d(x_branch, self.n_filter, 3, 1, padding='same', 189 | kernel_initializer=self.init_kernel, name='conv') 190 | x_branch = tf.nn.leaky_relu(x_branch, alpha=0.2, name='leakyReLU') 191 | for i in range(self.num_repeat_RRDB): 192 | x_branch = self._RRDB(x_branch, i) 193 | 194 | x_branch = tf.layers.conv2d(x_branch, self.n_filter, 3, 1, padding='same', 195 | kernel_initializer=self.init_kernel, name='trunk_conv') 196 | x += x_branch 197 | 198 | with tf.variable_scope('Upsampling'): 199 | x = self._upsampling_layer_v4(x, 1) 200 | x = self._upsampling_layer_v4(x, 2) 201 | 202 | with tf.variable_scope('last_conv'): 203 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 204 | name='conv_1') 205 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 206 | CHR = tf.layers.conv2d(x, self.channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, 207 | name='conv_2') 208 | # x = tf.nn.tanh(x) 209 | return CHR 210 | 211 | def build_downsample(self, x) : 212 | with tf.variable_scope('mask_conv') : 213 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 214 | name='conv') 215 | x = tf.nn.leaky_relu(x, alpha=0.2) 216 | 217 | with tf.variable_scope('RRDB') : 218 | x_branch = tf.identity(x) 219 | x_branch =tf.layers.conv2d(x_branch, self.n_filter, 3, 1, padding='same', 220 | kernel_initializer=self.init_kernel, name='conv') 221 | x_branch = tf.nn.leaky_relu(x_branch, alpha=0.2, name='leakyReLU') 222 | for i in range(self.num_repeat_RRDB) : 223 | x_branch = self._RRDB(x_branch, i) 224 | 225 | x_branch = tf.layers.conv2d(x_branch, self.n_filter, 3, 1, padding='same', 226 | kernel_initializer=self.init_kernel,name='trunk_conv') 227 | x += x_branch 228 | 229 | with tf.variable_scope('Upsampling') : 230 | x = self._upsampling_layer_v2(x, 1) 231 | x = self._upsampling_layer_v2(x, 2) 232 | 233 | with tf.variable_scope('last_conv') : 234 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 235 | name='conv_1') 236 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 237 | x = tf.layers.conv2d(x, self.channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, 238 | name='conv_2') 239 | #x = tf.nn.tanh(x) 240 | return x 241 | 242 | class Discriminator(object) : 243 | def __init__(self, args) : 244 | self.channel = args.channel 245 | self.n_filter = 64 246 | self.init_kernel = tf.initializers.he_normal(seed=args.initialization_random_seed) 247 | 248 | def _conv_block(self, x, out_channel, num=None): 249 | with tf.variable_scope('block_{0}'.format(num)): 250 | x = tf.layers.conv2d(x, out_channel, 3, 1, padding='same', use_bias=False, 251 | kernel_initializer=self.init_kernel, name='conv_1') 252 | #x = tf.layers.BatchNormalization(name='batch_norm_1')(x) 253 | x = batch_instance_norm(x, name='dconvBIN1_'+str(num)) 254 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 255 | 256 | x = tf.layers.conv2d(x, out_channel, 4, 2, padding='same', use_bias=False, 257 | kernel_initializer=self.init_kernel, name='conv_2') 258 | #x = tf.layers.BatchNormalization(name='batch_norm_2')(x) 259 | x = batch_instance_norm(x, name='dconvBIN2_'+str(num)) 260 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2') 261 | 262 | return x 263 | 264 | def build(self, x): 265 | with tf.variable_scope('first_conv'): 266 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', use_bias=False, 267 | kernel_initializer=self.init_kernel, name='conv_1') 268 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 269 | x = tf.layers.conv2d(x, self.n_filter, 4, 2, padding='same', use_bias=False, 270 | kernel_initializer=self.init_kernel, name='conv_2') 271 | #x = tf.layers.BatchNormalization(name='batch_norm_1')(x) 272 | x = batch_instance_norm(x, name='BIN1') 273 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2') 274 | 275 | with tf.variable_scope('conv_block'): 276 | x = self._conv_block(x, self.n_filter * 1, 0) 277 | x = self._conv_block(x, self.n_filter * 1, 1) 278 | x = self._conv_block(x, self.n_filter * 2, 2) 279 | x = self._conv_block(x, self.n_filter * 4, 3) 280 | 281 | with tf.variable_scope('full_connected'): 282 | x = tf.reduce_mean(x, axis=(1,2)) 283 | #x = tf.layers.flatten(x) 284 | x = tf.layers.dense(x, 100, name='fully_connected_1') 285 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 286 | x = tf.layers.dense(x, 1, name='fully_connected_2') 287 | 288 | return x 289 | 290 | class Discriminator_color(object) : 291 | def __init__(self, args) : 292 | self.channel = args.channel 293 | self.n_filter = 64 294 | self.init_kernel = tf.initializers.he_normal(seed=args.initialization_random_seed) 295 | 296 | def _conv_block(self, x, out_channel, num=None): 297 | with tf.variable_scope('block_{0}'.format(num)): 298 | x = tf.layers.conv2d(x, out_channel, 3, 1, padding='same', use_bias=False, 299 | kernel_initializer=self.init_kernel, name='conv_1') 300 | x = tf.layers.BatchNormalization(name='batch_norm_1')(x) 301 | #x = instance_norm(x, name='dconvIN1_'+str(num)) 302 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 303 | 304 | x = tf.layers.conv2d(x, out_channel, 4, 2, padding='same', use_bias=False, 305 | kernel_initializer=self.init_kernel, name='conv_2') 306 | x = tf.layers.BatchNormalization(name='batch_norm_2')(x) 307 | #x = instance_norm(x, name='dconvIN2_'+str(num)) 308 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2') 309 | 310 | return x 311 | 312 | def build(self, x): 313 | with tf.variable_scope('first_conv'): 314 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', use_bias=False, 315 | kernel_initializer=self.init_kernel, name='conv_1') 316 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 317 | x = tf.layers.conv2d(x, self.n_filter, 4, 2, padding='same', use_bias=False, 318 | kernel_initializer=self.init_kernel, name='conv_2') 319 | x = tf.layers.BatchNormalization(name='batch_norm_1')(x) 320 | #x = instance_norm(x, name='dconvIN0') 321 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2') 322 | 323 | with tf.variable_scope('conv_block'): 324 | x = self._conv_block(x, self.n_filter , 0) 325 | x = self._conv_block(x, self.n_filter, 1) 326 | x = self._conv_block(x, self.n_filter * 2, 2) 327 | x = self._conv_block(x, self.n_filter * 4, 3) 328 | 329 | with tf.variable_scope('full_connected'): 330 | x = tf.reduce_mean(x, axis=(1,2)) 331 | #x = tf.layers.flatten(x) 332 | x = tf.layers.dense(x, 100, name='fully_connected_1') 333 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 334 | x = tf.layers.dense(x, 1, name='fully_connected_2') 335 | 336 | return x 337 | 338 | class Perceptual_VGG19(object): 339 | """the definition of VGG19. This network is used for constructing perceptual loss""" 340 | @staticmethod 341 | def build(x): 342 | # Block 1 343 | x = tf.layers.conv2d(x, 64, (3, 3), activation='relu', padding='same', name='block1_conv1') 344 | x = tf.layers.conv2d(x, 64, (3, 3), activation='relu', padding='same', name='block1_conv2') 345 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block1_pool') 346 | 347 | # Block 2 348 | x = tf.layers.conv2d(x, 128, (3, 3), activation='relu', padding='same', name='block2_conv1') 349 | x = tf.layers.conv2d(x, 128, (3, 3), activation='relu', padding='same', name='block2_conv2') 350 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block2_pool') 351 | 352 | # Block 3 353 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv1') 354 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv2') 355 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv3') 356 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv4') 357 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block3_pool') 358 | 359 | # Block 4 360 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv1') 361 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv2') 362 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv3') 363 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv4') 364 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block4_pool') 365 | 366 | # Block 5 367 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv1') 368 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv2') 369 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv3') 370 | x = tf.layers.conv2d(x, 512, (3, 3), activation=None, padding='same', name='block5_conv4') 371 | 372 | return x 373 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras.applications.vgg19 import VGG19 5 | 6 | 7 | def instance_norm(input, name="instance_norm"): 8 | with tf.variable_scope(name): 9 | depth = input.get_shape()[3] 10 | scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)) 11 | offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0)) 12 | mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True) 13 | epsilon = 1e-5 14 | inv = tf.rsqrt(variance + epsilon) 15 | normalized = (input-mean)*inv 16 | return scale*normalized + offset 17 | 18 | def batch_instance_norm(input, name="batch_instance_norm"): 19 | with tf.variable_scope(name): 20 | ch = input.get_shape()[-1] 21 | epsilon = 1e-5 22 | 23 | batch_mean, batch_sigma = tf.nn.moments(input, axes=[0,1,2], keep_dims=True) 24 | x_batch = (input - batch_mean) / (tf.sqrt(batch_sigma + epsilon)) 25 | 26 | ins_mean, ins_sigma = tf.nn.moments(input, axes=[1,2], keep_dims=True) 27 | x_ins = (input - ins_mean) / (tf.sqrt(ins_sigma + epsilon)) 28 | 29 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), 30 | constraint=lambda x : tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 31 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0)) 32 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0)) 33 | 34 | x_hat = rho * x_batch + (1 - rho) * x_ins 35 | x_hat = x_hat * gamma + beta 36 | 37 | return x_hat 38 | 39 | def scale_initialization(weights, FLAGS): 40 | return [tf.assign(weight, weight * FLAGS.weight_initialize_scale) for weight in weights] 41 | 42 | 43 | def _transfer_vgg19_weight(FLAGS, weight_dict): 44 | from_model = VGG19(include_top=False, weights='imagenet', input_tensor=None, 45 | input_shape=(FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel)) 46 | 47 | fetch_weight = [] 48 | 49 | for layer in from_model.layers: 50 | if 'conv' in layer.name: 51 | W, b = layer.get_weights() 52 | 53 | fetch_weight.append( 54 | tf.assign(weight_dict['loss_generator/perceptual_vgg19/{}/kernel'.format(layer.name)], W) 55 | ) 56 | fetch_weight.append( 57 | tf.assign(weight_dict['loss_generator/perceptual_vgg19/{}/bias'.format(layer.name)], b) 58 | ) 59 | 60 | return fetch_weight 61 | 62 | 63 | def load_vgg19_weight(FLAGS): 64 | vgg_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='loss_generator/perceptual_vgg19') 65 | 66 | assert len(vgg_weight) > 0, 'No VGG19 weight was collected. The target scope might be wrong.' 67 | 68 | weight_dict = {} 69 | for weight in vgg_weight: 70 | weight_dict[weight.name.rsplit(':', 1)[0]] = weight 71 | 72 | return _transfer_vgg19_weight(FLAGS, weight_dict) 73 | 74 | 75 | def extract_weight(network_vars): 76 | weight_dict = OrderedDict() 77 | 78 | for weight in network_vars: 79 | weight_dict[weight.name] = weight.eval() 80 | 81 | return weight_dict 82 | 83 | 84 | def interpolate_weight(FLAGS, pretrain_weight): 85 | fetch_weight = [] 86 | alpha = FLAGS.interpolation_param 87 | 88 | for name, pre_weight in pretrain_weight.items(): 89 | esrgan_weight = tf.get_default_graph().get_tensor_by_name(name) 90 | 91 | assert pre_weight.shape == esrgan_weight.shape, 'The shape of weights does not match' 92 | 93 | fetch_weight.append(tf.assign(esrgan_weight, (1 - alpha) * pre_weight + alpha * esrgan_weight)) 94 | 95 | return fetch_weight -------------------------------------------------------------------------------- /pre_train1.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import gc 3 | import logging 4 | import os 5 | import argparse 6 | import tensorflow as tf 7 | import numpy as np 8 | from glob import glob 9 | 10 | from pretrain_generator import pretrain_generator1 , test_pretrain_generator 11 | from train_module import Network, Loss, Optimizer 12 | from utils import create_dirs, log, normalize_images, save_image, load_npz_data, load_and_save_data, generate_batch, generate_testset 13 | from utils import build_filter, apply_bicubic_downsample 14 | 15 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 17 | 18 | parser = argparse.ArgumentParser(description='') 19 | # ----- vanilla ------ 20 | # About Data 21 | parser.add_argument('--data_dir', dest='data_dir', default='/sdc1/NTIRE2020/task1/', help='path of the dataset') 22 | parser.add_argument('--crop', dest='crop', default=True, help='patch width') 23 | parser.add_argument('--crop_size', dest='crop_size', type=int, default=64, help='patch height') 24 | parser.add_argument('--stride', dest='stride', type=int, default=64, help='patch stride') 25 | 26 | # About Network 27 | parser.add_argument('--scale_SR', dest='scale_SR', default=4, help='the scale of super-resolution') 28 | parser.add_argument('--num_repeat_RRDB', dest='num_repeat_RRDB', type=int, default=15, help='the number of RRDB blocks') 29 | parser.add_argument('--residual_scaling', dest='residual_scaling', type=float, default=0.2, help='residual scaling parameter') 30 | parser.add_argument('--initialization_random_seed', dest='initialization_random_seed', default=111, help='random_seed') 31 | parser.add_argument('--perceptual_loss', dest='perceptual_loss', default='VGG19', help='the part of loss function. "VGG19" or "pixel-wise"') 32 | parser.add_argument('--gan_loss_type', dest='gan_loss_type', default='MaGAN', help='the type of GAN loss functions. "RaGAN" or "GAN"') 33 | 34 | # About training 35 | parser.add_argument('--num_iter', dest='num_iter', type=int, default=50000, help='The number of iterations') 36 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=16, help='Mini-batch size') 37 | parser.add_argument('--channel', dest='channel', type=int, default=3, help='Number of input/output image channel') 38 | parser.add_argument('--pretrain_generator', dest='pretrain_generator', type=bool, default=True, help='Whether to pretrain generator') 39 | parser.add_argument('--pretrain_learning_rate', dest='pretrain_learning_rate', type=float, default=2e-4, help='learning rate for pretrain') 40 | parser.add_argument('--pretrain_lr_decay_step', dest='pretrain_lr_decay_step', type=float, default=20000, help='decay by every n iteration') 41 | parser.add_argument('--learning_rate', dest='learning_rate', type=float, default=1e-4, help='learning rate') 42 | parser.add_argument('--weight_initialize_scale', dest='weight_initialize_scale', type=float, default=0.1, help='scale to multiply after MSRA initialization') 43 | parser.add_argument('--HR_image_size', dest='HR_image_size', type=int, default=128, help='Image width and height of LR image. This is should be 1/4 of HR_image_size exactly') 44 | parser.add_argument('--LR_image_size', dest='LR_image_size', type=int, default=32, help='Image width and height of LR image.') 45 | parser.add_argument('--epsilon', dest='epsilon', type=float, default=1e-12, help='used in loss function') 46 | parser.add_argument('--gan_loss_coeff', dest='gan_loss_coeff', type=float, default=0.1, help='used in perceptual loss') 47 | parser.add_argument('--content_loss_coeff', dest='content_loss_coeff', type=float, default=0.01, help='used in content loss') 48 | 49 | # About log 50 | parser.add_argument('--logging', dest='logging', type=bool, default=True, help='whether to record training log') 51 | parser.add_argument('--train_sample_save_freq', dest='train_sample_save_freq', type=int, default=500, help='save samples during training every n iteration') 52 | parser.add_argument('--train_ckpt_save_freq', dest='train_ckpt_save_freq', type=int, default=500, help='save checkpoint during training every n iteration') 53 | parser.add_argument('--train_summary_save_freq', dest='train_summary_save_freq', type=int, default=200, help='save summary during training every n iteration') 54 | 55 | # GPU setting 56 | parser.add_argument('--gpu_dev_num', dest='gpu_dev_num', type=str, default='0,1', help='Which GPU to use for multi-GPUs') 57 | 58 | # ----- my args ----- 59 | parser.add_argument('--image_batch_size', dest='image_batch_size', type=int, default=64, help='Mini-batch size') 60 | parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='Total Epochs') 61 | 62 | parser.add_argument('--logdir', dest='logdir', type=str, default='./log', help='log directory') 63 | parser.add_argument('--pre_train_checkpoint_dir', dest='pre_train_checkpoint_dir', type=str, default='./pre_train_checkpoint', help='pre-train checkpoint directory') 64 | parser.add_argument('--pre_valid_result_dir', dest='pre_valid_result_dir', default='/sdc1/NTIRE2020/pre_result1/valid_result', help='output directory during training') 65 | parser.add_argument('--pre_valid_LR_result_dir', dest='pre_valid_LR_result_dir', default='/sdc1/NTIRE2020/pre_result1/valid_LR_result', help='output directory during training') 66 | parser.add_argument('--pre_train_result_dir', dest='pre_train_result_dir', default='./sdc1/NTIRE2020/pre_result1/pre_train_result', help='output directory during training') 67 | parser.add_argument('--pre_raw_result_dir', dest='pre_raw_result_dir', default='./sdc1/NTIRE2020/pre_result1/pre_raw_result', help='output directory during training') 68 | 69 | args = parser.parse_args() 70 | 71 | def set_logger(args): 72 | """set logger for training recordinbg""" 73 | if args.logging: 74 | logfile = '{0}/training_logfile_{1}.log'.format(args.logdir, datetime.now().strftime("%Y%m%d_%H%M%S")) 75 | formatter = '%(levelname)s:%(asctime)s:%(message)s' 76 | logging.basicConfig(level=logging.INFO, filename=logfile, format=formatter, datefmt='%Y-%m-%d %I:%M:%S') 77 | return True 78 | else: 79 | print('No logging is set') 80 | return False 81 | 82 | 83 | def main(): 84 | # make dirs 85 | target_dirs = [args.pre_train_checkpoint_dir,args.logdir, 86 | args.pre_train_result_dir, args.pre_raw_result_dir, args.pre_valid_result_dir, 87 | args.pre_valid_LR_result_dir] 88 | create_dirs(target_dirs) 89 | 90 | # set logger 91 | logflag = set_logger(args) 92 | log(logflag, 'Training script start', 'info') 93 | 94 | # set logger 95 | logflag = set_logger(args) 96 | log(logflag, 'Training script start', 'info') 97 | 98 | # pre-train generator with pixel-wise loss and save the trained model 99 | if args.pretrain_generator: 100 | pretrain_generator1(args, logflag) 101 | #test_pretrain_generator(args, logflag) 102 | tf.reset_default_graph() 103 | gc.collect() 104 | else: 105 | log(logflag, 'Pre-train : Pre-train skips and an existing trained model will be used', 'info') 106 | 107 | if __name__ == '__main__': 108 | main() -------------------------------------------------------------------------------- /pre_train2.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import gc 3 | import logging 4 | import os 5 | import argparse 6 | import tensorflow as tf 7 | import numpy as np 8 | from glob import glob 9 | 10 | from pretrain_generator import pretrain_generator2 , test_pretrain_generator 11 | from train_module import Network, Loss, Optimizer 12 | from utils import create_dirs, log, normalize_images, save_image, load_npz_data, load_and_save_data, generate_batch, generate_testset 13 | from utils import build_filter, apply_bicubic_downsample 14 | 15 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 17 | 18 | parser = argparse.ArgumentParser(description='') 19 | # ----- vanilla ------ 20 | # About Data 21 | parser.add_argument('--data_dir', dest='data_dir', default='/sdc1/NTIRE2020/task1/', help='path of the dataset') 22 | # HR_data_dir LR_data_dir npz_data_dir HR_npz_filename LR_npz_filename save_data 23 | parser.add_argument('--crop', dest='crop', default=True, help='patch width') 24 | parser.add_argument('--crop_size', dest='crop_size', type=int, default=64, help='patch height') 25 | parser.add_argument('--stride', dest='stride', type=int, default=64, help='patch stride') 26 | #parser.add_argument('--data_augmentation', dest='data_augmentation', default=True, help='') 27 | 28 | # About Network 29 | parser.add_argument('--scale_SR', dest='scale_SR', default=4, help='the scale of super-resolution') 30 | parser.add_argument('--num_repeat_RRDB', dest='num_repeat_RRDB', type=int, default=15, help='the number of RRDB blocks') 31 | parser.add_argument('--residual_scaling', dest='residual_scaling', type=float, default=0.2, help='residual scaling parameter') 32 | parser.add_argument('--initialization_random_seed', dest='initialization_random_seed', default=111, help='random_seed') 33 | #parser.add_argument('--perceptual_loss', dest='perceptual_loss', default='VGG19', help='the part of loss function. "VGG19" or "pixel-wise"') 34 | parser.add_argument('--gan_loss_type', dest='gan_loss_type', default='RaGAN', help='the type of GAN loss functions. "RaGAN" or "GAN"') 35 | 36 | # About training 37 | parser.add_argument('--num_iter', dest='num_iter', type=int, default=50000, help='The number of iterations') 38 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=8, help='Mini-batch size') 39 | parser.add_argument('--channel', dest='channel', type=int, default=3, help='Number of input/output image channel') 40 | parser.add_argument('--pretrain_generator', dest='pretrain_generator', type=bool, default=True, help='Whether to pretrain generator') 41 | parser.add_argument('--pretrain_learning_rate', dest='pretrain_learning_rate', type=float, default=2e-4, help='learning rate for pretrain') 42 | parser.add_argument('--pretrain_lr_decay_step', dest='pretrain_lr_decay_step', type=float, default=20000, help='decay by every n iteration') 43 | parser.add_argument('--learning_rate', dest='learning_rate', type=float, default=1e-4, help='learning rate') 44 | parser.add_argument('--weight_initialize_scale', dest='weight_initialize_scale', type=float, default=0.1, help='scale to multiply after MSRA initialization') 45 | parser.add_argument('--HR_image_size', dest='HR_image_size', type=int, default=128, help='Image width and height of LR image. This is should be 1/4 of HR_image_size exactly') 46 | parser.add_argument('--LR_image_size', dest='LR_image_size', type=int, default=32, help='Image width and height of LR image.') 47 | parser.add_argument('--epsilon', dest='epsilon', type=float, default=1e-12, help='used in loss function') 48 | parser.add_argument('--gan_loss_coeff', dest='gan_loss_coeff', type=float, default=0.1, help='used in perceptual loss') 49 | parser.add_argument('--content_loss_coeff', dest='content_loss_coeff', type=float, default=0.01, help='used in content loss') 50 | 51 | # About log 52 | parser.add_argument('--logging', dest='logging', type=bool, default=True, help='whether to record training log') 53 | parser.add_argument('--train_sample_save_freq', dest='train_sample_save_freq', type=int, default=500, help='save samples during training every n iteration') 54 | parser.add_argument('--train_ckpt_save_freq', dest='train_ckpt_save_freq', type=int, default=500, help='save checkpoint during training every n iteration') 55 | parser.add_argument('--train_summary_save_freq', dest='train_summary_save_freq', type=int, default=200, help='save summary during training every n iteration') 56 | 57 | # GPU setting 58 | parser.add_argument('--gpu_dev_num', dest='gpu_dev_num', type=str, default='0', help='Which GPU to use for multi-GPUs') 59 | 60 | # ----- my args ----- 61 | parser.add_argument('--image_batch_size', dest='image_batch_size', type=int, default=64, help='Mini-batch size') 62 | parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='Total Epochs') 63 | 64 | parser.add_argument('--logdir', dest='logdir', type=str, default='./log', help='log directory') 65 | parser.add_argument('--pre_train_checkpoint_dir', dest='pre_train_checkpoint_dir', type=str, default='./pre_train_checkpoint2', help='pre-train checkpoint directory') 66 | parser.add_argument('--pre_valid_result_dir', dest='pre_valid_result_dir', default='/sdc1/NTIRE2020/pre_result2/valid_result', help='output directory during training') 67 | parser.add_argument('--pre_valid_LR_result_dir', dest='pre_valid_LR_result_dir', default='/sdc1/NTIRE2020/pre_result2/valid_LR_result', help='output directory during training') 68 | parser.add_argument('--pre_train_result_dir', dest='pre_train_result_dir', default='./sdc1/NTIRE2020/pre_result2/pre_train_result', help='output directory during training') 69 | parser.add_argument('--pre_raw_result_dir', dest='pre_raw_result_dir', default='./sdc1/NTIRE2020/pre_result2/pre_raw_result', help='output directory during training') 70 | 71 | args = parser.parse_args() 72 | 73 | def set_logger(args): 74 | """set logger for training recording""" 75 | if args.logging: 76 | logfile = '{0}/training_logfile_{1}.log'.format(args.logdir, datetime.now().strftime("%Y%m%d_%H%M%S")) 77 | formatter = '%(levelname)s:%(asctime)s:%(message)s' 78 | logging.basicConfig(level=logging.INFO, filename=logfile, format=formatter, datefmt='%Y-%m-%d %I:%M:%S') 79 | return True 80 | else: 81 | print('No logging is set') 82 | return False 83 | 84 | 85 | def main(): 86 | # make dirs 87 | target_dirs = [args.pre_train_checkpoint_dir,args.logdir, 88 | args.pre_train_result_dir, args.pre_raw_result_dir, args.pre_valid_result_dir, 89 | args.pre_valid_LR_result_dir] 90 | create_dirs(target_dirs) 91 | 92 | # set logger 93 | logflag = set_logger(args) 94 | log(logflag, 'Training script start', 'info') 95 | 96 | # set logger 97 | logflag = set_logger(args) 98 | log(logflag, 'Training script start', 'info') 99 | 100 | # pre-train generator with pixel-wise loss and save the trained model 101 | if args.pretrain_generator: 102 | pretrain_generator2(args, logflag) 103 | #test_pretrain_generator(args, logflag) 104 | tf.reset_default_graph() 105 | gc.collect() 106 | else: 107 | log(logflag, 'Pre-train : Pre-train skips and an existing trained model will be used', 'info') 108 | 109 | if __name__ == '__main__': 110 | main() -------------------------------------------------------------------------------- /pretrain_generator.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import math 4 | import time 5 | from sklearn.utils import shuffle 6 | import tensorflow as tf 7 | import numpy as np 8 | from glob import glob 9 | from ops import scale_initialization 10 | from train_module import Network, Loss, Optimizer 11 | from utils import log, normalize_images, save_image, generate_testset, generate_pretrain_batch, generate_pretrain_batch2 12 | from ops import load_vgg19_weight 13 | 14 | def pretrain_generator1(args, logflag): 15 | """pre-train Low Resolution CycleGAN""" 16 | log(logflag, 'Pre-train : Process start', 'info') 17 | 18 | NLR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 19 | name='NLR_input') 20 | CLR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 21 | name='CLR_input') 22 | 23 | # build Generator 24 | network = Network(args,NLR_data= NLR_data,CLR_data=CLR_data) 25 | CLR_N, NLR_NC, NLR_C, CLR_CN, CLR_C = network.pretrain_generator_LR() 26 | dis_out_realLR, dis_out_fakeLR, dis_out_noisyLR, Y_out, C_out = network.pretrain_discriminator_LR(CLR_N) 27 | print("Built generator!") 28 | 29 | # build loss function 30 | loss = Loss() 31 | gen_loss, dis_loss, Y_loss, C_loss = loss.pretrain_loss(args, NLR_data, CLR_data, NLR_NC, CLR_CN, CLR_C, dis_out_realLR, 32 | dis_out_fakeLR, dis_out_noisyLR, Y_out, C_out) 33 | print("Built loss function!") 34 | 35 | # build optimizer 36 | global_iter = tf.Variable(0, trainable=False) 37 | pre_gen_var, pre_gen_optimizer, pre_dis_optimizer, Y_optimizer, C_optimizer = \ 38 | Optimizer().pretrain_optimizer(args, global_iter, gen_loss, dis_loss, Y_loss, C_loss) 39 | print("Built optimizer!") 40 | 41 | # build summary writer 42 | #pre_summary = tf.summary.merge(loss.add_summary_writer()) 43 | 44 | fetches = {'pre_gen_loss': gen_loss, 'pre_dis_loss' : dis_loss, 'Y_loss' : Y_loss, 'C_loss' : C_loss, 45 | 'pre_gen_optimizer': pre_gen_optimizer, 46 | 'pre_dis_optimizer': pre_dis_optimizer, 47 | 'Y_optimizer' : Y_optimizer, 48 | 'C_optimizer' : C_optimizer, 49 | 'gen_HR': CLR_N} 50 | 51 | gc.collect() 52 | 53 | config = tf.ConfigProto( 54 | gpu_options=tf.GPUOptions( 55 | allow_growth=True, 56 | visible_device_list=args.gpu_dev_num 57 | ) 58 | ) 59 | 60 | saver = tf.train.Saver(max_to_keep=10) 61 | 62 | # Start session 63 | with tf.Session(config=config) as sess: 64 | log(logflag, 'Pre-train : Training starts', 'info') 65 | 66 | sess.run(tf.global_variables_initializer()) 67 | sess.run(global_iter.initializer) 68 | sess.run(scale_initialization(pre_gen_var, args)) 69 | saver.restore(sess, tf.train.latest_checkpoint(args.pre_train_checkpoint_dir)) 70 | if args.perceptual_loss == 'VGG19': 71 | sess.run(load_vgg19_weight(args)) 72 | writer = tf.summary.FileWriter(args.logdir, graph=sess.graph, filename_suffix='pre-train') 73 | 74 | _datapathNLR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/source/train_HR_aug/x4/', '*.png')))) 75 | _datapathCLR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/target/train_HR_aug/x4/', '*.png')))) 76 | idxN = np.random.permutation(len(_datapathNLR)) 77 | idxC = np.random.permutation(len(_datapathCLR)) 78 | datapathNLR = _datapathNLR[idxN] 79 | datapathCLR = _datapathCLR[idxC] 80 | 81 | epoch = 0 82 | counter = 0 83 | 84 | log(logflag, 'Pre-train Epoch: {0}'.format(epoch), 'info') 85 | start_time = time.time() 86 | #while counter <= args.num_iter: 87 | loss = 0.0 88 | steps = 0 89 | while True: 90 | lr = args.pretrain_learning_rate 91 | if counter >= len(_datapathCLR) - args.image_batch_size: 92 | log(logflag, 'Pre-train Epoch: {0} avg.loss : {1}'.format(epoch, loss / steps), 'info') 93 | idx = np.random.permutation(len(_datapathCLR)) 94 | datapathNLR = _datapathNLR[idx] 95 | datapathCLR = _datapathCLR[idx] 96 | counter = 0 97 | loss = 0.0 98 | steps = 0 99 | epoch += 1 100 | if epoch == 200 : 101 | break 102 | dataNLR, dataCLR = generate_pretrain_batch(datapathNLR[counter:counter + args.image_batch_size], 103 | datapathCLR[counter:counter + args.image_batch_size], 104 | args) 105 | counter += args.image_batch_size 106 | for iteration in range(0, dataNLR.shape[0], args.batch_size) : 107 | _NLR_data = dataNLR[iteration:iteration + args.batch_size] 108 | _CLR_data = dataCLR[iteration:iteration + args.batch_size] 109 | feed_dict = { 110 | NLR_data: _NLR_data, 111 | CLR_data: _CLR_data, 112 | } 113 | # update weights 114 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 115 | current_iter = tf.train.global_step(sess, global_iter) 116 | loss += result['pre_gen_loss'] 117 | steps += 1 118 | 119 | # save samples every n iter 120 | if current_iter % args.train_sample_save_freq == 0 : 121 | validpathLR = np.sort( 122 | np.asarray(glob(os.path.join(args.data_dir + '/validation/x/', '*.png')))) 123 | validpathHR = np.sort( 124 | np.asarray(glob(os.path.join(args.data_dir + '/validation/y/', '*.png')))) 125 | for valid_i in range(5): 126 | validLR, validHR = generate_testset(validpathLR[valid_i], 127 | validpathHR[valid_i], 128 | args) 129 | validLR = np.transpose(validLR[:, :, :, np.newaxis], (3, 0, 1, 2)) 130 | validHR = np.transpose(validHR[:, :, :, np.newaxis], (3, 0, 1, 2)) 131 | valid_out = sess.run(CLR_N,feed_dict={NLR_data: validLR, 132 | CLR_data: validLR, 133 | }) 134 | save_image(args, valid_out, 'pre-valid', current_iter + valid_i, save_max_num=5) 135 | save_image(args, result['gen_HR'], 'pre-train', current_iter, save_max_num=5) 136 | save_image(args, _NLR_data, 'pre-raw', current_iter, save_max_num=5) 137 | if current_iter % 10 == 0 : 138 | log(logflag, 139 | 'Pre-train iteration : {0}, pre_gen_loss : {1}, pre_dis_loss : {2} Y_loss : {3}, C_loss : {4}'.format( 140 | current_iter, result['pre_gen_loss'], result['pre_dis_loss'], result['Y_loss'], result['C_loss']), 141 | 'info') 142 | 143 | # save checkpoint 144 | if current_iter % args.train_ckpt_save_freq == 0: 145 | saver.save(sess, os.path.join(args.pre_train_checkpoint_dir, 'pre_gen'), global_step=current_iter) 146 | 147 | 148 | writer.close() 149 | log(logflag, 'Pre-train : Process end', 'info') 150 | 151 | 152 | def pretrain_generator2(args, logflag): 153 | """pre-train Low Resolution CycleGAN""" 154 | log(logflag, 'Pre-train : Process start', 'info') 155 | """ 156 | v_list = list() 157 | with tf.Session() as sess : 158 | saver = tf.train.import_meta_graph(args.pre_train_checkpoint_dir + "/pre_gen-831000.meta") 159 | saver.restore(sess=sess, save_path = args.pre_train_checkpoint_dir + "/pre_gen-831000") 160 | for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) : 161 | v_list.append(v.name) 162 | tf.reset_default_graph() 163 | """ 164 | g = tf.Graph() 165 | with g.as_default() as graph: 166 | CLR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 167 | name='CLR_input') 168 | CHR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 169 | name='CHR_input') 170 | # build Generator 171 | network = Network(args,CLR_data= CLR_data,CHR_data=CHR_data) 172 | CHR = network.pretrain_generator_SR() 173 | print("Built generator!") 174 | 175 | # build loss function 176 | loss = Loss() 177 | gen_loss = loss.pretrain_loss2(CHR_data, CHR) 178 | print("Built loss function!") 179 | 180 | # build optimizer 181 | global_iter = tf.Variable(0, trainable=False) 182 | pre_gen_var, pre_gen_optimizer= Optimizer().pretrain_optimizer2(args, global_iter, gen_loss) 183 | print("Built optimizer!") 184 | 185 | # build summary writer 186 | #pre_summary = tf.summary.merge(loss.add_summary_writer()) 187 | 188 | fetches = {'pre_gen_loss': gen_loss, 189 | 'pre_gen_optimizer': pre_gen_optimizer, 190 | 'CHR_out': CHR} 191 | gc.collect() 192 | 193 | config = tf.ConfigProto( 194 | gpu_options=tf.GPUOptions( 195 | allow_growth=True, 196 | visible_device_list=args.gpu_dev_num 197 | ) 198 | ) 199 | 200 | # Start session 201 | with tf.Session(config=config, graph=g) as sess: 202 | log(logflag, 'Pre-train : Training starts', 'info') 203 | 204 | sess.run(tf.global_variables_initializer()) 205 | sess.run(global_iter.initializer) 206 | sess.run(scale_initialization(pre_gen_var, args)) 207 | #var_list = [g.get_tensor_by_name('%s' % name) for name in v_list] 208 | #saver = tf.train.Saver(max_to_keep=10, var_list=var_list) 209 | #saver.restore(sess, tf.train.latest_checkpoint(args.pre_train_checkpoint_dir)) 210 | saver = tf.train.Saver(max_to_keep=10) 211 | saver.restore(sess, tf.train.latest_checkpoint(args.pre_train_checkpoint_dir)) 212 | 213 | writer = tf.summary.FileWriter(args.logdir, graph=sess.graph, filename_suffix='pre-train') 214 | 215 | _datapathCLR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/target/train_LR_aug/x4/', '*.png')))) 216 | _datapathCHR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/target/train_HR_aug/x4/', '*.png')))) 217 | idx = np.random.permutation(len(_datapathCHR)) 218 | datapathCLR = _datapathCLR[idx] 219 | datapathCHR = _datapathCHR[idx] 220 | 221 | epoch = 0 222 | counter = 0 223 | 224 | log(logflag, 'Pre-train Epoch: {0}'.format(epoch), 'info') 225 | start_time = time.time() 226 | loss = 0.0 227 | steps = 0 228 | while True: 229 | lr = args.pretrain_learning_rate 230 | if counter >= len(_datapathCLR) - args.image_batch_size: 231 | log(logflag, 'Pre-train Epoch: {0} avg.loss : {1}'.format(epoch, loss / steps), 'info') 232 | idx = np.random.permutation(len(_datapathCLR)) 233 | datapathCLR = _datapathCLR[idx] 234 | datapathCHR = _datapathCHR[idx] 235 | counter = 0 236 | loss = 0.0 237 | steps = 0 238 | epoch += 1 239 | if epoch == 200 : 240 | break 241 | dataCLR, dataCHR = generate_pretrain_batch2(datapathCLR[counter:counter + args.image_batch_size], 242 | datapathCHR[counter:counter + args.image_batch_size], 243 | args) 244 | counter += args.image_batch_size 245 | #if current_iter > args.num_iter: 246 | # break 247 | for iteration in range(0, dataCLR.shape[0], args.batch_size) : 248 | _CLR_data = dataCLR[iteration:iteration + args.batch_size] 249 | _CHR_data = dataCHR[iteration:iteration + args.batch_size] 250 | feed_dict = { 251 | CLR_data: _CLR_data, 252 | CHR_data: _CHR_data, 253 | } 254 | # update weights 255 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 256 | current_iter = tf.train.global_step(sess, global_iter) 257 | loss += result['pre_gen_loss'] 258 | steps += 1 259 | 260 | # save summary every n iter 261 | #if current_iter % args.train_summary_save_freq == 0: 262 | # writer.add_summary(result['summary'], global_step=current_iter) 263 | 264 | # save samples every n iter 265 | if current_iter % args.train_sample_save_freq == 0: 266 | validpathLR = np.sort( 267 | np.asarray(glob(os.path.join(args.data_dir + '/validation/x/', '*.png')))) 268 | validpathHR = np.sort( 269 | np.asarray(glob(os.path.join(args.data_dir + '/validation/y/', '*.png')))) 270 | for valid_i in range(5): 271 | validLR, validHR = generate_testset(validpathLR[valid_i], 272 | validpathHR[valid_i], 273 | args) 274 | validLR = np.transpose(validLR[:, :, :, np.newaxis], (3, 0, 1, 2)) 275 | validHR = np.transpose(validHR[:, :, :, np.newaxis], (3, 0, 1, 2)) 276 | valid_out = sess.run(CHR,feed_dict={CLR_data: validLR, 277 | CHR_data: validHR}) 278 | save_image(args, valid_out, 'pre-valid', current_iter + valid_i, save_max_num=5) 279 | 280 | save_image(args, result['CHR_out'], 'pre-train', current_iter, save_max_num=5) 281 | save_image(args, _CHR_data, 'pre-raw', current_iter, save_max_num=5) 282 | if current_iter % 10 == 0 : 283 | log(logflag, 284 | 'Pre-train iteration : {0}, pre_gen_loss : {1}'.format( 285 | current_iter, result['pre_gen_loss']), 286 | 'info') 287 | # save checkpoint 288 | if current_iter % args.train_ckpt_save_freq == 0: 289 | saver.save(sess, os.path.join(args.pre_train_checkpoint_dir, 'pre_gen'), global_step=current_iter) 290 | 291 | 292 | writer.close() 293 | log(logflag, 'Pre-train : Process end', 'info') 294 | 295 | def test_pretrain_generator(args, logflag): 296 | """pre-train deep network as initialization weights of ESRGAN Generator""" 297 | log(logflag, 'Pre-test : Process start', 'info') 298 | 299 | LR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 300 | name='LR_input') 301 | HR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 302 | name='HR_input') 303 | 304 | # build Generator 305 | network = Network(args, LR_data) 306 | pre_gen_out = network.generator() 307 | 308 | # build loss function 309 | loss = Loss() 310 | pre_gen_loss = loss.pretrain_loss(pre_gen_out, HR_data) 311 | 312 | # build optimizer 313 | global_iter = tf.Variable(0, trainable=False) 314 | pre_gen_var, pre_gen_optimizer = Optimizer().pretrain_optimizer(args, global_iter, pre_gen_loss) 315 | 316 | # build summary writer 317 | pre_summary = tf.summary.merge(loss.add_summary_writer()) 318 | 319 | fetches = {'gen_HR': pre_gen_out, 'summary': pre_summary} 320 | 321 | gc.collect() 322 | 323 | config = tf.ConfigProto( 324 | gpu_options=tf.GPUOptions( 325 | allow_growth=True, 326 | visible_device_list=args.gpu_dev_num 327 | ) 328 | ) 329 | 330 | saver = tf.train.Saver(max_to_keep=10) 331 | 332 | # Start session 333 | with tf.Session(config=config) as sess: 334 | log(logflag, 'Pre-train : Test starts', 'info') 335 | 336 | sess.run(tf.global_variables_initializer()) 337 | sess.run(global_iter.initializer) 338 | sess.run(scale_initialization(pre_gen_var, args)) 339 | saver.restore(sess, tf.train.latest_checkpoint(args.pre_train_checkpoint_dir)) 340 | 341 | writer = tf.summary.FileWriter(args.logdir, graph=sess.graph, filename_suffix='pre-train') 342 | 343 | #_datapathLR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/source/train_LR_aug/x4/', '*.png')))) 344 | #_datapathHR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/source/train_HR_aug/x4/', '*.png')))) 345 | _datapathLR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/validation/x/', '*.png')))) 346 | #_datapathHR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/validation/y/', '*.png')))) 347 | #_datapathLR = np.sort(np.asarray(glob(os.path.join('/home/super/PycharmProjects/NTIRE2020_tensorflow/sample/', 'A_*.jpg')))) 348 | _datapathHR = np.sort(np.asarray(glob(os.path.join(args.data_dir + '/validation/y/', '*.png')))) 349 | 350 | #idx = np.random.permutation(len(_datapathLR)) 351 | #datapathLR = _datapathLR[idx] 352 | #datapathHR = _datapathHR[idx] 353 | datapathLR = _datapathLR 354 | datapathHR = _datapathHR 355 | 356 | for i in range(0,len(datapathLR),1) : 357 | log(logflag, 'Pre-train, info') 358 | start_time = time.time() 359 | dataLR, dataHR = generate_testset(datapathLR[i], 360 | datapathHR[i], 361 | args) 362 | dataLR = np.transpose(dataLR[:,:,:,np.newaxis], (3,0,1,2)) 363 | dataHR = np.transpose(dataHR[:,:,:,np.newaxis], (3,0,1,2)) 364 | feed_dict = { 365 | HR_data: dataHR, 366 | LR_data: dataLR 367 | } 368 | # update weights 369 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 370 | current_iter = i 371 | save_image(args, result['gen_HR'], 'pre-test', current_iter, save_max_num=5) 372 | print("saved %d" %i) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import argparse 7 | import tensorflow as tf 8 | from sklearn.utils import shuffle 9 | import numpy as np 10 | from glob import glob 11 | 12 | from train_module import Network, Loss, Optimizer 13 | from utils import create_dirs, log, normalize_images, save_image, load_npz_data, load_and_save_data, generate_batch, generate_testset 14 | from utils import build_filter, apply_bicubic_downsample 15 | #os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 16 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0" 17 | 18 | parser = argparse.ArgumentParser(description='') 19 | # ----- vanilla ------ 20 | parser.add_argument('--data_dir', dest='data_dir', default='./data/track1/', help='path of the dataset') 21 | parser.add_argument('--test_result_dir', dest='test_result_dir', default='./test_result', help='output directory during training') 22 | #parser.add_argument('--test_LR_result_dir', dest='test_LR_result_dir', default='./test_LR_result', help='output directory during training') 23 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', type=str, default='./checkpoint_track1', help='checkpoint directory') 24 | 25 | # About Data 26 | parser.add_argument('--crop', dest='crop', default=True, help='patch width') 27 | parser.add_argument('--crop_size', dest='crop_size', type=int, default=64, help='patch height') 28 | parser.add_argument('--stride', dest='stride', type=int, default=64, help='patch stride') 29 | parser.add_argument('--data_augmentation', dest='data_augmentation', default=True, help='') 30 | 31 | # About Network 32 | parser.add_argument('--scale_SR', dest='scale_SR', default=4, help='the scale of super-resolution') 33 | parser.add_argument('--num_repeat_RRDB', dest='num_repeat_RRDB', type=int, default=15, help='the number of RRDB blocks') 34 | parser.add_argument('--residual_scaling', dest='residual_scaling', type=float, default=0.2, help='residual scaling parameter') 35 | parser.add_argument('--initialization_random_seed', dest='initialization_random_seed', default=111, help='random_seed') 36 | parser.add_argument('--perceptual_loss', dest='perceptual_loss', default='VGG19', help='the part of loss function. "VGG19" or "pixel-wise"') 37 | parser.add_argument('--gan_loss_type', dest='gan_loss_type', default='MaGAN', help='the type of GAN loss functions. "RaGAN" or "GAN"') 38 | 39 | # About training 40 | parser.add_argument('--num_iter', dest='num_iter', type=int, default=50000, help='The number of iterations') 41 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=2, help='Mini-batch size') 42 | parser.add_argument('--channel', dest='channel', type=int, default=3, help='Number of input/output image channel') 43 | parser.add_argument('--pretrain_generator', dest='pretrain_generator', type=bool, default=False, help='Whether to pretrain generator') 44 | parser.add_argument('--pretrain_learning_rate', dest='pretrain_learning_rate', type=float, default=2e-4, help='learning rate for pretrain') 45 | parser.add_argument('--pretrain_lr_decay_step', dest='pretrain_lr_decay_step', type=float, default=20000, help='decay by every n iteration') 46 | parser.add_argument('--learning_rate', dest='learning_rate', type=float, default=1e-4, help='learning rate') 47 | parser.add_argument('--weight_initialize_scale', dest='weight_initialize_scale', type=float, default=0.1, help='scale to multiply after MSRA initialization') 48 | parser.add_argument('--HR_image_size', dest='HR_image_size', type=int, default=128, help='Image width and height of LR image. This is should be 1/4 of HR_image_size exactly') 49 | parser.add_argument('--LR_image_size', dest='LR_image_size', type=int, default=32, help='Image width and height of LR image.') 50 | parser.add_argument('--epsilon', dest='epsilon', type=float, default=1e-12, help='used in loss function') 51 | parser.add_argument('--gan_loss_coeff', dest='gan_loss_coeff', type=float, default=1.0, help='used in perceptual loss') 52 | parser.add_argument('--content_loss_coeff', dest='content_loss_coeff', type=float, default=0.01, help='used in content loss') 53 | 54 | # About log 55 | parser.add_argument('--logging', dest='logging', type=bool, default=True, help='whether to record training log') 56 | parser.add_argument('--train_sample_save_freq', dest='train_sample_save_freq', type=int, default=100, help='save samples during training every n iteration') 57 | parser.add_argument('--train_ckpt_save_freq', dest='train_ckpt_save_freq', type=int, default=100, help='save checkpoint during training every n iteration') 58 | parser.add_argument('--train_summary_save_freq', dest='train_summary_save_freq', type=int, default=200, help='save summary during training every n iteration') 59 | parser.add_argument('--pre_train_checkpoint_dir', dest='pre_train_checkpoint_dir', type=str, default='./pre_train_checkpoint', help='pre-train checkpoint directory') 60 | parser.add_argument('--logdir', dest='logdir', type=str, default='./log_test', help='log directory') 61 | 62 | # GPU setting 63 | parser.add_argument('--gpu_dev_num', dest='gpu_dev_num', type=str, default='0', help='Which GPU to use for multi-GPUs') 64 | 65 | # ----- my args ----- 66 | parser.add_argument('--image_batch_size', dest='image_batch_size', type=int, default=64, help='Mini-batch size') 67 | parser.add_argument('--epochs', dest='epochs', type=int, default=200, help='Total Epochs') 68 | 69 | args = parser.parse_args() 70 | 71 | def set_logger(args): 72 | """set logger for training recording""" 73 | if args.logging: 74 | logfile = '{0}/training_logfile_{1}.log'.format(args.logdir, datetime.now().strftime("%Y%m%d_%H%M%S")) 75 | formatter = '%(levelname)s:%(asctime)s:%(message)s' 76 | logging.basicConfig(level=logging.INFO, filename=logfile, format=formatter, datefmt='%Y-%m-%d %I:%M:%S') 77 | return True 78 | else: 79 | print('No logging is set') 80 | return False 81 | 82 | def main(): 83 | # make dirs 84 | target_dirs = [args.logdir,args.test_result_dir]#., args.test_LR_result_dir] 85 | create_dirs(target_dirs) 86 | 87 | # set logger 88 | logflag = set_logger(args) 89 | log(logflag, 'Test script start', 'info') 90 | 91 | NLR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 92 | name='NLR_input') 93 | CLR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 94 | name='CLR_input') 95 | NHR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 96 | name='NHR_input') 97 | CHR_data = tf.placeholder(tf.float32, shape=[None, None, None, args.channel], 98 | name='CHR_input') 99 | 100 | # build Generator and Discriminator 101 | network = Network(args, NLR_data=NLR_data, CLR_data=CLR_data, NHR_data=NHR_data, CHR_data=CHR_data, is_test =True) 102 | CLR_C1, NLR_C1, CLR_C2, CHR_C3, NLR_C3, CHR_C4, CLR_I1, CHR_I1, CHR_I2 = network.train_generator() 103 | 104 | # define optimizers 105 | global_iter = tf.Variable(0, trainable=False) 106 | 107 | gc.collect() 108 | 109 | config = tf.ConfigProto( 110 | gpu_options=tf.GPUOptions( 111 | allow_growth=True, 112 | visible_device_list=args.gpu_dev_num 113 | ) 114 | ) 115 | 116 | # Start Session 117 | with tf.Session(config=config) as sess: 118 | log(logflag, 'Start Session', 'info') 119 | 120 | sess.run(tf.global_variables_initializer()) 121 | sess.run(global_iter.initializer) 122 | 123 | saver = tf.train.Saver(max_to_keep=10) 124 | saver.restore(sess, tf.train.latest_checkpoint(args.checkpoint_dir)) 125 | 126 | 127 | validpathLR = np.sort( 128 | np.asarray(glob(os.path.join(args.data_dir, '*.png')))) 129 | validpathHR = np.sort( 130 | np.asarray(glob(os.path.join(args.data_dir, '*.png')))) 131 | import time 132 | 133 | avgtime = 0 134 | for valid_i in range(100) : 135 | validLR, validHR = generate_testset(validpathLR[valid_i], 136 | validpathHR[valid_i], 137 | args) 138 | name = validpathLR[valid_i].split('/')[-1] 139 | validLR = np.transpose(validLR[:, :, :, np.newaxis], (3, 0, 1, 2)) 140 | validHR = np.transpose(validHR[:, :, :, np.newaxis], (3, 0, 1, 2)) 141 | starttime = time.time() 142 | valid_out, valid_out_LR = sess.run([CHR_C3, CLR_C1], 143 | feed_dict={NLR_data: validLR, 144 | NHR_data: validHR, 145 | CLR_data: validLR, 146 | CHR_data: validHR}) 147 | 148 | validLR = np.rot90(validLR, 1, axes=(1, 2)) 149 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 150 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 151 | feed_dict={NLR_data: validLR, 152 | NHR_data: validHR, 153 | CLR_data: validLR, 154 | CHR_data: validHR}) 155 | valid_out += np.rot90(valid_out90, 3, axes=(1, 2)) 156 | valid_out_LR += np.rot90(valid_out_LR90, 3, axes=(1, 2)) 157 | 158 | validLR = np.rot90(validLR, 1, axes=(1, 2)) 159 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 160 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 161 | feed_dict={NLR_data: validLR, 162 | NHR_data: validHR, 163 | CLR_data: validLR, 164 | CHR_data: validHR}) 165 | valid_out += np.rot90(valid_out90, 2, axes=(1, 2)) 166 | valid_out_LR += np.rot90(valid_out_LR90, 2, axes=(1, 2)) 167 | 168 | validLR = np.rot90(validLR, 1, axes=(1, 2)) 169 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 170 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 171 | feed_dict={NLR_data: validLR, 172 | NHR_data: validHR, 173 | CLR_data: validLR, 174 | CHR_data: validHR}) 175 | valid_out += np.rot90(valid_out90, 1, axes=(1, 2)) 176 | valid_out_LR += np.rot90(valid_out_LR90, 1, axes=(1, 2)) 177 | 178 | validLR = np.rot90(validLR, 1, axes=(1, 2)) 179 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 180 | 181 | validLR = validLR[:,::-1, :, :] 182 | validHR = validHR[:,::-1, :, :] 183 | 184 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 185 | feed_dict={NLR_data: validLR, 186 | NHR_data: validHR, 187 | CLR_data: validLR, 188 | CHR_data: validHR}) 189 | valid_out += valid_out90[:,::-1,:,:] 190 | valid_out_LR += valid_out_LR90[:,::-1,:,:] 191 | 192 | validLR = np.rot90(validLR, 1, axes=(1, 2)) 193 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 194 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 195 | feed_dict={NLR_data: validLR, 196 | NHR_data: validHR, 197 | CLR_data: validLR, 198 | CHR_data: validHR}) 199 | valid_out90 = np.rot90(valid_out90, 3, axes=(1, 2)) 200 | valid_out_LR90 = np.rot90(valid_out_LR90, 3, axes=(1, 2)) 201 | 202 | valid_out += valid_out90[:,::-1,:,:] 203 | valid_out_LR += valid_out_LR90[:,::-1,:,:] 204 | 205 | validLR = np.rot90(validLR, 1, axes=(1, 2)) 206 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 207 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 208 | feed_dict={NLR_data: validLR, 209 | NHR_data: validHR, 210 | CLR_data: validLR, 211 | CHR_data: validHR}) 212 | valid_out90 = np.rot90(valid_out90, 2, axes=(1, 2)) 213 | valid_out_LR90 = np.rot90(valid_out_LR90, 2, axes=(1, 2)) 214 | 215 | valid_out += valid_out90[:,::-1,:,:] 216 | valid_out_LR += valid_out_LR90[:,::-1,:,:] 217 | 218 | validLR = np.rot90(validLR, 1, axes=(1, 2)) 219 | validHR = np.rot90(validHR, 1, axes=(1, 2)) 220 | valid_out90, valid_out_LR90 = sess.run([CHR_C3, CLR_C1], 221 | feed_dict={NLR_data: validLR, 222 | NHR_data: validHR, 223 | CLR_data: validLR, 224 | CHR_data: validHR}) 225 | valid_out90 = np.rot90(valid_out90, 1, axes=(1, 2)) 226 | valid_out_LR90 = np.rot90(valid_out_LR90, 1, axes=(1, 2)) 227 | 228 | valid_out += valid_out90[:,::-1,:,:] 229 | valid_out_LR += valid_out_LR90[:,::-1,:,:] 230 | 231 | valid_out /= 8. 232 | valid_out_LR /= 8. 233 | currtime = time.time() - starttime 234 | print("time : %fs"%(currtime)) 235 | avgtime += currtime / 100 236 | save_image(args, valid_out, 'test', name, save_max_num=1) 237 | #save_image(args, valid_out_LR, 'test_LR', valid_i, save_max_num=5) 238 | print("avg. time : %f"%avgtime) 239 | 240 | if __name__ == '__main__': 241 | main() -------------------------------------------------------------------------------- /train_module.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import tensorflow as tf 4 | from network import Generator, Discriminator, Perceptual_VGG19 5 | import numpy as np 6 | from scipy import signal 7 | def rgb2gray(rgb): 8 | _rgb = (rgb + 1) / 2 9 | gray = tf.image.rgb_to_grayscale(rgb) 10 | gray = gray * 2 - 1 11 | return gray 12 | 13 | def gaussian_kernel(size, mean, std) : 14 | k = signal.gaussian(size, std=std).reshape(size,1) 15 | k = np.outer(k, k)[...,np.newaxis] 16 | k = k / np.sum(k) 17 | k = np.concatenate((k,k,k), axis=2) 18 | return tf.constant(k, dtype=tf.float32, shape=(size,size,3,1)) 19 | 20 | #d = tf.distributions.Normal(mean, std) 21 | #vals = d.prob(tf.range(start=-size, limit=size+1, dtype=tf.float32)) 22 | #gauss_kernel = tf.einsum('i,j->ij', vals, vals) 23 | #return tf.constant(gauss_kernel / tf.reduce_sum(gauss_kernel)) 24 | 25 | def gauss_blur(image) : 26 | image = (image + 1) / 2 27 | gauss_kernel = gaussian_kernel(7, 0., 3.) 28 | blur = tf.nn.conv2d(image, gauss_kernel, strides=[1,1,1,1], padding='SAME') 29 | blur = blur * 2 - 1 30 | return blur 31 | 32 | class Network(object): 33 | """class to build networks""" 34 | def __init__(self, args, NLR_data=None, NHR_data=None, CLR_data = None, CHR_data=None, mask_data = None, is_test = False): 35 | self.args = args 36 | self.NLR_data = NLR_data 37 | self.NHR_data = NHR_data 38 | self.CLR_data = CLR_data 39 | self.CHR_data = CHR_data 40 | self.mask_data = mask_data 41 | self.is_test = is_test 42 | 43 | def pretrain_generator_LR(self): 44 | with tf.device("/gpu:1"): 45 | with tf.name_scope('generator'): 46 | with tf.variable_scope('generator', reuse=False): 47 | CLR_N = Generator(self.args).build_G1(self.NLR_data) 48 | NLR_NC = Generator(self.args).build_G2(CLR_N) 49 | with tf.variable_scope('generator', reuse=True): 50 | NLR_C = Generator(self.args).build_G2(self.CLR_data) 51 | CLR_CN = Generator(self.args).build_G1(NLR_C) 52 | with tf.variable_scope('generator', reuse=True): 53 | CLR_C = Generator(self.args).build_G1(self.CLR_data) 54 | 55 | return CLR_N, NLR_NC, NLR_C, CLR_CN, CLR_C 56 | 57 | def pretrain_discriminator_LR(self, CLR_out): 58 | with tf.device("/gpu:1"): 59 | discriminatorLR = Discriminator(self.args) 60 | with tf.name_scope('real_discriminator'): 61 | with tf.variable_scope('discriminator', reuse=False): 62 | dis_out_realLR = discriminatorLR.build(self.CLR_data) 63 | with tf.name_scope('fake_discriminator'): 64 | with tf.variable_scope('discriminator', reuse=True): 65 | dis_out_fakeLR = discriminatorLR.build(CLR_out) 66 | with tf.name_scope('noisy_discriminator'): 67 | with tf.variable_scope('discriminator', reuse=True): 68 | dis_out_noisyLR = discriminatorLR.build(self.NLR_data) 69 | 70 | Y_CLR_out = rgb2gray(CLR_out) 71 | Y_CLR_data = rgb2gray(self.CLR_data) 72 | Y_NLR_data = rgb2gray(self.NLR_data) 73 | Y_out = list() 74 | with tf.device("/gpu:1"): 75 | Y_discriminator = Discriminator(self.args) 76 | with tf.name_scope('Y_real_discriminator'): 77 | with tf.variable_scope('Y_discriminator', reuse=False): 78 | Y_dis_out_realLR = Y_discriminator.build(Y_CLR_data) 79 | Y_out.append(Y_dis_out_realLR) 80 | with tf.name_scope('Y_fake_discriminator'): 81 | with tf.variable_scope('Y_discriminator', reuse=True): 82 | Y_dis_out_fakeLR = Y_discriminator.build(Y_CLR_out) 83 | Y_out.append(Y_dis_out_fakeLR) 84 | with tf.name_scope('Y_noisy_discriminator'): 85 | with tf.variable_scope('Y_discriminator', reuse=True): 86 | Y_dis_out_noisyLR = Y_discriminator.build(Y_NLR_data) 87 | Y_out.append(Y_dis_out_noisyLR) 88 | 89 | C_CLR_out = gauss_blur(CLR_out) 90 | C_CLR_data = gauss_blur(self.CLR_data) 91 | C_NLR_data = gauss_blur(self.NLR_data) 92 | C_out = list() 93 | with tf.device("/gpu:1"): 94 | C_discriminator = Discriminator(self.args) 95 | with tf.name_scope('C_real_discriminator'): 96 | with tf.variable_scope('C_discriminator', reuse=False): 97 | C_dis_out_realLR = C_discriminator.build(C_CLR_data) 98 | C_out.append(C_dis_out_realLR) 99 | with tf.name_scope('C_fake_discriminator'): 100 | with tf.variable_scope('C_discriminator', reuse=True): 101 | C_dis_out_fakeLR = C_discriminator.build(C_CLR_out) 102 | C_out.append(C_dis_out_fakeLR) 103 | with tf.name_scope('C_noisy_discriminator'): 104 | with tf.variable_scope('C_discriminator', reuse=True): 105 | C_dis_out_noisyLR = C_discriminator.build(C_NLR_data) 106 | C_out.append(C_dis_out_noisyLR) 107 | 108 | return dis_out_realLR, dis_out_fakeLR, dis_out_noisyLR, Y_out, C_out 109 | 110 | def pretrain_generator_SR(self): 111 | with tf.device("/gpu:0"): 112 | with tf.name_scope('generator'): 113 | with tf.variable_scope('generator', reuse=False): 114 | CHR = Generator(self.args).build_SR(self.CLR_data) 115 | 116 | return CHR 117 | 118 | def train_generator(self): 119 | if self.is_test: 120 | dev = "0" 121 | else: 122 | dev = "1" 123 | generator = Generator(self.args) 124 | with tf.name_scope('generator'): 125 | # cycle Nl -> Cl -> Nl 126 | with tf.variable_scope('generator', reuse=False): 127 | with tf.device("/gpu:0"): 128 | CLR_C1 = generator.build_G1(self.NLR_data) 129 | NLR_C1 = generator.build_G2(CLR_C1) 130 | 131 | # cycle Cl -> Nl -> Cl 132 | with tf.variable_scope('generator', reuse=True): 133 | with tf.device("/gpu:0"): 134 | NLR_C2 = generator.build_G2(self.CLR_data) 135 | CLR_C2 = generator.build_G1(NLR_C2) 136 | 137 | # cycle Nl -> Ch -> Nl 138 | with tf.variable_scope('generator', reuse=False): 139 | with tf.device("/gpu:"+dev): 140 | CHR_C3 = generator.build_SR(CLR_C1) 141 | with tf.device("/gpu:"+dev): 142 | NLR_C3 = generator.build_G3(CHR_C3) 143 | 144 | # cycle Ch -> Nl -> Ch 145 | with tf.variable_scope('generator', reuse=True): 146 | with tf.device("/gpu:"+dev): 147 | NLR_C4 = generator.build_G3(self.CHR_data) 148 | with tf.device("/gpu:0"): 149 | CHR_C4 = generator.build_G1(NLR_C4) 150 | with tf.device("/gpu:"+dev): 151 | CHR_C4 = generator.build_SR(CHR_C4) 152 | 153 | # Identity 154 | with tf.variable_scope('generator', reuse=True): 155 | with tf.device("/gpu:0"): 156 | CLR_I1 = generator.build_G1(self.CLR_data) 157 | with tf.device("/gpu:"+dev): 158 | CHR_I1 = generator.build_SR(CLR_I1) 159 | CHR_I2 = generator.build_SR(self.CLR_data) 160 | 161 | return CLR_C1, NLR_C1, CLR_C2, CHR_C3, NLR_C3, CHR_C4, CLR_I1, CHR_I1, CHR_I2 162 | 163 | def train_discriminator(self, CLR_out, CHR_out): 164 | D_out = list() 165 | with tf.device("/gpu:2"): 166 | discriminator = Discriminator(self.args) 167 | with tf.variable_scope('discriminator', reuse=False): 168 | dis_out_real = discriminator.build(self.CHR_data) 169 | D_out.append(dis_out_real) 170 | with tf.variable_scope('discriminator', reuse=True): 171 | dis_out_fake = discriminator.build(CHR_out) 172 | D_out.append(dis_out_fake) 173 | with tf.variable_scope('discriminator', reuse=True): 174 | dis_out_noisy = discriminator.build(self.NHR_data) 175 | D_out.append(dis_out_noisy) 176 | with tf.variable_scope('discriminator', reuse=True): 177 | dis_out_fakeLR = discriminator.build(CLR_out) 178 | D_out.append(dis_out_fakeLR) 179 | with tf.variable_scope('discriminator', reuse=True): 180 | dis_out_noisyLR = discriminator.build(self.NLR_data) 181 | D_out.append(dis_out_noisyLR) 182 | Y_CHR_out = rgb2gray(CHR_out) 183 | Y_CHR_data = rgb2gray(self.CHR_data) 184 | Y_NHR_data = rgb2gray(self.NHR_data) 185 | Y_CLR_out = rgb2gray(CLR_out) 186 | Y_NLR_data = rgb2gray(self.NLR_data) 187 | Y_out = list() 188 | 189 | with tf.device("/gpu:2"): 190 | Y_discriminator = Discriminator(self.args) 191 | with tf.variable_scope('Y_discriminator', reuse=False): 192 | Y_dis_out_real = Y_discriminator.build(Y_CHR_data) 193 | Y_out.append(Y_dis_out_real) 194 | with tf.variable_scope('Y_discriminator', reuse=True): 195 | Y_dis_out_fake = Y_discriminator.build(Y_CHR_out) 196 | Y_out.append(Y_dis_out_fake) 197 | with tf.variable_scope('Y_discriminator', reuse=True): 198 | Y_dis_out_noisy = Y_discriminator.build(Y_NHR_data) 199 | Y_out.append(Y_dis_out_noisy) 200 | with tf.variable_scope('Y_discriminator', reuse=True): 201 | Y_dis_out_fakeLR = Y_discriminator.build(Y_CLR_out) 202 | Y_out.append(Y_dis_out_fakeLR) 203 | with tf.variable_scope('Y_discriminator', reuse=True): 204 | Y_dis_out_noisyLR = Y_discriminator.build(Y_NLR_data) 205 | Y_out.append(Y_dis_out_noisyLR) 206 | 207 | C_CHR_out = gauss_blur(CHR_out) 208 | C_CHR_data = gauss_blur(self.CHR_data) 209 | C_NHR_data = gauss_blur(self.NHR_data) 210 | C_CLR_out = gauss_blur(CLR_out) 211 | C_NLR_data = gauss_blur(self.NLR_data) 212 | C_out = list() 213 | with tf.device("/gpu:2"): 214 | C_discriminator = Discriminator(self.args) 215 | with tf.variable_scope('C_discriminator', reuse=False): 216 | C_dis_out_real = C_discriminator.build(C_CHR_data) 217 | C_out.append(C_dis_out_real) 218 | with tf.variable_scope('C_discriminator', reuse=True): 219 | C_dis_out_fake = C_discriminator.build(C_CHR_out) 220 | C_out.append(C_dis_out_fake) 221 | with tf.variable_scope('C_discriminator', reuse=True): 222 | C_dis_out_noisy = C_discriminator.build(C_NHR_data) 223 | C_out.append(C_dis_out_noisy) 224 | with tf.variable_scope('C_discriminator', reuse=True): 225 | C_dis_out_fakeLR = C_discriminator.build(C_CLR_out) 226 | C_out.append(C_dis_out_fakeLR) 227 | with tf.variable_scope('C_discriminator', reuse=True): 228 | C_dis_out_noisyLR = C_discriminator.build(C_NLR_data) 229 | C_out.append(C_dis_out_noisyLR) 230 | return D_out, Y_out, C_out 231 | 232 | class Loss(object): 233 | """class to build loss functions""" 234 | def __init__(self): 235 | self.summary_target = OrderedDict() 236 | 237 | def pretrain_loss(self, args, NLR, CLR, NLR_NC, CLR_CN, CLR_C, dis_out_realLR, dis_out_fakeLR, dis_out_noisyLR, Y_out, C_out): 238 | with tf.name_scope('loss_function'): 239 | with tf.variable_scope('loss_generator') : 240 | pre_gen_N = tf.reduce_mean(tf.abs(NLR_NC - NLR)) 241 | vgg_out_gen, vgg_out_hr = self._perceptual_vgg19_loss(NLR, NLR_NC) 242 | pre_gen_N += tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr)) 243 | pre_gen_N += 2 - tf.reduce_mean(tf.image.ssim(NLR, NLR_NC, 2.0)) 244 | 245 | pre_gen_C = tf.reduce_mean(tf.abs(CLR_CN - CLR)) 246 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(CLR, CLR_CN) 247 | pre_gen_C += tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 248 | pre_gen_C += 2 - tf.reduce_mean(tf.image.ssim(CLR, CLR_CN, 2.0)) 249 | 250 | pre_identity = tf.reduce_mean(tf.abs(CLR_C - CLR)) 251 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(CLR, CLR_C) 252 | pre_identity += tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr)) 253 | pre_identity += 2 - tf.reduce_mean(tf.image.ssim(CLR, CLR_C, 2.0)) 254 | 255 | with tf.variable_scope('loss_generator'): 256 | if args.gan_loss_type == 'GAN': 257 | g_loss_fake = tf.reduce_mean( 258 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fakeLR, 259 | labels=tf.ones_like(dis_out_fakeLR))) 260 | if args.gan_loss_type == 'MaGAN': 261 | g_loss_fake = tf.reduce_mean( 262 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fakeLR, 263 | labels=tf.ones_like(dis_out_fakeLR))) 264 | g_loss_fake += tf.reduce_mean( 265 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[1], 266 | labels=tf.ones_like(Y_out[1]))) 267 | g_loss_fake += tf.reduce_mean( 268 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[1], 269 | labels=tf.ones_like(C_out[1]))) 270 | 271 | gen_loss = pre_gen_N + pre_gen_C + pre_identity + 1e-2 *g_loss_fake 272 | 273 | with tf.variable_scope('loss_discriminator'): 274 | if args.gan_loss_type == 'GAN': 275 | d_loss_real = tf.reduce_mean( 276 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_realLR, 277 | labels=tf.ones_like(dis_out_realLR))) 278 | d_loss_fake = tf.reduce_mean( 279 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fakeLR, 280 | labels=tf.zeros_like(dis_out_fakeLR))) 281 | d_loss_noisy = tf.reduce_mean( 282 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_noisyLR, 283 | labels=tf.zeros_like(dis_out_noisyLR))) 284 | dis_loss = d_loss_real + d_loss_fake + d_loss_noisy 285 | 286 | if args.gan_loss_type == 'MaGAN': 287 | d_loss_real = tf.reduce_mean( 288 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_realLR, 289 | labels=tf.ones_like(dis_out_realLR))) 290 | d_loss_fake = tf.reduce_mean( 291 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fakeLR, 292 | labels=tf.zeros_like(dis_out_fakeLR))) 293 | d_loss_noisy = tf.reduce_mean( 294 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_noisyLR, 295 | labels=tf.zeros_like(dis_out_noisyLR))) 296 | Y_loss_real = tf.reduce_mean( 297 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[0], 298 | labels=tf.ones_like(Y_out[0]))) 299 | Y_loss_fake = tf.reduce_mean( 300 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[1], 301 | labels=tf.zeros_like(Y_out[1]))) 302 | Y_loss_noisy = tf.reduce_mean( 303 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[2], 304 | labels=tf.zeros_like(Y_out[2]))) 305 | C_loss_real = tf.reduce_mean( 306 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[0], 307 | labels=tf.ones_like(C_out[0]))) 308 | C_loss_fake = tf.reduce_mean( 309 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[1], 310 | labels=tf.zeros_like(C_out[1]))) 311 | C_loss_noisy = tf.reduce_mean( 312 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[2], 313 | labels=tf.zeros_like(C_out[2]))) 314 | 315 | Y_loss = Y_loss_real + Y_loss_fake + Y_loss_noisy 316 | C_loss = C_loss_real + C_loss_fake + C_loss_noisy 317 | dis_loss = d_loss_real + d_loss_fake + d_loss_noisy 318 | 319 | return gen_loss, dis_loss, Y_loss, C_loss 320 | 321 | def pretrain_loss2(self, CHR_data, CHR): 322 | with tf.name_scope('loss_function'): 323 | with tf.variable_scope('pixel-wise_loss') : 324 | pre_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(CHR - CHR_data), axis=3)) 325 | 326 | return pre_loss 327 | 328 | def _perceptual_vgg19_loss(self, HR_data, gen_out): 329 | with tf.device("/gpu:1"): 330 | with tf.name_scope('perceptual_vgg19_HR'): 331 | with tf.variable_scope('perceptual_vgg19', reuse=False): 332 | vgg_out_hr = Perceptual_VGG19().build(HR_data) 333 | 334 | with tf.name_scope('perceptual_vgg19_Gen'): 335 | with tf.variable_scope('perceptual_vgg19', reuse=True): 336 | vgg_out_gen = Perceptual_VGG19().build(gen_out) 337 | 338 | return vgg_out_hr, vgg_out_gen 339 | def __perceptual_vgg19_loss(self, HR_data, gen_out): 340 | with tf.device("/gpu:1"): 341 | with tf.name_scope('perceptual_vgg19_HR'): 342 | with tf.variable_scope('perceptual_vgg19', reuse=True): 343 | vgg_out_hr = Perceptual_VGG19().build(HR_data) 344 | 345 | with tf.name_scope('perceptual_vgg19_Gen'): 346 | with tf.variable_scope('perceptual_vgg19', reuse=True): 347 | vgg_out_gen = Perceptual_VGG19().build(gen_out) 348 | 349 | return vgg_out_hr, vgg_out_gen 350 | 351 | 352 | 353 | def gan_loss(self, FLAGS, NLR_data, CLR_data, NHR_data, CHR_data, 354 | CLR_C1, NLR_C1, CLR_C2, CHR_C3, NLR_C3, CHR_C4, CLR_I1, CHR_I1, CHR_I2, 355 | D_out, Y_out, C_out): 356 | with tf.name_scope('loss_function'): 357 | with tf.variable_scope('loss_generator'): 358 | w_gen = 1e-3 359 | if FLAGS.gan_loss_type == 'GAN': 360 | gen_loss = 1e-2 *tf.reduce_mean( 361 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[3], 362 | labels=tf.ones_like(D_out[3]))) 363 | gen_loss += 1e-2 *tf.reduce_mean( 364 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[1], 365 | labels=tf.ones_like(D_out[1]))) 366 | 367 | elif FLAGS.gan_loss_type == 'MaGAN': 368 | gen_loss = 2*w_gen *tf.reduce_mean( 369 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[1], 370 | labels=tf.ones_like(D_out[1]))) 371 | gen_loss += 2*w_gen * tf.reduce_mean( 372 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[1], 373 | labels=tf.ones_like(Y_out[1]))) 374 | gen_loss += w_gen * tf.reduce_mean( 375 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[1], 376 | labels=tf.ones_like(C_out[1]))) 377 | gen_loss += 2*w_gen * tf.reduce_mean( 378 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[3], 379 | labels=tf.ones_like(D_out[3]))) 380 | gen_loss += 2*w_gen * tf.reduce_mean( 381 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[3], 382 | labels=tf.ones_like(Y_out[3]))) 383 | gen_loss += w_gen * tf.reduce_mean( 384 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[3], 385 | labels=tf.ones_like(C_out[3]))) 386 | else: 387 | raise ValueError('Unknown GAN loss function type') 388 | 389 | w_l1 = 1 390 | w_vgg = 1 391 | w_ssim = 1 392 | 393 | # cycle 1 loss 394 | g_loss_cycle = w_l1*tf.reduce_mean(tf.reduce_mean(tf.abs(NLR_C1 - NLR_data), axis=[1,2,3])) 395 | vgg_out_gen, vgg_out_hr = self._perceptual_vgg19_loss(NLR_data, NLR_C1) 396 | g_loss_cycle += w_vgg*tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 397 | g_loss_cycle += w_ssim*(2 - tf.reduce_mean(tf.image.ssim(NLR_data, NLR_C1, 2.0))) 398 | 399 | # cycle 2 loss 400 | g_loss_cycle += w_l1*tf.reduce_mean(tf.reduce_mean(tf.abs(CLR_C2 - CLR_data), axis=[1,2,3])) 401 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(CLR_data, CLR_C2) 402 | g_loss_cycle += w_vgg*tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 403 | g_loss_cycle += w_ssim*(2 - tf.reduce_mean(tf.image.ssim(CLR_data, CLR_C2, 2.0))) 404 | 405 | # cycle 3 loss 406 | g_loss_cycle += w_l1*tf.reduce_mean(tf.reduce_mean(tf.abs(NLR_C3 - NLR_data), axis=[1,2,3])) 407 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(NLR_data, NLR_C3) 408 | g_loss_cycle += w_vgg*tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 409 | g_loss_cycle += w_ssim*(2 - tf.reduce_mean(tf.image.ssim(NLR_data, NLR_C3, 2.0))) 410 | 411 | # cycle 4 loss 412 | g_loss_cycle += w_l1*tf.reduce_mean(tf.reduce_mean(tf.abs(CHR_C4 - CHR_data), axis=[1,2,3])) 413 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(CHR_data, CHR_C4) 414 | g_loss_cycle += w_vgg*tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 415 | g_loss_cycle += w_ssim*(2 - tf.reduce_mean(tf.image.ssim(CHR_data, CHR_C4, 2.0))) 416 | 417 | # Identity 1 Loss 418 | g_loss_identity = w_l1*tf.reduce_mean(tf.reduce_mean(tf.abs(CLR_I1 - CLR_data), axis=[1,2,3])) 419 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(CLR_data, CLR_I1) 420 | g_loss_identity += w_vgg*tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 421 | g_loss_identity += w_ssim*(2 - tf.reduce_mean(tf.image.ssim(CLR_data, CLR_I1, 2.0))) 422 | 423 | # Identity 2 Loss 424 | g_loss_identity += w_l1*tf.reduce_mean(tf.reduce_mean(tf.abs(CHR_I1 - CHR_data), axis=[1,2,3])) 425 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(CHR_data, CHR_I1) 426 | g_loss_identity += w_vgg*tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 427 | g_loss_identity += w_ssim*(2 - tf.reduce_mean(tf.image.ssim(CHR_data, CHR_I1, 2.0))) 428 | 429 | # Identity 3 Loss 430 | g_loss_identity += w_l1*tf.reduce_mean(tf.reduce_mean(tf.abs(CHR_I2 - CHR_data), axis=[1,2,3])) 431 | vgg_out_gen, vgg_out_hr = self.__perceptual_vgg19_loss(CHR_data, CHR_I2) 432 | g_loss_identity += w_vgg*tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 433 | g_loss_identity += w_ssim*(2 - tf.reduce_mean(tf.image.ssim(CHR_data, CHR_I2, 2.0))) 434 | 435 | g_gen_loss = tf.identity(gen_loss) 436 | gen_loss += 2*g_loss_cycle 437 | gen_loss += g_loss_identity 438 | 439 | 440 | with tf.variable_scope('loss_discriminator'): 441 | if FLAGS.gan_loss_type == 'GAN': 442 | d_loss_real = tf.reduce_mean( 443 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[0], 444 | labels=tf.ones_like(D_out[0]))) 445 | d_loss_fake = tf.reduce_mean( 446 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[1], 447 | labels=tf.zeros_like(D_out[1]))) 448 | d_loss_noisy = tf.reduce_mean( 449 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[2], 450 | labels=tf.zeros_like(D_out[2]))) 451 | dis_loss = d_loss_real + d_loss_fake + d_loss_noisy 452 | elif FLAGS.gan_loss_type == 'MaGAN': 453 | d_loss_real = tf.reduce_mean( 454 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[0], 455 | labels=tf.ones_like(D_out[0]))) 456 | d_loss_fake = tf.reduce_mean( 457 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[1], 458 | labels=tf.zeros_like(D_out[1]))) 459 | d_loss_noisy = tf.reduce_mean( 460 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[2], 461 | labels=tf.zeros_like(D_out[2]))) 462 | d_loss_fakeLR = tf.reduce_mean( 463 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[3], 464 | labels=tf.zeros_like(D_out[3]))) 465 | d_loss_noisyLR = tf.reduce_mean( 466 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_out[4], 467 | labels=tf.zeros_like(D_out[4]))) 468 | dis_loss = 2 *d_loss_real + d_loss_fake + d_loss_noisy + d_loss_fakeLR + d_loss_noisyLR 469 | 470 | Y_loss_real = tf.reduce_mean( 471 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[0], 472 | labels=tf.ones_like(Y_out[0]))) 473 | Y_loss_fake = tf.reduce_mean( 474 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[1], 475 | labels=tf.zeros_like(Y_out[1]))) 476 | Y_loss_noisy = tf.reduce_mean( 477 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[2], 478 | labels=tf.zeros_like(Y_out[2]))) 479 | Y_loss_fakeLR = tf.reduce_mean( 480 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[3], 481 | labels=tf.zeros_like(Y_out[3]))) 482 | Y_loss_noisyLR = tf.reduce_mean( 483 | tf.nn.sigmoid_cross_entropy_with_logits(logits=Y_out[4], 484 | labels=tf.zeros_like(Y_out[4]))) 485 | Y_dis_loss = 2 * Y_loss_real + Y_loss_fake + Y_loss_noisy + Y_loss_fakeLR + Y_loss_noisyLR 486 | 487 | C_loss_real = tf.reduce_mean( 488 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[0], 489 | labels=tf.ones_like(C_out[0]))) 490 | C_loss_fake = tf.reduce_mean( 491 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[1], 492 | labels=tf.zeros_like(C_out[1]))) 493 | C_loss_noisy = tf.reduce_mean( 494 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[2], 495 | labels=tf.zeros_like(C_out[2]))) 496 | C_loss_fakeLR = tf.reduce_mean( 497 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[3], 498 | labels=tf.zeros_like(C_out[3]))) 499 | C_loss_noisyLR = tf.reduce_mean( 500 | tf.nn.sigmoid_cross_entropy_with_logits(logits=C_out[4], 501 | labels=tf.zeros_like(C_out[4]))) 502 | 503 | C_dis_loss = 2 * C_loss_real + C_loss_fake + C_loss_noisy + C_loss_fakeLR + C_loss_noisyLR 504 | 505 | else: 506 | raise ValueError('Unknown GAN loss function type') 507 | 508 | self.summary_target['generator_loss'] = gen_loss 509 | self.summary_target['discriminator_loss'] = dis_loss 510 | return gen_loss, g_gen_loss, dis_loss, Y_dis_loss, C_dis_loss 511 | 512 | def add_summary_writer(self): 513 | return [tf.summary.scalar(key, value) for key, value in self.summary_target.items()] 514 | 515 | 516 | class Optimizer(object): 517 | """class to build optimizers""" 518 | @staticmethod 519 | def pretrain_optimizer(FLAGS, global_iter, pre_gen_loss, pre_dis_loss, Y_loss, C_loss): 520 | learning_rate = tf.train.exponential_decay(FLAGS.pretrain_learning_rate, global_iter, 521 | FLAGS.pretrain_lr_decay_step, 0.5, staircase=True) 522 | 523 | with tf.name_scope('optimizer'): 524 | with tf.variable_scope('optimizer_discriminator'): 525 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 526 | pre_dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=pre_dis_loss, 527 | var_list=dis_var) 528 | with tf.name_scope('optimizer'): 529 | with tf.variable_scope('Y_optimizer_discriminator'): 530 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Y_discriminator') 531 | Y_dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=Y_loss, 532 | var_list=dis_var) 533 | with tf.name_scope('optimizer'): 534 | with tf.variable_scope('C_optimizer_discriminator'): 535 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='C_discriminator') 536 | C_dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=C_loss, 537 | var_list=dis_var) 538 | 539 | with tf.variable_scope('optimizer_generator'): 540 | pre_gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 541 | pre_gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=pre_gen_loss, 542 | global_step=global_iter, 543 | var_list=pre_gen_var) 544 | 545 | return pre_gen_var, pre_gen_optimizer, pre_dis_optimizer, Y_dis_optimizer, C_dis_optimizer 546 | @staticmethod 547 | def pretrain_optimizer2(FLAGS, global_iter, pre_gen_loss): 548 | learning_rate = tf.train.exponential_decay(FLAGS.pretrain_learning_rate, global_iter, 549 | FLAGS.pretrain_lr_decay_step, 0.5, staircase=True) 550 | 551 | with tf.name_scope('optimizer'): 552 | with tf.variable_scope('optimizer_generator'): 553 | pre_gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 554 | pre_gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=pre_gen_loss, 555 | global_step=global_iter, 556 | var_list=pre_gen_var) 557 | 558 | return pre_gen_var, pre_gen_optimizer 559 | 560 | @staticmethod 561 | def gan_optimizer(FLAGS, global_iter, dis_loss, gen_loss, Y_dis_loss, C_dis_loss): 562 | boundaries = [100000, 200000, 300000, 400000] 563 | values = [FLAGS.learning_rate, FLAGS.learning_rate * 0.5, FLAGS.learning_rate * 0.5 ** 2, 564 | FLAGS.learning_rate * 0.5 ** 3, FLAGS.learning_rate * 0.5 ** 4] 565 | learning_rate = tf.train.piecewise_constant(global_iter, boundaries, values) 566 | with tf.name_scope('optimizer'): 567 | with tf.variable_scope('optimizer_discriminator'): 568 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 569 | dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, 570 | beta1=0.5).minimize(loss=dis_loss, var_list=dis_var) 571 | with tf.name_scope('optimizer'): 572 | with tf.variable_scope('Y_optimizer_discriminator'): 573 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Y_discriminator') 574 | Y_dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=Y_dis_loss, 575 | var_list=dis_var) 576 | with tf.name_scope('optimizer'): 577 | with tf.variable_scope('C_optimizer_discriminator'): 578 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='C_discriminator') 579 | C_dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=C_dis_loss, 580 | var_list=dis_var) 581 | with tf.name_scope('optimizer'): 582 | with tf.variable_scope('optimizer_generator'): 583 | with tf.control_dependencies([dis_optimizer]): 584 | gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 585 | gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.5).minimize(loss=gen_loss, 586 | global_step=global_iter, 587 | var_list=gen_var) 588 | 589 | 590 | return dis_var, dis_optimizer, gen_var, gen_optimizer, Y_dis_optimizer, C_dis_optimizer -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import glob 4 | import tensorflow as tf 5 | import cv2 6 | import numpy as np 7 | 8 | 9 | def log(logflag, message, level='info'): 10 | """logging to stdout and logfile if flag is true""" 11 | print(message, flush=True) 12 | 13 | if logflag: 14 | if level == 'info': 15 | logging.info(message) 16 | elif level == 'warning': 17 | logging.warning(message) 18 | elif level == 'error': 19 | logging.error(message) 20 | elif level == 'critical': 21 | logging.critical(message) 22 | 23 | 24 | def create_dirs(target_dirs): 25 | """create necessary directories to save output files""" 26 | for dir_path in target_dirs: 27 | if not os.path.isdir(dir_path): 28 | os.makedirs(dir_path) 29 | 30 | 31 | def normalize_images(*arrays): 32 | """normalize input image arrays""" 33 | return [arr / 127.5 - 1 for arr in arrays] 34 | 35 | 36 | def de_normalize_image(image): 37 | """de-normalize input image array""" 38 | return (image + 1) * 127.5 39 | 40 | 41 | def save_image(args, images, phase, global_iter, save_max_num=5): 42 | """save images in specified directory""" 43 | save_dir = '' 44 | if phase == 'pre-train': 45 | save_dir = args.pre_train_result_dir 46 | elif phase =='pre-raw': 47 | save_dir = args.pre_raw_result_dir 48 | elif phase == 'pre-valid' : 49 | save_dir = args.pre_valid_result_dir 50 | elif phase == 'pre-valid_LR' : 51 | save_dir = args.pre_valid_LR_result_dir 52 | elif phase == 'train' : 53 | save_dir = args.train_result_dir 54 | elif phase == 'valid' : 55 | save_dir = args.valid_result_dir 56 | elif phase == 'valid_LR' : 57 | save_dir = args.valid_LR_result_dir 58 | elif phase == 'test' : 59 | save_dir = args.test_result_dir 60 | elif phase == 'test_LR' : 61 | save_dir = args.test_LR_result_dir 62 | elif phase == 'pre-test' : 63 | save_dir = './pre_test_images/' 64 | if not os.path.isdir(save_dir) : 65 | os.makedirs(save_dir) 66 | else: 67 | print('specified phase is invalid') 68 | 69 | 70 | for i, img in enumerate(images): 71 | if i >= save_max_num: 72 | break 73 | #cv2.imwrite(save_dir + '/{0}.png'.format(global_iter+901), de_normalize_image(img)) 74 | cv2.imwrite(save_dir +'/' +global_iter, de_normalize_image(img)) 75 | 76 | 77 | def crop(img, args): 78 | """crop patch from an image with specified size""" 79 | img_h, img_w, _ = img.shape 80 | 81 | rand_h = np.random.randint(img_h - args.crop_size) 82 | rand_w = np.random.randint(img_w - args.crop_size) 83 | 84 | return img[rand_h:rand_h + args.crop_size, rand_w:rand_w + args.crop_size, :] 85 | 86 | 87 | def data_augmentation(LR_images, HR_images, aug_type='horizontal_flip'): 88 | """data augmentation. input arrays should be [N, H, W, C]""" 89 | 90 | if aug_type == 'horizontal_flip': 91 | return LR_images[:, :, ::-1, :], HR_images[:, :, ::-1, :] 92 | elif aug_type == 'rotation_90': 93 | return np.rot90(LR_images, k=1, axes=(1, 2)), np.rot90(HR_images, k=1, axes=(1, 2)) 94 | 95 | 96 | def generate_batch(NLR_list, NHR_list, CLR_list, CHR_list, args) : 97 | patch_h = args.crop_size 98 | patch_w = args.crop_size 99 | scale = args.scale_SR 100 | stride = args.stride 101 | 102 | patches_NLR = list() 103 | patches_NHR = list() 104 | patches_CLR = list() 105 | patches_CHR = list() 106 | for i in range(len(NLR_list)) : 107 | NLR = cv2.imread(NLR_list[i]) 108 | NHR = cv2.imread(NHR_list[i]) 109 | CLR = cv2.imread(CLR_list[i]) 110 | CHR = cv2.imread(CHR_list[i]) 111 | 112 | NLR = NLR / 127.5 - 1. 113 | NHR = NHR / 127.5 - 1. 114 | CLR = CLR / 127.5 - 1. 115 | CHR = CHR / 127.5 - 1 116 | NLR_x, NLR_y, depth = NLR.shape 117 | NHR_x, NHR_y, depth = NLR.shape 118 | CLR_x, CLR_y, depth = CLR.shape 119 | 120 | for h in range(0, NLR_x-patch_h, stride) : 121 | for w in range(0, NLR_y-patch_w, stride) : 122 | x = h 123 | y = w 124 | patch_NLR = NLR[x:x+patch_h,y:y+patch_w] 125 | patches_NLR.append(patch_NLR) 126 | 127 | for h in range(0, NHR_x-patch_h*4, stride*4) : 128 | for w in range(0, NHR_y-patch_w*4, stride*4) : 129 | x = h 130 | y = w 131 | patch_NHR = NHR[x:x+patch_h,y:y+patch_w] 132 | patches_NHR.append(patch_NHR) 133 | 134 | for h in range(0, CLR_x - patch_h, stride): 135 | for w in range(0, CLR_y - patch_w, stride): 136 | x = h 137 | y = w 138 | #x = np.random.randint(CLR_x - patch_h) 139 | #y = np.random.randint(CLR_y - patch_w) 140 | patch_CLR = CLR[x:x+patch_h,y:y+patch_w] 141 | patches_CLR.append(patch_CLR) 142 | 143 | x *= scale 144 | y *= scale 145 | 146 | patch_CHR = CHR[x:x+patch_h*scale,y:y+patch_w*scale] 147 | patches_CHR.append(patch_CHR) 148 | """ 149 | s = cv2.cvtColor(patch_s, cv2.COLOR_BGR2RGB) 150 | t = cv2.cvtColor(patch_t, cv2.COLOR_BGR2RGB) 151 | import matplotlib.pyplot as plt 152 | plt.figure(0) 153 | plt.imshow(s) 154 | 155 | plt.figure(1) 156 | plt.imshow(t) 157 | plt.show() 158 | """ 159 | if len(patches_NLR) > len(patches_CLR) : 160 | patches_NLR = patches_NLR[:len(patches_CLR)] 161 | patches_NHR = patches_NHR[:len(patches_CLR)] 162 | elif len(patches_NLR) < len(patches_CLR) : 163 | patches_CLR = patches_CLR[:len(patches_NLR)] 164 | patches_CHR = patches_CHR[:len(patches_CHR)] 165 | np.random.seed(36) 166 | np.random.shuffle(patches_NLR) 167 | np.random.seed(36) 168 | np.random.shuffle(patches_NHR) 169 | np.random.seed(36) 170 | np.random.shuffle(patches_CLR) 171 | np.random.seed(36) 172 | np.random.shuffle(patches_CHR) 173 | 174 | return np.array(patches_NLR), np.array(patches_NHR), np.array(patches_CLR), np.array(patches_CHR) 175 | 176 | def generate_pretrain_batch(source_list, target_list, args) : 177 | patch_h = args.crop_size 178 | patch_w = args.crop_size 179 | scale = args.scale_SR 180 | stride = args.stride 181 | 182 | patch_source = list() 183 | patch_target = list() 184 | for i in range(len(source_list)) : 185 | source = cv2.imread(source_list[i]) 186 | target = cv2.imread(target_list[i]) 187 | 188 | source = source / 127.5 - 1. 189 | target = target / 127.5 - 1 190 | source_x, source_y, depth = source.shape 191 | target_x, target_y, depth = target.shape 192 | 193 | for h in range(0, source_x - patch_h, stride): 194 | for w in range(0, source_y - patch_w, stride): 195 | x = h 196 | y = w 197 | patch_s = source[x:x+patch_h,y:y+patch_w] 198 | patch_source.append(patch_s) 199 | 200 | for h in range(0, target_x - patch_h, stride): 201 | for w in range(0, target_y - patch_w, stride): 202 | x = h 203 | y = w 204 | patch_t = target[x:x+patch_h,y:y+patch_w] 205 | patch_target.append(patch_t) 206 | 207 | if len(patch_source) > len(patch_target) : 208 | patch_source = patch_source[:len(patch_target)] 209 | elif len(patch_target) < len(patch_source) : 210 | patch_target = patch_target[:len(patch_source)] 211 | 212 | np.random.seed(36) 213 | np.random.shuffle(patch_source) 214 | np.random.seed(36) 215 | np.random.shuffle(patch_target) 216 | 217 | return np.array(patch_source), np.array(patch_target) 218 | 219 | def generate_pretrain_batch2(source_list, target_list, args) : 220 | patch_h = args.crop_size 221 | patch_w = args.crop_size 222 | scale = args.scale_SR 223 | stride = args.stride 224 | 225 | patch_source = list() 226 | patch_target = list() 227 | for i in range(len(source_list)) : 228 | source = cv2.imread(source_list[i]) 229 | target = cv2.imread(target_list[i]) 230 | 231 | source = source / 127.5 - 1. 232 | target = target / 127.5 - 1 233 | source_x, source_y, depth = source.shape 234 | target_x, target_y, depth = target.shape 235 | 236 | for h in range(0, source_x - patch_h, stride): 237 | for w in range(0, source_y - patch_w, stride): 238 | x = h 239 | y = w 240 | patch_s = source[x:x+patch_h,y:y+patch_w] 241 | patch_source.append(patch_s) 242 | 243 | x *= scale 244 | y *= scale 245 | 246 | patch_t = target[x:x+patch_h*scale, y:y+patch_w*scale] 247 | patch_target.append(patch_t) 248 | 249 | np.random.seed(36) 250 | np.random.shuffle(patch_source) 251 | np.random.seed(36) 252 | np.random.shuffle(patch_target) 253 | 254 | return np.array(patch_source), np.array(patch_target) 255 | 256 | def generate_testset(source_list, target_list, args) : 257 | 258 | source = cv2.imread(source_list) 259 | image_source = source / 127.5 - 1 260 | 261 | target = cv2.imread(target_list) 262 | image_target = target / 127.5 - 1 263 | 264 | return np.array(image_source), np.array(image_target) 265 | 266 | # TensorFlow Better Bicubic Downsample 267 | # https://github.com/trevor-m/tensorflow-bicubic-downsample 268 | def bicubic_kernel(x, a=-0.5): 269 | """https://clouard.users.greyc.fr/Pantheon/experiments/rescaling/index-en.html#bicubic""" 270 | if abs(x) <= 1: 271 | return (a + 2)*abs(x)**3 - (a + 3)*abs(x)**2 + 1 272 | elif 1 < abs(x) and abs(x) < 2: 273 | return a*abs(x)**3 - 5*a*abs(x)**2 + 8*a*abs(x) - 4*a 274 | else: 275 | return 0 276 | 277 | def build_filter(factor): 278 | size = factor*4 279 | k = np.zeros((size)) 280 | for i in range(size): 281 | x = (1/factor)*(i- np.floor(size/2) +0.5) 282 | k[i] = bicubic_kernel(x) 283 | k = k / np.sum(k) 284 | # make 2d 285 | k = np.outer(k, k.T) 286 | k = tf.constant(k, dtype=tf.float32, shape=(size, size, 1, 1)) 287 | return tf.concat([k, k, k], axis=2) 288 | 289 | def apply_bicubic_downsample(x, filter, factor): 290 | """Downsample x by a factor of factor, using the filter built by build_filter() 291 | x: a rank 4 tensor with format NHWC 292 | filter: from build_filter(factor) 293 | factor: downsampling factor (ex: factor=2 means the output size is (h/2, w/2)) 294 | """ 295 | # using padding calculations from https://www.tensorflow.org/api_guides/python/nn#Convolution 296 | filter_height = factor*4 297 | filter_width = factor*4 298 | strides = factor 299 | pad_along_height = max(filter_height - strides, 0) 300 | pad_along_width = max(filter_width - strides, 0) 301 | # compute actual padding values for each side 302 | pad_top = pad_along_height // 2 303 | pad_bottom = pad_along_height - pad_top 304 | pad_left = pad_along_width // 2 305 | pad_right = pad_along_width - pad_left 306 | # apply mirror padding 307 | x = tf.pad(x, [[0,0], [pad_top,pad_bottom], [pad_left,pad_right], [0,0]], mode='REFLECT') 308 | # downsampling performed by strided conv 309 | x = tf.nn.depthwise_conv2d(x, filter=filter, strides=[1,strides,strides,1], padding='VALID') 310 | return x 311 | 312 | 313 | 314 | 315 | #---------------------------- unused ---------------------- 316 | def load_and_save_data(args, logflag): 317 | """make HR and LR data. And save them as npz files""" 318 | assert os.path.isdir(args.data_dir) is True, 'Directory specified by data_dir does not exist or is not a directory' 319 | 320 | all_file_path = glob.glob(args.data_dir + '/*') 321 | assert len(all_file_path) > 0, 'No file in the directory' 322 | 323 | ret_HR_image = [] 324 | ret_LR_image = [] 325 | 326 | for file in all_file_path: 327 | img = cv2.imread(file) 328 | filename = file.rsplit('/', 1)[-1] 329 | 330 | # crop patches if flag is true. Otherwise just resize HR and LR images 331 | if args.crop: 332 | for _ in range(args.num_crop_per_image): 333 | img_h, img_w, _ = img.shape 334 | 335 | if (img_h < args.crop_size) or (img_w < args.crop_size): 336 | print('Skip crop target image because of insufficient size') 337 | continue 338 | 339 | HR_image = crop(img, args) 340 | LR_crop_size = np.int(np.floor(args.crop_size / args.scale_SR)) 341 | LR_image = cv2.resize(HR_image, (LR_crop_size, LR_crop_size), interpolation=cv2.INTER_LANCZOS4) 342 | 343 | cv2.imwrite(args.HR_data_dir + '/' + filename, HR_image) 344 | cv2.imwrite(args.LR_data_dir + '/' + filename, LR_image) 345 | 346 | ret_HR_image.append(HR_image) 347 | ret_LR_image.append(LR_image) 348 | else: 349 | HR_image = cv2.resize(img, (args.HR_image_size, args.HR_image_size), interpolation=cv2.INTER_LANCZOS4) 350 | LR_image = cv2.resize(img, (args.LR_image_size, args.LR_image_size), interpolation=cv2.INTER_LANCZOS4) 351 | 352 | cv2.imwrite(args.HR_data_dir + '/' + filename, HR_image) 353 | cv2.imwrite(args.LR_data_dir + '/' + filename, LR_image) 354 | 355 | ret_HR_image.append(HR_image) 356 | ret_LR_image.append(LR_image) 357 | 358 | assert len(ret_HR_image) > 0 and len(ret_LR_image) > 0, 'No availale image is found in the directory' 359 | log(logflag, 'Data process : {} images are processed'.format(len(ret_HR_image)), 'info') 360 | 361 | ret_HR_image = np.array(ret_HR_image) 362 | ret_LR_image = np.array(ret_LR_image) 363 | 364 | if args.data_augmentation: 365 | LR_flip, HR_flip = data_augmentation(ret_LR_image, ret_HR_image, aug_type='horizontal_flip') 366 | LR_rot, HR_rot = data_augmentation(ret_LR_image, ret_HR_image, aug_type='rotation_90') 367 | 368 | ret_LR_image = np.append(ret_LR_image, LR_flip, axis=0) 369 | ret_HR_image = np.append(ret_HR_image, HR_flip, axis=0) 370 | ret_LR_image = np.append(ret_LR_image, LR_rot, axis=0) 371 | ret_HR_image = np.append(ret_HR_image, HR_rot, axis=0) 372 | 373 | del LR_flip, HR_flip, LR_rot, HR_rot 374 | 375 | np.savez(args.npz_data_dir + '/' + args.HR_npz_filename, images=ret_HR_image) 376 | np.savez(args.npz_data_dir + '/' + args.LR_npz_filename, images=ret_LR_image) 377 | 378 | return ret_HR_image, ret_LR_image 379 | 380 | 381 | def load_npz_data(FLAGS): 382 | """load array data from data_path""" 383 | return np.load(FLAGS.npz_data_dir + '/' + FLAGS.HR_npz_filename)['images'], \ 384 | np.load(FLAGS.npz_data_dir + '/' + FLAGS.LR_npz_filename)['images'] 385 | 386 | 387 | def load_inference_data(FLAGS): 388 | """load data from directory for inference""" 389 | assert os.path.isdir(FLAGS.data_dir) is True, 'Directory specified by data_dir does not exist or is not a directory' 390 | 391 | all_file_path = glob.glob(FLAGS.data_dir + '/*') 392 | assert len(all_file_path) > 0, 'No file in the directory' 393 | 394 | ret_LR_image = [] 395 | ret_filename = [] 396 | 397 | for file in all_file_path: 398 | img = cv2.imread(file) 399 | img = normalize_images(img) 400 | ret_LR_image.append(img[0][np.newaxis, ...]) 401 | 402 | ret_filename.append(file.rsplit('/', 1)[-1]) 403 | 404 | assert len(ret_LR_image) > 0, 'No available image is found in the directory' 405 | 406 | return ret_LR_image, ret_filename --------------------------------------------------------------------------------