├── img
├── 0833.png
├── 0887.png
├── 0896.png
├── 1117e6b64a7b4336df58eb351cff435529485e91.png
├── 11183b7a2e0ee4be9990721d9ddc7fa34997b41f.png
└── 111b822af95747f45f5d25a84f8094c10b27c765.png
├── evaluation.py
├── lib
├── ops.py
├── pretrain_generator.py
├── utils.py
├── network.py
└── train_module.py
├── inference.py
├── visualize.py
├── network_interpolation.py
├── README.md
├── train.py
└── LICENSE
/img/0833.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/0833.png
--------------------------------------------------------------------------------
/img/0887.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/0887.png
--------------------------------------------------------------------------------
/img/0896.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/0896.png
--------------------------------------------------------------------------------
/img/1117e6b64a7b4336df58eb351cff435529485e91.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/1117e6b64a7b4336df58eb351cff435529485e91.png
--------------------------------------------------------------------------------
/img/11183b7a2e0ee4be9990721d9ddc7fa34997b41f.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/11183b7a2e0ee4be9990721d9ddc7fa34997b41f.png
--------------------------------------------------------------------------------
/img/111b822af95747f45f5d25a84f8094c10b27c765.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiram64/ESRGAN-tensorflow/HEAD/img/111b822af95747f45f5d25a84f8094c10b27c765.png
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import cv2
6 | from skimage.measure import compare_psnr, compare_ssim
7 |
8 |
9 | def calc_measures(hr_path, calc_psnr=True, calc_ssim=True):
10 | """calculate PSNR and SSIM for all HR images and their mean.
11 | These paired images should have the same filename.
12 | """
13 |
14 | HR_files = glob.glob(hr_path + '/*')
15 | mean_psnr = 0
16 | mean_ssim = 0
17 |
18 | for file in HR_files:
19 | hr_img = cv2.imread(file)
20 | filename = file.rsplit('/', 1)[-1]
21 | path = os.path.join(args.inference_result, filename)
22 |
23 | if not os.path.isfile(path):
24 | raise FileNotFoundError('')
25 |
26 | inf_img = cv2.imread(path)
27 |
28 | # compare HR image and inferenced image with measures
29 | print('-' * 10)
30 | if calc_psnr:
31 | psnr = compare_psnr(hr_img, inf_img)
32 | print('{0} : PSNR {1:.3f} dB'.format(filename, psnr))
33 | mean_psnr += psnr
34 | if calc_ssim:
35 | ssim = compare_ssim(hr_img, inf_img, multichannel=True)
36 | print('{0} : SSIM {1:.3f}'.format(filename, ssim))
37 | mean_ssim += ssim
38 |
39 | print('-' * 10)
40 | if calc_psnr:
41 | print('mean-PSNR {:.3f} dB'.format(mean_psnr / len(HR_files)))
42 | if calc_ssim:
43 | print('mean-SSIM {:.3f}'.format(mean_ssim / len(HR_files)))
44 |
45 |
46 | if __name__ == '__main__':
47 | parser = argparse.ArgumentParser()
48 |
49 | parser.add_argument('--HR_data_dir', default='./data/div2_inf_HR', type=str)
50 | parser.add_argument('--inference_result', default='./inference_result_div2', type=str)
51 |
52 | args = parser.parse_args()
53 |
54 | calc_measures(args.HR_data_dir, calc_psnr=True, calc_ssim=True)
55 |
--------------------------------------------------------------------------------
/lib/ops.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import tensorflow as tf
4 | from tensorflow.keras.applications.vgg19 import VGG19
5 |
6 |
7 | def scale_initialization(weights, FLAGS):
8 | return [tf.assign(weight, weight * FLAGS.weight_initialize_scale) for weight in weights]
9 |
10 |
11 | def _transfer_vgg19_weight(FLAGS, weight_dict):
12 | from_model = VGG19(include_top=False, weights='imagenet', input_tensor=None,
13 | input_shape=(FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel))
14 |
15 | fetch_weight = []
16 |
17 | for layer in from_model.layers:
18 | if 'conv' in layer.name:
19 | W, b = layer.get_weights()
20 |
21 | fetch_weight.append(
22 | tf.assign(weight_dict['loss_generator/perceptual_vgg19/{}/kernel'.format(layer.name)], W)
23 | )
24 | fetch_weight.append(
25 | tf.assign(weight_dict['loss_generator/perceptual_vgg19/{}/bias'.format(layer.name)], b)
26 | )
27 |
28 | return fetch_weight
29 |
30 |
31 | def load_vgg19_weight(FLAGS):
32 | vgg_weight = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='loss_generator/perceptual_vgg19')
33 |
34 | assert len(vgg_weight) > 0, 'No VGG19 weight was collected. The target scope might be wrong.'
35 |
36 | weight_dict = {}
37 | for weight in vgg_weight:
38 | weight_dict[weight.name.rsplit(':', 1)[0]] = weight
39 |
40 | return _transfer_vgg19_weight(FLAGS, weight_dict)
41 |
42 |
43 | def extract_weight(network_vars):
44 | weight_dict = OrderedDict()
45 |
46 | for weight in network_vars:
47 | weight_dict[weight.name] = weight.eval()
48 |
49 | return weight_dict
50 |
51 |
52 | def interpolate_weight(FLAGS, pretrain_weight):
53 | fetch_weight = []
54 | alpha = FLAGS.interpolation_param
55 |
56 | for name, pre_weight in pretrain_weight.items():
57 | esrgan_weight = tf.get_default_graph().get_tensor_by_name(name)
58 |
59 | assert pre_weight.shape == esrgan_weight.shape, 'The shape of weights does not match'
60 |
61 | fetch_weight.append(tf.assign(esrgan_weight, (1 - alpha) * pre_weight + alpha * esrgan_weight))
62 |
63 | return fetch_weight
64 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from lib.train_module import Network
8 | from lib.utils import create_dirs, de_normalize_image, load_inference_data
9 |
10 |
11 | def set_flags():
12 | Flags = tf.app.flags
13 |
14 | Flags.DEFINE_string('data_dir', './data/inference', 'inference data directory')
15 | Flags.DEFINE_string('checkpoint_dir', './checkpoint', 'checkpoint directory')
16 | Flags.DEFINE_string('inference_checkpoint', '',
17 | 'checkpoint to use for inference. Empty string means the latest checkpoint is used')
18 | Flags.DEFINE_string('inference_result_dir', './inference_result', 'output directory during inference')
19 | Flags.DEFINE_integer('channel', 3, 'Number of input/output image channel')
20 | Flags.DEFINE_integer('num_repeat_RRDB', 15, 'The number of repeats of RRDB blocks')
21 | Flags.DEFINE_float('residual_scaling', 0.2, 'residual scaling parameter')
22 | Flags.DEFINE_integer('initialization_random_seed', 111, 'random seed of networks initialization')
23 |
24 | return Flags.FLAGS
25 |
26 |
27 | def main():
28 | # set flag
29 | FLAGS = set_flags()
30 |
31 | # make dirs
32 | target_dirs = [FLAGS.inference_result_dir]
33 | create_dirs(target_dirs)
34 |
35 | # load test data
36 | LR_inference, LR_filenames = load_inference_data(FLAGS)
37 |
38 | LR_data = tf.placeholder(tf.float32, shape=[1, None, None, FLAGS.channel], name='LR_input')
39 |
40 | # build Generator
41 | network = Network(FLAGS, LR_data)
42 | gen_out = network.generator()
43 |
44 | fetches = {'gen_HR': gen_out}
45 |
46 | # Start Session
47 | config = tf.ConfigProto()
48 | config.gpu_options.allow_growth = True
49 |
50 | with tf.Session(config=config) as sess:
51 | print('Inference start')
52 |
53 | saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator'))
54 |
55 | if FLAGS.inference_checkpoint:
56 | saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, FLAGS.inference_checkpoint))
57 | else:
58 | print('No checkpoint is specified. The latest one is used for inference')
59 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
60 |
61 | for i, test_img in enumerate(LR_inference):
62 |
63 | feed_dict = {
64 | LR_data: test_img
65 | }
66 |
67 | result = sess.run(fetches=fetches, feed_dict=feed_dict)
68 |
69 | cv2.imwrite(os.path.join(FLAGS.inference_result_dir, LR_filenames[i]),
70 | de_normalize_image(np.squeeze(result['gen_HR'])))
71 |
72 | print('Inference end')
73 |
74 |
75 | if __name__ == '__main__':
76 | main()
77 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import math
4 | import os
5 |
6 | import cv2
7 | import numpy as np
8 |
9 |
10 | def visualize(args):
11 | """visualize images. Bicubic interpolation(generated in this script), ESRGAN, ESRGAN with network interpolation
12 | and HR are tiled into an image.
13 | The images of ESRGAN, ESRGAN with network interpolation have the same filename as HR.
14 | """
15 | HR_files = glob.glob(args.HR_data_dir + '/*')
16 |
17 | for file in HR_files:
18 | # HR(GT)
19 | hr_img = cv2.imread(file)
20 | h, w, _ = hr_img.shape
21 | filename = file.rsplit('/', 1)[-1].rsplit('.', 1)[0]
22 |
23 | # LR -> bicubic
24 | r_h, r_w = math.floor(h / 4), math.floor(w / 4)
25 | lr_img = cv2.resize(hr_img, (r_w, r_h), cv2.INTER_CUBIC)
26 | bic_img = cv2.resize(lr_img, (w, h), cv2.INTER_CUBIC)
27 |
28 | # inference
29 | inf_path_jpg = os.path.join(args.inference_result, filename + '.jpg')
30 | inf_path_png = os.path.join(args.inference_result, filename + '.png')
31 |
32 | if os.path.isfile(inf_path_jpg):
33 | inf_path = inf_path_jpg
34 | elif os.path.isfile(inf_path_png):
35 | inf_path = inf_path_png
36 | else:
37 | raise FileNotFoundError('Images should have the same filename as HR image and be the formats of jpg or png')
38 |
39 | inf_img = cv2.imread(inf_path)
40 |
41 | # network interpolation inference
42 | ni_path_jpg = os.path.join(args.network_interpolation_result, filename + '.jpg')
43 | ni_path_png = os.path.join(args.network_interpolation_result, filename + '.png')
44 | if os.path.isfile(ni_path_jpg):
45 | ni_path = ni_path_jpg
46 | elif os.path.isfile(inf_path_png):
47 | ni_path = ni_path_png
48 | else:
49 | raise FileNotFoundError('Images should have the same filename as HR image and be the formats of jpg or png')
50 |
51 | ni_img = cv2.imread(ni_path)
52 |
53 | h_upper = int(math.floor(h / 2) + args.path_size / 2)
54 | h_lower = int(math.floor(h / 2) - args.path_size / 2)
55 |
56 | w_right = int(math.floor(w / 2) + args.path_size / 2)
57 | w_left = int(math.floor(w / 2) - args.path_size / 2)
58 |
59 | h_size = h_upper - h_lower
60 | w_size = w_right - w_left
61 |
62 | out_arr = np.empty((h_size, w_size * 4, 3))
63 |
64 | # tile images from left to right : bicubic -> ESRGAN-inference -> Network interpolation -> HR(GT)
65 | out_arr[:, :w_size, :] = bic_img[h_lower:h_upper, w_left:w_right, :]
66 | out_arr[:, w_size:w_size * 2, :] = inf_img[h_lower:h_upper, w_left:w_right, :]
67 | out_arr[:, w_size * 2:w_size * 3, :] = ni_img[h_lower:h_upper, w_left:w_right, :]
68 | out_arr[:, w_size * 3:w_size * 4, :] = hr_img[h_lower:h_upper, w_left:w_right, :]
69 |
70 | cv2.imwrite(args.output_dir + '/' + '{}.png'.format(filename), out_arr)
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 |
76 | parser.add_argument('--HR_data_dir', default='./data/div2_inf_HR', type=str)
77 | parser.add_argument('--inference_result', default='./inference_result_div2', type=str)
78 | parser.add_argument('--network_interpolation_result', default='./interpolation_result_div2', type=str)
79 | parser.add_argument('--path_size', default=512, type=str)
80 | parser.add_argument('--output_dir', default='./', type=str)
81 |
82 | args = parser.parse_args()
83 |
84 | visualize(args)
85 |
--------------------------------------------------------------------------------
/lib/pretrain_generator.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import math
4 |
5 | from sklearn.utils import shuffle
6 | import tensorflow as tf
7 |
8 | from lib.ops import scale_initialization
9 | from lib.train_module import Network, Loss, Optimizer
10 | from lib.utils import log, normalize_images, save_image
11 |
12 |
13 | def train_pretrain_generator(FLAGS, LR_train, HR_train, logflag):
14 | """pre-train deep network as initialization weights of ESRGAN Generator"""
15 | log(logflag, 'Pre-train : Process start', 'info')
16 |
17 | LR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.LR_image_size, FLAGS.LR_image_size, FLAGS.channel],
18 | name='LR_input')
19 | HR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel],
20 | name='HR_input')
21 |
22 | # build Generator
23 | network = Network(FLAGS, LR_data)
24 | pre_gen_out = network.generator()
25 |
26 | # build loss function
27 | loss = Loss()
28 | pre_gen_loss = loss.pretrain_loss(pre_gen_out, HR_data)
29 |
30 | # build optimizer
31 | global_iter = tf.Variable(0, trainable=False)
32 | pre_gen_var, pre_gen_optimizer = Optimizer().pretrain_optimizer(FLAGS, global_iter, pre_gen_loss)
33 |
34 | # build summary writer
35 | pre_summary = tf.summary.merge(loss.add_summary_writer())
36 |
37 | num_train_data = len(HR_train)
38 | num_batch_in_train = int(math.floor(num_train_data / FLAGS.batch_size))
39 | num_epoch = int(math.ceil(FLAGS.num_iter / num_batch_in_train))
40 |
41 | HR_train, LR_train = normalize_images(HR_train, LR_train)
42 |
43 | fetches = {'pre_gen_loss': pre_gen_loss, 'pre_gen_optimizer': pre_gen_optimizer, 'gen_HR': pre_gen_out,
44 | 'summary': pre_summary}
45 |
46 | gc.collect()
47 |
48 | config = tf.ConfigProto(
49 | gpu_options=tf.GPUOptions(
50 | allow_growth=True,
51 | visible_device_list=FLAGS.gpu_dev_num
52 | )
53 | )
54 |
55 | saver = tf.train.Saver(max_to_keep=10)
56 |
57 | # Start session
58 | with tf.Session(config=config) as sess:
59 | log(logflag, 'Pre-train : Training starts', 'info')
60 |
61 | sess.run(tf.global_variables_initializer())
62 | sess.run(global_iter.initializer)
63 | sess.run(scale_initialization(pre_gen_var, FLAGS))
64 |
65 | writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph, filename_suffix='pre-train')
66 |
67 | for epoch in range(num_epoch):
68 | log(logflag, 'Pre-train Epoch: {0}'.format(epoch), 'info')
69 |
70 | HR_train, LR_train = shuffle(HR_train, LR_train, random_state=222)
71 |
72 | for iteration in range(num_batch_in_train):
73 | current_iter = tf.train.global_step(sess, global_iter)
74 |
75 | if current_iter > FLAGS.num_iter:
76 | break
77 |
78 | feed_dict = {
79 | HR_data: HR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size],
80 | LR_data: LR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size]
81 | }
82 |
83 | # update weights
84 | result = sess.run(fetches=fetches, feed_dict=feed_dict)
85 |
86 | # save summary every n iter
87 | if current_iter % FLAGS.train_summary_save_freq == 0:
88 | writer.add_summary(result['summary'], global_step=current_iter)
89 |
90 | # save samples every n iter
91 | if current_iter % FLAGS.train_sample_save_freq == 0:
92 | log(logflag,
93 | 'Pre-train iteration : {0}, pixel-wise_loss : {1}'.format(current_iter, result['pre_gen_loss']),
94 | 'info')
95 | save_image(FLAGS, result['gen_HR'], 'pre-train', current_iter, save_max_num=5)
96 |
97 | # save checkpoint
98 | if current_iter % FLAGS.train_ckpt_save_freq == 0:
99 | saver.save(sess, os.path.join(FLAGS.pre_train_checkpoint_dir, 'pre_gen'), global_step=current_iter)
100 |
101 | writer.close()
102 | log(logflag, 'Pre-train : Process end', 'info')
103 |
--------------------------------------------------------------------------------
/network_interpolation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from lib.train_module import Network
8 | from lib.ops import extract_weight, interpolate_weight
9 | from lib.utils import create_dirs, de_normalize_image, load_inference_data
10 |
11 |
12 | def set_flags():
13 | Flags = tf.app.flags
14 |
15 | Flags.DEFINE_string('data_dir', './data/inference', 'inference data directory')
16 | Flags.DEFINE_string('pre_train_checkpoint_dir', './pre_train_checkpoint', 'pre-train checkpoint directory')
17 | Flags.DEFINE_string('checkpoint_dir', './checkpoint', 'checkpoint directory')
18 | Flags.DEFINE_string('inference_pretrain_checkpoint', '',
19 | 'pretrain checkpoint to use for network interpolation. Empty string means the latest checkpoint is used')
20 | Flags.DEFINE_string('inference_checkpoint', '',
21 | 'checkpoint to use for network interpolation. Empty string means the latest checkpoint is used')
22 | Flags.DEFINE_string('interpolation_result_dir', './interpolation_result', 'output directory during inference')
23 | Flags.DEFINE_integer('channel', 3, 'Number of input/output image channel')
24 | Flags.DEFINE_integer('num_repeat_RRDB', 15, 'The number of repeats of RRDB blocks')
25 | Flags.DEFINE_float('residual_scaling', 0.2, 'residual scaling parameter')
26 | Flags.DEFINE_integer('initialization_random_seed', 111, 'random seed of networks initialization')
27 | Flags.DEFINE_float('interpolation_param', 0.8, 'tuning parameter for ')
28 |
29 | return Flags.FLAGS
30 |
31 |
32 | def main():
33 | # set flag
34 | FLAGS = set_flags()
35 |
36 | # make dirs
37 | target_dirs = [FLAGS.interpolation_result_dir]
38 | create_dirs(target_dirs)
39 |
40 | # load test data
41 | LR_inference, LR_filenames = load_inference_data(FLAGS)
42 | LR_data = tf.placeholder(tf.float32, shape=[1, None, None, FLAGS.channel], name='LR_input')
43 |
44 | config = tf.ConfigProto()
45 | config.gpu_options.allow_growth = True
46 |
47 | #### Load pretrain model weights ####
48 | # build Generator
49 | network = Network(FLAGS, LR_data)
50 | gen_out = network.generator()
51 |
52 | with tf.Session(config=config) as sess:
53 | pre_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator'))
54 |
55 | if FLAGS.inference_checkpoint:
56 | pre_saver.restore(sess, os.path.join(FLAGS.pre_train_checkpoint_dir, FLAGS.inference_pretrain_checkpoint))
57 | else:
58 | print('No checkpoint is specified. The latest one is used for network interpolation')
59 | pre_saver.restore(sess, tf.train.latest_checkpoint(FLAGS.pre_train_checkpoint_dir))
60 |
61 | pretrain_weight = extract_weight(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator'))
62 |
63 | tf.reset_default_graph()
64 |
65 | #### Run network interpolation and inference ####
66 | LR_data = tf.placeholder(tf.float32, shape=[1, None, None, FLAGS.channel], name='LR_input')
67 |
68 | # build Generator
69 | network = Network(FLAGS, LR_data)
70 | gen_out = network.generator()
71 |
72 | fetches = {'gen_HR': gen_out}
73 |
74 | with tf.Session(config=config) as sess:
75 | print('Inference start')
76 | saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator'))
77 |
78 | if FLAGS.inference_checkpoint:
79 | saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, FLAGS.inference_checkpoint))
80 | else:
81 | print('No checkpoint is specified. The latest one is used for network interpolation')
82 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
83 |
84 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir))
85 |
86 | sess.run(interpolate_weight(FLAGS, pretrain_weight))
87 |
88 | for i, test_img in enumerate(LR_inference):
89 |
90 | feed_dict = {
91 | LR_data: test_img
92 | }
93 |
94 | result = sess.run(fetches=fetches, feed_dict=feed_dict)
95 |
96 | cv2.imwrite(os.path.join(FLAGS.interpolation_result_dir, LR_filenames[i]),
97 | de_normalize_image(np.squeeze(result['gen_HR'])))
98 |
99 | print('Inference end')
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ESRGAN (TensorFlow)
2 |
3 | This repository provides a TensorFlow implementation of the paper "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" by X. Wang et al.
4 |
5 | ## Dependencies
6 | tensorflow, openCV, sklearn, numpy
7 |
8 | The versions of my test environment :
9 | Python==3.6.8, tensorflow-gpu==1.12.0, openCV==4.1.0, scikit-learn==0.20.3, numpy==1.16.2
10 |
11 | ## How to Use
12 |
13 | #### 1. Prepare data for training
14 |
15 | Prepare your data and put them into the directory specified by the flag "data_dir"(e.g. './data/LSUN') of train.py. Other necessary directories are
16 | created automatically as set in the script.
17 |
18 | #### 2. Prepare data for training
19 | Run train.py script. The main processes are :
20 | - Data processing : create patches of HR and LR(by downsampling HR patches). These processed data can be saved in directories so that they can be recycled to use.
21 |
22 | - Pre-train with pixel-wise loss : As described in the paper, pre-training of Generator is done. You can set "pretrain_generator" flag to False to use an existing pre-trained checkpoint model. (training ESRGAN without pre-trained model is not supported.)
23 |
24 | - Training ESRGAN : based on pre-trained model, training ESRGAN is done
25 |
26 | ```
27 | # python train.py
28 |
29 | (data directory can be passed by the optional argument)
30 | # python train.py --data_dir ./data/LSUN
31 | ```
32 |
33 | #### 3. Inference LR data
34 | After training is finished, super-resolution of LR images is available. Input data can be specified "data_dir" of inference.py script.
35 |
36 | ```
37 | # python inference.py
38 |
39 | (data directory can be passed by the optional argument)
40 | # python inference.py --data_dir ./data/inference
41 | ```
42 |
43 | #### 4. Inference via Network interpolation
44 | The paper proposes the network interpolation method which linearly combines the weights of pixelwise-based pretrain model and ESRGAN generator. You can run this after training both pre-train model and ESRGAN finishes. Input data can be specified "data_dir" of network_interpolation.py script.
45 |
46 | ```
47 | # python network_interpolation.py
48 |
49 | (data directory can be passed by the optional argument)
50 | # python network_interpolation.py --data_dir ./data/inference
51 | ```
52 |
53 | ## Experiment Result
54 | #### DIV2K dataset
55 | DIV2K is a collection of 2K resolution high quality images.
56 | https://data.vision.ee.ethz.ch/cvl/DIV2K/
57 |
58 |
59 |
60 |
61 | from left to right: bicubic interpolation, ESRGAN, ESRGAN with network interpolation, High resolution(GT). 4x super resolution.
62 |
63 | #### LSUN
64 | LSUN is a collection of ordinaly resolution bedroom images.
65 | https://www.kaggle.com/jhoward/lsun_bedroom/data
66 |
67 |
68 |
69 |
70 | from left to right: bicubic interpolation, ESRGAN, ESRGAN with network interpolation, High resolution(GT). 4x super resolution.
71 |
72 | #### Experiment condition
73 | - training with 800 images and cropped 2 patches per image for DIV2K
74 | - training with about 5000 images from 20% collection dataset and cropped 2 patches per image for LSUN
75 | - apply data augmentation(horizontal flip and rotate by 90 degree)
76 | - 15 RRDBs, 32 batchsize, 50,000 iteration per training phase. Other parameters are the same as the paper.
77 | - Network interpolation parameter is 0.2
78 |
79 |
80 | ## Limitations
81 |
82 | - Only 4x super-resolution is supported
83 | - Grayscale images are not supported
84 | - Only Single GPU usage
85 |
86 |
87 | ## To do list
88 | The following features have not been implemented apart from the paper.
89 |
90 | - [x] Perceptual loss using VGG19(currently pixel-wise loss is implemented instead)
91 | - [x] Learning rate scheduling
92 | - [x] Network interpolation
93 | - [x] Data augmentation
94 | - [x] Evaluation metrics
95 |
96 | ### Notes
97 | Some setting parameters like the number of RRDB blocks, mini-batch size, the number of iteration are changed corresponding to my test environment.
98 | So, please change them if you would prefer the same condition as the paper.
99 |
100 |
101 | ## Reference
102 | * Paper
103 | Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, and Chen Change Loy : ESRGAN: Enhanced Super-ResolutionGenerative Adversarial Networks, ECCV, 2018. http://openaccess.thecvf.com/content_ECCVW_2018/papers/11133/Wang_ESRGAN_Enhanced_Super-Resolution_Generative_Adversarial_Networks_ECCVW_2018_paper.pdf
104 |
105 |
106 | * Official implementation with Pytorch by the paper's authors
107 | https://github.com/xinntao/BasicSR
108 |
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import glob
4 |
5 | import cv2
6 | import numpy as np
7 |
8 |
9 | def log(logflag, message, level='info'):
10 | """logging to stdout and logfile if flag is true"""
11 | print(message, flush=True)
12 |
13 | if logflag:
14 | if level == 'info':
15 | logging.info(message)
16 | elif level == 'warning':
17 | logging.warning(message)
18 | elif level == 'error':
19 | logging.error(message)
20 | elif level == 'critical':
21 | logging.critical(message)
22 |
23 |
24 | def create_dirs(target_dirs):
25 | """create necessary directories to save output files"""
26 | for dir_path in target_dirs:
27 | if not os.path.isdir(dir_path):
28 | os.makedirs(dir_path)
29 |
30 |
31 | def normalize_images(*arrays):
32 | """normalize input image arrays"""
33 | return [arr / 127.5 - 1 for arr in arrays]
34 |
35 |
36 | def de_normalize_image(image):
37 | """de-normalize input image array"""
38 | return (image + 1) * 127.5
39 |
40 |
41 | def save_image(FLAGS, images, phase, global_iter, save_max_num=5):
42 | """save images in specified directory"""
43 | if phase == 'train' or phase == 'pre-train':
44 | save_dir = FLAGS.train_result_dir
45 | elif phase == 'inference':
46 | save_dir = FLAGS.inference_result_dir
47 | save_max_num = len(images)
48 | else:
49 | print('specified phase is invalid')
50 |
51 | for i, img in enumerate(images):
52 | if i >= save_max_num:
53 | break
54 |
55 | cv2.imwrite(save_dir + '/{0}_HR_{1}_{2}.jpg'.format(phase, global_iter, i), de_normalize_image(img))
56 |
57 |
58 | def crop(img, FLAGS):
59 | """crop patch from an image with specified size"""
60 | img_h, img_w, _ = img.shape
61 |
62 | rand_h = np.random.randint(img_h - FLAGS.crop_size)
63 | rand_w = np.random.randint(img_w - FLAGS.crop_size)
64 |
65 | return img[rand_h:rand_h + FLAGS.crop_size, rand_w:rand_w + FLAGS.crop_size, :]
66 |
67 |
68 | def data_augmentation(LR_images, HR_images, aug_type='horizontal_flip'):
69 | """data augmentation. input arrays should be [N, H, W, C]"""
70 |
71 | if aug_type == 'horizontal_flip':
72 | return LR_images[:, :, ::-1, :], HR_images[:, :, ::-1, :]
73 | elif aug_type == 'rotation_90':
74 | return np.rot90(LR_images, k=1, axes=(1, 2)), np.rot90(HR_images, k=1, axes=(1, 2))
75 |
76 |
77 | def load_and_save_data(FLAGS, logflag):
78 | """make HR and LR data. And save them as npz files"""
79 | assert os.path.isdir(FLAGS.data_dir) is True, 'Directory specified by data_dir does not exist or is not a directory'
80 |
81 | all_file_path = glob.glob(FLAGS.data_dir + '/*')
82 | assert len(all_file_path) > 0, 'No file in the directory'
83 |
84 | ret_HR_image = []
85 | ret_LR_image = []
86 |
87 | for file in all_file_path:
88 | img = cv2.imread(file)
89 | filename = file.rsplit('/', 1)[-1]
90 |
91 | # crop patches if flag is true. Otherwise just resize HR and LR images
92 | if FLAGS.crop:
93 | for _ in range(FLAGS.num_crop_per_image):
94 | img_h, img_w, _ = img.shape
95 |
96 | if (img_h < FLAGS.crop_size) or (img_w < FLAGS.crop_size):
97 | print('Skip crop target image because of insufficient size')
98 | continue
99 |
100 | HR_image = crop(img, FLAGS)
101 | LR_crop_size = np.int(np.floor(FLAGS.crop_size / FLAGS.scale_SR))
102 | LR_image = cv2.resize(HR_image, (LR_crop_size, LR_crop_size), interpolation=cv2.INTER_LANCZOS4)
103 |
104 | cv2.imwrite(FLAGS.HR_data_dir + '/' + filename, HR_image)
105 | cv2.imwrite(FLAGS.LR_data_dir + '/' + filename, LR_image)
106 |
107 | ret_HR_image.append(HR_image)
108 | ret_LR_image.append(LR_image)
109 | else:
110 | HR_image = cv2.resize(img, (FLAGS.HR_image_size, FLAGS.HR_image_size), interpolation=cv2.INTER_LANCZOS4)
111 | LR_image = cv2.resize(img, (FLAGS.LR_image_size, FLAGS.LR_image_size), interpolation=cv2.INTER_LANCZOS4)
112 |
113 | cv2.imwrite(FLAGS.HR_data_dir + '/' + filename, HR_image)
114 | cv2.imwrite(FLAGS.LR_data_dir + '/' + filename, LR_image)
115 |
116 | ret_HR_image.append(HR_image)
117 | ret_LR_image.append(LR_image)
118 |
119 | assert len(ret_HR_image) > 0 and len(ret_LR_image) > 0, 'No availale image is found in the directory'
120 | log(logflag, 'Data process : {} images are processed'.format(len(ret_HR_image)), 'info')
121 |
122 | ret_HR_image = np.array(ret_HR_image)
123 | ret_LR_image = np.array(ret_LR_image)
124 |
125 | if FLAGS.data_augmentation:
126 | LR_flip, HR_flip = data_augmentation(ret_LR_image, ret_HR_image, aug_type='horizontal_flip')
127 | LR_rot, HR_rot = data_augmentation(ret_LR_image, ret_HR_image, aug_type='rotation_90')
128 |
129 | ret_LR_image = np.append(ret_LR_image, LR_flip, axis=0)
130 | ret_HR_image = np.append(ret_HR_image, HR_flip, axis=0)
131 | ret_LR_image = np.append(ret_LR_image, LR_rot, axis=0)
132 | ret_HR_image = np.append(ret_HR_image, HR_rot, axis=0)
133 |
134 | del LR_flip, HR_flip, LR_rot, HR_rot
135 |
136 | np.savez(FLAGS.npz_data_dir + '/' + FLAGS.HR_npz_filename, images=ret_HR_image)
137 | np.savez(FLAGS.npz_data_dir + '/' + FLAGS.LR_npz_filename, images=ret_LR_image)
138 |
139 | return ret_HR_image, ret_LR_image
140 |
141 |
142 | def load_npz_data(FLAGS):
143 | """load array data from data_path"""
144 | return np.load(FLAGS.npz_data_dir + '/' + FLAGS.HR_npz_filename)['images'], \
145 | np.load(FLAGS.npz_data_dir + '/' + FLAGS.LR_npz_filename)['images']
146 |
147 |
148 | def load_inference_data(FLAGS):
149 | """load data from directory for inference"""
150 | assert os.path.isdir(FLAGS.data_dir) is True, 'Directory specified by data_dir does not exist or is not a directory'
151 |
152 | all_file_path = glob.glob(FLAGS.data_dir + '/*')
153 | assert len(all_file_path) > 0, 'No file in the directory'
154 |
155 | ret_LR_image = []
156 | ret_filename = []
157 |
158 | for file in all_file_path:
159 | img = cv2.imread(file)
160 | img = normalize_images(img)
161 | ret_LR_image.append(img[0][np.newaxis, ...])
162 |
163 | ret_filename.append(file.rsplit('/', 1)[-1])
164 |
165 | assert len(ret_LR_image) > 0, 'No available image is found in the directory'
166 |
167 | return ret_LR_image, ret_filename
168 |
--------------------------------------------------------------------------------
/lib/network.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class Generator(object):
5 | """the definition of Generator"""
6 |
7 | def __init__(self, FLAGS):
8 | self.channel = FLAGS.channel
9 | self.n_filter = 64
10 | self.inc_filter = 32
11 | self.num_repeat_RRDB = FLAGS.num_repeat_RRDB
12 | self.residual_scaling = FLAGS.residual_scaling
13 | self.init_kernel = tf.initializers.he_normal(seed=FLAGS.initialization_random_seed)
14 |
15 | def _conv_RRDB(self, x, out_channel, num=None, activate=True):
16 | with tf.variable_scope('block_{0}'.format(num)):
17 | x = tf.layers.conv2d(x, out_channel, 3, 1, padding='same', kernel_initializer=self.init_kernel, name='conv')
18 | if activate:
19 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU')
20 |
21 | return x
22 |
23 | def _denseBlock(self, x, num=None):
24 | with tf.variable_scope('DenseBlock_sub{0}'.format(num)):
25 | x1 = self._conv_RRDB(x, self.inc_filter, 0)
26 | x2 = self._conv_RRDB(tf.concat([x, x1], axis=3), self.inc_filter, 1)
27 | x3 = self._conv_RRDB(tf.concat([x, x1, x2], axis=3), self.inc_filter, 2)
28 | x4 = self._conv_RRDB(tf.concat([x, x1, x2, x3], axis=3), self.inc_filter, 3)
29 | x5 = self._conv_RRDB(tf.concat([x, x1, x2, x3, x4], axis=3), self.n_filter, 4, activate=False)
30 |
31 | return x5 * self.residual_scaling
32 |
33 | def _RRDB(self, x, num=None):
34 | """Residual in Residual Dense Block"""
35 | with tf.variable_scope('RRDB_sub{0}'.format(num)):
36 | x_branch = tf.identity(x)
37 |
38 | x_branch += self._denseBlock(x_branch, 0)
39 | x_branch += self._denseBlock(x_branch, 1)
40 | x_branch += self._denseBlock(x_branch, 2)
41 |
42 | return x + x_branch * self.residual_scaling
43 |
44 | def _upsampling_layer(self, x, num=None):
45 | x = tf.layers.conv2d_transpose(x, self.n_filter, 3, 2, padding='same', name='upsample_{0}'.format(num))
46 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU')
47 |
48 | return x
49 |
50 | def build(self, x):
51 | with tf.variable_scope('first_conv'):
52 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel,
53 | name='conv')
54 |
55 | with tf.variable_scope('RRDB'):
56 | x_branch = tf.identity(x)
57 |
58 | for i in range(self.num_repeat_RRDB):
59 | x_branch = self._RRDB(x_branch, i)
60 |
61 | x_branch = tf.layers.conv2d(x_branch, self.n_filter, 3, 1, padding='same',
62 | kernel_initializer=self.init_kernel, name='trunk_conv')
63 |
64 | x += x_branch
65 |
66 | with tf.variable_scope('Upsampling'):
67 | x = self._upsampling_layer(x, 1)
68 | x = self._upsampling_layer(x, 2)
69 |
70 | with tf.variable_scope('last_conv'):
71 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', kernel_initializer=self.init_kernel,
72 | name='conv_1')
73 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU')
74 | x = tf.layers.conv2d(x, self.channel, 3, 1, padding='same', kernel_initializer=self.init_kernel,
75 | name='conv_2')
76 |
77 | return x
78 |
79 |
80 | class Discriminator(object):
81 | """the definition of Discriminator"""
82 |
83 | def __init__(self, FLAGS):
84 | self.channel = FLAGS.channel
85 | self.n_filter = 64
86 | self.inc_filter = 32
87 | self.init_kernel = tf.initializers.he_normal(seed=FLAGS.initialization_random_seed)
88 |
89 | def _conv_block(self, x, out_channel, num=None):
90 | with tf.variable_scope('block_{0}'.format(num)):
91 | x = tf.layers.conv2d(x, out_channel, 3, 1, padding='same', use_bias=False,
92 | kernel_initializer=self.init_kernel, name='conv_1')
93 | x = tf.layers.BatchNormalization(name='batch_norm_1')(x)
94 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1')
95 |
96 | x = tf.layers.conv2d(x, out_channel, 4, 2, padding='same', use_bias=False,
97 | kernel_initializer=self.init_kernel, name='conv_2')
98 | x = tf.layers.BatchNormalization(name='batch_norm_2')(x)
99 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2')
100 |
101 | return x
102 |
103 | def build(self, x):
104 | with tf.variable_scope('first_conv'):
105 | x = tf.layers.conv2d(x, self.n_filter, 3, 1, padding='same', use_bias=False,
106 | kernel_initializer=self.init_kernel, name='conv_1')
107 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1')
108 | x = tf.layers.conv2d(x, self.n_filter, 4, 2, padding='same', use_bias=False,
109 | kernel_initializer=self.init_kernel, name='conv_2')
110 | x = tf.layers.BatchNormalization(name='batch_norm_1')(x)
111 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_2')
112 |
113 | with tf.variable_scope('conv_block'):
114 | x = self._conv_block(x, self.n_filter * 2, 0)
115 | x = self._conv_block(x, self.n_filter * 4, 1)
116 | x = self._conv_block(x, self.n_filter * 8, 2)
117 | x = self._conv_block(x, self.n_filter * 8, 3)
118 |
119 | with tf.variable_scope('full_connected'):
120 | x = tf.layers.flatten(x)
121 | x = tf.layers.dense(x, 100, name='fully_connected_1')
122 | x = tf.nn.leaky_relu(x, alpha=0.2, name='leakyReLU_1')
123 | x = tf.layers.dense(x, 1, name='fully_connected_2')
124 |
125 | return x
126 |
127 |
128 | class Perceptual_VGG19(object):
129 | """the definition of VGG19. This network is used for constructing perceptual loss"""
130 | @staticmethod
131 | def build(x):
132 | # Block 1
133 | x = tf.layers.conv2d(x, 64, (3, 3), activation='relu', padding='same', name='block1_conv1')
134 | x = tf.layers.conv2d(x, 64, (3, 3), activation='relu', padding='same', name='block1_conv2')
135 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block1_pool')
136 |
137 | # Block 2
138 | x = tf.layers.conv2d(x, 128, (3, 3), activation='relu', padding='same', name='block2_conv1')
139 | x = tf.layers.conv2d(x, 128, (3, 3), activation='relu', padding='same', name='block2_conv2')
140 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block2_pool')
141 |
142 | # Block 3
143 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv1')
144 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv2')
145 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv3')
146 | x = tf.layers.conv2d(x, 256, (3, 3), activation='relu', padding='same', name='block3_conv4')
147 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block3_pool')
148 |
149 | # Block 4
150 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv1')
151 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv2')
152 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv3')
153 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block4_conv4')
154 | x = tf.layers.max_pooling2d(x, (2, 2), strides=(2, 2), name='block4_pool')
155 |
156 | # Block 5
157 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv1')
158 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv2')
159 | x = tf.layers.conv2d(x, 512, (3, 3), activation='relu', padding='same', name='block5_conv3')
160 | x = tf.layers.conv2d(x, 512, (3, 3), activation=None, padding='same', name='block5_conv4')
161 |
162 | return x
163 |
--------------------------------------------------------------------------------
/lib/train_module.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import tensorflow as tf
4 |
5 | from lib.network import Generator, Discriminator, Perceptual_VGG19
6 |
7 |
8 | class Network(object):
9 | """class to build networks"""
10 | def __init__(self, FLAGS, LR_data=None, HR_data=None):
11 | self.FLAGS = FLAGS
12 | self.LR_data = LR_data
13 | self.HR_data = HR_data
14 |
15 | def generator(self):
16 | with tf.name_scope('generator'):
17 | with tf.variable_scope('generator'):
18 | gen_out = Generator(self.FLAGS).build(self.LR_data)
19 |
20 | return gen_out
21 |
22 | def discriminator(self, gen_out):
23 | discriminator = Discriminator(self.FLAGS)
24 |
25 | with tf.name_scope('real_discriminator'):
26 | with tf.variable_scope('discriminator', reuse=False):
27 | dis_out_real = discriminator.build(self.HR_data)
28 |
29 | with tf.name_scope('fake_discriminator'):
30 | with tf.variable_scope('discriminator', reuse=True):
31 | dis_out_fake = discriminator.build(gen_out)
32 |
33 | return dis_out_real, dis_out_fake
34 |
35 |
36 | class Loss(object):
37 | """class to build loss functions"""
38 | def __init__(self):
39 | self.summary_target = OrderedDict()
40 |
41 | def pretrain_loss(self, pre_gen_out, HR_data):
42 | with tf.name_scope('loss_function'):
43 | with tf.variable_scope('pixel-wise_loss'):
44 | pre_gen_loss = tf.reduce_mean(tf.reduce_mean(tf.square(pre_gen_out - HR_data), axis=3))
45 |
46 | self.summary_target['pre-train : pixel-wise_loss'] = pre_gen_loss
47 | return pre_gen_loss
48 |
49 | def _perceptual_vgg19_loss(self, HR_data, gen_out):
50 | with tf.name_scope('perceptual_vgg19_HR'):
51 | with tf.variable_scope('perceptual_vgg19', reuse=False):
52 | vgg_out_hr = Perceptual_VGG19().build(HR_data)
53 |
54 | with tf.name_scope('perceptual_vgg19_Gen'):
55 | with tf.variable_scope('perceptual_vgg19', reuse=True):
56 | vgg_out_gen = Perceptual_VGG19().build(gen_out)
57 |
58 | return vgg_out_hr, vgg_out_gen
59 |
60 | def gan_loss(self, FLAGS, HR_data, gen_out, dis_out_real, dis_out_fake):
61 |
62 | with tf.name_scope('loss_function'):
63 | with tf.variable_scope('loss_generator'):
64 | if FLAGS.gan_loss_type == 'RaGAN':
65 | g_loss_p1 = tf.reduce_mean(
66 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_real - tf.reduce_mean(dis_out_fake),
67 | labels=tf.zeros_like(dis_out_real)))
68 |
69 | g_loss_p2 = tf.reduce_mean(
70 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake - tf.reduce_mean(dis_out_real),
71 | labels=tf.ones_like(dis_out_fake)))
72 |
73 | gen_loss = FLAGS.gan_loss_coeff * (g_loss_p1 + g_loss_p2) / 2
74 | elif FLAGS.gan_loss_type == 'GAN':
75 | gen_loss = FLAGS.gan_loss_coeff * tf.reduce_mean(
76 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake, labels=tf.ones_like(dis_out_fake)))
77 | else:
78 | raise ValueError('Unknown GAN loss function type')
79 |
80 | # content loss : L1 distance
81 | content_loss = FLAGS.content_loss_coeff * tf.reduce_mean(
82 | tf.reduce_sum(tf.abs(gen_out - HR_data), axis=[1, 2, 3]))
83 |
84 | gen_loss += content_loss
85 |
86 | # perceptual loss
87 | if FLAGS.perceptual_loss == 'pixel-wise':
88 | perc_loss = tf.reduce_mean(tf.reduce_mean(tf.square(gen_out - HR_data), axis=3))
89 | gen_loss += perc_loss
90 | elif FLAGS.perceptual_loss == 'VGG19':
91 | vgg_out_gen, vgg_out_hr = self._perceptual_vgg19_loss(HR_data, gen_out)
92 | perc_loss = tf.reduce_mean(tf.reduce_mean(tf.square(vgg_out_gen - vgg_out_hr), axis=3))
93 | gen_loss += perc_loss
94 | else:
95 | raise ValueError('Unknown perceptual loss type')
96 |
97 | with tf.variable_scope('loss_discriminator'):
98 | if FLAGS.gan_loss_type == 'RaGAN':
99 | d_loss_real = tf.reduce_mean(
100 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_real - tf.reduce_mean(dis_out_fake),
101 | labels=tf.ones_like(dis_out_real))) / 2
102 |
103 | d_loss_fake = tf.reduce_mean(
104 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake - tf.reduce_mean(dis_out_real),
105 | labels=tf.zeros_like(dis_out_fake))) / 2
106 |
107 | dis_loss = d_loss_real + d_loss_fake
108 | elif FLAGS.gan_loss_type == 'GAN':
109 | d_loss_real = tf.reduce_mean(
110 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_real, labels=tf.ones_like(dis_out_real)))
111 | d_loss_fake = tf.reduce_mean(
112 | tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_out_fake,
113 | labels=tf.zeros_like(dis_out_fake)))
114 |
115 | dis_loss = d_loss_real + d_loss_fake
116 |
117 | else:
118 | raise ValueError('Unknown GAN loss function type')
119 |
120 | self.summary_target['generator_loss'] = gen_loss
121 | self.summary_target['content_loss'] = content_loss
122 | self.summary_target['perceptual_loss'] = perc_loss
123 | self.summary_target['discriminator_loss'] = dis_loss
124 | self.summary_target['discriminator_real_loss'] = d_loss_real
125 | self.summary_target['discriminator_fake_loss'] = d_loss_fake
126 |
127 | return gen_loss, dis_loss
128 |
129 | def add_summary_writer(self):
130 | return [tf.summary.scalar(key, value) for key, value in self.summary_target.items()]
131 |
132 |
133 | class Optimizer(object):
134 | """class to build optimizers"""
135 | @staticmethod
136 | def pretrain_optimizer(FLAGS, global_iter, pre_gen_loss):
137 | learning_rate = tf.train.exponential_decay(FLAGS.pretrain_learning_rate, global_iter,
138 | FLAGS.pretrain_lr_decay_step, 0.5, staircase=True)
139 |
140 | with tf.name_scope('optimizer'):
141 | with tf.variable_scope('optimizer_generator'):
142 | pre_gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
143 | pre_gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=pre_gen_loss,
144 | global_step=global_iter,
145 | var_list=pre_gen_var)
146 |
147 | return pre_gen_var, pre_gen_optimizer
148 |
149 | @staticmethod
150 | def gan_optimizer(FLAGS, global_iter, dis_loss, gen_loss):
151 | boundaries = [50000, 100000, 200000, 300000]
152 | values = [FLAGS.learning_rate, FLAGS.learning_rate * 0.5, FLAGS.learning_rate * 0.5 ** 2,
153 | FLAGS.learning_rate * 0.5 ** 3, FLAGS.learning_rate * 0.5 ** 4]
154 | learning_rate = tf.train.piecewise_constant(global_iter, boundaries, values)
155 |
156 | with tf.name_scope('optimizer'):
157 | with tf.variable_scope('optimizer_discriminator'):
158 | dis_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
159 | dis_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=dis_loss,
160 | var_list=dis_var)
161 |
162 | with tf.variable_scope('optimizer_generator'):
163 | with tf.control_dependencies([dis_optimizer]):
164 | gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
165 | gen_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss=gen_loss,
166 | global_step=global_iter,
167 | var_list=gen_var)
168 |
169 | return dis_var, dis_optimizer, gen_var, gen_optimizer
170 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import gc
3 | import logging
4 | import math
5 | import os
6 |
7 | import tensorflow as tf
8 | from sklearn.utils import shuffle
9 |
10 | from lib.ops import load_vgg19_weight
11 | from lib.pretrain_generator import train_pretrain_generator
12 | from lib.train_module import Network, Loss, Optimizer
13 | from lib.utils import create_dirs, log, normalize_images, save_image, load_npz_data, load_and_save_data
14 |
15 |
16 | def set_flags():
17 | Flags = tf.app.flags
18 |
19 | # About data
20 | Flags.DEFINE_string('data_dir', './data/DIV2K_train_HR', 'data directory')
21 | Flags.DEFINE_string('HR_data_dir', './data/HR_data', 'HR data directory')
22 | Flags.DEFINE_string('LR_data_dir', './data/LR_data', 'LR data directory')
23 | Flags.DEFINE_string('npz_data_dir', './data/npz', 'The npz data dir')
24 | Flags.DEFINE_string('HR_npz_filename', 'HR_image.npz', 'the filename of HR image npz file')
25 | Flags.DEFINE_string('LR_npz_filename', 'LR_image.npz', 'the filename of LR image npz file')
26 | Flags.DEFINE_boolean('save_data', True, 'Whether to load and save data as npz file')
27 | Flags.DEFINE_string('train_result_dir', './train_result', 'output directory during training')
28 | Flags.DEFINE_boolean('crop', True, 'Whether image cropping is enabled')
29 | Flags.DEFINE_integer('crop_size', 128, 'the size of crop of training HR images')
30 | Flags.DEFINE_integer('num_crop_per_image', 2, 'the number of random-cropped images per image')
31 | Flags.DEFINE_boolean('data_augmentation', True, 'whether to augment data')
32 |
33 | # About Network
34 | Flags.DEFINE_integer('scale_SR', 4, 'the scale of super-resolution')
35 | Flags.DEFINE_integer('num_repeat_RRDB', 15, 'The number of repeats of RRDB blocks')
36 | Flags.DEFINE_float('residual_scaling', 0.2, 'residual scaling parameter')
37 | Flags.DEFINE_integer('initialization_random_seed', 111, 'random seed of networks initialization')
38 | Flags.DEFINE_string('perceptual_loss', 'VGG19', 'the part of loss function. "VGG19" or "pixel-wise"')
39 | Flags.DEFINE_string('gan_loss_type', 'RaGAN', 'the type of GAN loss functions. "RaGAN or GAN"')
40 |
41 | # About training
42 | Flags.DEFINE_integer('num_iter', 50000, 'The number of iterations')
43 | Flags.DEFINE_integer('batch_size', 32, 'Mini-batch size')
44 | Flags.DEFINE_integer('channel', 3, 'Number of input/output image channel')
45 | Flags.DEFINE_boolean('pretrain_generator', True, 'Whether to pretrain generator')
46 | Flags.DEFINE_float('pretrain_learning_rate', 2e-4, 'learning rate for pretrain')
47 | Flags.DEFINE_float('pretrain_lr_decay_step', 20000, 'decay by every n iteration')
48 | Flags.DEFINE_float('learning_rate', 1e-4, 'learning rate')
49 | Flags.DEFINE_float('weight_initialize_scale', 0.1, 'scale to multiply after MSRA initialization')
50 | Flags.DEFINE_integer('HR_image_size', 128,
51 | 'Image width and height of HR image. This flag is valid when crop flag is set to false.')
52 | Flags.DEFINE_integer('LR_image_size', 32,
53 | 'Image width and height of LR image. This size should be 1/4 of HR_image_size exactly. '
54 | 'This flag is valid when crop flag is set to false.')
55 | Flags.DEFINE_float('epsilon', 1e-12, 'used in loss function')
56 | Flags.DEFINE_float('gan_loss_coeff', 0.005, 'used in perceptual loss')
57 | Flags.DEFINE_float('content_loss_coeff', 0.01, 'used in content loss')
58 |
59 | # About log
60 | Flags.DEFINE_boolean('logging', True, 'whether to record training log')
61 | Flags.DEFINE_integer('train_sample_save_freq', 2000, 'save samples during training every n iteration')
62 | Flags.DEFINE_integer('train_ckpt_save_freq', 2000, 'save checkpoint during training every n iteration')
63 | Flags.DEFINE_integer('train_summary_save_freq', 200, 'save summary during training every n iteration')
64 | Flags.DEFINE_string('pre_train_checkpoint_dir', './pre_train_checkpoint', 'pre-train checkpoint directory')
65 | Flags.DEFINE_string('checkpoint_dir', './checkpoint', 'checkpoint directory')
66 | Flags.DEFINE_string('logdir', './log', 'log directory')
67 |
68 | # About GPU setting
69 | Flags.DEFINE_string('gpu_dev_num', '0', 'Which GPU to use for multi-GPUs.')
70 |
71 | return Flags.FLAGS
72 |
73 |
74 | def set_logger(FLAGS):
75 | """set logger for training recording"""
76 | if FLAGS.logging:
77 | logfile = '{0}/training_logfile_{1}.log'.format(FLAGS.logdir, datetime.now().strftime("%Y%m%d_%H%M%S"))
78 | formatter = '%(levelname)s:%(asctime)s:%(message)s'
79 | logging.basicConfig(level=logging.INFO, filename=logfile, format=formatter, datefmt='%Y-%m-%d %I:%M:%S')
80 |
81 | return True
82 | else:
83 | print('No logging is set')
84 | return False
85 |
86 |
87 | def main():
88 | # set flag
89 | FLAGS = set_flags()
90 |
91 | # make dirs
92 | target_dirs = [FLAGS.HR_data_dir, FLAGS.LR_data_dir, FLAGS.npz_data_dir, FLAGS.train_result_dir,
93 | FLAGS.pre_train_checkpoint_dir, FLAGS.checkpoint_dir, FLAGS.logdir]
94 | create_dirs(target_dirs)
95 |
96 | # set logger
97 | logflag = set_logger(FLAGS)
98 | log(logflag, 'Training script start', 'info')
99 |
100 | # load data
101 | if FLAGS.save_data:
102 | log(logflag, 'Data process : Data processing start', 'info')
103 | HR_train, LR_train = load_and_save_data(FLAGS, logflag)
104 | log(logflag, 'Data process : Data loading and data processing are completed', 'info')
105 | else:
106 | log(logflag, 'Data process : Data loading start', 'info')
107 | HR_train, LR_train = load_npz_data(FLAGS)
108 | log(logflag,
109 | 'Data process : Loading existing data is completed. {} images are loaded'.format(len(HR_train)),
110 | 'info')
111 |
112 | # pre-train generator with pixel-wise loss and save the trained model
113 | if FLAGS.pretrain_generator:
114 | train_pretrain_generator(FLAGS, LR_train, HR_train, logflag)
115 | tf.reset_default_graph()
116 | gc.collect()
117 | else:
118 | log(logflag, 'Pre-train : Pre-train skips and an existing trained model will be used', 'info')
119 |
120 | LR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.LR_image_size, FLAGS.LR_image_size, FLAGS.channel],
121 | name='LR_input')
122 | HR_data = tf.placeholder(tf.float32, shape=[None, FLAGS.HR_image_size, FLAGS.HR_image_size, FLAGS.channel],
123 | name='HR_input')
124 |
125 | # build Generator and Discriminator
126 | network = Network(FLAGS, LR_data, HR_data)
127 | gen_out = network.generator()
128 | dis_out_real, dis_out_fake = network.discriminator(gen_out)
129 |
130 | # build loss function
131 | loss = Loss()
132 | gen_loss, dis_loss = loss.gan_loss(FLAGS, HR_data, gen_out, dis_out_real, dis_out_fake)
133 |
134 | # define optimizers
135 | global_iter = tf.Variable(0, trainable=False)
136 | dis_var, dis_optimizer, gen_var, gen_optimizer = Optimizer().gan_optimizer(FLAGS, global_iter, dis_loss, gen_loss)
137 |
138 | # build summary writer
139 | tr_summary = tf.summary.merge(loss.add_summary_writer())
140 |
141 | num_train_data = len(HR_train)
142 | num_batch_in_train = int(math.floor(num_train_data / FLAGS.batch_size))
143 | num_epoch = int(math.ceil(FLAGS.num_iter / num_batch_in_train))
144 |
145 | HR_train, LR_train = normalize_images(HR_train, LR_train)
146 |
147 | fetches = {'dis_optimizer': dis_optimizer, 'gen_optimizer': gen_optimizer,
148 | 'dis_loss': dis_loss, 'gen_loss': gen_loss,
149 | 'gen_HR': gen_out,
150 | 'summary': tr_summary
151 | }
152 |
153 | gc.collect()
154 |
155 | config = tf.ConfigProto(
156 | gpu_options=tf.GPUOptions(
157 | allow_growth=True,
158 | visible_device_list=FLAGS.gpu_dev_num
159 | )
160 | )
161 |
162 | # Start Session
163 | with tf.Session(config=config) as sess:
164 | log(logflag, 'Training ESRGAN starts', 'info')
165 |
166 | sess.run(tf.global_variables_initializer())
167 | sess.run(global_iter.initializer)
168 |
169 | writer = tf.summary.FileWriter(FLAGS.logdir, graph=sess.graph)
170 |
171 | pre_saver = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator'))
172 | pre_saver.restore(sess, tf.train.latest_checkpoint(FLAGS.pre_train_checkpoint_dir))
173 |
174 | if FLAGS.perceptual_loss == 'VGG19':
175 | sess.run(load_vgg19_weight(FLAGS))
176 |
177 | saver = tf.train.Saver(max_to_keep=10)
178 |
179 | for epoch in range(num_epoch):
180 | log(logflag, 'ESRGAN Epoch: {0}'.format(epoch), 'info')
181 | HR_train, LR_train = shuffle(HR_train, LR_train, random_state=222)
182 |
183 | for iteration in range(num_batch_in_train):
184 | current_iter = tf.train.global_step(sess, global_iter)
185 | if current_iter > FLAGS.num_iter:
186 | break
187 |
188 | feed_dict = {
189 | HR_data: HR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size],
190 | LR_data: LR_train[iteration * FLAGS.batch_size:iteration * FLAGS.batch_size + FLAGS.batch_size]
191 | }
192 |
193 | # update weights of G/D
194 | result = sess.run(fetches=fetches, feed_dict=feed_dict)
195 |
196 | # save summary every n iter
197 | if current_iter % FLAGS.train_summary_save_freq == 0:
198 | writer.add_summary(result['summary'], global_step=current_iter)
199 |
200 | # save samples every n iter
201 | if current_iter % FLAGS.train_sample_save_freq == 0:
202 | log(logflag,
203 | 'ESRGAN iteration : {0}, gen_loss : {1}, dis_loss : {2}'.format(current_iter,
204 | result['gen_loss'],
205 | result['dis_loss']),
206 | 'info')
207 |
208 | save_image(FLAGS, result['gen_HR'], 'train', current_iter, save_max_num=5)
209 |
210 | if current_iter % FLAGS.train_ckpt_save_freq == 0:
211 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'gen'), global_step=current_iter)
212 |
213 | writer.close()
214 | log(logflag, 'Training ESRGAN end', 'info')
215 | log(logflag, 'Training script end', 'info')
216 |
217 |
218 | if __name__ == '__main__':
219 | main()
220 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------