├── img ├── 0833.png ├── 0887.png ├── 0896.png ├── 1117e6b64a7b4336df58eb351cff435529485e91.png ├── 11183b7a2e0ee4be9990721d9ddc7fa34997b41f.png └── 111b822af95747f45f5d25a84f8094c10b27c765.png ├── evaluation.py ├── lib ├── ops.py ├── pretrain_generator.py ├── utils.py ├── network.py └── train_module.py ├── inference.py ├── visualize.py ├── network_interpolation.py ├── README.md ├── train.py └── LICENSE /img/0833.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/0833.png -------------------------------------------------------------------------------- /img/0887.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/0887.png -------------------------------------------------------------------------------- /img/0896.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/0896.png -------------------------------------------------------------------------------- /img/1117e6b64a7b4336df58eb351cff435529485e91.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/1117e6b64a7b4336df58eb351cff435529485e91.png -------------------------------------------------------------------------------- /img/11183b7a2e0ee4be9990721d9ddc7fa34997b41f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/11183b7a2e0ee4be9990721d9ddc7fa34997b41f.png -------------------------------------------------------------------------------- /img/111b822af95747f45f5d25a84f8094c10b27c765.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/111b822af95747f45f5d25a84f8094c10b27c765.png -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import cv2 6 | from skimage.measure import compare_psnr, compare_ssim 7 | 8 | 9 | def calc_measures(hr_path, calc_psnr=True, calc_ssim=True): 10 | """calculate PSNR and SSIM for all HR images and their mean. 11 | These paired images should have the same filename. 12 | """ 13 | 14 | HR_files = glob.glob(hr_path + '/*') 15 | mean_psnr = 0 16 | mean_ssim = 0 17 | 18 | for file in HR_files: 19 | hr_img = cv2.imread(file) 20 | filename = file.rsplit('/', 1)[-1] 21 | path = os.path.join(args.inference_result, filename) 22 | 23 | if not os.path.isfile(path): 24 | raise FileNotFoundError('') 25 | 26 | inf_img = cv2.imread(path) 27 | 28 | # compare HR image and inferenced image with measures 29 | print('-' * 10) 30 | if calc_psnr: 31 | psnr = compare_psnr(hr_img, inf_img) 32 | print('{0} : PSNR {1:.3f} dB'.format(filename, psnr)) 33 | mean_psnr += psnr 34 | if calc_ssim: 35 | ssim = compare_ssim(hr_img, inf_img, multichannel=True) 36 | print('{0} : SSIM {1:.3f}'.format(filename, ssim)) 37 | mean_ssim += ssim 38 | 39 | print('-' * 10) 40 | if calc_psnr: 41 | print('mean-PSNR {:.3f} dB'.format(mean_psnr / len(HR_files))) 42 | if calc_ssim: 43 | print('mean-SSIM {:.3f}'.format(mean_ssim / len(HR_files))) 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument('--HR_data_dir', default='./data/div2_inf_HR', type=str) 50 | parser.add_argument('--inference_result', default='./inference_result_div2', type=str) 51 | 52 | args = parser.parse_args() 53 | 54 | calc_measures(args.HR_data_dir, calc_psnr=True, calc_ssim=True) 55 | -------------------------------------------------------------------------------- /lib/ops.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras.applications.vgg19 import VGG19 5 | 6 | 7 | def scale_initialization(weights, FLAGS): 8 | return [tf.assign(weight, weight * FLAGS.weight_initialize_scale) for weight in weights] 9 | 10 | 11 | def _transfer_vgg19_weight(FLAGS, weight_dict): 12 | from_model = VGG19(include_top=False, weights='imagenet', input_tensor=None, 13 | input_shape=(FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel)) 14 | 15 | fetch_weight = [] 16 | 17 | for layer in from_model.layers: 18 | if 'conv' in layer.name: 19 | W, b = layer.get_weights() 20 | 21 | fetch_weight.append( 22 | tf.assign(weight_dict['loss_generator/perceptual_vgg19/{}/kernel'.format(layer.name)], W) 23 | ) 24 | fetch_weight.append( 25 | tf.assign(weight_dict['loss_generator/perceptual_vgg19/{}/bias'.format(layer.name)], b) 26 | ) 27 | 28 | return fetch_weight 29 | 30 | 31 | def load_vgg19_weight(FLAGS): 32 | vgg_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='loss_generator/perceptual_vgg19') 33 | 34 | assert len(vgg_weight) > 0, 'No VGG19 weight was collected. The target scope might be wrong.' 35 | 36 | weight_dict = {} 37 | for weight in vgg_weight: 38 | weight_dict[weight.name.rsplit(':', 1)[0]] = weight 39 | 40 | return _transfer_vgg19_weight(FLAGS, weight_dict) 41 | 42 | 43 | def extract_weight(network_vars): 44 | weight_dict = OrderedDict() 45 | 46 | for weight in network_vars: 47 | weight_dict[weight.name] = weight.eval() 48 | 49 | return weight_dict 50 | 51 | 52 | def interpolate_weight(FLAGS, pretrain_weight): 53 | fetch_weight = [] 54 | alpha = FLAGS.interpolation_param 55 | 56 | for name, pre_weight in pretrain_weight.items(): 57 | esrgan_weight = tf.get_default_graph().get_tensor_by_name(name) 58 | 59 | assert pre_weight.shape == esrgan_weight.shape, 'The shape of weights does not match' 60 | 61 | fetch_weight.append(tf.assign(esrgan_weight, (1 - alpha) * pre_weight + alpha * esrgan_weight)) 62 | 63 | return fetch_weight 64 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from lib.train_module import Network 8 | from lib.utils import create_dirs, de_normalize_image, load_inference_data 9 | 10 | 11 | def set_flags(): 12 | Flags = tf.app.flags 13 | 14 | Flags.DEFINE_string('data_dir', './data/inference', 'inference data directory') 15 | Flags.DEFINE_string('checkpoint_dir', './checkpoint', 'checkpoint directory') 16 | Flags.DEFINE_string('inference_checkpoint', '', 17 | 'checkpoint to use for inference. Empty string means the latest checkpoint is used') 18 | Flags.DEFINE_string('inference_result_dir', './inference_result', 'output directory during inference') 19 | Flags.DEFINE_integer('channel', 3, 'Number of input/output image channel') 20 | Flags.DEFINE_integer('num_repeat_RRDB', 15, 'The number of repeats of RRDB blocks') 21 | Flags.DEFINE_float('residual_scaling', 0.2, 'residual scaling parameter') 22 | Flags.DEFINE_integer('initialization_random_seed', 111, 'random seed of networks initialization') 23 | 24 | return Flags.FLAGS 25 | 26 | 27 | def main(): 28 | # set flag 29 | FLAGS = set_flags() 30 | 31 | # make dirs 32 | target_dirs = [FLAGS.inference_result_dir] 33 | create_dirs(target_dirs) 34 | 35 | # load test data 36 | LR_inference, LR_filenames = load_inference_data(FLAGS) 37 | 38 | LR_data = tf.placeholder(tf.float32, shape=[1, None, None, FLAGS.channel], name='LR_input') 39 | 40 | # build Generator 41 | network = Network(FLAGS, LR_data) 42 | gen_out = network.generator() 43 | 44 | fetches = {'gen_HR': gen_out} 45 | 46 | # Start Session 47 | config = tf.ConfigProto() 48 | config.gpu_options.allow_growth = True 49 | 50 | with tf.Session(config=config) as sess: 51 | print('Inference start') 52 | 53 | saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')) 54 | 55 | if FLAGS.inference_checkpoint: 56 | saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, FLAGS.inference_checkpoint)) 57 | else: 58 | print('No checkpoint is specified. The latest one is used for inference') 59 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) 60 | 61 | for i, test_img in enumerate(LR_inference): 62 | 63 | feed_dict = { 64 | LR_data: test_img 65 | } 66 | 67 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 68 | 69 | cv2.imwrite(os.path.join(FLAGS.inference_result_dir, LR_filenames[i]), 70 | de_normalize_image(np.squeeze(result['gen_HR']))) 71 | 72 | print('Inference end') 73 | 74 | 75 | if __name__ == '__main__': 76 | main() 77 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import math 4 | import os 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | def visualize(args): 11 | """visualize images. Bicubic interpolation(generated in this script), ESRGAN, ESRGAN with network interpolation 12 | and HR are tiled into an image. 13 | The images of ESRGAN, ESRGAN with network interpolation have the same filename as HR. 14 | """ 15 | HR_files = glob.glob(args.HR_data_dir + '/*') 16 | 17 | for file in HR_files: 18 | # HR(GT) 19 | hr_img = cv2.imread(file) 20 | h, w, _ = hr_img.shape 21 | filename = file.rsplit('/', 1)[-1].rsplit('.', 1)[0] 22 | 23 | # LR -> bicubic 24 | r_h, r_w = math.floor(h / 4), math.floor(w / 4) 25 | lr_img = cv2.resize(hr_img, (r_w, r_h), cv2.INTER_CUBIC) 26 | bic_img = cv2.resize(lr_img, (w, h), cv2.INTER_CUBIC) 27 | 28 | # inference 29 | inf_path_jpg = os.path.join(args.inference_result, filename + '.jpg') 30 | inf_path_png = os.path.join(args.inference_result, filename + '.png') 31 | 32 | if os.path.isfile(inf_path_jpg): 33 | inf_path = inf_path_jpg 34 | elif os.path.isfile(inf_path_png): 35 | inf_path = inf_path_png 36 | else: 37 | raise FileNotFoundError('Images should have the same filename as HR image and be the formats of jpg or png') 38 | 39 | inf_img = cv2.imread(inf_path) 40 | 41 | # network interpolation inference 42 | ni_path_jpg = os.path.join(args.network_interpolation_result, filename + '.jpg') 43 | ni_path_png = os.path.join(args.network_interpolation_result, filename + '.png') 44 | if os.path.isfile(ni_path_jpg): 45 | ni_path = ni_path_jpg 46 | elif os.path.isfile(inf_path_png): 47 | ni_path = ni_path_png 48 | else: 49 | raise FileNotFoundError('Images should have the same filename as HR image and be the formats of jpg or png') 50 | 51 | ni_img = cv2.imread(ni_path) 52 | 53 | h_upper = int(math.floor(h / 2) + args.path_size / 2) 54 | h_lower = int(math.floor(h / 2) - args.path_size / 2) 55 | 56 | w_right = int(math.floor(w / 2) + args.path_size / 2) 57 | w_left = int(math.floor(w / 2) - args.path_size / 2) 58 | 59 | h_size = h_upper - h_lower 60 | w_size = w_right - w_left 61 | 62 | out_arr = np.empty((h_size, w_size * 4, 3)) 63 | 64 | # tile images from left to right : bicubic -> ESRGAN-inference -> Network interpolation -> HR(GT) 65 | out_arr[:, :w_size, :] = bic_img[h_lower:h_upper, w_left:w_right, :] 66 | out_arr[:, w_size:w_size * 2, :] = inf_img[h_lower:h_upper, w_left:w_right, :] 67 | out_arr[:, w_size * 2:w_size * 3, :] = ni_img[h_lower:h_upper, w_left:w_right, :] 68 | out_arr[:, w_size * 3:w_size * 4, :] = hr_img[h_lower:h_upper, w_left:w_right, :] 69 | 70 | cv2.imwrite(args.output_dir + '/' + '{}.png'.format(filename), out_arr) 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | 76 | parser.add_argument('--HR_data_dir', default='./data/div2_inf_HR', type=str) 77 | parser.add_argument('--inference_result', default='./inference_result_div2', type=str) 78 | parser.add_argument('--network_interpolation_result', default='./interpolation_result_div2', type=str) 79 | parser.add_argument('--path_size', default=512, type=str) 80 | parser.add_argument('--output_dir', default='./', type=str) 81 | 82 | args = parser.parse_args() 83 | 84 | visualize(args) 85 | -------------------------------------------------------------------------------- /lib/pretrain_generator.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import math 4 | 5 | from sklearn.utils import shuffle 6 | import tensorflow as tf 7 | 8 | from lib.ops import scale_initialization 9 | from lib.train_module import Network, Loss, Optimizer 10 | from lib.utils import log, normalize_images, save_image 11 | 12 | 13 | def train_pretrain_generator(FLAGS, LR_train, HR_train, logflag): 14 | """pre-train deep network as initialization weights of ESRGAN Generator""" 15 | log(logflag, 'Pre-train : Process start', 'info') 16 | 17 | LR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.LR_image_size, FLAGS.LR_image_size, FLAGS.channel], 18 | name='LR_input') 19 | HR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel], 20 | name='HR_input') 21 | 22 | # build Generator 23 | network = Network(FLAGS, LR_data) 24 | pre_gen_out = network.generator() 25 | 26 | # build loss function 27 | loss = Loss() 28 | pre_gen_loss = loss.pretrain_loss(pre_gen_out, HR_data) 29 | 30 | # build optimizer 31 | global_iter = tf.Variable(0, trainable=False) 32 | pre_gen_var, pre_gen_optimizer = Optimizer().pretrain_optimizer(FLAGS, global_iter, pre_gen_loss) 33 | 34 | # build summary writer 35 | pre_summary = tf.summary.merge(loss.add_summary_writer()) 36 | 37 | num_train_data = len(HR_train) 38 | num_batch_in_train = int(math.floor(num_train_data / FLAGS.batch_size)) 39 | num_epoch = int(math.ceil(FLAGS.num_iter / num_batch_in_train)) 40 | 41 | HR_train, LR_train = normalize_images(HR_train, LR_train) 42 | 43 | fetches = {'pre_gen_loss': pre_gen_loss, 'pre_gen_optimizer': pre_gen_optimizer, 'gen_HR': pre_gen_out, 44 | 'summary': pre_summary} 45 | 46 | gc.collect() 47 | 48 | config = tf.ConfigProto( 49 | gpu_options=tf.GPUOptions( 50 | allow_growth=True, 51 | visible_device_list=FLAGS.gpu_dev_num 52 | ) 53 | ) 54 | 55 | saver = tf.train.Saver(max_to_keep=10) 56 | 57 | # Start session 58 | with tf.Session(config=config) as sess: 59 | log(logflag, 'Pre-train : Training starts', 'info') 60 | 61 | sess.run(tf.global_variables_initializer()) 62 | sess.run(global_iter.initializer) 63 | sess.run(scale_initialization(pre_gen_var, FLAGS)) 64 | 65 | writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph, filename_suffix='pre-train') 66 | 67 | for epoch in range(num_epoch): 68 | log(logflag, 'Pre-train Epoch: {0}'.format(epoch), 'info') 69 | 70 | HR_train, LR_train = shuffle(HR_train, LR_train, random_state=222) 71 | 72 | for iteration in range(num_batch_in_train): 73 | current_iter = tf.train.global_step(sess, global_iter) 74 | 75 | if current_iter > FLAGS.num_iter: 76 | break 77 | 78 | feed_dict = { 79 | HR_data: HR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size], 80 | LR_data: LR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size] 81 | } 82 | 83 | # update weights 84 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 85 | 86 | # save summary every n iter 87 | if current_iter % FLAGS.train_summary_save_freq == 0: 88 | writer.add_summary(result['summary'], global_step=current_iter) 89 | 90 | # save samples every n iter 91 | if current_iter % FLAGS.train_sample_save_freq == 0: 92 | log(logflag, 93 | 'Pre-train iteration : {0}, pixel-wise_loss : {1}'.format(current_iter, result['pre_gen_loss']), 94 | 'info') 95 | save_image(FLAGS, result['gen_HR'], 'pre-train', current_iter, save_max_num=5) 96 | 97 | # save checkpoint 98 | if current_iter % FLAGS.train_ckpt_save_freq == 0: 99 | saver.save(sess, os.path.join(FLAGS.pre_train_checkpoint_dir, 'pre_gen'), global_step=current_iter) 100 | 101 | writer.close() 102 | log(logflag, 'Pre-train : Process end', 'info') 103 | -------------------------------------------------------------------------------- /network_interpolation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from lib.train_module import Network 8 | from lib.ops import extract_weight, interpolate_weight 9 | from lib.utils import create_dirs, de_normalize_image, load_inference_data 10 | 11 | 12 | def set_flags(): 13 | Flags = tf.app.flags 14 | 15 | Flags.DEFINE_string('data_dir', './data/inference', 'inference data directory') 16 | Flags.DEFINE_string('pre_train_checkpoint_dir', './pre_train_checkpoint', 'pre-train checkpoint directory') 17 | Flags.DEFINE_string('checkpoint_dir', './checkpoint', 'checkpoint directory') 18 | Flags.DEFINE_string('inference_pretrain_checkpoint', '', 19 | 'pretrain checkpoint to use for network interpolation. Empty string means the latest checkpoint is used') 20 | Flags.DEFINE_string('inference_checkpoint', '', 21 | 'checkpoint to use for network interpolation. Empty string means the latest checkpoint is used') 22 | Flags.DEFINE_string('interpolation_result_dir', './interpolation_result', 'output directory during inference') 23 | Flags.DEFINE_integer('channel', 3, 'Number of input/output image channel') 24 | Flags.DEFINE_integer('num_repeat_RRDB', 15, 'The number of repeats of RRDB blocks') 25 | Flags.DEFINE_float('residual_scaling', 0.2, 'residual scaling parameter') 26 | Flags.DEFINE_integer('initialization_random_seed', 111, 'random seed of networks initialization') 27 | Flags.DEFINE_float('interpolation_param', 0.8, 'tuning parameter for ') 28 | 29 | return Flags.FLAGS 30 | 31 | 32 | def main(): 33 | # set flag 34 | FLAGS = set_flags() 35 | 36 | # make dirs 37 | target_dirs = [FLAGS.interpolation_result_dir] 38 | create_dirs(target_dirs) 39 | 40 | # load test data 41 | LR_inference, LR_filenames = load_inference_data(FLAGS) 42 | LR_data = tf.placeholder(tf.float32, shape=[1, None, None, FLAGS.channel], name='LR_input') 43 | 44 | config = tf.ConfigProto() 45 | config.gpu_options.allow_growth = True 46 | 47 | #### Load pretrain model weights #### 48 | # build Generator 49 | network = Network(FLAGS, LR_data) 50 | gen_out = network.generator() 51 | 52 | with tf.Session(config=config) as sess: 53 | pre_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')) 54 | 55 | if FLAGS.inference_checkpoint: 56 | pre_saver.restore(sess, os.path.join(FLAGS.pre_train_checkpoint_dir, FLAGS.inference_pretrain_checkpoint)) 57 | else: 58 | print('No checkpoint is specified. The latest one is used for network interpolation') 59 | pre_saver.restore(sess, tf.train.latest_checkpoint(FLAGS.pre_train_checkpoint_dir)) 60 | 61 | pretrain_weight = extract_weight(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')) 62 | 63 | tf.reset_default_graph() 64 | 65 | #### Run network interpolation and inference #### 66 | LR_data = tf.placeholder(tf.float32, shape=[1, None, None, FLAGS.channel], name='LR_input') 67 | 68 | # build Generator 69 | network = Network(FLAGS, LR_data) 70 | gen_out = network.generator() 71 | 72 | fetches = {'gen_HR': gen_out} 73 | 74 | with tf.Session(config=config) as sess: 75 | print('Inference start') 76 | saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')) 77 | 78 | if FLAGS.inference_checkpoint: 79 | saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, FLAGS.inference_checkpoint)) 80 | else: 81 | print('No checkpoint is specified. The latest one is used for network interpolation') 82 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) 83 | 84 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) 85 | 86 | sess.run(interpolate_weight(FLAGS, pretrain_weight)) 87 | 88 | for i, test_img in enumerate(LR_inference): 89 | 90 | feed_dict = { 91 | LR_data: test_img 92 | } 93 | 94 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 95 | 96 | cv2.imwrite(os.path.join(FLAGS.interpolation_result_dir, LR_filenames[i]), 97 | de_normalize_image(np.squeeze(result['gen_HR']))) 98 | 99 | print('Inference end') 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ESRGAN (TensorFlow) 2 | 3 | This repository provides a TensorFlow implementation of the paper "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" by X. Wang et al. 4 | 5 | ## Dependencies 6 | tensorflow, openCV, sklearn, numpy 7 | 8 | The versions of my test environment : 9 | Python==3.6.8, tensorflow-gpu==1.12.0, openCV==4.1.0, scikit-learn==0.20.3, numpy==1.16.2 10 | 11 | ## How to Use 12 | 13 | #### 1. Prepare data for training 14 | 15 | Prepare your data and put them into the directory specified by the flag "data_dir"(e.g. './data/LSUN') of train.py. Other necessary directories are 16 | created automatically as set in the script. 17 | 18 | #### 2. Prepare data for training 19 | Run train.py script. The main processes are : 20 | - Data processing : create patches of HR and LR(by downsampling HR patches). These processed data can be saved in directories so that they can be recycled to use. 21 | 22 | - Pre-train with pixel-wise loss : As described in the paper, pre-training of Generator is done. You can set "pretrain_generator" flag to False to use an existing pre-trained checkpoint model. (training ESRGAN without pre-trained model is not supported.) 23 | 24 | - Training ESRGAN : based on pre-trained model, training ESRGAN is done 25 | 26 | ``` 27 | # python train.py 28 | 29 | (data directory can be passed by the optional argument) 30 | # python train.py --data_dir ./data/LSUN 31 | ``` 32 | 33 | #### 3. Inference LR data 34 | After training is finished, super-resolution of LR images is available. Input data can be specified "data_dir" of inference.py script. 35 | 36 | ``` 37 | # python inference.py 38 | 39 | (data directory can be passed by the optional argument) 40 | # python inference.py --data_dir ./data/inference 41 | ``` 42 | 43 | #### 4. Inference via Network interpolation 44 | The paper proposes the network interpolation method which linearly combines the weights of pixelwise-based pretrain model and ESRGAN generator. You can run this after training both pre-train model and ESRGAN finishes. Input data can be specified "data_dir" of network_interpolation.py script. 45 | 46 | ``` 47 | # python network_interpolation.py 48 | 49 | (data directory can be passed by the optional argument) 50 | # python network_interpolation.py --data_dir ./data/inference 51 | ``` 52 | 53 | ## Experiment Result 54 | #### DIV2K dataset 55 | DIV2K is a collection of 2K resolution high quality images.
56 | https://data.vision.ee.ethz.ch/cvl/DIV2K/ 57 | 58 | 59 | 60 | 61 | from left to right: bicubic interpolation, ESRGAN, ESRGAN with network interpolation, High resolution(GT). 4x super resolution. 62 | 63 | #### LSUN 64 | LSUN is a collection of ordinaly resolution bedroom images.
65 | https://www.kaggle.com/jhoward/lsun_bedroom/data 66 | 67 | 68 | 69 | 70 | from left to right: bicubic interpolation, ESRGAN, ESRGAN with network interpolation, High resolution(GT). 4x super resolution. 71 | 72 | #### Experiment condition 73 | - training with 800 images and cropped 2 patches per image for DIV2K 74 | - training with about 5000 images from 20% collection dataset and cropped 2 patches per image for LSUN 75 | - apply data augmentation(horizontal flip and rotate by 90 degree) 76 | - 15 RRDBs, 32 batchsize, 50,000 iteration per training phase. Other parameters are the same as the paper. 77 | - Network interpolation parameter is 0.2 78 | 79 | 80 | ## Limitations 81 | 82 | - Only 4x super-resolution is supported 83 | - Grayscale images are not supported 84 | - Only Single GPU usage 85 | 86 | 87 | ## To do list 88 | The following features have not been implemented apart from the paper. 89 | 90 | - [x] Perceptual loss using VGG19(currently pixel-wise loss is implemented instead) 91 | - [x] Learning rate scheduling 92 | - [x] Network interpolation 93 | - [x] Data augmentation 94 | - [x] Evaluation metrics 95 | 96 | ### Notes 97 | Some setting parameters like the number of RRDB blocks, mini-batch size, the number of iteration are changed corresponding to my test environment. 98 | So, please change them if you would prefer the same condition as the paper. 99 | 100 | 101 | ## Reference 102 | * Paper 103 | Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, and Chen Change Loy : ESRGAN: Enhanced Super-ResolutionGenerative Adversarial Networks, ECCV, 2018. http://openaccess.thecvf.com/content_ECCVW_2018/papers/11133/Wang_ESRGAN_Enhanced_Super-Resolution_Generative_Adversarial_Networks_ECCVW_2018_paper.pdf 104 | 105 | 106 | * Official implementation with Pytorch by the paper's authors 107 | https://github.com/xinntao/BasicSR 108 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import glob 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | 9 | def log(logflag, message, level='info'): 10 | """logging to stdout and logfile if flag is true""" 11 | print(message, flush=True) 12 | 13 | if logflag: 14 | if level == 'info': 15 | logging.info(message) 16 | elif level == 'warning': 17 | logging.warning(message) 18 | elif level == 'error': 19 | logging.error(message) 20 | elif level == 'critical': 21 | logging.critical(message) 22 | 23 | 24 | def create_dirs(target_dirs): 25 | """create necessary directories to save output files""" 26 | for dir_path in target_dirs: 27 | if not os.path.isdir(dir_path): 28 | os.makedirs(dir_path) 29 | 30 | 31 | def normalize_images(*arrays): 32 | """normalize input image arrays""" 33 | return [arr / 127.5 - 1 for arr in arrays] 34 | 35 | 36 | def de_normalize_image(image): 37 | """de-normalize input image array""" 38 | return (image + 1) * 127.5 39 | 40 | 41 | def save_image(FLAGS, images, phase, global_iter, save_max_num=5): 42 | """save images in specified directory""" 43 | if phase == 'train' or phase == 'pre-train': 44 | save_dir = FLAGS.train_result_dir 45 | elif phase == 'inference': 46 | save_dir = FLAGS.inference_result_dir 47 | save_max_num = len(images) 48 | else: 49 | print('specified phase is invalid') 50 | 51 | for i, img in enumerate(images): 52 | if i >= save_max_num: 53 | break 54 | 55 | cv2.imwrite(save_dir + '/{0}_HR_{1}_{2}.jpg'.format(phase, global_iter, i), de_normalize_image(img)) 56 | 57 | 58 | def crop(img, FLAGS): 59 | """crop patch from an image with specified size""" 60 | img_h, img_w, _ = img.shape 61 | 62 | rand_h = np.random.randint(img_h - FLAGS.crop_size) 63 | rand_w = np.random.randint(img_w - FLAGS.crop_size) 64 | 65 | return img[rand_h:rand_h + FLAGS.crop_size, rand_w:rand_w + FLAGS.crop_size, :] 66 | 67 | 68 | def data_augmentation(LR_images, HR_images, aug_type='horizontal_flip'): 69 | """data augmentation. input arrays should be [N, H, W, C]""" 70 | 71 | if aug_type == 'horizontal_flip': 72 | return LR_images[:, :, ::-1, :], HR_images[:, :, ::-1, :] 73 | elif aug_type == 'rotation_90': 74 | return np.rot90(LR_images, k=1, axes=(1, 2)), np.rot90(HR_images, k=1, axes=(1, 2)) 75 | 76 | 77 | def load_and_save_data(FLAGS, logflag): 78 | """make HR and LR data. And save them as npz files""" 79 | assert os.path.isdir(FLAGS.data_dir) is True, 'Directory specified by data_dir does not exist or is not a directory' 80 | 81 | all_file_path = glob.glob(FLAGS.data_dir + '/*') 82 | assert len(all_file_path) > 0, 'No file in the directory' 83 | 84 | ret_HR_image = [] 85 | ret_LR_image = [] 86 | 87 | for file in all_file_path: 88 | img = cv2.imread(file) 89 | filename = file.rsplit('/', 1)[-1] 90 | 91 | # crop patches if flag is true. Otherwise just resize HR and LR images 92 | if FLAGS.crop: 93 | for _ in range(FLAGS.num_crop_per_image): 94 | img_h, img_w, _ = img.shape 95 | 96 | if (img_h < FLAGS.crop_size) or (img_w < FLAGS.crop_size): 97 | print('Skip crop target image because of insufficient size') 98 | continue 99 | 100 | HR_image = crop(img, FLAGS) 101 | LR_crop_size = np.int(np.floor(FLAGS.crop_size / FLAGS.scale_SR)) 102 | LR_image = cv2.resize(HR_image, (LR_crop_size, LR_crop_size), interpolation=cv2.INTER_LANCZOS4) 103 | 104 | cv2.imwrite(FLAGS.HR_data_dir + '/' + filename, HR_image) 105 | cv2.imwrite(FLAGS.LR_data_dir + '/' + filename, LR_image) 106 | 107 | ret_HR_image.append(HR_image) 108 | ret_LR_image.append(LR_image) 109 | else: 110 | HR_image = cv2.resize(img, (FLAGS.HR_image_size, FLAGS.HR_image_size), interpolation=cv2.INTER_LANCZOS4) 111 | LR_image = cv2.resize(img, (FLAGS.LR_image_size, FLAGS.LR_image_size), interpolation=cv2.INTER_LANCZOS4) 112 | 113 | cv2.imwrite(FLAGS.HR_data_dir + '/' + filename, HR_image) 114 | cv2.imwrite(FLAGS.LR_data_dir + '/' + filename, LR_image) 115 | 116 | ret_HR_image.append(HR_image) 117 | ret_LR_image.append(LR_image) 118 | 119 | assert len(ret_HR_image) > 0 and len(ret_LR_image) > 0, 'No availale image is found in the directory' 120 | log(logflag, 'Data process : {} images are processed'.format(len(ret_HR_image)), 'info') 121 | 122 | ret_HR_image = np.array(ret_HR_image) 123 | ret_LR_image = np.array(ret_LR_image) 124 | 125 | if FLAGS.data_augmentation: 126 | LR_flip, HR_flip = data_augmentation(ret_LR_image, ret_HR_image, aug_type='horizontal_flip') 127 | LR_rot, HR_rot = data_augmentation(ret_LR_image, ret_HR_image, aug_type='rotation_90') 128 | 129 | ret_LR_image = np.append(ret_LR_image, LR_flip, axis=0) 130 | ret_HR_image = np.append(ret_HR_image, HR_flip, axis=0) 131 | ret_LR_image = np.append(ret_LR_image, LR_rot, axis=0) 132 | ret_HR_image = np.append(ret_HR_image, HR_rot, axis=0) 133 | 134 | del LR_flip, HR_flip, LR_rot, HR_rot 135 | 136 | np.savez(FLAGS.npz_data_dir + '/' + FLAGS.HR_npz_filename, images=ret_HR_image) 137 | np.savez(FLAGS.npz_data_dir + '/' + FLAGS.LR_npz_filename, images=ret_LR_image) 138 | 139 | return ret_HR_image, ret_LR_image 140 | 141 | 142 | def load_npz_data(FLAGS): 143 | """load array data from data_path""" 144 | return np.load(FLAGS.npz_data_dir + '/' + FLAGS.HR_npz_filename)['images'], \ 145 | np.load(FLAGS.npz_data_dir + '/' + FLAGS.LR_npz_filename)['images'] 146 | 147 | 148 | def load_inference_data(FLAGS): 149 | """load data from directory for inference""" 150 | assert os.path.isdir(FLAGS.data_dir) is True, 'Directory specified by data_dir does not exist or is not a directory' 151 | 152 | all_file_path = glob.glob(FLAGS.data_dir + '/*') 153 | assert len(all_file_path) > 0, 'No file in the directory' 154 | 155 | ret_LR_image = [] 156 | ret_filename = [] 157 | 158 | for file in all_file_path: 159 | img = cv2.imread(file) 160 | img = normalize_images(img) 161 | ret_LR_image.append(img[0][np.newaxis, ...]) 162 | 163 | ret_filename.append(file.rsplit('/', 1)[-1]) 164 | 165 | assert len(ret_LR_image) > 0, 'No available image is found in the directory' 166 | 167 | return ret_LR_image, ret_filename 168 | -------------------------------------------------------------------------------- /lib/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Generator(object): 5 | """the definition of Generator""" 6 | 7 | def __init__(self, FLAGS): 8 | self.channel = FLAGS.channel 9 | self.n_filter = 64 10 | self.inc_filter = 32 11 | self.num_repeat_RRDB = FLAGS.num_repeat_RRDB 12 | self.residual_scaling = FLAGS.residual_scaling 13 | self.init_kernel = tf.initializers.he_normal(seed=FLAGS.initialization_random_seed) 14 | 15 | def _conv_RRDB(self, x, out_channel, num=None, activate=True): 16 | with tf.variable_scope('block_{0}'.format(num)): 17 | x = tf.layers.conv2d(x, out_channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, name='conv') 18 | if activate: 19 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 20 | 21 | return x 22 | 23 | def _denseBlock(self, x, num=None): 24 | with tf.variable_scope('DenseBlock_sub{0}'.format(num)): 25 | x1 = self._conv_RRDB(x, self.inc_filter, 0) 26 | x2 = self._conv_RRDB(tf.concat([x, x1], axis=3), self.inc_filter, 1) 27 | x3 = self._conv_RRDB(tf.concat([x, x1, x2], axis=3), self.inc_filter, 2) 28 | x4 = self._conv_RRDB(tf.concat([x, x1, x2, x3], axis=3), self.inc_filter, 3) 29 | x5 = self._conv_RRDB(tf.concat([x, x1, x2, x3, x4], axis=3), self.n_filter, 4, activate=False) 30 | 31 | return x5 * self.residual_scaling 32 | 33 | def _RRDB(self, x, num=None): 34 | """Residual in Residual Dense Block""" 35 | with tf.variable_scope('RRDB_sub{0}'.format(num)): 36 | x_branch = tf.identity(x) 37 | 38 | x_branch += self._denseBlock(x_branch, 0) 39 | x_branch += self._denseBlock(x_branch, 1) 40 | x_branch += self._denseBlock(x_branch, 2) 41 | 42 | return x + x_branch * self.residual_scaling 43 | 44 | def _upsampling_layer(self, x, num=None): 45 | x = tf.layers.conv2d_transpose(x, self.n_filter, 3, 2, padding='same', name='upsample_{0}'.format(num)) 46 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 47 | 48 | return x 49 | 50 | def build(self, x): 51 | with tf.variable_scope('first_conv'): 52 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 53 | name='conv') 54 | 55 | with tf.variable_scope('RRDB'): 56 | x_branch = tf.identity(x) 57 | 58 | for i in range(self.num_repeat_RRDB): 59 | x_branch = self._RRDB(x_branch, i) 60 | 61 | x_branch = tf.layers.conv2d(x_branch, self.n_filter, 3, 1, padding='same', 62 | kernel_initializer=self.init_kernel, name='trunk_conv') 63 | 64 | x += x_branch 65 | 66 | with tf.variable_scope('Upsampling'): 67 | x = self._upsampling_layer(x, 1) 68 | x = self._upsampling_layer(x, 2) 69 | 70 | with tf.variable_scope('last_conv'): 71 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel, 72 | name='conv_1') 73 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU') 74 | x = tf.layers.conv2d(x, self.channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, 75 | name='conv_2') 76 | 77 | return x 78 | 79 | 80 | class Discriminator(object): 81 | """the definition of Discriminator""" 82 | 83 | def __init__(self, FLAGS): 84 | self.channel = FLAGS.channel 85 | self.n_filter = 64 86 | self.inc_filter = 32 87 | self.init_kernel = tf.initializers.he_normal(seed=FLAGS.initialization_random_seed) 88 | 89 | def _conv_block(self, x, out_channel, num=None): 90 | with tf.variable_scope('block_{0}'.format(num)): 91 | x = tf.layers.conv2d(x, out_channel, 3, 1, padding='same', use_bias=False, 92 | kernel_initializer=self.init_kernel, name='conv_1') 93 | x = tf.layers.BatchNormalization(name='batch_norm_1')(x) 94 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 95 | 96 | x = tf.layers.conv2d(x, out_channel, 4, 2, padding='same', use_bias=False, 97 | kernel_initializer=self.init_kernel, name='conv_2') 98 | x = tf.layers.BatchNormalization(name='batch_norm_2')(x) 99 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2') 100 | 101 | return x 102 | 103 | def build(self, x): 104 | with tf.variable_scope('first_conv'): 105 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', use_bias=False, 106 | kernel_initializer=self.init_kernel, name='conv_1') 107 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 108 | x = tf.layers.conv2d(x, self.n_filter, 4, 2, padding='same', use_bias=False, 109 | kernel_initializer=self.init_kernel, name='conv_2') 110 | x = tf.layers.BatchNormalization(name='batch_norm_1')(x) 111 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2') 112 | 113 | with tf.variable_scope('conv_block'): 114 | x = self._conv_block(x, self.n_filter * 2, 0) 115 | x = self._conv_block(x, self.n_filter * 4, 1) 116 | x = self._conv_block(x, self.n_filter * 8, 2) 117 | x = self._conv_block(x, self.n_filter * 8, 3) 118 | 119 | with tf.variable_scope('full_connected'): 120 | x = tf.layers.flatten(x) 121 | x = tf.layers.dense(x, 100, name='fully_connected_1') 122 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1') 123 | x = tf.layers.dense(x, 1, name='fully_connected_2') 124 | 125 | return x 126 | 127 | 128 | class Perceptual_VGG19(object): 129 | """the definition of VGG19. This network is used for constructing perceptual loss""" 130 | @staticmethod 131 | def build(x): 132 | # Block 1 133 | x = tf.layers.conv2d(x, 64, (3, 3), activation='relu', padding='same', name='block1_conv1') 134 | x = tf.layers.conv2d(x, 64, (3, 3), activation='relu', padding='same', name='block1_conv2') 135 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block1_pool') 136 | 137 | # Block 2 138 | x = tf.layers.conv2d(x, 128, (3, 3), activation='relu', padding='same', name='block2_conv1') 139 | x = tf.layers.conv2d(x, 128, (3, 3), activation='relu', padding='same', name='block2_conv2') 140 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block2_pool') 141 | 142 | # Block 3 143 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv1') 144 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv2') 145 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv3') 146 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv4') 147 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block3_pool') 148 | 149 | # Block 4 150 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv1') 151 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv2') 152 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv3') 153 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv4') 154 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block4_pool') 155 | 156 | # Block 5 157 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv1') 158 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv2') 159 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv3') 160 | x = tf.layers.conv2d(x, 512, (3, 3), activation=None, padding='same', name='block5_conv4') 161 | 162 | return x 163 | -------------------------------------------------------------------------------- /lib/train_module.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import tensorflow as tf 4 | 5 | from lib.network import Generator, Discriminator, Perceptual_VGG19 6 | 7 | 8 | class Network(object): 9 | """class to build networks""" 10 | def __init__(self, FLAGS, LR_data=None, HR_data=None): 11 | self.FLAGS = FLAGS 12 | self.LR_data = LR_data 13 | self.HR_data = HR_data 14 | 15 | def generator(self): 16 | with tf.name_scope('generator'): 17 | with tf.variable_scope('generator'): 18 | gen_out = Generator(self.FLAGS).build(self.LR_data) 19 | 20 | return gen_out 21 | 22 | def discriminator(self, gen_out): 23 | discriminator = Discriminator(self.FLAGS) 24 | 25 | with tf.name_scope('real_discriminator'): 26 | with tf.variable_scope('discriminator', reuse=False): 27 | dis_out_real = discriminator.build(self.HR_data) 28 | 29 | with tf.name_scope('fake_discriminator'): 30 | with tf.variable_scope('discriminator', reuse=True): 31 | dis_out_fake = discriminator.build(gen_out) 32 | 33 | return dis_out_real, dis_out_fake 34 | 35 | 36 | class Loss(object): 37 | """class to build loss functions""" 38 | def __init__(self): 39 | self.summary_target = OrderedDict() 40 | 41 | def pretrain_loss(self, pre_gen_out, HR_data): 42 | with tf.name_scope('loss_function'): 43 | with tf.variable_scope('pixel-wise_loss'): 44 | pre_gen_loss = tf.reduce_mean(tf.reduce_mean(tf.square(pre_gen_out - HR_data), axis=3)) 45 | 46 | self.summary_target['pre-train : pixel-wise_loss'] = pre_gen_loss 47 | return pre_gen_loss 48 | 49 | def _perceptual_vgg19_loss(self, HR_data, gen_out): 50 | with tf.name_scope('perceptual_vgg19_HR'): 51 | with tf.variable_scope('perceptual_vgg19', reuse=False): 52 | vgg_out_hr = Perceptual_VGG19().build(HR_data) 53 | 54 | with tf.name_scope('perceptual_vgg19_Gen'): 55 | with tf.variable_scope('perceptual_vgg19', reuse=True): 56 | vgg_out_gen = Perceptual_VGG19().build(gen_out) 57 | 58 | return vgg_out_hr, vgg_out_gen 59 | 60 | def gan_loss(self, FLAGS, HR_data, gen_out, dis_out_real, dis_out_fake): 61 | 62 | with tf.name_scope('loss_function'): 63 | with tf.variable_scope('loss_generator'): 64 | if FLAGS.gan_loss_type == 'RaGAN': 65 | g_loss_p1 = tf.reduce_mean( 66 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_real - tf.reduce_mean(dis_out_fake), 67 | labels=tf.zeros_like(dis_out_real))) 68 | 69 | g_loss_p2 = tf.reduce_mean( 70 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake - tf.reduce_mean(dis_out_real), 71 | labels=tf.ones_like(dis_out_fake))) 72 | 73 | gen_loss = FLAGS.gan_loss_coeff * (g_loss_p1 + g_loss_p2) / 2 74 | elif FLAGS.gan_loss_type == 'GAN': 75 | gen_loss = FLAGS.gan_loss_coeff * tf.reduce_mean( 76 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake, labels=tf.ones_like(dis_out_fake))) 77 | else: 78 | raise ValueError('Unknown GAN loss function type') 79 | 80 | # content loss : L1 distance 81 | content_loss = FLAGS.content_loss_coeff * tf.reduce_mean( 82 | tf.reduce_sum(tf.abs(gen_out - HR_data), axis=[1, 2, 3])) 83 | 84 | gen_loss += content_loss 85 | 86 | # perceptual loss 87 | if FLAGS.perceptual_loss == 'pixel-wise': 88 | perc_loss = tf.reduce_mean(tf.reduce_mean(tf.square(gen_out - HR_data), axis=3)) 89 | gen_loss += perc_loss 90 | elif FLAGS.perceptual_loss == 'VGG19': 91 | vgg_out_gen, vgg_out_hr = self._perceptual_vgg19_loss(HR_data, gen_out) 92 | perc_loss = tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3)) 93 | gen_loss += perc_loss 94 | else: 95 | raise ValueError('Unknown perceptual loss type') 96 | 97 | with tf.variable_scope('loss_discriminator'): 98 | if FLAGS.gan_loss_type == 'RaGAN': 99 | d_loss_real = tf.reduce_mean( 100 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_real - tf.reduce_mean(dis_out_fake), 101 | labels=tf.ones_like(dis_out_real))) / 2 102 | 103 | d_loss_fake = tf.reduce_mean( 104 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake - tf.reduce_mean(dis_out_real), 105 | labels=tf.zeros_like(dis_out_fake))) / 2 106 | 107 | dis_loss = d_loss_real + d_loss_fake 108 | elif FLAGS.gan_loss_type == 'GAN': 109 | d_loss_real = tf.reduce_mean( 110 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_real, labels=tf.ones_like(dis_out_real))) 111 | d_loss_fake = tf.reduce_mean( 112 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake, 113 | labels=tf.zeros_like(dis_out_fake))) 114 | 115 | dis_loss = d_loss_real + d_loss_fake 116 | 117 | else: 118 | raise ValueError('Unknown GAN loss function type') 119 | 120 | self.summary_target['generator_loss'] = gen_loss 121 | self.summary_target['content_loss'] = content_loss 122 | self.summary_target['perceptual_loss'] = perc_loss 123 | self.summary_target['discriminator_loss'] = dis_loss 124 | self.summary_target['discriminator_real_loss'] = d_loss_real 125 | self.summary_target['discriminator_fake_loss'] = d_loss_fake 126 | 127 | return gen_loss, dis_loss 128 | 129 | def add_summary_writer(self): 130 | return [tf.summary.scalar(key, value) for key, value in self.summary_target.items()] 131 | 132 | 133 | class Optimizer(object): 134 | """class to build optimizers""" 135 | @staticmethod 136 | def pretrain_optimizer(FLAGS, global_iter, pre_gen_loss): 137 | learning_rate = tf.train.exponential_decay(FLAGS.pretrain_learning_rate, global_iter, 138 | FLAGS.pretrain_lr_decay_step, 0.5, staircase=True) 139 | 140 | with tf.name_scope('optimizer'): 141 | with tf.variable_scope('optimizer_generator'): 142 | pre_gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 143 | pre_gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=pre_gen_loss, 144 | global_step=global_iter, 145 | var_list=pre_gen_var) 146 | 147 | return pre_gen_var, pre_gen_optimizer 148 | 149 | @staticmethod 150 | def gan_optimizer(FLAGS, global_iter, dis_loss, gen_loss): 151 | boundaries = [50000, 100000, 200000, 300000] 152 | values = [FLAGS.learning_rate, FLAGS.learning_rate * 0.5, FLAGS.learning_rate * 0.5 ** 2, 153 | FLAGS.learning_rate * 0.5 ** 3, FLAGS.learning_rate * 0.5 ** 4] 154 | learning_rate = tf.train.piecewise_constant(global_iter, boundaries, values) 155 | 156 | with tf.name_scope('optimizer'): 157 | with tf.variable_scope('optimizer_discriminator'): 158 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 159 | dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=dis_loss, 160 | var_list=dis_var) 161 | 162 | with tf.variable_scope('optimizer_generator'): 163 | with tf.control_dependencies([dis_optimizer]): 164 | gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 165 | gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=gen_loss, 166 | global_step=global_iter, 167 | var_list=gen_var) 168 | 169 | return dis_var, dis_optimizer, gen_var, gen_optimizer 170 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | 7 | import tensorflow as tf 8 | from sklearn.utils import shuffle 9 | 10 | from lib.ops import load_vgg19_weight 11 | from lib.pretrain_generator import train_pretrain_generator 12 | from lib.train_module import Network, Loss, Optimizer 13 | from lib.utils import create_dirs, log, normalize_images, save_image, load_npz_data, load_and_save_data 14 | 15 | 16 | def set_flags(): 17 | Flags = tf.app.flags 18 | 19 | # About data 20 | Flags.DEFINE_string('data_dir', './data/DIV2K_train_HR', 'data directory') 21 | Flags.DEFINE_string('HR_data_dir', './data/HR_data', 'HR data directory') 22 | Flags.DEFINE_string('LR_data_dir', './data/LR_data', 'LR data directory') 23 | Flags.DEFINE_string('npz_data_dir', './data/npz', 'The npz data dir') 24 | Flags.DEFINE_string('HR_npz_filename', 'HR_image.npz', 'the filename of HR image npz file') 25 | Flags.DEFINE_string('LR_npz_filename', 'LR_image.npz', 'the filename of LR image npz file') 26 | Flags.DEFINE_boolean('save_data', True, 'Whether to load and save data as npz file') 27 | Flags.DEFINE_string('train_result_dir', './train_result', 'output directory during training') 28 | Flags.DEFINE_boolean('crop', True, 'Whether image cropping is enabled') 29 | Flags.DEFINE_integer('crop_size', 128, 'the size of crop of training HR images') 30 | Flags.DEFINE_integer('num_crop_per_image', 2, 'the number of random-cropped images per image') 31 | Flags.DEFINE_boolean('data_augmentation', True, 'whether to augment data') 32 | 33 | # About Network 34 | Flags.DEFINE_integer('scale_SR', 4, 'the scale of super-resolution') 35 | Flags.DEFINE_integer('num_repeat_RRDB', 15, 'The number of repeats of RRDB blocks') 36 | Flags.DEFINE_float('residual_scaling', 0.2, 'residual scaling parameter') 37 | Flags.DEFINE_integer('initialization_random_seed', 111, 'random seed of networks initialization') 38 | Flags.DEFINE_string('perceptual_loss', 'VGG19', 'the part of loss function. "VGG19" or "pixel-wise"') 39 | Flags.DEFINE_string('gan_loss_type', 'RaGAN', 'the type of GAN loss functions. "RaGAN or GAN"') 40 | 41 | # About training 42 | Flags.DEFINE_integer('num_iter', 50000, 'The number of iterations') 43 | Flags.DEFINE_integer('batch_size', 32, 'Mini-batch size') 44 | Flags.DEFINE_integer('channel', 3, 'Number of input/output image channel') 45 | Flags.DEFINE_boolean('pretrain_generator', True, 'Whether to pretrain generator') 46 | Flags.DEFINE_float('pretrain_learning_rate', 2e-4, 'learning rate for pretrain') 47 | Flags.DEFINE_float('pretrain_lr_decay_step', 20000, 'decay by every n iteration') 48 | Flags.DEFINE_float('learning_rate', 1e-4, 'learning rate') 49 | Flags.DEFINE_float('weight_initialize_scale', 0.1, 'scale to multiply after MSRA initialization') 50 | Flags.DEFINE_integer('HR_image_size', 128, 51 | 'Image width and height of HR image. This flag is valid when crop flag is set to false.') 52 | Flags.DEFINE_integer('LR_image_size', 32, 53 | 'Image width and height of LR image. This size should be 1/4 of HR_image_size exactly. ' 54 | 'This flag is valid when crop flag is set to false.') 55 | Flags.DEFINE_float('epsilon', 1e-12, 'used in loss function') 56 | Flags.DEFINE_float('gan_loss_coeff', 0.005, 'used in perceptual loss') 57 | Flags.DEFINE_float('content_loss_coeff', 0.01, 'used in content loss') 58 | 59 | # About log 60 | Flags.DEFINE_boolean('logging', True, 'whether to record training log') 61 | Flags.DEFINE_integer('train_sample_save_freq', 2000, 'save samples during training every n iteration') 62 | Flags.DEFINE_integer('train_ckpt_save_freq', 2000, 'save checkpoint during training every n iteration') 63 | Flags.DEFINE_integer('train_summary_save_freq', 200, 'save summary during training every n iteration') 64 | Flags.DEFINE_string('pre_train_checkpoint_dir', './pre_train_checkpoint', 'pre-train checkpoint directory') 65 | Flags.DEFINE_string('checkpoint_dir', './checkpoint', 'checkpoint directory') 66 | Flags.DEFINE_string('logdir', './log', 'log directory') 67 | 68 | # About GPU setting 69 | Flags.DEFINE_string('gpu_dev_num', '0', 'Which GPU to use for multi-GPUs.') 70 | 71 | return Flags.FLAGS 72 | 73 | 74 | def set_logger(FLAGS): 75 | """set logger for training recording""" 76 | if FLAGS.logging: 77 | logfile = '{0}/training_logfile_{1}.log'.format(FLAGS.logdir, datetime.now().strftime("%Y%m%d_%H%M%S")) 78 | formatter = '%(levelname)s:%(asctime)s:%(message)s' 79 | logging.basicConfig(level=logging.INFO, filename=logfile, format=formatter, datefmt='%Y-%m-%d %I:%M:%S') 80 | 81 | return True 82 | else: 83 | print('No logging is set') 84 | return False 85 | 86 | 87 | def main(): 88 | # set flag 89 | FLAGS = set_flags() 90 | 91 | # make dirs 92 | target_dirs = [FLAGS.HR_data_dir, FLAGS.LR_data_dir, FLAGS.npz_data_dir, FLAGS.train_result_dir, 93 | FLAGS.pre_train_checkpoint_dir, FLAGS.checkpoint_dir, FLAGS.logdir] 94 | create_dirs(target_dirs) 95 | 96 | # set logger 97 | logflag = set_logger(FLAGS) 98 | log(logflag, 'Training script start', 'info') 99 | 100 | # load data 101 | if FLAGS.save_data: 102 | log(logflag, 'Data process : Data processing start', 'info') 103 | HR_train, LR_train = load_and_save_data(FLAGS, logflag) 104 | log(logflag, 'Data process : Data loading and data processing are completed', 'info') 105 | else: 106 | log(logflag, 'Data process : Data loading start', 'info') 107 | HR_train, LR_train = load_npz_data(FLAGS) 108 | log(logflag, 109 | 'Data process : Loading existing data is completed. {} images are loaded'.format(len(HR_train)), 110 | 'info') 111 | 112 | # pre-train generator with pixel-wise loss and save the trained model 113 | if FLAGS.pretrain_generator: 114 | train_pretrain_generator(FLAGS, LR_train, HR_train, logflag) 115 | tf.reset_default_graph() 116 | gc.collect() 117 | else: 118 | log(logflag, 'Pre-train : Pre-train skips and an existing trained model will be used', 'info') 119 | 120 | LR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.LR_image_size, FLAGS.LR_image_size, FLAGS.channel], 121 | name='LR_input') 122 | HR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel], 123 | name='HR_input') 124 | 125 | # build Generator and Discriminator 126 | network = Network(FLAGS, LR_data, HR_data) 127 | gen_out = network.generator() 128 | dis_out_real, dis_out_fake = network.discriminator(gen_out) 129 | 130 | # build loss function 131 | loss = Loss() 132 | gen_loss, dis_loss = loss.gan_loss(FLAGS, HR_data, gen_out, dis_out_real, dis_out_fake) 133 | 134 | # define optimizers 135 | global_iter = tf.Variable(0, trainable=False) 136 | dis_var, dis_optimizer, gen_var, gen_optimizer = Optimizer().gan_optimizer(FLAGS, global_iter, dis_loss, gen_loss) 137 | 138 | # build summary writer 139 | tr_summary = tf.summary.merge(loss.add_summary_writer()) 140 | 141 | num_train_data = len(HR_train) 142 | num_batch_in_train = int(math.floor(num_train_data / FLAGS.batch_size)) 143 | num_epoch = int(math.ceil(FLAGS.num_iter / num_batch_in_train)) 144 | 145 | HR_train, LR_train = normalize_images(HR_train, LR_train) 146 | 147 | fetches = {'dis_optimizer': dis_optimizer, 'gen_optimizer': gen_optimizer, 148 | 'dis_loss': dis_loss, 'gen_loss': gen_loss, 149 | 'gen_HR': gen_out, 150 | 'summary': tr_summary 151 | } 152 | 153 | gc.collect() 154 | 155 | config = tf.ConfigProto( 156 | gpu_options=tf.GPUOptions( 157 | allow_growth=True, 158 | visible_device_list=FLAGS.gpu_dev_num 159 | ) 160 | ) 161 | 162 | # Start Session 163 | with tf.Session(config=config) as sess: 164 | log(logflag, 'Training ESRGAN starts', 'info') 165 | 166 | sess.run(tf.global_variables_initializer()) 167 | sess.run(global_iter.initializer) 168 | 169 | writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph) 170 | 171 | pre_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')) 172 | pre_saver.restore(sess, tf.train.latest_checkpoint(FLAGS.pre_train_checkpoint_dir)) 173 | 174 | if FLAGS.perceptual_loss == 'VGG19': 175 | sess.run(load_vgg19_weight(FLAGS)) 176 | 177 | saver = tf.train.Saver(max_to_keep=10) 178 | 179 | for epoch in range(num_epoch): 180 | log(logflag, 'ESRGAN Epoch: {0}'.format(epoch), 'info') 181 | HR_train, LR_train = shuffle(HR_train, LR_train, random_state=222) 182 | 183 | for iteration in range(num_batch_in_train): 184 | current_iter = tf.train.global_step(sess, global_iter) 185 | if current_iter > FLAGS.num_iter: 186 | break 187 | 188 | feed_dict = { 189 | HR_data: HR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size], 190 | LR_data: LR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size] 191 | } 192 | 193 | # update weights of G/D 194 | result = sess.run(fetches=fetches, feed_dict=feed_dict) 195 | 196 | # save summary every n iter 197 | if current_iter % FLAGS.train_summary_save_freq == 0: 198 | writer.add_summary(result['summary'], global_step=current_iter) 199 | 200 | # save samples every n iter 201 | if current_iter % FLAGS.train_sample_save_freq == 0: 202 | log(logflag, 203 | 'ESRGAN iteration : {0}, gen_loss : {1}, dis_loss : {2}'.format(current_iter, 204 | result['gen_loss'], 205 | result['dis_loss']), 206 | 'info') 207 | 208 | save_image(FLAGS, result['gen_HR'], 'train', current_iter, save_max_num=5) 209 | 210 | if current_iter % FLAGS.train_ckpt_save_freq == 0: 211 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'gen'), global_step=current_iter) 212 | 213 | writer.close() 214 | log(logflag, 'Training ESRGAN end', 'info') 215 | log(logflag, 'Training script end', 'info') 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------