├── LICENSE ├── README.md ├── SR_models.py ├── dataset ├── SR_data_load.py ├── __init__.py └── data_util.py ├── losses.py ├── outputs ├── ENet-E.png ├── ENet-PAT.png └── Input.png ├── test_SR.py ├── train_SR.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Geonmo Gu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EnhanceNet 2 | 3 | Tensorflow implementation of EnhanceNet for a magnification ratio of 4. 4 | 5 | **We slightly changed the procedure for training Enet as followings** 6 | + Discriminator has been changed like DCGAN. 7 | + We only used ```pool5_4 and conv3_1 features``` from VGG-19. See losses.py 8 | + So, we changed hyper-parameters for loss combination. 9 | 10 | ### Results 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 |
InputEnet-EEnet-PAT
23 | 24 | ### How to train? 25 | 26 | 1. Download COCO_train_DB for training ENet and unzip train2017.zip 27 | ``` 28 | wget http://images.cocodataset.org/zips/train2017.zip 29 | unzip train2017.zip 30 | ``` 31 | 32 | 2. Download VGG-19 slim model and untar 33 | ``` 34 | wget http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz 35 | tar xvzf vgg_19_2016_08_28.tar.gz 36 | ``` 37 | 38 | 3. Do train! 39 | ``` 40 | # ENet-E 41 | python3 train_SR.py --model=enhancenet --upsample=nearest \ 42 | --recon_type=residual --SR_scale=4 --run_gpu=0 \ 43 | --batch_size=32 --num_readers=4 --input_size=32 \ 44 | --losses='mse' \ 45 | --learning_rate=0.0001 \ 46 | --save_path=/your/models/will/be/saved \ 47 | --image_path=/where/is/your/COCODB/train2017/*.jpg 48 | 49 | # ENet-PAT 50 | python3 train_SR.py --model=enhancenet --upsample=nearest \ 51 | --recon_type=residual --SR_scale=4 --run_gpu=0 \ 52 | --batch_size=32 --num_readers=4 --input_size=32 \ 53 | --losses='perceptual,texture,adv' --adv_ver=ver2 \ 54 | --adv_gen_w=0.003 --learning_rate=0.0001 \ 55 | --save_path=/your/models/will/be/saved \ 56 | --image_path=/where/is/your/COCODB/train2017/*.jpg \ 57 | --vgg_path=/where/is/your/vgg19/vgg_19.ckpt 58 | ``` 59 | 60 | ### How to test? 61 | 62 | ``` 63 | python3 test_SR.py --model_path=/your/pretrained/model/folder \ 64 | --image_path=/your/image/folder \ 65 | --save_path=/generated_image/will/be/saved/here \ 66 | --run_gpu=0 67 | ``` 68 | 69 | ### Reference 70 | ``` 71 | @inproceedings{enhancenet, 72 | title={{EnhanceNet: Single Image Super-Resolution through Automated Texture Synthesis}}, 73 | author={Sajjadi, Mehdi S. M. and Sch{\"o}lkopf, Bernhard and Hirsch, Michael}, 74 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on}, 75 | pages={4501--4510}, 76 | year={2017}, 77 | organization={IEEE}, 78 | url={https://arxiv.org/abs/1612.07919/} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /SR_models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.contrib import slim 4 | 5 | tf.app.flags.DEFINE_string('upsample', 'nearest', 'nearest, bilinear, or pixelShuffler') 6 | tf.app.flags.DEFINE_string('model', 'enhancenet', 'for now, only enhancenet supported') 7 | tf.app.flags.DEFINE_string('recon_type', 'residual', 'residual or direct') 8 | tf.app.flags.DEFINE_boolean('use_bn', False, 'for res_block_bn') 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | 12 | class model_builder: 13 | def __init__(self): 14 | return 15 | 16 | def preprocess(self, images): 17 | pp_images = images / 255.0 18 | ## simple mean shift 19 | pp_images = pp_images * 2.0 - 1.0 20 | 21 | return pp_images 22 | 23 | def postprocess(self, images): 24 | pp_images = ((images + 1.0) / 2.0) * 255.0 25 | 26 | return pp_images 27 | 28 | def tf_nn_lrelu(self, inputs, a=0.2): 29 | with tf.name_scope('lrelu'): 30 | x = tf.identity(inputs) 31 | return (0.5 * (1.0 + a)) * x + (0.5 * (1.0 - a)) * tf.abs(x) 32 | 33 | def tf_nn_prelu(self, inputs, scope): 34 | # scope like 'prelu_1', 'prelu_2', ... 35 | with tf.variable_scope(scope): 36 | alphas = tf.get_variable('alpha', inputs.get_shape()[-1], initializer=tf.zeros_initializer(), dtype=tf.float32) 37 | pos = tf.nn.relu(inputs) 38 | neg = alphas * (inputs - tf.abs(inputs)) * 0.5 39 | 40 | return pos + neg 41 | 42 | def res_block(self, features, out_ch, scope): 43 | input_features = features 44 | with tf.variable_scope(scope): 45 | features = slim.conv2d(input_features, out_ch, 3, activation_fn=tf.nn.relu, normalizer_fn=None) 46 | features = slim.conv2d(features, out_ch, 3, activation_fn=None, normalizer_fn=None) 47 | 48 | return input_features + features 49 | 50 | def res_block_bn(self, features, out_ch, is_training, scope): # bn-relu-conv!!! 51 | batch_norm_params = { 52 | 'decay': 0.997, 53 | 'epsilon': 1e-5, 54 | 'scale': True, 55 | 'is_training': is_training 56 | } 57 | 58 | # input_features already gone through bn-relu 59 | input_features = features 60 | with tf.variable_scope(scope): 61 | features = slim.conv2d(input_features, out_ch, 3, activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params) 62 | features = slim.conv2d(features, out_ch, 3, activation_fn=None, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params) 63 | 64 | return input_features + features 65 | 66 | def phaseShift(self, features, scale, shape_1, shape_2): 67 | X = tf.reshape(features, shape_1) 68 | X = tf.transpose(X, [0, 1, 3, 2, 4]) 69 | 70 | return tf.reshape(X, shape_2) 71 | 72 | def pixelShuffler(self, features, scale=2): 73 | size = tf.shape(features) 74 | batch_size = size[0] 75 | h = size[1] 76 | w = size[2] 77 | c = features.get_shape().as_list()[-1]#size[3] 78 | 79 | channel_target = c // (scale * scale) 80 | channel_factor = c // channel_target 81 | 82 | shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale] 83 | shape_2 = [batch_size, h * scale, w * scale, 1] 84 | 85 | input_split = tf.split(axis=3, num_or_size_splits=channel_target, value=features) #features, channel_target, axis=3) 86 | output = tf.concat([self.phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3) 87 | 88 | return output 89 | 90 | def upsample(self, features, rate=2): 91 | if FLAGS.upsample == 'nearest': 92 | return tf.image.resize_nearest_neighbor(features, size=[rate * tf.shape(features)[1], rate * tf.shape(features)[2]]) 93 | elif FLAGS.upsample == 'bilinear': 94 | return tf.image.resize_bilinear(features, size=[rate * tf.shape(features)[1], rate * tf.shape(features)[2]]) 95 | else: #pixelShuffler 96 | return self.pixelShuffler(features, scale=2) 97 | 98 | def recon_image(self, inputs, outputs): 99 | ''' 100 | LR to HR -> inputs: LR, outputs: HR 101 | HR to LR -> inputs: HR, outputs: LR 102 | ''' 103 | resized_inputs = tf.image.resize_bicubic(inputs, size=[tf.shape(outputs)[1], tf.shape(outputs)[2]]) 104 | if FLAGS.recon_type == 'residual': 105 | recon_outputs = resized_inputs + outputs 106 | else: 107 | recon_outputs = outputs 108 | 109 | resized_inputs = self.postprocess(resized_inputs) 110 | resized_inputs = tf.cast(tf.clip_by_value(resized_inputs, 0, 255), tf.uint8) 111 | #tf.summary.image('4_bicubic image', resized_inputs) 112 | 113 | recon_outputs = self.postprocess(recon_outputs) 114 | 115 | return recon_outputs, resized_inputs 116 | 117 | ### model part 118 | ''' 119 | list: 120 | enhancenet 121 | ''' 122 | def enhancenet(self, inputs, is_training): 123 | with slim.arg_scope([slim.conv2d], 124 | activation_fn=tf.nn.relu, 125 | normalizer_fn=None): 126 | 127 | features = slim.conv2d(inputs, 64, 3, scope='conv1') 128 | 129 | for idx in range(10): 130 | if FLAGS.use_bn: 131 | features = self.res_block_bn(features, out_ch=64, is_training=is_training, scope='res_block_bn_%d' % (idx)) 132 | else: 133 | features = self.res_block(features, out_ch=64, scope='res_block_%d' % (idx)) 134 | 135 | features = self.upsample(features) 136 | features = slim.conv2d(features, 64, 3, scope='conv2') 137 | 138 | features = self.upsample(features) 139 | features = slim.conv2d(features, 64, 3, scope='conv3') 140 | features = slim.conv2d(features, 64, 3, scope='conv4') 141 | outputs = slim.conv2d(features, 3, 3, activation_fn=None, scope='conv5') 142 | 143 | return outputs 144 | 145 | ########## Let's enhance our method! 146 | 147 | def generator(self, inputs, is_training, model='enhancenet'): 148 | ''' 149 | LR to HR 150 | ''' 151 | 152 | inputs = self.preprocess(inputs) 153 | 154 | with tf.variable_scope('generator'): 155 | if model == 'enhancenet': 156 | outputs = self.enhancenet(inputs, is_training) 157 | 158 | outputs, resized_inputs = self.recon_image(inputs, outputs) 159 | 160 | return outputs, resized_inputs 161 | 162 | ### test part 163 | 164 | if __name__ == '__main__': 165 | 166 | batch_size = 64 167 | h = 512 168 | w = 512 169 | c = 3 # rgb 170 | 171 | high_images = np.zeros([batch_size, h, w, c]) # gt 172 | low_images = np.zeros([batch_size, int(h/4), int(w/4), c]) 173 | 174 | input_high_images = tf.placeholder(tf.float32, shape=[batch_size, h, w, c], name='input_high_images') 175 | input_low_images = tf.placeholder(tf.float32, shape=[batch_size, int(h/4), int(w/4), c], name='input_low_images') 176 | 177 | model_builder = model_builder() 178 | 179 | outputs = model_builder.generator(input_low_images) 180 | 181 | print(outputs) 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /dataset/SR_data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import glob 4 | import cv2 5 | import random 6 | import numpy as np 7 | import tensorflow as tf 8 | try: 9 | import data_util 10 | except ImportError: 11 | from dataset import data_util 12 | 13 | tf.app.flags.DEFINE_float('SR_scale', 4.0, '') 14 | tf.app.flags.DEFINE_boolean('random_resize', False, 'True or False') 15 | tf.app.flags.DEFINE_string('load_mode', 'real', 'real or text') # real -> COCO DB 16 | FLAGS = tf.app.flags.FLAGS 17 | 18 | ''' 19 | image_path = '/where/your/images/*.jpg' 20 | ''' 21 | 22 | def load_image(im_fn, hr_size): 23 | high_image = cv2.imread(im_fn, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)[:,:,::-1] # rgb converted 24 | 25 | resize_scale = 1 / FLAGS.SR_scale 26 | 27 | ''' 28 | if FLAGS.random_resize: 29 | resize_table = [0.5, 1.0, 1.5, 2.0] 30 | selected_scale = np.random.choice(resize_table, 1)[0] 31 | shrinked_hr_size = int(hr_size / selected_scale) 32 | 33 | h, w, _ = high_image.shape 34 | if h <= shrinked_hr_size or w <= shrinked_hr_size: 35 | high_image = cv2.resize(high_image, (hr_size, hr_size)) 36 | else: 37 | h_edge = h - shrinked_hr_size 38 | w_edge = w - shrinked_hr_size 39 | h_start = np.random.randint(low=0, high=h_edge, size=1)[0] 40 | w_start = np.random.randint(low=0, high=w_edge, size=1)[0] 41 | high_image_crop = high_image[h_start:h_start+hr_size, w_start:w_start+hr_size, :] 42 | high_image = cv2.resize(high_image_crop, (hr_size, hr_size)) 43 | ''' 44 | h, w, _ = high_image.shape 45 | if h <= hr_size or w <= hr_size: 46 | high_image = cv2.resize(high_image, (hr_size, hr_size), 47 | interpolation=cv2.INTER_AREA) 48 | else: 49 | h_edge = h - hr_size 50 | w_edge = w - hr_size 51 | h_start = np.random.randint(low=0, high=h_edge, size=1)[0] 52 | w_start = np.random.randint(low=0, high=w_edge, size=1)[0] 53 | high_image = high_image[h_start:h_start+hr_size, w_start:w_start+hr_size, :] 54 | 55 | low_image = cv2.resize(high_image, (0, 0), fx=resize_scale, fy=resize_scale) 56 | #interpolation=cv2.INTER_AREA) 57 | return high_image, low_image 58 | 59 | def load_txtimage(im_fn, hr_size, batch_size): 60 | high_image = cv2.imread(im_fn, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)[:,:,::-1] 61 | resize_scale = 1 / FLAGS.SR_scale 62 | 63 | lr_size = int(hr_size * resize_scale) 64 | 65 | h, w, _ = high_image.shape 66 | 67 | hr_batch = np.zeros([batch_size, hr_size, hr_size, 3], dtype='float32') 68 | lr_batch = np.zeros([batch_size, lr_size, lr_size, 3], dtype='float32') 69 | 70 | h_edge = h - hr_size 71 | w_edge = w - hr_size 72 | 73 | passed_idx = 0 74 | max_iter = 200 75 | iter_idx = 0 76 | while passed_idx < batch_size: 77 | h_start = np.random.randint(low=0, high=h_edge, size=1)[0] 78 | w_start = np.random.randint(low=0, high=w_edge, size=1)[0] 79 | 80 | crop_hr_image = high_image[h_start:h_start + hr_size, w_start:w_start+hr_size,:] 81 | 82 | if np.mean(crop_hr_image) < 250.0: 83 | hr_batch[passed_idx,:,:,:] = crop_hr_image.copy() 84 | crop_lr_image = cv2.resize(crop_hr_image, (0, 0), fx=0.25, fy=0.25, 85 | interpolation=cv2.INTER_AREA) 86 | lr_batch[passed_idx,:,:,:] = crop_lr_image.copy() 87 | passed_idx += 1 88 | else: 89 | iter_idx += 1 90 | 91 | if iter_idx == max_iter: 92 | crop_lr_image = cv2.resize(crop_hr_image, (0, 0), fx=0.25, fy=0.25, 93 | interpolation=cv2.INTER_AREA) 94 | while passed_idx < batch_size: 95 | hr_batch[passed_idx,:,:,:] = crop_hr_image.copy() 96 | lr_batch[passed_idx,:,:,:] = crop_lr_image.copy() 97 | passed_idx += 1 98 | return hr_batch, lr_batch 99 | 100 | return hr_batch, lr_batch 101 | 102 | def get_record(image_path): 103 | images = glob.glob(image_path) 104 | print('%d files found' % (len(images))) 105 | 106 | if len(images) == 0: 107 | raise FileNotFoundError('check your training dataset path') 108 | 109 | index = list(range(len(images))) 110 | 111 | while True: 112 | random.shuffle(index) 113 | 114 | for i in index: 115 | im_fn = images[i] 116 | 117 | yield im_fn #high_image, low_image 118 | 119 | 120 | def generator(image_path, hr_size=512, batch_size=32): 121 | high_images = [] 122 | low_images = [] 123 | 124 | for im_fn in get_record(image_path): 125 | try: 126 | # TODO: data augmentation 127 | ''' 128 | used augmentation methods 129 | only linear augmenation methods will be used: 130 | random resize, ... 131 | not yet implemented 132 | 133 | ''' 134 | if FLAGS.load_mode == 'real': 135 | high_image, low_image = load_image(im_fn, hr_size) 136 | 137 | high_images.append(high_image) 138 | low_images.append(low_image) 139 | elif FLAGS.load_mode == 'text': 140 | high_images, low_images = load_txtimage(im_fn, hr_size, batch_size) 141 | 142 | if len(high_images) == batch_size: 143 | yield high_images, low_images 144 | 145 | high_images = [] 146 | low_images = [] 147 | 148 | except FileNotFoundError as e: 149 | print(e) 150 | break 151 | except Exception as e: 152 | import traceback 153 | traceback.print_exc() 154 | continue 155 | 156 | def get_generator(image_path, **kwargs): 157 | return generator(image_path, **kwargs) 158 | 159 | ## image_path = '/where/is/your/images/*.jpg' 160 | def get_batch(image_path, num_workers, **kwargs): 161 | try: 162 | generator = get_generator(image_path, **kwargs) 163 | enqueuer = data_util.GeneratorEnqueuer(generator, use_multiprocessing=True) 164 | enqueuer.start(max_queue_size=24, workers=num_workers) 165 | generator_ouptut = None 166 | while True: 167 | while enqueuer.is_running(): 168 | if not enqueuer.queue.empty(): 169 | generator_output = enqueuer.queue.get() 170 | break 171 | else: 172 | time.sleep(0.001) 173 | yield generator_output 174 | generator_output = None 175 | finally: 176 | if enqueuer is not None: 177 | enqueuer.stop() 178 | 179 | if __name__ == '__main__': 180 | image_path = '/data/OCR/DB/icdar_rctw_training/icdar2013_test/*.jpg' 181 | num_workers = 4 182 | batch_size = 32 183 | input_size = 32 184 | data_generator = get_batch(image_path=image_path, 185 | num_workers=num_workers, 186 | batch_size=batch_size, 187 | hr_size=int(input_size*FLAGS.SR_scale)) 188 | 189 | for _ in range(100): 190 | start_time = time.time() 191 | data = next(data_generator) 192 | high_images = np.asarray(data[0]) 193 | low_images = np.asarray(data[1]) 194 | print('%d done!!! %f' % (_, time.time() - start_time), high_images.shape, low_images.shape) 195 | for sub_idx, (high_image, low_image) in enumerate(zip(high_images, low_images)): 196 | hr_save_path = '/data/IE/dataset/test_hr/%03d_%02d_hr_image.jpg' % (_, sub_idx) 197 | lr_save_path = '/data/IE/dataset/test_lr/%03d_%02d_sr_image.jpg' % (_, sub_idx) 198 | cv2.imwrite(hr_save_path, high_image[:,:,::-1]) 199 | cv2.imwrite(lr_save_path, low_image[:,:,::-1]) 200 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/data_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this file is modified from keras implemention of data process multi-threading, 3 | see https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py 4 | ''' 5 | import time 6 | import numpy as np 7 | import threading 8 | import multiprocessing 9 | try: 10 | import queue 11 | except ImportError: 12 | import Queue as queue 13 | 14 | 15 | class GeneratorEnqueuer(): 16 | """Builds a queue out of a data generator. 17 | 18 | Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 19 | 20 | # Arguments 21 | generator: a generator function which endlessly yields data 22 | use_multiprocessing: use multiprocessing if True, otherwise threading 23 | wait_time: time to sleep in-between calls to `put()` 24 | random_seed: Initial seed for workers, 25 | will be incremented by one for each workers. 26 | """ 27 | 28 | def __init__(self, generator, 29 | use_multiprocessing=False, 30 | wait_time=0.05, 31 | random_seed=None): 32 | self.wait_time = wait_time 33 | self._generator = generator 34 | self._use_multiprocessing = use_multiprocessing 35 | self._threads = [] 36 | self._stop_event = None 37 | self.queue = None 38 | self.random_seed = random_seed 39 | 40 | def start(self, workers=1, max_queue_size=10): 41 | """Kicks off threads which add data from the generator into the queue. 42 | 43 | # Arguments 44 | workers: number of worker threads 45 | max_queue_size: queue size 46 | (when full, threads could block on `put()`) 47 | """ 48 | 49 | def data_generator_task(): 50 | while not self._stop_event.is_set(): 51 | try: 52 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size: 53 | generator_output = next(self._generator) 54 | self.queue.put(generator_output) 55 | else: 56 | time.sleep(self.wait_time) 57 | except Exception: 58 | self._stop_event.set() 59 | raise 60 | 61 | try: 62 | if self._use_multiprocessing: 63 | self.queue = multiprocessing.Queue(maxsize=max_queue_size) 64 | self._stop_event = multiprocessing.Event() 65 | else: 66 | self.queue = queue.Queue() 67 | self._stop_event = threading.Event() 68 | 69 | for _ in range(workers): 70 | if self._use_multiprocessing: 71 | # Reset random seed else all children processes 72 | # share the same seed 73 | np.random.seed(self.random_seed) 74 | thread = multiprocessing.Process(target=data_generator_task) 75 | thread.daemon = True 76 | if self.random_seed is not None: 77 | self.random_seed += 1 78 | else: 79 | thread = threading.Thread(target=data_generator_task) 80 | self._threads.append(thread) 81 | thread.start() 82 | except: 83 | self.stop() 84 | raise 85 | 86 | def is_running(self): 87 | return self._stop_event is not None and not self._stop_event.is_set() 88 | 89 | def stop(self, timeout=None): 90 | """Stops running threads and wait for them to exit, if necessary. 91 | 92 | Should be called by the same thread which called `start()`. 93 | 94 | # Arguments 95 | timeout: maximum time to wait on `thread.join()`. 96 | """ 97 | if self.is_running(): 98 | self._stop_event.set() 99 | 100 | for thread in self._threads: 101 | if thread.is_alive(): 102 | if self._use_multiprocessing: 103 | thread.terminate() 104 | else: 105 | thread.join(timeout) 106 | 107 | if self._use_multiprocessing: 108 | if self.queue is not None: 109 | self.queue.close() 110 | 111 | self._threads = [] 112 | self._stop_event = None 113 | self.queue = None 114 | 115 | def get(self): 116 | """Creates a generator to extract data from the queue. 117 | 118 | Skip the data if it is `None`. 119 | 120 | # Returns 121 | A generator 122 | """ 123 | while self.is_running(): 124 | if not self.queue.empty(): 125 | inputs = self.queue.get() 126 | if inputs is not None: 127 | yield inputs 128 | else: 129 | time.sleep(self.wait_time) 130 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.contrib import slim 4 | 5 | tf.app.flags.DEFINE_string('adv_ver', 'ver1', 'ver1 or ver2') 6 | FLAGS = tf.app.flags.FLAGS 7 | 8 | VGG_MEAN = [123.68, 116.779, 103.939] # RGB ordered 9 | EPS = 1e-12 10 | 11 | class loss_builder: 12 | def __init__(self): 13 | self.vgg_path = FLAGS.vgg_path 14 | self.vgg_used = False 15 | 16 | def _rgb_subtraction(self, images): 17 | channels = tf.split(axis=3, num_or_size_splits=3, value=images) 18 | for i in range(3): 19 | channels[i] -= VGG_MEAN[i] 20 | return tf.concat(axis=3, values=channels) 21 | 22 | def _build_vgg_19(self, images): 23 | input_images = self._rgb_subtraction(images) / 255.0 24 | 25 | ### vgg_19 26 | with slim.arg_scope([slim.conv2d], 27 | activation_fn=tf.nn.relu, 28 | normalizer_fn=None): 29 | 30 | self.conv1_1 = slim.conv2d(input_images, 64, 3, scope='conv1/conv1_1') 31 | self.conv1_2 = slim.conv2d(self.conv1_1, 64, 3, scope='conv1/conv1_2') 32 | self.pool1 = slim.max_pool2d(self.conv1_2, 2, scope='pool1') 33 | 34 | self.conv2_1 = slim.conv2d(self.pool1, 128, 3, scope='conv2/conv2_1') 35 | self.conv2_2 = slim.conv2d(self.conv2_1, 128, 3, scope='conv2/conv2_2') 36 | self.pool2 = slim.max_pool2d(self.conv2_2, 2, scope='pool2') 37 | 38 | self.conv3_1 = slim.conv2d(self.pool2, 256, 3, scope='conv3/conv3_1') 39 | self.conv3_2 = slim.conv2d(self.conv3_1, 256, 3, scope='conv3/conv3_2') 40 | self.conv3_3 = slim.conv2d(self.conv3_2, 256, 3, scope='conv3/conv3_3') 41 | self.conv3_4 = slim.conv2d(self.conv3_3, 256, 3, scope='conv3/conv3_4') 42 | self.pool3 = slim.max_pool2d(self.conv3_4, 2, scope='pool3') 43 | 44 | self.conv4_1 = slim.conv2d(self.pool3, 512, 3, scope='conv4/conv4_1') 45 | self.conv4_2 = slim.conv2d(self.conv4_1, 512, 3, scope='conv4/conv4_2') 46 | self.conv4_3 = slim.conv2d(self.conv4_2, 512, 3, scope='conv4/conv4_3') 47 | self.conv4_4 = slim.conv2d(self.conv4_3, 512, 3, scope='conv4/conv4_4') 48 | self.pool4 = slim.max_pool2d(self.conv4_4, 2, scope='pool4') 49 | 50 | self.conv5_1 = slim.conv2d(self.pool4, 512, 3, scope='conv5/conv5_1') 51 | self.conv5_2 = slim.conv2d(self.conv5_1, 512, 3, scope='conv5/conv5_2') 52 | self.conv5_3 = slim.conv2d(self.conv5_2, 512, 3, scope='conv5/conv5_3') 53 | self.conv5_4 = slim.conv2d(self.conv5_3, 512, 3, scope='conv5/conv5_4') 54 | self.pool5 = slim.max_pool2d(self.conv5_4, 2, scope='pool5') 55 | 56 | self.vgg_19 = {'conv1_1':self.conv1_1, 'conv1_2':self.conv1_2, 'pool1':self.pool1, 57 | 'conv2_1':self.conv2_1, 'conv2_2':self.conv2_2, 'pool2':self.pool2, 58 | 'conv3_1':self.conv3_1, 'conv3_2':self.conv3_2, 'conv3_3':self.conv3_3, 'conv3_4':self.conv3_4, 'pool3':self.pool3, 59 | 'conv4_1':self.conv4_1, 'conv4_2':self.conv4_2, 'conv4_3':self.conv4_3, 'conv4_4':self.conv4_4, 'pool4':self.pool4, 60 | 'conv5_1':self.conv5_1, 'conv5_2':self.conv5_2, 'conv5_3':self.conv5_3, 'conv5_4':self.conv5_4, 'pool5':self.pool5, 61 | } 62 | 63 | self.vgg_used = True 64 | 65 | def _lrelu(self, x, a=0.2): 66 | with tf.name_scope('lrelu'): 67 | x = tf.identity(x) 68 | return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) 69 | 70 | def _build_discriminator(self, images, is_training): 71 | batch_norm_params = { 72 | 'decay': 0.997, 73 | 'epsilon': 1e-5, 74 | 'scale': True, 75 | 'is_training': is_training 76 | } 77 | 78 | 79 | with slim.arg_scope([slim.conv2d], 80 | activation_fn=None, 81 | normalizer_fn=slim.batch_norm, 82 | normalizer_params=batch_norm_params): 83 | x = slim.conv2d(images, 32, 3, scope='conv1') 84 | x = self._lrelu(x) 85 | x = slim.conv2d(x, 32, 3, stride=2, scope='conv2') 86 | x = self._lrelu(x) 87 | 88 | x = slim.conv2d(x, 64, 3, scope='conv3') 89 | x = self._lrelu(x) 90 | x = slim.conv2d(x, 64, 3, stride=2, scope='conv4') 91 | x = self._lrelu(x) 92 | 93 | x = slim.conv2d(x, 128, 3, scope='conv5') 94 | x = self._lrelu(x) 95 | x = slim.conv2d(x, 128, 3, stride=2, scope='conv6') 96 | x = self._lrelu(x) 97 | 98 | x = slim.conv2d(x, 256, 3, scope='conv7') 99 | x = self._lrelu(x) 100 | x = slim.conv2d(x, 256, 3, stride=2, scope='conv8') 101 | x = self._lrelu(x) 102 | 103 | x = slim.conv2d(x, 512, 3, scope='conv9') 104 | x = self._lrelu(x) 105 | x = slim.conv2d(x, 512, 3, stride=2, scope='conv10') 106 | x = self._lrelu(x) 107 | 108 | x = slim.flatten(x, scope='flatten') 109 | x = slim.fully_connected(x, 1024, activation_fn=None, normalizer_fn=None, scope='fc1') 110 | x = self._lrelu(x) 111 | logits = slim.fully_connected(x, 1, activation_fn=None, normalizer_fn=None, scope='fc2') 112 | outputs = tf.nn.sigmoid(logits) 113 | 114 | return outputs #logits 115 | 116 | def _build_discriminator_ver2(self, images, is_training): 117 | batch_norm_params = { 118 | 'decay': 0.997, 119 | 'epsilon': 1e-5, 120 | 'scale': True, 121 | 'is_training': is_training 122 | } 123 | 124 | 125 | with slim.arg_scope([slim.conv2d], 126 | activation_fn=None, 127 | normalizer_fn=slim.batch_norm, 128 | normalizer_params=batch_norm_params): 129 | 130 | x = slim.conv2d(images, 64, 3, stride=1, scope='conv1') 131 | x = self._lrelu(x) 132 | x = slim.conv2d(x, 64, 3, stride=2, scope='conv2') 133 | x = self._lrelu(x) 134 | x = slim.conv2d(x, 128, 3, stride=1, scope='conv3') 135 | x = self._lrelu(x) 136 | x = slim.conv2d(x, 128, 3, stride=2, scope='conv4') 137 | x = self._lrelu(x) 138 | x = slim.conv2d(x, 256, 3, stride=1, scope='conv5') 139 | x = self._lrelu(x) 140 | x = slim.conv2d(x, 256, 3, stride=2, scope='conv6') 141 | x = self._lrelu(x) 142 | x = slim.conv2d(x, 512, 3, stride=1, scope='conv7') 143 | x = self._lrelu(x) 144 | x = slim.conv2d(x, 512, 3, stride=2, scope='conv8') 145 | x = self._lrelu(x) 146 | 147 | x = slim.flatten(x, scope='flatten') 148 | x = slim.fully_connected(x, 1024, activation_fn=None, normalizer_fn=None, scope='fc1') 149 | x = self._lrelu(x) 150 | logits = slim.fully_connected(x, 1, activation_fn=None, normalizer_fn=None, scope='fc2') 151 | outputs = tf.nn.sigmoid(logits) 152 | 153 | return outputs 154 | 155 | def _mse(self, gt, pred): 156 | return tf.losses.mean_squared_error(gt, pred) 157 | 158 | def _l1_loss(self, gt, pred): 159 | return tf.reduce_mean(tf.abs(gt - pred)) 160 | 161 | def _gram_matrix(self, features): 162 | dims = features.get_shape().as_list() 163 | features = tf.reshape(features, [-1, dims[1] * dims[2], dims[3]]) 164 | 165 | gram_matrix = tf.matmul(features, features, transpose_a=True) 166 | normalized_gram_matrix = gram_matrix / (dims[1] * dims[2] * dims[3]) 167 | 168 | return normalized_gram_matrix #tf.matmul(features, features, transpose_a=True) 169 | 170 | def _normalize(self, features): 171 | dims = features.get_shape().as_list() 172 | return features / (dims[1] * dims[2] * dims[3]) 173 | 174 | def _preprocess(self, images): 175 | return (images / 255.0) * 2.0 - 1.0 176 | 177 | def _texture_loss(self, features, patch_size=16): 178 | ''' 179 | the front part of features : gt features 180 | the latter part of features : pred features 181 | I will do calculating gt and pred features at once! 182 | ''' 183 | #features = self._normalize(features) 184 | batch_size, h, w, c = features.get_shape().as_list() 185 | features = tf.space_to_batch_nd(features, [patch_size, patch_size], [[0, 0], [0, 0]]) 186 | features = tf.reshape(features, [patch_size, patch_size, -1, h // patch_size, w // patch_size, c]) 187 | features = tf.transpose(features, [2, 3, 4, 0, 1, 5]) 188 | patches_gt, patches_pred = tf.split(features, 2, axis=0) 189 | 190 | patches_gt = tf.reshape(patches_gt, [-1, patch_size, patch_size, c]) 191 | patches_pred = tf.reshape(patches_pred, [-1, patch_size, patch_size, c]) 192 | 193 | gram_matrix_gt = self._gram_matrix(patches_gt) 194 | gram_matrix_pred = self._gram_matrix(patches_pred) 195 | 196 | tl_features = tf.reduce_mean(tf.reduce_sum(tf.square(gram_matrix_gt - gram_matrix_pred), axis=-1)) 197 | return tl_features 198 | 199 | def _perceptual_loss(self): 200 | gt_pool5, pred_pool5 = tf.split(self.vgg_19['conv5_4'], 2, axis=0) 201 | 202 | pl_pool5 = tf.reduce_mean(tf.reduce_sum(tf.square(gt_pool5 - pred_pool5), axis=-1)) 203 | 204 | return pl_pool5 205 | 206 | def _adv_loss(self, gt_logits, pred_logits):#gan_logits): 207 | # gt_logits -> real, pred_logits -> fake 208 | # all values went through tf.nn.sigmoid 209 | 210 | 211 | adv_gen_loss = tf.reduce_mean(-tf.log(pred_logits + EPS)) 212 | 213 | adv_disc_loss = tf.reduce_mean(-(tf.log(gt_logits + EPS) + tf.log(1.0 - pred_logits + EPS))) 214 | 215 | 216 | #adv_gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_logits, labels=tf.ones_like(pred_logits))) 217 | 218 | #adv_disc_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=gt_logits, labels=tf.ones_like(gt_logits))) 219 | #adv_disc_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_logits, labels=tf.zeros_like(pred_logits))) 220 | #adv_disc_loss = adv_disc_real + adv_disc_fake 221 | 222 | 223 | return adv_gen_loss, adv_disc_loss 224 | 225 | def build_vgg_19(self, gt, pred): 226 | input_images = tf.concat([gt, pred], axis=0) 227 | with tf.variable_scope('vgg_19'): 228 | self._build_vgg_19(input_images) 229 | 230 | def build_discriminator(self, gt, pred, is_training=True): 231 | ''' 232 | build_discriminator is only used for training! 233 | ''' 234 | 235 | gt = self._preprocess(gt) # -1.0 ~ 1.0 236 | pred = self._preprocess(pred) # -1.0 ~ 1.0 237 | 238 | 239 | with tf.variable_scope('discriminator'): 240 | if FLAGS.adv_ver == 'ver1': 241 | gt_logits = self._build_discriminator(gt, is_training) 242 | else: 243 | gt_logits = self._build_discriminator_ver2(gt, is_training) 244 | 245 | with tf.variable_scope('discriminator', reuse=True): 246 | b, h, w, c = gt.get_shape().as_list() 247 | pred.set_shape([b,h,w,c]) 248 | 249 | if FLAGS.adv_ver == 'ver1': 250 | pred_logits = self._build_discriminator(pred, is_training) 251 | else: 252 | pred_logits = self._build_discriminator_ver2(pred, is_training) 253 | 254 | return gt_logits, pred_logits 255 | 256 | #def build_cycle_discriminator(self, 257 | 258 | def get_loss(self, gt, pred, type='mse'): 259 | ''' 260 | 'mse', 'inverse_mse', 'fft_mse' 261 | 'perceptual', 'texture' 262 | 'adv, 'cycle_adv' 263 | ''' 264 | if type == 'mse': # See SRCNN. MSE is very simple loss function. 265 | gt = self._preprocess(gt) 266 | pred = self._preprocess(pred) 267 | return self._mse(gt, pred) 268 | elif type == 'inverse_mse': 269 | # gt is the input_lr image!!! 270 | gt = self._preprocess(gt) 271 | pred = self._preprocess(pred) 272 | pred = tf.image.resize_bilinear(pred, size=[tf.shape(gt)[1], tf.shape(gt)[2]]) 273 | return self._mse(gt, pred) 274 | elif type == 'fft_mse': 275 | # check whether both gt and pred need preprocessing 276 | gt = self._preprocess(gt) 277 | pred = self._preprocess(pred) 278 | 279 | ### fft then mse 280 | gt = tf.cast(gt, tf.complex64) 281 | pred = tf.cast(pred, tf.complex64) 282 | 283 | gt = tf.fft2d(gt) 284 | pred = tf.fft2d(pred) 285 | 286 | return self._mse(gt, pred) 287 | elif type == 'l1_loss': 288 | gt = self._preprocess(gt) 289 | pred = self._preprocess(pred) 290 | 291 | return self._l1_loss(gt, pred) 292 | elif type == 'perceptual': # See Enhancenet. 293 | if not self.vgg_used: 294 | self.build_vgg_19(gt, pred) 295 | 296 | pl_pool5 = self._perceptual_loss() 297 | 298 | return pl_pool5 299 | elif type == 'texture': # See Enhancenet, Style transfer papers. 300 | if not self.vgg_used: 301 | self.build_vgg_19(gt, pred) 302 | 303 | tl_conv1 = self._texture_loss(self.vgg_19['conv1_1']) 304 | tl_conv2 = self._texture_loss(self.vgg_19['conv2_1']) 305 | tl_conv3 = self._texture_loss(self.vgg_19['conv3_1']) 306 | 307 | return tl_conv1, tl_conv2, tl_conv3 308 | elif type == 'adv': 309 | gt_logits, pred_logits = self.build_discriminator(gt, pred) 310 | 311 | adv_gen_loss, adv_disc_loss = self._adv_loss(gt_logits, pred_logits) 312 | 313 | return adv_gen_loss, adv_disc_loss 314 | else: 315 | print('%s is not implemented.' % (type)) 316 | 317 | if __name__ == '__main__': 318 | 319 | loss_factory = loss_builder() 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /outputs/ENet-E.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/outputs/ENet-E.png -------------------------------------------------------------------------------- /outputs/ENet-PAT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/outputs/ENet-PAT.png -------------------------------------------------------------------------------- /outputs/Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/outputs/Input.png -------------------------------------------------------------------------------- /test_SR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import numpy as np 5 | import tensorflow as tf 6 | import sys 7 | import SR_models 8 | import utils 9 | import cv2 10 | 11 | tf.app.flags.DEFINE_string('model_path', '/where/your/model/folder', '') 12 | tf.app.flags.DEFINE_string('image_path', '/where/your/test_image/folder', '') 13 | tf.app.flags.DEFINE_string('save_path', '/where/your/generated_image/folder', '') 14 | tf.app.flags.DEFINE_string('run_gpu', '0', '') 15 | FLAGS = tf.app.flags.FLAGS 16 | 17 | def load_model(model_path): 18 | ''' 19 | model_path = '.../where/your/save/model/folder' 20 | ''' 21 | input_low_images = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='input_low_images') 22 | 23 | model_builder = SR_models.model_builder() 24 | 25 | generated_high_images, resized_low_images = model_builder.generator(input_low_images, is_training=False, model='enhancenet') 26 | 27 | generated_high_images = tf.cast(tf.clip_by_value(generated_high_images, 0, 255), tf.uint8) 28 | 29 | all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 30 | 31 | gen_vars = [var for var in all_vars if var.name.startswith('generator')] 32 | 33 | saver = tf.train.Saver(gen_vars) 34 | 35 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 36 | 37 | ckpt_path = utils.get_last_ckpt_path(model_path) 38 | saver.restore(sess, ckpt_path) 39 | 40 | return input_low_images, generated_high_images, sess 41 | 42 | def init_resize_image(im): 43 | h, w, _ = im.shape 44 | size = [h, w] 45 | max_arg = np.argmax(size) 46 | max_len = size[max_arg] 47 | min_arg = max_arg - 1 48 | min_len = size[min_arg] 49 | 50 | maximum_size = 1024 51 | if max_len < maximum_size: 52 | maximum_size = max_len 53 | ratio = 1.0 54 | return im, ratio 55 | else: 56 | ratio = maximum_size / max_len 57 | max_len = max_len * ratio 58 | min_len = min_len * ratio 59 | size[max_arg] = int(max_len) 60 | size[min_arg] = int(min_len) 61 | 62 | im = cv2.resize(im, (size[1], size[0])) 63 | 64 | return im, ratio 65 | 66 | if __name__ == '__main__': 67 | # set your gpus usage 68 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.run_gpu 69 | 70 | # get pre-trained generator model 71 | input_image, generated_image, sess = load_model(FLAGS.model_path) 72 | 73 | # get test_image_list 74 | test_image_list = utils.get_image_paths(FLAGS.image_path) 75 | 76 | # make save_folder 77 | if not os.path.exists(FLAGS.save_path): 78 | os.makedirs(FLAGS.save_path) 79 | 80 | # do test 81 | 82 | for test_idx, test_image in enumerate(test_image_list): 83 | loaded_image = cv2.imread(test_image) 84 | processed_image, tmp_ratio = init_resize_image(loaded_image) 85 | 86 | feed_dict = {input_image : [processed_image[:,:,::-1]]} 87 | 88 | output_image = sess.run(generated_image, feed_dict=feed_dict) 89 | 90 | output_image = output_image[0,:,:,:] 91 | 92 | image_name = os.path.basename(test_image) 93 | 94 | tmp_save_path = os.path.join(FLAGS.save_path, 'SR_' + image_name) 95 | 96 | cv2.imwrite(tmp_save_path, output_image[:,:,::-1]) 97 | 98 | print('%d / %d completed!!!' % (test_idx + 1, len(test_image_list))) 99 | 100 | -------------------------------------------------------------------------------- /train_SR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.contrib import slim 6 | import cv2 7 | from dataset import SR_data_load 8 | import SR_models 9 | import losses 10 | import utils 11 | 12 | tf.app.flags.DEFINE_string('run_gpu', '0', 'use single gpu') 13 | tf.app.flags.DEFINE_string('save_path', '/where/your/folder', '') 14 | tf.app.flags.DEFINE_boolean('model_restore', False, '') 15 | tf.app.flags.DEFINE_string('image_path', '/where/your/saved/image/folder', '') 16 | tf.app.flags.DEFINE_integer('batch_size', 32, '') 17 | tf.app.flags.DEFINE_integer('num_readers', 4, '') 18 | tf.app.flags.DEFINE_integer('input_size', 32, '') 19 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'define your learing strategy') 20 | tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '') 21 | tf.app.flags.DEFINE_string('vgg_path', None, '/where/your/vgg_19.ckpt') 22 | tf.app.flags.DEFINE_integer('num_workers', 4, '') 23 | tf.app.flags.DEFINE_integer('max_to_keep', 10, 'how many do you want to save models?') 24 | tf.app.flags.DEFINE_integer('save_model_steps', 10000, '') 25 | tf.app.flags.DEFINE_integer('save_summary_steps', 10, '') 26 | tf.app.flags.DEFINE_integer('max_steps', 1000000, '') 27 | tf.app.flags.DEFINE_string('losses', 'perceptual', 'mse,perceptual,texture,adv') 28 | tf.app.flags.DEFINE_string('adv_direction', 'g2d', 'g2d or d2g') 29 | tf.app.flags.DEFINE_float('adv_gen_w', 0.001, '') 30 | tf.app.flags.DEFINE_float('adv_disc_w', 1.0, '') 31 | FLAGS = tf.app.flags.FLAGS 32 | 33 | def main(argv=None): 34 | ######################### System setup 35 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.run_gpu 36 | utils.prepare_checkpoint_path(FLAGS.save_path, FLAGS.model_restore) 37 | 38 | ######################### Model setup 39 | low_size = FLAGS.input_size 40 | high_size = int(FLAGS.input_size * FLAGS.SR_scale) 41 | 42 | input_low_images = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, low_size, low_size, 3], name='input_low_images') 43 | input_high_images = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, high_size, high_size, 3], name='input_high_images') 44 | 45 | 46 | model_builder = SR_models.model_builder() 47 | 48 | generated_high_images, resized_low_images = model_builder.generator(input_low_images, is_training=True, model=FLAGS.model) 49 | 50 | tf.summary.image('1_input_low_images', input_low_images) 51 | tf.summary.image('2_input_high_images', input_high_images) 52 | vis_gen_images = tf.cast(tf.clip_by_value(generated_high_images, 0, 255), tf.uint8) 53 | tf.summary.image('3_generated_images', vis_gen_images) 54 | tf.summary.image('4_bicubic_images', resized_low_images) 55 | vis_high_images = tf.cast(input_high_images, tf.uint8) 56 | vis_gen_images = tf.concat([resized_low_images, vis_gen_images, vis_high_images], axis=2) 57 | tf.summary.image('5_bicubic_gen_gt', vis_gen_images) 58 | 59 | 60 | ######################### Losses setup 61 | loss_builder = losses.loss_builder() 62 | 63 | loss_list = utils.loss_parser(FLAGS.losses) 64 | generator_loss = 0.0 65 | 66 | if 'mse' in loss_list or 'l2' in loss_list or 'l2_loss' in loss_list: 67 | mse_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='mse') 68 | generator_loss = generator_loss + 1.0 * mse_loss 69 | tf.summary.scalar('mse_loss', mse_loss) 70 | if 'inverse_mse' in loss_list: 71 | inv_mse_loss = loss_builder.get_loss(input_low_images, generated_high_images, type='inverse_mse') 72 | generator_loss = generator_loss + 100.0 * inv_mse_loss 73 | tf.summary.scalar('inv_mse_loss', inv_mse_loss) 74 | if 'fft_mse' in loss_list: 75 | fft_mse_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='fft_mse') 76 | generator_loss = generator_loss + 1.0 * fft_mse_loss 77 | tf.summary.scalar('fft_mse_loss', fft_mse_loss) 78 | if 'l1' in loss_list or 'l1_loss' in loss_list: 79 | l1_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='l1_loss') 80 | generator_loss = generator_loss + 0.01 * l1_loss 81 | tf.summary.scalar('l1_loss', l1_loss) 82 | if 'perceptual' in loss_list: 83 | pl_pool5 = loss_builder.get_loss(input_high_images, generated_high_images, type='perceptual') 84 | pl_pool5 *= 2e-2 85 | generator_loss = generator_loss + pl_pool5 86 | tf.summary.scalar('pl_pool5', pl_pool5) 87 | if 'texture' in loss_list: 88 | tl_conv1, tl_conv2, tl_conv3 = loss_builder.get_loss(input_high_images, generated_high_images, type='texture') 89 | 90 | #generator_loss = generator_loss + 1e-2 * tl_conv1 + 1e-2 * tl_conv2 + 1e-2 * tl_conv3 91 | tl_weight = 10.0 92 | #generator_loss = generator_loss + tl_weight * tl_conv1 + tl_weight * tl_conv2 + tl_weight * tl_conv3 93 | generator_loss = generator_loss + tl_weight * tl_conv3 94 | 95 | tf.summary.scalar('tl_conv1', tl_conv1) 96 | tf.summary.scalar('tl_conv2', tl_conv2) 97 | tf.summary.scalar('tl_conv3', tl_conv3) 98 | if 'adv' in loss_list: 99 | adv_gen_loss, adv_disc_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='adv') 100 | tf.summary.scalar('adv_gen', adv_gen_loss) 101 | tf.summary.scalar('adv_disc', adv_disc_loss) 102 | discrim_loss = FLAGS.adv_disc_w * adv_disc_loss 103 | generator_loss = generator_loss + FLAGS.adv_gen_w * adv_gen_loss 104 | 105 | tf.summary.scalar('generator_loss', generator_loss) 106 | 107 | ######################### Training setup 108 | global_step = tf.get_variable('global_step', [], dtype=tf.int64, initializer=tf.constant_initializer(0), trainable=False) 109 | 110 | train_vars = tf.trainable_variables() 111 | 112 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 113 | 114 | learning_rate = utils.configure_learning_rate(FLAGS.learning_rate, global_step) 115 | 116 | #gen_optimizer = utils.configure_optimizer(learning_rate) 117 | #gen_gradients = gen_optimizer.compute_gradients(generator_loss, var_list=generator_vars) 118 | #gen_grad_updates = gen_optimizer.apply_gradients(gen_gradients)#, global_step=global_step) 119 | 120 | if 'adv' in loss_list: 121 | discrim_vars = [var for var in train_vars if var.name.startswith('discriminator')] 122 | disc_optimizer = utils.configure_optimizer(learning_rate) 123 | disc_gradients = disc_optimizer.compute_gradients(discrim_loss, var_list=discrim_vars) 124 | disc_grad_updates = disc_optimizer.apply_gradients(disc_gradients, global_step=global_step) 125 | 126 | with tf.control_dependencies([disc_grad_updates] + update_ops): 127 | generator_vars = [var for var in train_vars if var.name.startswith('generator')] 128 | gen_optimizer = utils.configure_optimizer(learning_rate) 129 | gen_gradients = gen_optimizer.compute_gradients(generator_loss, var_list=generator_vars) 130 | gen_grad_updates = gen_optimizer.apply_gradients(gen_gradients, global_step=global_step) 131 | 132 | train_op = gen_grad_updates 133 | 134 | else: 135 | 136 | generator_vars = [var for var in train_vars if var.name.startswith('generator')] 137 | discrim_vars = generator_vars 138 | gen_optimizer = utils.configure_optimizer(learning_rate) 139 | gen_gradients = gen_optimizer.compute_gradients(generator_loss, var_list=generator_vars) 140 | gen_grad_updates = gen_optimizer.apply_gradients(gen_gradients)#, global_step=global_step) 141 | with tf.control_dependencies([gen_grad_updates] + update_ops): 142 | train_op = tf.no_op(name='train_op') 143 | 144 | saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) 145 | summary_op = tf.summary.merge_all() 146 | summary_writer = tf.summary.FileWriter(FLAGS.save_path, tf.get_default_graph()) 147 | 148 | ######################### Train process 149 | data_generator = SR_data_load.get_batch(image_path=FLAGS.image_path, 150 | num_workers=FLAGS.num_workers, 151 | batch_size=FLAGS.batch_size, 152 | hr_size=high_size) 153 | 154 | ## vgg_stop process 155 | #utils.print_vars(train_vars) 156 | #utils.print_vars(generator_vars) 157 | if FLAGS.vgg_path is not None: 158 | variable_restore_op = utils.get_restore_op(FLAGS.vgg_path, train_vars) 159 | 160 | ############# 161 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 162 | if FLAGS.model_restore: 163 | ckpt = tf.train.latest_checkpoint(FLAGS.save_path) 164 | saver.restore(sess, ckpt) 165 | else: 166 | sess.run(tf.global_variables_initializer()) 167 | if FLAGS.vgg_path is not None: 168 | variable_restore_op(sess) 169 | 170 | start_time = time.time() 171 | for iter_val in range(int(global_step.eval()) + 1, FLAGS.max_steps + 1): 172 | data = next(data_generator) 173 | high_images = np.asarray(data[0]) 174 | low_images = np.asarray(data[1]) 175 | 176 | feed_dict = {input_low_images: low_images, 177 | input_high_images: high_images} 178 | 179 | generator_loss_val, _, g_w, d_w = sess.run([generator_loss, train_op, generator_vars[0], discrim_vars[0]], feed_dict=feed_dict) 180 | 181 | if iter_val != 0 and iter_val % FLAGS.save_summary_steps == 0: 182 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 183 | summary_writer.add_summary(summary_str, global_step=iter_val) 184 | 185 | used_time = time.time() - start_time 186 | avg_time_per_step = used_time / FLAGS.save_summary_steps 187 | avg_examples_per_second = (FLAGS.save_summary_steps * FLAGS.batch_size) / used_time 188 | 189 | print('step %d, generator_loss %.4f, weights %.2f, %.2f, %.2f seconds/step, %.2f examples/second' 190 | % (iter_val, generator_loss_val, np.sum(g_w), np.sum(d_w), avg_time_per_step, avg_examples_per_second)) 191 | start_time = time.time() 192 | 193 | if iter_val != 0 and iter_val % FLAGS.save_model_steps == 0: 194 | checkpoint_fn = os.path.join(FLAGS.save_path, 'model.ckpt') 195 | saver.save(sess, checkpoint_fn, global_step=iter_val) 196 | 197 | print('') 198 | print('*' * 30) 199 | print(' Training done!!! ') 200 | print('*' * 30) 201 | print('') 202 | 203 | if __name__ == '__main__': 204 | tf.app.run() 205 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib import slim 4 | import os 5 | import glob 6 | 7 | FLAGS = tf.app.flags.FLAGS 8 | 9 | def prepare_checkpoint_path(save_path, restore): 10 | if not tf.gfile.Exists(save_path): 11 | tf.gfile.MkDir(save_path) 12 | else: 13 | if not restore: 14 | tf.gfile.DeleteRecursively(save_path) 15 | tf.gfile.MkDir(save_path) 16 | 17 | def configure_learning_rate(learning_rate_init_value, global_step): 18 | learning_rate = tf.train.exponential_decay(learning_rate_init_value, global_step, decay_steps=10000, decay_rate=0.94, staircase=True) 19 | tf.summary.scalar('learning_rate', learning_rate) 20 | return learning_rate 21 | 22 | def configure_optimizer(learning_rate): 23 | return tf.train.AdamOptimizer(learning_rate) 24 | 25 | def get_restore_op(vgg_path, train_vars, check=False): 26 | vgg_19_vars = [var for var in train_vars if var.name.startswith('vgg')] 27 | if check: 28 | print_vars(vgg_19_vars) 29 | variable_restore_op = slim.assign_from_checkpoint_fn( 30 | vgg_path, 31 | vgg_19_vars, 32 | ignore_missing_vars=True) 33 | return variable_restore_op 34 | 35 | def print_vars(var_list): 36 | print('') 37 | for var in var_list: 38 | print(var) 39 | print('') 40 | 41 | def loss_parser(str_loss): 42 | ''' 43 | NOTE!!! str_loss should be like 'mse,perceptual,texture,adv'... 44 | ''' 45 | selected_loss_array = str_loss.split(',') 46 | return selected_loss_array 47 | 48 | def get_last_ckpt_path(folder_path): 49 | ''' 50 | folder_path = .../where/your/saved/model/folder 51 | ''' 52 | 53 | meta_paths = sorted(glob.glob(os.path.join(folder_path, '*.meta'))) 54 | 55 | numbers = [] 56 | 57 | for meta_path in meta_paths: 58 | numbers.append(int(meta_path.split('-')[-1].split('.')[0])) 59 | 60 | numbers = np.asarray(numbers) 61 | 62 | sorted_idx = np.argsort(numbers) 63 | 64 | latest_meta_path = meta_paths[sorted_idx[-1]] 65 | 66 | ckpt_path = latest_meta_path.replace('.meta', '') 67 | 68 | return ckpt_path 69 | 70 | def get_image_paths(image_folder): 71 | possible_image_type = ['jpg', 'png', 'JPEG', 'jpeg'] 72 | 73 | image_list = [image_path for image_paths in [glob.glob(os.path.join(image_folder, '*.%s' % ext)) for ext in possible_image_type] for image_path in image_paths] 74 | 75 | return image_list 76 | 77 | --------------------------------------------------------------------------------