├── LICENSE
├── README.md
├── SR_models.py
├── dataset
├── SR_data_load.py
├── __init__.py
└── data_util.py
├── losses.py
├── outputs
├── ENet-E.png
├── ENet-PAT.png
└── Input.png
├── test_SR.py
├── train_SR.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Geonmo Gu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EnhanceNet
2 |
3 | Tensorflow implementation of EnhanceNet for a magnification ratio of 4.
4 |
5 | **We slightly changed the procedure for training Enet as followings**
6 | + Discriminator has been changed like DCGAN.
7 | + We only used ```pool5_4 and conv3_1 features``` from VGG-19. See losses.py
8 | + So, we changed hyper-parameters for loss combination.
9 |
10 | ### Results
11 |
12 |
13 | Input |
14 | Enet-E |
15 | Enet-PAT |
16 |
17 |
18 |  |
19 |  |
20 |  |
21 |
22 |
23 |
24 | ### How to train?
25 |
26 | 1. Download COCO_train_DB for training ENet and unzip train2017.zip
27 | ```
28 | wget http://images.cocodataset.org/zips/train2017.zip
29 | unzip train2017.zip
30 | ```
31 |
32 | 2. Download VGG-19 slim model and untar
33 | ```
34 | wget http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz
35 | tar xvzf vgg_19_2016_08_28.tar.gz
36 | ```
37 |
38 | 3. Do train!
39 | ```
40 | # ENet-E
41 | python3 train_SR.py --model=enhancenet --upsample=nearest \
42 | --recon_type=residual --SR_scale=4 --run_gpu=0 \
43 | --batch_size=32 --num_readers=4 --input_size=32 \
44 | --losses='mse' \
45 | --learning_rate=0.0001 \
46 | --save_path=/your/models/will/be/saved \
47 | --image_path=/where/is/your/COCODB/train2017/*.jpg
48 |
49 | # ENet-PAT
50 | python3 train_SR.py --model=enhancenet --upsample=nearest \
51 | --recon_type=residual --SR_scale=4 --run_gpu=0 \
52 | --batch_size=32 --num_readers=4 --input_size=32 \
53 | --losses='perceptual,texture,adv' --adv_ver=ver2 \
54 | --adv_gen_w=0.003 --learning_rate=0.0001 \
55 | --save_path=/your/models/will/be/saved \
56 | --image_path=/where/is/your/COCODB/train2017/*.jpg \
57 | --vgg_path=/where/is/your/vgg19/vgg_19.ckpt
58 | ```
59 |
60 | ### How to test?
61 |
62 | ```
63 | python3 test_SR.py --model_path=/your/pretrained/model/folder \
64 | --image_path=/your/image/folder \
65 | --save_path=/generated_image/will/be/saved/here \
66 | --run_gpu=0
67 | ```
68 |
69 | ### Reference
70 | ```
71 | @inproceedings{enhancenet,
72 | title={{EnhanceNet: Single Image Super-Resolution through Automated Texture Synthesis}},
73 | author={Sajjadi, Mehdi S. M. and Sch{\"o}lkopf, Bernhard and Hirsch, Michael},
74 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
75 | pages={4501--4510},
76 | year={2017},
77 | organization={IEEE},
78 | url={https://arxiv.org/abs/1612.07919/}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/SR_models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from tensorflow.contrib import slim
4 |
5 | tf.app.flags.DEFINE_string('upsample', 'nearest', 'nearest, bilinear, or pixelShuffler')
6 | tf.app.flags.DEFINE_string('model', 'enhancenet', 'for now, only enhancenet supported')
7 | tf.app.flags.DEFINE_string('recon_type', 'residual', 'residual or direct')
8 | tf.app.flags.DEFINE_boolean('use_bn', False, 'for res_block_bn')
9 |
10 | FLAGS = tf.app.flags.FLAGS
11 |
12 | class model_builder:
13 | def __init__(self):
14 | return
15 |
16 | def preprocess(self, images):
17 | pp_images = images / 255.0
18 | ## simple mean shift
19 | pp_images = pp_images * 2.0 - 1.0
20 |
21 | return pp_images
22 |
23 | def postprocess(self, images):
24 | pp_images = ((images + 1.0) / 2.0) * 255.0
25 |
26 | return pp_images
27 |
28 | def tf_nn_lrelu(self, inputs, a=0.2):
29 | with tf.name_scope('lrelu'):
30 | x = tf.identity(inputs)
31 | return (0.5 * (1.0 + a)) * x + (0.5 * (1.0 - a)) * tf.abs(x)
32 |
33 | def tf_nn_prelu(self, inputs, scope):
34 | # scope like 'prelu_1', 'prelu_2', ...
35 | with tf.variable_scope(scope):
36 | alphas = tf.get_variable('alpha', inputs.get_shape()[-1], initializer=tf.zeros_initializer(), dtype=tf.float32)
37 | pos = tf.nn.relu(inputs)
38 | neg = alphas * (inputs - tf.abs(inputs)) * 0.5
39 |
40 | return pos + neg
41 |
42 | def res_block(self, features, out_ch, scope):
43 | input_features = features
44 | with tf.variable_scope(scope):
45 | features = slim.conv2d(input_features, out_ch, 3, activation_fn=tf.nn.relu, normalizer_fn=None)
46 | features = slim.conv2d(features, out_ch, 3, activation_fn=None, normalizer_fn=None)
47 |
48 | return input_features + features
49 |
50 | def res_block_bn(self, features, out_ch, is_training, scope): # bn-relu-conv!!!
51 | batch_norm_params = {
52 | 'decay': 0.997,
53 | 'epsilon': 1e-5,
54 | 'scale': True,
55 | 'is_training': is_training
56 | }
57 |
58 | # input_features already gone through bn-relu
59 | input_features = features
60 | with tf.variable_scope(scope):
61 | features = slim.conv2d(input_features, out_ch, 3, activation_fn=tf.nn.relu, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)
62 | features = slim.conv2d(features, out_ch, 3, activation_fn=None, normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params)
63 |
64 | return input_features + features
65 |
66 | def phaseShift(self, features, scale, shape_1, shape_2):
67 | X = tf.reshape(features, shape_1)
68 | X = tf.transpose(X, [0, 1, 3, 2, 4])
69 |
70 | return tf.reshape(X, shape_2)
71 |
72 | def pixelShuffler(self, features, scale=2):
73 | size = tf.shape(features)
74 | batch_size = size[0]
75 | h = size[1]
76 | w = size[2]
77 | c = features.get_shape().as_list()[-1]#size[3]
78 |
79 | channel_target = c // (scale * scale)
80 | channel_factor = c // channel_target
81 |
82 | shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale]
83 | shape_2 = [batch_size, h * scale, w * scale, 1]
84 |
85 | input_split = tf.split(axis=3, num_or_size_splits=channel_target, value=features) #features, channel_target, axis=3)
86 | output = tf.concat([self.phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis=3)
87 |
88 | return output
89 |
90 | def upsample(self, features, rate=2):
91 | if FLAGS.upsample == 'nearest':
92 | return tf.image.resize_nearest_neighbor(features, size=[rate * tf.shape(features)[1], rate * tf.shape(features)[2]])
93 | elif FLAGS.upsample == 'bilinear':
94 | return tf.image.resize_bilinear(features, size=[rate * tf.shape(features)[1], rate * tf.shape(features)[2]])
95 | else: #pixelShuffler
96 | return self.pixelShuffler(features, scale=2)
97 |
98 | def recon_image(self, inputs, outputs):
99 | '''
100 | LR to HR -> inputs: LR, outputs: HR
101 | HR to LR -> inputs: HR, outputs: LR
102 | '''
103 | resized_inputs = tf.image.resize_bicubic(inputs, size=[tf.shape(outputs)[1], tf.shape(outputs)[2]])
104 | if FLAGS.recon_type == 'residual':
105 | recon_outputs = resized_inputs + outputs
106 | else:
107 | recon_outputs = outputs
108 |
109 | resized_inputs = self.postprocess(resized_inputs)
110 | resized_inputs = tf.cast(tf.clip_by_value(resized_inputs, 0, 255), tf.uint8)
111 | #tf.summary.image('4_bicubic image', resized_inputs)
112 |
113 | recon_outputs = self.postprocess(recon_outputs)
114 |
115 | return recon_outputs, resized_inputs
116 |
117 | ### model part
118 | '''
119 | list:
120 | enhancenet
121 | '''
122 | def enhancenet(self, inputs, is_training):
123 | with slim.arg_scope([slim.conv2d],
124 | activation_fn=tf.nn.relu,
125 | normalizer_fn=None):
126 |
127 | features = slim.conv2d(inputs, 64, 3, scope='conv1')
128 |
129 | for idx in range(10):
130 | if FLAGS.use_bn:
131 | features = self.res_block_bn(features, out_ch=64, is_training=is_training, scope='res_block_bn_%d' % (idx))
132 | else:
133 | features = self.res_block(features, out_ch=64, scope='res_block_%d' % (idx))
134 |
135 | features = self.upsample(features)
136 | features = slim.conv2d(features, 64, 3, scope='conv2')
137 |
138 | features = self.upsample(features)
139 | features = slim.conv2d(features, 64, 3, scope='conv3')
140 | features = slim.conv2d(features, 64, 3, scope='conv4')
141 | outputs = slim.conv2d(features, 3, 3, activation_fn=None, scope='conv5')
142 |
143 | return outputs
144 |
145 | ########## Let's enhance our method!
146 |
147 | def generator(self, inputs, is_training, model='enhancenet'):
148 | '''
149 | LR to HR
150 | '''
151 |
152 | inputs = self.preprocess(inputs)
153 |
154 | with tf.variable_scope('generator'):
155 | if model == 'enhancenet':
156 | outputs = self.enhancenet(inputs, is_training)
157 |
158 | outputs, resized_inputs = self.recon_image(inputs, outputs)
159 |
160 | return outputs, resized_inputs
161 |
162 | ### test part
163 |
164 | if __name__ == '__main__':
165 |
166 | batch_size = 64
167 | h = 512
168 | w = 512
169 | c = 3 # rgb
170 |
171 | high_images = np.zeros([batch_size, h, w, c]) # gt
172 | low_images = np.zeros([batch_size, int(h/4), int(w/4), c])
173 |
174 | input_high_images = tf.placeholder(tf.float32, shape=[batch_size, h, w, c], name='input_high_images')
175 | input_low_images = tf.placeholder(tf.float32, shape=[batch_size, int(h/4), int(w/4), c], name='input_low_images')
176 |
177 | model_builder = model_builder()
178 |
179 | outputs = model_builder.generator(input_low_images)
180 |
181 | print(outputs)
182 |
183 |
184 |
185 |
--------------------------------------------------------------------------------
/dataset/SR_data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import glob
4 | import cv2
5 | import random
6 | import numpy as np
7 | import tensorflow as tf
8 | try:
9 | import data_util
10 | except ImportError:
11 | from dataset import data_util
12 |
13 | tf.app.flags.DEFINE_float('SR_scale', 4.0, '')
14 | tf.app.flags.DEFINE_boolean('random_resize', False, 'True or False')
15 | tf.app.flags.DEFINE_string('load_mode', 'real', 'real or text') # real -> COCO DB
16 | FLAGS = tf.app.flags.FLAGS
17 |
18 | '''
19 | image_path = '/where/your/images/*.jpg'
20 | '''
21 |
22 | def load_image(im_fn, hr_size):
23 | high_image = cv2.imread(im_fn, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)[:,:,::-1] # rgb converted
24 |
25 | resize_scale = 1 / FLAGS.SR_scale
26 |
27 | '''
28 | if FLAGS.random_resize:
29 | resize_table = [0.5, 1.0, 1.5, 2.0]
30 | selected_scale = np.random.choice(resize_table, 1)[0]
31 | shrinked_hr_size = int(hr_size / selected_scale)
32 |
33 | h, w, _ = high_image.shape
34 | if h <= shrinked_hr_size or w <= shrinked_hr_size:
35 | high_image = cv2.resize(high_image, (hr_size, hr_size))
36 | else:
37 | h_edge = h - shrinked_hr_size
38 | w_edge = w - shrinked_hr_size
39 | h_start = np.random.randint(low=0, high=h_edge, size=1)[0]
40 | w_start = np.random.randint(low=0, high=w_edge, size=1)[0]
41 | high_image_crop = high_image[h_start:h_start+hr_size, w_start:w_start+hr_size, :]
42 | high_image = cv2.resize(high_image_crop, (hr_size, hr_size))
43 | '''
44 | h, w, _ = high_image.shape
45 | if h <= hr_size or w <= hr_size:
46 | high_image = cv2.resize(high_image, (hr_size, hr_size),
47 | interpolation=cv2.INTER_AREA)
48 | else:
49 | h_edge = h - hr_size
50 | w_edge = w - hr_size
51 | h_start = np.random.randint(low=0, high=h_edge, size=1)[0]
52 | w_start = np.random.randint(low=0, high=w_edge, size=1)[0]
53 | high_image = high_image[h_start:h_start+hr_size, w_start:w_start+hr_size, :]
54 |
55 | low_image = cv2.resize(high_image, (0, 0), fx=resize_scale, fy=resize_scale)
56 | #interpolation=cv2.INTER_AREA)
57 | return high_image, low_image
58 |
59 | def load_txtimage(im_fn, hr_size, batch_size):
60 | high_image = cv2.imread(im_fn, cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)[:,:,::-1]
61 | resize_scale = 1 / FLAGS.SR_scale
62 |
63 | lr_size = int(hr_size * resize_scale)
64 |
65 | h, w, _ = high_image.shape
66 |
67 | hr_batch = np.zeros([batch_size, hr_size, hr_size, 3], dtype='float32')
68 | lr_batch = np.zeros([batch_size, lr_size, lr_size, 3], dtype='float32')
69 |
70 | h_edge = h - hr_size
71 | w_edge = w - hr_size
72 |
73 | passed_idx = 0
74 | max_iter = 200
75 | iter_idx = 0
76 | while passed_idx < batch_size:
77 | h_start = np.random.randint(low=0, high=h_edge, size=1)[0]
78 | w_start = np.random.randint(low=0, high=w_edge, size=1)[0]
79 |
80 | crop_hr_image = high_image[h_start:h_start + hr_size, w_start:w_start+hr_size,:]
81 |
82 | if np.mean(crop_hr_image) < 250.0:
83 | hr_batch[passed_idx,:,:,:] = crop_hr_image.copy()
84 | crop_lr_image = cv2.resize(crop_hr_image, (0, 0), fx=0.25, fy=0.25,
85 | interpolation=cv2.INTER_AREA)
86 | lr_batch[passed_idx,:,:,:] = crop_lr_image.copy()
87 | passed_idx += 1
88 | else:
89 | iter_idx += 1
90 |
91 | if iter_idx == max_iter:
92 | crop_lr_image = cv2.resize(crop_hr_image, (0, 0), fx=0.25, fy=0.25,
93 | interpolation=cv2.INTER_AREA)
94 | while passed_idx < batch_size:
95 | hr_batch[passed_idx,:,:,:] = crop_hr_image.copy()
96 | lr_batch[passed_idx,:,:,:] = crop_lr_image.copy()
97 | passed_idx += 1
98 | return hr_batch, lr_batch
99 |
100 | return hr_batch, lr_batch
101 |
102 | def get_record(image_path):
103 | images = glob.glob(image_path)
104 | print('%d files found' % (len(images)))
105 |
106 | if len(images) == 0:
107 | raise FileNotFoundError('check your training dataset path')
108 |
109 | index = list(range(len(images)))
110 |
111 | while True:
112 | random.shuffle(index)
113 |
114 | for i in index:
115 | im_fn = images[i]
116 |
117 | yield im_fn #high_image, low_image
118 |
119 |
120 | def generator(image_path, hr_size=512, batch_size=32):
121 | high_images = []
122 | low_images = []
123 |
124 | for im_fn in get_record(image_path):
125 | try:
126 | # TODO: data augmentation
127 | '''
128 | used augmentation methods
129 | only linear augmenation methods will be used:
130 | random resize, ...
131 | not yet implemented
132 |
133 | '''
134 | if FLAGS.load_mode == 'real':
135 | high_image, low_image = load_image(im_fn, hr_size)
136 |
137 | high_images.append(high_image)
138 | low_images.append(low_image)
139 | elif FLAGS.load_mode == 'text':
140 | high_images, low_images = load_txtimage(im_fn, hr_size, batch_size)
141 |
142 | if len(high_images) == batch_size:
143 | yield high_images, low_images
144 |
145 | high_images = []
146 | low_images = []
147 |
148 | except FileNotFoundError as e:
149 | print(e)
150 | break
151 | except Exception as e:
152 | import traceback
153 | traceback.print_exc()
154 | continue
155 |
156 | def get_generator(image_path, **kwargs):
157 | return generator(image_path, **kwargs)
158 |
159 | ## image_path = '/where/is/your/images/*.jpg'
160 | def get_batch(image_path, num_workers, **kwargs):
161 | try:
162 | generator = get_generator(image_path, **kwargs)
163 | enqueuer = data_util.GeneratorEnqueuer(generator, use_multiprocessing=True)
164 | enqueuer.start(max_queue_size=24, workers=num_workers)
165 | generator_ouptut = None
166 | while True:
167 | while enqueuer.is_running():
168 | if not enqueuer.queue.empty():
169 | generator_output = enqueuer.queue.get()
170 | break
171 | else:
172 | time.sleep(0.001)
173 | yield generator_output
174 | generator_output = None
175 | finally:
176 | if enqueuer is not None:
177 | enqueuer.stop()
178 |
179 | if __name__ == '__main__':
180 | image_path = '/data/OCR/DB/icdar_rctw_training/icdar2013_test/*.jpg'
181 | num_workers = 4
182 | batch_size = 32
183 | input_size = 32
184 | data_generator = get_batch(image_path=image_path,
185 | num_workers=num_workers,
186 | batch_size=batch_size,
187 | hr_size=int(input_size*FLAGS.SR_scale))
188 |
189 | for _ in range(100):
190 | start_time = time.time()
191 | data = next(data_generator)
192 | high_images = np.asarray(data[0])
193 | low_images = np.asarray(data[1])
194 | print('%d done!!! %f' % (_, time.time() - start_time), high_images.shape, low_images.shape)
195 | for sub_idx, (high_image, low_image) in enumerate(zip(high_images, low_images)):
196 | hr_save_path = '/data/IE/dataset/test_hr/%03d_%02d_hr_image.jpg' % (_, sub_idx)
197 | lr_save_path = '/data/IE/dataset/test_lr/%03d_%02d_sr_image.jpg' % (_, sub_idx)
198 | cv2.imwrite(hr_save_path, high_image[:,:,::-1])
199 | cv2.imwrite(lr_save_path, low_image[:,:,::-1])
200 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/data_util.py:
--------------------------------------------------------------------------------
1 | '''
2 | this file is modified from keras implemention of data process multi-threading,
3 | see https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
4 | '''
5 | import time
6 | import numpy as np
7 | import threading
8 | import multiprocessing
9 | try:
10 | import queue
11 | except ImportError:
12 | import Queue as queue
13 |
14 |
15 | class GeneratorEnqueuer():
16 | """Builds a queue out of a data generator.
17 |
18 | Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
19 |
20 | # Arguments
21 | generator: a generator function which endlessly yields data
22 | use_multiprocessing: use multiprocessing if True, otherwise threading
23 | wait_time: time to sleep in-between calls to `put()`
24 | random_seed: Initial seed for workers,
25 | will be incremented by one for each workers.
26 | """
27 |
28 | def __init__(self, generator,
29 | use_multiprocessing=False,
30 | wait_time=0.05,
31 | random_seed=None):
32 | self.wait_time = wait_time
33 | self._generator = generator
34 | self._use_multiprocessing = use_multiprocessing
35 | self._threads = []
36 | self._stop_event = None
37 | self.queue = None
38 | self.random_seed = random_seed
39 |
40 | def start(self, workers=1, max_queue_size=10):
41 | """Kicks off threads which add data from the generator into the queue.
42 |
43 | # Arguments
44 | workers: number of worker threads
45 | max_queue_size: queue size
46 | (when full, threads could block on `put()`)
47 | """
48 |
49 | def data_generator_task():
50 | while not self._stop_event.is_set():
51 | try:
52 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
53 | generator_output = next(self._generator)
54 | self.queue.put(generator_output)
55 | else:
56 | time.sleep(self.wait_time)
57 | except Exception:
58 | self._stop_event.set()
59 | raise
60 |
61 | try:
62 | if self._use_multiprocessing:
63 | self.queue = multiprocessing.Queue(maxsize=max_queue_size)
64 | self._stop_event = multiprocessing.Event()
65 | else:
66 | self.queue = queue.Queue()
67 | self._stop_event = threading.Event()
68 |
69 | for _ in range(workers):
70 | if self._use_multiprocessing:
71 | # Reset random seed else all children processes
72 | # share the same seed
73 | np.random.seed(self.random_seed)
74 | thread = multiprocessing.Process(target=data_generator_task)
75 | thread.daemon = True
76 | if self.random_seed is not None:
77 | self.random_seed += 1
78 | else:
79 | thread = threading.Thread(target=data_generator_task)
80 | self._threads.append(thread)
81 | thread.start()
82 | except:
83 | self.stop()
84 | raise
85 |
86 | def is_running(self):
87 | return self._stop_event is not None and not self._stop_event.is_set()
88 |
89 | def stop(self, timeout=None):
90 | """Stops running threads and wait for them to exit, if necessary.
91 |
92 | Should be called by the same thread which called `start()`.
93 |
94 | # Arguments
95 | timeout: maximum time to wait on `thread.join()`.
96 | """
97 | if self.is_running():
98 | self._stop_event.set()
99 |
100 | for thread in self._threads:
101 | if thread.is_alive():
102 | if self._use_multiprocessing:
103 | thread.terminate()
104 | else:
105 | thread.join(timeout)
106 |
107 | if self._use_multiprocessing:
108 | if self.queue is not None:
109 | self.queue.close()
110 |
111 | self._threads = []
112 | self._stop_event = None
113 | self.queue = None
114 |
115 | def get(self):
116 | """Creates a generator to extract data from the queue.
117 |
118 | Skip the data if it is `None`.
119 |
120 | # Returns
121 | A generator
122 | """
123 | while self.is_running():
124 | if not self.queue.empty():
125 | inputs = self.queue.get()
126 | if inputs is not None:
127 | yield inputs
128 | else:
129 | time.sleep(self.wait_time)
130 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from tensorflow.contrib import slim
4 |
5 | tf.app.flags.DEFINE_string('adv_ver', 'ver1', 'ver1 or ver2')
6 | FLAGS = tf.app.flags.FLAGS
7 |
8 | VGG_MEAN = [123.68, 116.779, 103.939] # RGB ordered
9 | EPS = 1e-12
10 |
11 | class loss_builder:
12 | def __init__(self):
13 | self.vgg_path = FLAGS.vgg_path
14 | self.vgg_used = False
15 |
16 | def _rgb_subtraction(self, images):
17 | channels = tf.split(axis=3, num_or_size_splits=3, value=images)
18 | for i in range(3):
19 | channels[i] -= VGG_MEAN[i]
20 | return tf.concat(axis=3, values=channels)
21 |
22 | def _build_vgg_19(self, images):
23 | input_images = self._rgb_subtraction(images) / 255.0
24 |
25 | ### vgg_19
26 | with slim.arg_scope([slim.conv2d],
27 | activation_fn=tf.nn.relu,
28 | normalizer_fn=None):
29 |
30 | self.conv1_1 = slim.conv2d(input_images, 64, 3, scope='conv1/conv1_1')
31 | self.conv1_2 = slim.conv2d(self.conv1_1, 64, 3, scope='conv1/conv1_2')
32 | self.pool1 = slim.max_pool2d(self.conv1_2, 2, scope='pool1')
33 |
34 | self.conv2_1 = slim.conv2d(self.pool1, 128, 3, scope='conv2/conv2_1')
35 | self.conv2_2 = slim.conv2d(self.conv2_1, 128, 3, scope='conv2/conv2_2')
36 | self.pool2 = slim.max_pool2d(self.conv2_2, 2, scope='pool2')
37 |
38 | self.conv3_1 = slim.conv2d(self.pool2, 256, 3, scope='conv3/conv3_1')
39 | self.conv3_2 = slim.conv2d(self.conv3_1, 256, 3, scope='conv3/conv3_2')
40 | self.conv3_3 = slim.conv2d(self.conv3_2, 256, 3, scope='conv3/conv3_3')
41 | self.conv3_4 = slim.conv2d(self.conv3_3, 256, 3, scope='conv3/conv3_4')
42 | self.pool3 = slim.max_pool2d(self.conv3_4, 2, scope='pool3')
43 |
44 | self.conv4_1 = slim.conv2d(self.pool3, 512, 3, scope='conv4/conv4_1')
45 | self.conv4_2 = slim.conv2d(self.conv4_1, 512, 3, scope='conv4/conv4_2')
46 | self.conv4_3 = slim.conv2d(self.conv4_2, 512, 3, scope='conv4/conv4_3')
47 | self.conv4_4 = slim.conv2d(self.conv4_3, 512, 3, scope='conv4/conv4_4')
48 | self.pool4 = slim.max_pool2d(self.conv4_4, 2, scope='pool4')
49 |
50 | self.conv5_1 = slim.conv2d(self.pool4, 512, 3, scope='conv5/conv5_1')
51 | self.conv5_2 = slim.conv2d(self.conv5_1, 512, 3, scope='conv5/conv5_2')
52 | self.conv5_3 = slim.conv2d(self.conv5_2, 512, 3, scope='conv5/conv5_3')
53 | self.conv5_4 = slim.conv2d(self.conv5_3, 512, 3, scope='conv5/conv5_4')
54 | self.pool5 = slim.max_pool2d(self.conv5_4, 2, scope='pool5')
55 |
56 | self.vgg_19 = {'conv1_1':self.conv1_1, 'conv1_2':self.conv1_2, 'pool1':self.pool1,
57 | 'conv2_1':self.conv2_1, 'conv2_2':self.conv2_2, 'pool2':self.pool2,
58 | 'conv3_1':self.conv3_1, 'conv3_2':self.conv3_2, 'conv3_3':self.conv3_3, 'conv3_4':self.conv3_4, 'pool3':self.pool3,
59 | 'conv4_1':self.conv4_1, 'conv4_2':self.conv4_2, 'conv4_3':self.conv4_3, 'conv4_4':self.conv4_4, 'pool4':self.pool4,
60 | 'conv5_1':self.conv5_1, 'conv5_2':self.conv5_2, 'conv5_3':self.conv5_3, 'conv5_4':self.conv5_4, 'pool5':self.pool5,
61 | }
62 |
63 | self.vgg_used = True
64 |
65 | def _lrelu(self, x, a=0.2):
66 | with tf.name_scope('lrelu'):
67 | x = tf.identity(x)
68 | return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
69 |
70 | def _build_discriminator(self, images, is_training):
71 | batch_norm_params = {
72 | 'decay': 0.997,
73 | 'epsilon': 1e-5,
74 | 'scale': True,
75 | 'is_training': is_training
76 | }
77 |
78 |
79 | with slim.arg_scope([slim.conv2d],
80 | activation_fn=None,
81 | normalizer_fn=slim.batch_norm,
82 | normalizer_params=batch_norm_params):
83 | x = slim.conv2d(images, 32, 3, scope='conv1')
84 | x = self._lrelu(x)
85 | x = slim.conv2d(x, 32, 3, stride=2, scope='conv2')
86 | x = self._lrelu(x)
87 |
88 | x = slim.conv2d(x, 64, 3, scope='conv3')
89 | x = self._lrelu(x)
90 | x = slim.conv2d(x, 64, 3, stride=2, scope='conv4')
91 | x = self._lrelu(x)
92 |
93 | x = slim.conv2d(x, 128, 3, scope='conv5')
94 | x = self._lrelu(x)
95 | x = slim.conv2d(x, 128, 3, stride=2, scope='conv6')
96 | x = self._lrelu(x)
97 |
98 | x = slim.conv2d(x, 256, 3, scope='conv7')
99 | x = self._lrelu(x)
100 | x = slim.conv2d(x, 256, 3, stride=2, scope='conv8')
101 | x = self._lrelu(x)
102 |
103 | x = slim.conv2d(x, 512, 3, scope='conv9')
104 | x = self._lrelu(x)
105 | x = slim.conv2d(x, 512, 3, stride=2, scope='conv10')
106 | x = self._lrelu(x)
107 |
108 | x = slim.flatten(x, scope='flatten')
109 | x = slim.fully_connected(x, 1024, activation_fn=None, normalizer_fn=None, scope='fc1')
110 | x = self._lrelu(x)
111 | logits = slim.fully_connected(x, 1, activation_fn=None, normalizer_fn=None, scope='fc2')
112 | outputs = tf.nn.sigmoid(logits)
113 |
114 | return outputs #logits
115 |
116 | def _build_discriminator_ver2(self, images, is_training):
117 | batch_norm_params = {
118 | 'decay': 0.997,
119 | 'epsilon': 1e-5,
120 | 'scale': True,
121 | 'is_training': is_training
122 | }
123 |
124 |
125 | with slim.arg_scope([slim.conv2d],
126 | activation_fn=None,
127 | normalizer_fn=slim.batch_norm,
128 | normalizer_params=batch_norm_params):
129 |
130 | x = slim.conv2d(images, 64, 3, stride=1, scope='conv1')
131 | x = self._lrelu(x)
132 | x = slim.conv2d(x, 64, 3, stride=2, scope='conv2')
133 | x = self._lrelu(x)
134 | x = slim.conv2d(x, 128, 3, stride=1, scope='conv3')
135 | x = self._lrelu(x)
136 | x = slim.conv2d(x, 128, 3, stride=2, scope='conv4')
137 | x = self._lrelu(x)
138 | x = slim.conv2d(x, 256, 3, stride=1, scope='conv5')
139 | x = self._lrelu(x)
140 | x = slim.conv2d(x, 256, 3, stride=2, scope='conv6')
141 | x = self._lrelu(x)
142 | x = slim.conv2d(x, 512, 3, stride=1, scope='conv7')
143 | x = self._lrelu(x)
144 | x = slim.conv2d(x, 512, 3, stride=2, scope='conv8')
145 | x = self._lrelu(x)
146 |
147 | x = slim.flatten(x, scope='flatten')
148 | x = slim.fully_connected(x, 1024, activation_fn=None, normalizer_fn=None, scope='fc1')
149 | x = self._lrelu(x)
150 | logits = slim.fully_connected(x, 1, activation_fn=None, normalizer_fn=None, scope='fc2')
151 | outputs = tf.nn.sigmoid(logits)
152 |
153 | return outputs
154 |
155 | def _mse(self, gt, pred):
156 | return tf.losses.mean_squared_error(gt, pred)
157 |
158 | def _l1_loss(self, gt, pred):
159 | return tf.reduce_mean(tf.abs(gt - pred))
160 |
161 | def _gram_matrix(self, features):
162 | dims = features.get_shape().as_list()
163 | features = tf.reshape(features, [-1, dims[1] * dims[2], dims[3]])
164 |
165 | gram_matrix = tf.matmul(features, features, transpose_a=True)
166 | normalized_gram_matrix = gram_matrix / (dims[1] * dims[2] * dims[3])
167 |
168 | return normalized_gram_matrix #tf.matmul(features, features, transpose_a=True)
169 |
170 | def _normalize(self, features):
171 | dims = features.get_shape().as_list()
172 | return features / (dims[1] * dims[2] * dims[3])
173 |
174 | def _preprocess(self, images):
175 | return (images / 255.0) * 2.0 - 1.0
176 |
177 | def _texture_loss(self, features, patch_size=16):
178 | '''
179 | the front part of features : gt features
180 | the latter part of features : pred features
181 | I will do calculating gt and pred features at once!
182 | '''
183 | #features = self._normalize(features)
184 | batch_size, h, w, c = features.get_shape().as_list()
185 | features = tf.space_to_batch_nd(features, [patch_size, patch_size], [[0, 0], [0, 0]])
186 | features = tf.reshape(features, [patch_size, patch_size, -1, h // patch_size, w // patch_size, c])
187 | features = tf.transpose(features, [2, 3, 4, 0, 1, 5])
188 | patches_gt, patches_pred = tf.split(features, 2, axis=0)
189 |
190 | patches_gt = tf.reshape(patches_gt, [-1, patch_size, patch_size, c])
191 | patches_pred = tf.reshape(patches_pred, [-1, patch_size, patch_size, c])
192 |
193 | gram_matrix_gt = self._gram_matrix(patches_gt)
194 | gram_matrix_pred = self._gram_matrix(patches_pred)
195 |
196 | tl_features = tf.reduce_mean(tf.reduce_sum(tf.square(gram_matrix_gt - gram_matrix_pred), axis=-1))
197 | return tl_features
198 |
199 | def _perceptual_loss(self):
200 | gt_pool5, pred_pool5 = tf.split(self.vgg_19['conv5_4'], 2, axis=0)
201 |
202 | pl_pool5 = tf.reduce_mean(tf.reduce_sum(tf.square(gt_pool5 - pred_pool5), axis=-1))
203 |
204 | return pl_pool5
205 |
206 | def _adv_loss(self, gt_logits, pred_logits):#gan_logits):
207 | # gt_logits -> real, pred_logits -> fake
208 | # all values went through tf.nn.sigmoid
209 |
210 |
211 | adv_gen_loss = tf.reduce_mean(-tf.log(pred_logits + EPS))
212 |
213 | adv_disc_loss = tf.reduce_mean(-(tf.log(gt_logits + EPS) + tf.log(1.0 - pred_logits + EPS)))
214 |
215 |
216 | #adv_gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_logits, labels=tf.ones_like(pred_logits)))
217 |
218 | #adv_disc_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=gt_logits, labels=tf.ones_like(gt_logits)))
219 | #adv_disc_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pred_logits, labels=tf.zeros_like(pred_logits)))
220 | #adv_disc_loss = adv_disc_real + adv_disc_fake
221 |
222 |
223 | return adv_gen_loss, adv_disc_loss
224 |
225 | def build_vgg_19(self, gt, pred):
226 | input_images = tf.concat([gt, pred], axis=0)
227 | with tf.variable_scope('vgg_19'):
228 | self._build_vgg_19(input_images)
229 |
230 | def build_discriminator(self, gt, pred, is_training=True):
231 | '''
232 | build_discriminator is only used for training!
233 | '''
234 |
235 | gt = self._preprocess(gt) # -1.0 ~ 1.0
236 | pred = self._preprocess(pred) # -1.0 ~ 1.0
237 |
238 |
239 | with tf.variable_scope('discriminator'):
240 | if FLAGS.adv_ver == 'ver1':
241 | gt_logits = self._build_discriminator(gt, is_training)
242 | else:
243 | gt_logits = self._build_discriminator_ver2(gt, is_training)
244 |
245 | with tf.variable_scope('discriminator', reuse=True):
246 | b, h, w, c = gt.get_shape().as_list()
247 | pred.set_shape([b,h,w,c])
248 |
249 | if FLAGS.adv_ver == 'ver1':
250 | pred_logits = self._build_discriminator(pred, is_training)
251 | else:
252 | pred_logits = self._build_discriminator_ver2(pred, is_training)
253 |
254 | return gt_logits, pred_logits
255 |
256 | #def build_cycle_discriminator(self,
257 |
258 | def get_loss(self, gt, pred, type='mse'):
259 | '''
260 | 'mse', 'inverse_mse', 'fft_mse'
261 | 'perceptual', 'texture'
262 | 'adv, 'cycle_adv'
263 | '''
264 | if type == 'mse': # See SRCNN. MSE is very simple loss function.
265 | gt = self._preprocess(gt)
266 | pred = self._preprocess(pred)
267 | return self._mse(gt, pred)
268 | elif type == 'inverse_mse':
269 | # gt is the input_lr image!!!
270 | gt = self._preprocess(gt)
271 | pred = self._preprocess(pred)
272 | pred = tf.image.resize_bilinear(pred, size=[tf.shape(gt)[1], tf.shape(gt)[2]])
273 | return self._mse(gt, pred)
274 | elif type == 'fft_mse':
275 | # check whether both gt and pred need preprocessing
276 | gt = self._preprocess(gt)
277 | pred = self._preprocess(pred)
278 |
279 | ### fft then mse
280 | gt = tf.cast(gt, tf.complex64)
281 | pred = tf.cast(pred, tf.complex64)
282 |
283 | gt = tf.fft2d(gt)
284 | pred = tf.fft2d(pred)
285 |
286 | return self._mse(gt, pred)
287 | elif type == 'l1_loss':
288 | gt = self._preprocess(gt)
289 | pred = self._preprocess(pred)
290 |
291 | return self._l1_loss(gt, pred)
292 | elif type == 'perceptual': # See Enhancenet.
293 | if not self.vgg_used:
294 | self.build_vgg_19(gt, pred)
295 |
296 | pl_pool5 = self._perceptual_loss()
297 |
298 | return pl_pool5
299 | elif type == 'texture': # See Enhancenet, Style transfer papers.
300 | if not self.vgg_used:
301 | self.build_vgg_19(gt, pred)
302 |
303 | tl_conv1 = self._texture_loss(self.vgg_19['conv1_1'])
304 | tl_conv2 = self._texture_loss(self.vgg_19['conv2_1'])
305 | tl_conv3 = self._texture_loss(self.vgg_19['conv3_1'])
306 |
307 | return tl_conv1, tl_conv2, tl_conv3
308 | elif type == 'adv':
309 | gt_logits, pred_logits = self.build_discriminator(gt, pred)
310 |
311 | adv_gen_loss, adv_disc_loss = self._adv_loss(gt_logits, pred_logits)
312 |
313 | return adv_gen_loss, adv_disc_loss
314 | else:
315 | print('%s is not implemented.' % (type))
316 |
317 | if __name__ == '__main__':
318 |
319 | loss_factory = loss_builder()
320 |
321 |
322 |
323 |
--------------------------------------------------------------------------------
/outputs/ENet-E.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/outputs/ENet-E.png
--------------------------------------------------------------------------------
/outputs/ENet-PAT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/outputs/ENet-PAT.png
--------------------------------------------------------------------------------
/outputs/Input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geonm/EnhanceNet-Tensorflow/d0e527418f8b3fd167a61c8777483259d04fc4ab/outputs/Input.png
--------------------------------------------------------------------------------
/test_SR.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import numpy as np
5 | import tensorflow as tf
6 | import sys
7 | import SR_models
8 | import utils
9 | import cv2
10 |
11 | tf.app.flags.DEFINE_string('model_path', '/where/your/model/folder', '')
12 | tf.app.flags.DEFINE_string('image_path', '/where/your/test_image/folder', '')
13 | tf.app.flags.DEFINE_string('save_path', '/where/your/generated_image/folder', '')
14 | tf.app.flags.DEFINE_string('run_gpu', '0', '')
15 | FLAGS = tf.app.flags.FLAGS
16 |
17 | def load_model(model_path):
18 | '''
19 | model_path = '.../where/your/save/model/folder'
20 | '''
21 | input_low_images = tf.placeholder(tf.float32, shape=[1, None, None, 3], name='input_low_images')
22 |
23 | model_builder = SR_models.model_builder()
24 |
25 | generated_high_images, resized_low_images = model_builder.generator(input_low_images, is_training=False, model='enhancenet')
26 |
27 | generated_high_images = tf.cast(tf.clip_by_value(generated_high_images, 0, 255), tf.uint8)
28 |
29 | all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
30 |
31 | gen_vars = [var for var in all_vars if var.name.startswith('generator')]
32 |
33 | saver = tf.train.Saver(gen_vars)
34 |
35 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
36 |
37 | ckpt_path = utils.get_last_ckpt_path(model_path)
38 | saver.restore(sess, ckpt_path)
39 |
40 | return input_low_images, generated_high_images, sess
41 |
42 | def init_resize_image(im):
43 | h, w, _ = im.shape
44 | size = [h, w]
45 | max_arg = np.argmax(size)
46 | max_len = size[max_arg]
47 | min_arg = max_arg - 1
48 | min_len = size[min_arg]
49 |
50 | maximum_size = 1024
51 | if max_len < maximum_size:
52 | maximum_size = max_len
53 | ratio = 1.0
54 | return im, ratio
55 | else:
56 | ratio = maximum_size / max_len
57 | max_len = max_len * ratio
58 | min_len = min_len * ratio
59 | size[max_arg] = int(max_len)
60 | size[min_arg] = int(min_len)
61 |
62 | im = cv2.resize(im, (size[1], size[0]))
63 |
64 | return im, ratio
65 |
66 | if __name__ == '__main__':
67 | # set your gpus usage
68 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.run_gpu
69 |
70 | # get pre-trained generator model
71 | input_image, generated_image, sess = load_model(FLAGS.model_path)
72 |
73 | # get test_image_list
74 | test_image_list = utils.get_image_paths(FLAGS.image_path)
75 |
76 | # make save_folder
77 | if not os.path.exists(FLAGS.save_path):
78 | os.makedirs(FLAGS.save_path)
79 |
80 | # do test
81 |
82 | for test_idx, test_image in enumerate(test_image_list):
83 | loaded_image = cv2.imread(test_image)
84 | processed_image, tmp_ratio = init_resize_image(loaded_image)
85 |
86 | feed_dict = {input_image : [processed_image[:,:,::-1]]}
87 |
88 | output_image = sess.run(generated_image, feed_dict=feed_dict)
89 |
90 | output_image = output_image[0,:,:,:]
91 |
92 | image_name = os.path.basename(test_image)
93 |
94 | tmp_save_path = os.path.join(FLAGS.save_path, 'SR_' + image_name)
95 |
96 | cv2.imwrite(tmp_save_path, output_image[:,:,::-1])
97 |
98 | print('%d / %d completed!!!' % (test_idx + 1, len(test_image_list)))
99 |
100 |
--------------------------------------------------------------------------------
/train_SR.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.contrib import slim
6 | import cv2
7 | from dataset import SR_data_load
8 | import SR_models
9 | import losses
10 | import utils
11 |
12 | tf.app.flags.DEFINE_string('run_gpu', '0', 'use single gpu')
13 | tf.app.flags.DEFINE_string('save_path', '/where/your/folder', '')
14 | tf.app.flags.DEFINE_boolean('model_restore', False, '')
15 | tf.app.flags.DEFINE_string('image_path', '/where/your/saved/image/folder', '')
16 | tf.app.flags.DEFINE_integer('batch_size', 32, '')
17 | tf.app.flags.DEFINE_integer('num_readers', 4, '')
18 | tf.app.flags.DEFINE_integer('input_size', 32, '')
19 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'define your learing strategy')
20 | tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '')
21 | tf.app.flags.DEFINE_string('vgg_path', None, '/where/your/vgg_19.ckpt')
22 | tf.app.flags.DEFINE_integer('num_workers', 4, '')
23 | tf.app.flags.DEFINE_integer('max_to_keep', 10, 'how many do you want to save models?')
24 | tf.app.flags.DEFINE_integer('save_model_steps', 10000, '')
25 | tf.app.flags.DEFINE_integer('save_summary_steps', 10, '')
26 | tf.app.flags.DEFINE_integer('max_steps', 1000000, '')
27 | tf.app.flags.DEFINE_string('losses', 'perceptual', 'mse,perceptual,texture,adv')
28 | tf.app.flags.DEFINE_string('adv_direction', 'g2d', 'g2d or d2g')
29 | tf.app.flags.DEFINE_float('adv_gen_w', 0.001, '')
30 | tf.app.flags.DEFINE_float('adv_disc_w', 1.0, '')
31 | FLAGS = tf.app.flags.FLAGS
32 |
33 | def main(argv=None):
34 | ######################### System setup
35 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.run_gpu
36 | utils.prepare_checkpoint_path(FLAGS.save_path, FLAGS.model_restore)
37 |
38 | ######################### Model setup
39 | low_size = FLAGS.input_size
40 | high_size = int(FLAGS.input_size * FLAGS.SR_scale)
41 |
42 | input_low_images = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, low_size, low_size, 3], name='input_low_images')
43 | input_high_images = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, high_size, high_size, 3], name='input_high_images')
44 |
45 |
46 | model_builder = SR_models.model_builder()
47 |
48 | generated_high_images, resized_low_images = model_builder.generator(input_low_images, is_training=True, model=FLAGS.model)
49 |
50 | tf.summary.image('1_input_low_images', input_low_images)
51 | tf.summary.image('2_input_high_images', input_high_images)
52 | vis_gen_images = tf.cast(tf.clip_by_value(generated_high_images, 0, 255), tf.uint8)
53 | tf.summary.image('3_generated_images', vis_gen_images)
54 | tf.summary.image('4_bicubic_images', resized_low_images)
55 | vis_high_images = tf.cast(input_high_images, tf.uint8)
56 | vis_gen_images = tf.concat([resized_low_images, vis_gen_images, vis_high_images], axis=2)
57 | tf.summary.image('5_bicubic_gen_gt', vis_gen_images)
58 |
59 |
60 | ######################### Losses setup
61 | loss_builder = losses.loss_builder()
62 |
63 | loss_list = utils.loss_parser(FLAGS.losses)
64 | generator_loss = 0.0
65 |
66 | if 'mse' in loss_list or 'l2' in loss_list or 'l2_loss' in loss_list:
67 | mse_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='mse')
68 | generator_loss = generator_loss + 1.0 * mse_loss
69 | tf.summary.scalar('mse_loss', mse_loss)
70 | if 'inverse_mse' in loss_list:
71 | inv_mse_loss = loss_builder.get_loss(input_low_images, generated_high_images, type='inverse_mse')
72 | generator_loss = generator_loss + 100.0 * inv_mse_loss
73 | tf.summary.scalar('inv_mse_loss', inv_mse_loss)
74 | if 'fft_mse' in loss_list:
75 | fft_mse_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='fft_mse')
76 | generator_loss = generator_loss + 1.0 * fft_mse_loss
77 | tf.summary.scalar('fft_mse_loss', fft_mse_loss)
78 | if 'l1' in loss_list or 'l1_loss' in loss_list:
79 | l1_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='l1_loss')
80 | generator_loss = generator_loss + 0.01 * l1_loss
81 | tf.summary.scalar('l1_loss', l1_loss)
82 | if 'perceptual' in loss_list:
83 | pl_pool5 = loss_builder.get_loss(input_high_images, generated_high_images, type='perceptual')
84 | pl_pool5 *= 2e-2
85 | generator_loss = generator_loss + pl_pool5
86 | tf.summary.scalar('pl_pool5', pl_pool5)
87 | if 'texture' in loss_list:
88 | tl_conv1, tl_conv2, tl_conv3 = loss_builder.get_loss(input_high_images, generated_high_images, type='texture')
89 |
90 | #generator_loss = generator_loss + 1e-2 * tl_conv1 + 1e-2 * tl_conv2 + 1e-2 * tl_conv3
91 | tl_weight = 10.0
92 | #generator_loss = generator_loss + tl_weight * tl_conv1 + tl_weight * tl_conv2 + tl_weight * tl_conv3
93 | generator_loss = generator_loss + tl_weight * tl_conv3
94 |
95 | tf.summary.scalar('tl_conv1', tl_conv1)
96 | tf.summary.scalar('tl_conv2', tl_conv2)
97 | tf.summary.scalar('tl_conv3', tl_conv3)
98 | if 'adv' in loss_list:
99 | adv_gen_loss, adv_disc_loss = loss_builder.get_loss(input_high_images, generated_high_images, type='adv')
100 | tf.summary.scalar('adv_gen', adv_gen_loss)
101 | tf.summary.scalar('adv_disc', adv_disc_loss)
102 | discrim_loss = FLAGS.adv_disc_w * adv_disc_loss
103 | generator_loss = generator_loss + FLAGS.adv_gen_w * adv_gen_loss
104 |
105 | tf.summary.scalar('generator_loss', generator_loss)
106 |
107 | ######################### Training setup
108 | global_step = tf.get_variable('global_step', [], dtype=tf.int64, initializer=tf.constant_initializer(0), trainable=False)
109 |
110 | train_vars = tf.trainable_variables()
111 |
112 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
113 |
114 | learning_rate = utils.configure_learning_rate(FLAGS.learning_rate, global_step)
115 |
116 | #gen_optimizer = utils.configure_optimizer(learning_rate)
117 | #gen_gradients = gen_optimizer.compute_gradients(generator_loss, var_list=generator_vars)
118 | #gen_grad_updates = gen_optimizer.apply_gradients(gen_gradients)#, global_step=global_step)
119 |
120 | if 'adv' in loss_list:
121 | discrim_vars = [var for var in train_vars if var.name.startswith('discriminator')]
122 | disc_optimizer = utils.configure_optimizer(learning_rate)
123 | disc_gradients = disc_optimizer.compute_gradients(discrim_loss, var_list=discrim_vars)
124 | disc_grad_updates = disc_optimizer.apply_gradients(disc_gradients, global_step=global_step)
125 |
126 | with tf.control_dependencies([disc_grad_updates] + update_ops):
127 | generator_vars = [var for var in train_vars if var.name.startswith('generator')]
128 | gen_optimizer = utils.configure_optimizer(learning_rate)
129 | gen_gradients = gen_optimizer.compute_gradients(generator_loss, var_list=generator_vars)
130 | gen_grad_updates = gen_optimizer.apply_gradients(gen_gradients, global_step=global_step)
131 |
132 | train_op = gen_grad_updates
133 |
134 | else:
135 |
136 | generator_vars = [var for var in train_vars if var.name.startswith('generator')]
137 | discrim_vars = generator_vars
138 | gen_optimizer = utils.configure_optimizer(learning_rate)
139 | gen_gradients = gen_optimizer.compute_gradients(generator_loss, var_list=generator_vars)
140 | gen_grad_updates = gen_optimizer.apply_gradients(gen_gradients)#, global_step=global_step)
141 | with tf.control_dependencies([gen_grad_updates] + update_ops):
142 | train_op = tf.no_op(name='train_op')
143 |
144 | saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
145 | summary_op = tf.summary.merge_all()
146 | summary_writer = tf.summary.FileWriter(FLAGS.save_path, tf.get_default_graph())
147 |
148 | ######################### Train process
149 | data_generator = SR_data_load.get_batch(image_path=FLAGS.image_path,
150 | num_workers=FLAGS.num_workers,
151 | batch_size=FLAGS.batch_size,
152 | hr_size=high_size)
153 |
154 | ## vgg_stop process
155 | #utils.print_vars(train_vars)
156 | #utils.print_vars(generator_vars)
157 | if FLAGS.vgg_path is not None:
158 | variable_restore_op = utils.get_restore_op(FLAGS.vgg_path, train_vars)
159 |
160 | #############
161 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
162 | if FLAGS.model_restore:
163 | ckpt = tf.train.latest_checkpoint(FLAGS.save_path)
164 | saver.restore(sess, ckpt)
165 | else:
166 | sess.run(tf.global_variables_initializer())
167 | if FLAGS.vgg_path is not None:
168 | variable_restore_op(sess)
169 |
170 | start_time = time.time()
171 | for iter_val in range(int(global_step.eval()) + 1, FLAGS.max_steps + 1):
172 | data = next(data_generator)
173 | high_images = np.asarray(data[0])
174 | low_images = np.asarray(data[1])
175 |
176 | feed_dict = {input_low_images: low_images,
177 | input_high_images: high_images}
178 |
179 | generator_loss_val, _, g_w, d_w = sess.run([generator_loss, train_op, generator_vars[0], discrim_vars[0]], feed_dict=feed_dict)
180 |
181 | if iter_val != 0 and iter_val % FLAGS.save_summary_steps == 0:
182 | summary_str = sess.run(summary_op, feed_dict=feed_dict)
183 | summary_writer.add_summary(summary_str, global_step=iter_val)
184 |
185 | used_time = time.time() - start_time
186 | avg_time_per_step = used_time / FLAGS.save_summary_steps
187 | avg_examples_per_second = (FLAGS.save_summary_steps * FLAGS.batch_size) / used_time
188 |
189 | print('step %d, generator_loss %.4f, weights %.2f, %.2f, %.2f seconds/step, %.2f examples/second'
190 | % (iter_val, generator_loss_val, np.sum(g_w), np.sum(d_w), avg_time_per_step, avg_examples_per_second))
191 | start_time = time.time()
192 |
193 | if iter_val != 0 and iter_val % FLAGS.save_model_steps == 0:
194 | checkpoint_fn = os.path.join(FLAGS.save_path, 'model.ckpt')
195 | saver.save(sess, checkpoint_fn, global_step=iter_val)
196 |
197 | print('')
198 | print('*' * 30)
199 | print(' Training done!!! ')
200 | print('*' * 30)
201 | print('')
202 |
203 | if __name__ == '__main__':
204 | tf.app.run()
205 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.contrib import slim
4 | import os
5 | import glob
6 |
7 | FLAGS = tf.app.flags.FLAGS
8 |
9 | def prepare_checkpoint_path(save_path, restore):
10 | if not tf.gfile.Exists(save_path):
11 | tf.gfile.MkDir(save_path)
12 | else:
13 | if not restore:
14 | tf.gfile.DeleteRecursively(save_path)
15 | tf.gfile.MkDir(save_path)
16 |
17 | def configure_learning_rate(learning_rate_init_value, global_step):
18 | learning_rate = tf.train.exponential_decay(learning_rate_init_value, global_step, decay_steps=10000, decay_rate=0.94, staircase=True)
19 | tf.summary.scalar('learning_rate', learning_rate)
20 | return learning_rate
21 |
22 | def configure_optimizer(learning_rate):
23 | return tf.train.AdamOptimizer(learning_rate)
24 |
25 | def get_restore_op(vgg_path, train_vars, check=False):
26 | vgg_19_vars = [var for var in train_vars if var.name.startswith('vgg')]
27 | if check:
28 | print_vars(vgg_19_vars)
29 | variable_restore_op = slim.assign_from_checkpoint_fn(
30 | vgg_path,
31 | vgg_19_vars,
32 | ignore_missing_vars=True)
33 | return variable_restore_op
34 |
35 | def print_vars(var_list):
36 | print('')
37 | for var in var_list:
38 | print(var)
39 | print('')
40 |
41 | def loss_parser(str_loss):
42 | '''
43 | NOTE!!! str_loss should be like 'mse,perceptual,texture,adv'...
44 | '''
45 | selected_loss_array = str_loss.split(',')
46 | return selected_loss_array
47 |
48 | def get_last_ckpt_path(folder_path):
49 | '''
50 | folder_path = .../where/your/saved/model/folder
51 | '''
52 |
53 | meta_paths = sorted(glob.glob(os.path.join(folder_path, '*.meta')))
54 |
55 | numbers = []
56 |
57 | for meta_path in meta_paths:
58 | numbers.append(int(meta_path.split('-')[-1].split('.')[0]))
59 |
60 | numbers = np.asarray(numbers)
61 |
62 | sorted_idx = np.argsort(numbers)
63 |
64 | latest_meta_path = meta_paths[sorted_idx[-1]]
65 |
66 | ckpt_path = latest_meta_path.replace('.meta', '')
67 |
68 | return ckpt_path
69 |
70 | def get_image_paths(image_folder):
71 | possible_image_type = ['jpg', 'png', 'JPEG', 'jpeg']
72 |
73 | image_list = [image_path for image_paths in [glob.glob(os.path.join(image_folder, '*.%s' % ext)) for ext in possible_image_type] for image_path in image_paths]
74 |
75 | return image_list
76 |
77 |
--------------------------------------------------------------------------------