├── .idea
├── FSR-Tensorflow.iml
├── deployment.xml
├── encodings.xml
├── misc.xml
├── modules.xml
├── remote-mappings.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── data_loader.py
├── imgs
├── comp_real.jpg
├── comp_real.png
├── comp_sota.jpg
├── teaser.jpg
└── teaser.png
├── models
├── __init__.py
├── model.py
└── model_bn.py
├── run_model.py
├── run_model_bn.py
├── testing_res
└── 9099.png
├── testing_set
└── 9099.png
└── util
├── __init__.py
├── util.py
└── util_bn.py
/.idea/FSR-Tensorflow.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 | 1562588463861
201 |
202 |
203 | 1562588463861
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 | file://$PROJECT_DIR$/data_loader.py
248 | 118
249 |
250 |
251 |
252 | file://$PROJECT_DIR$/models/model_bn.py
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FSRNet: End-to-End Learning Face Super-Resolution with Facial Priors
2 |
3 | Tensorflow implement FSRNet based on [SRN-Deblur](https://github.com/jiangsutx/SRN-Deblur/)
4 |
5 |
6 | ## Testing
7 |
8 | Download pretrained models and unzip, make sure the model path is ./checkpoints/color/checkpoints/deblur.model*
9 |
10 | `--input_path=` and save the outputs to `--output_path=`.
11 | For example:
12 |
13 | ```bash
14 | python run_model.py --input_path=./testing_set --output_path=./testing_res --gpu=0 --model=color --phase=test --height=128 --width=128
15 | ```
16 |
17 |
18 |
19 |
20 |
21 | ## Training
22 |
23 | 1. use data_loader.py to generate tfrecords in main function
24 | 2. Hyper parameters such as batch size, learning rate, epoch number can be tuned through command line:
25 |
26 | ```bash
27 | python run_model.py --phase=train --batch=16 --lr=1e-4 --epoch=500
28 | ```
29 |
30 |
31 | ## Some problems
32 |
33 | 1. Since the author do not open the code of cropping the face, so the dataset i use is different from theirs, our face is bigger than theirs.
34 | 2. I use [face alignment](https://github.com/1adrianb/face-alignment) to generate landmarks.
35 | 3. Download model from [model](https://pan.baidu.com/s/1HBzZmcty45dhhUz-uGnMLw)
36 | password: 0z3l
37 |
38 |
39 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import numpy as np
4 | import cv2
5 | from scipy.io import loadmat
6 | import pickle
7 | import glob
8 | from util.util_bn import *
9 | from scipy import misc
10 | p=128
11 |
12 | def _int64_feature(value):
13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
14 |
15 |
16 | def _bytes_feature(value):
17 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
18 |
19 |
20 | class DataLoader():
21 | def __init__(self,train_lst='/home/liang/train.txt',data_dir='/home/liang/',map_parallel_num=8,batch_size=16,shuffle_num=2000,prefetch_num=2000):
22 | self.data_dir=data_dir
23 | self.batch_size=batch_size
24 | self.shuffle_num=shuffle_num
25 | self.prefetch_num=prefetch_num
26 | self.train_dir=train_lst
27 | self.map_parallel_num=map_parallel_num
28 | self.scale = 8
29 | def resize(self,prior):
30 | out = np.zeros((128,128,prior.shape[-1]),dtype=np.uint8)
31 | for i in range(prior.shape[-1]):
32 | out[:,:,i] = misc.imresize(prior[:,:,i],1.0,'bicubic')
33 | return out
34 |
35 | def gen_tfrecords(self,save_dir='tfrecords1',tfrecord_num=10):
36 | file_num=tfrecord_num
37 | sample_num=0
38 | if not os.path.exists(save_dir):
39 | os.makedirs(save_dir)
40 | fs=[]
41 | for i in range(file_num):
42 | fs.append(tf.python_io.TFRecordWriter(os.path.join(save_dir, 'data%d.tfrecords' % i)))
43 | data_dir = os.listdir('/home/liang/PycharmProjects/facesr/data/train/input')
44 |
45 | for img_path in data_dir:
46 | print(img_path)
47 | hr = cv2.imread('/home/liang/PycharmProjects/facesr/data/train/input/'+img_path)
48 | hr = cv2.cvtColor(hr, cv2.COLOR_BGR2RGB)
49 |
50 | landmark = np.load('/home/liang/PycharmProjects/facesr/data/train/input_landmark/'+img_path[:-4]+'.npz')
51 | landmark = landmark['heatmap']
52 | landmark = (landmark*255).astype(np.uint8)
53 |
54 | print(landmark.shape,landmark.dtype)
55 |
56 | maps = loadmat('/home/liang/PycharmProjects/facesr/data/train/input_label/'+img_path[:-4]+'.mat')
57 | maps = maps['pos']
58 | out_maps = np.zeros((128,128,11),dtype=np.uint8)
59 | print(out_maps.shape)
60 | for i in range(1,12):
61 | out_maps[:,:,i-1] = ((maps==i).astype(np.uint8)*255)
62 | maps = out_maps
63 | print(maps.shape,maps.dtype)
64 | prior = np.concatenate([maps,landmark],axis=-1)
65 | prior = self.resize(prior)
66 | print(prior.dtype)
67 | example = tf.train.Example(features=tf.train.Features(feature={
68 | 'hr': _bytes_feature(hr.tostring()),
69 | 'prior': _bytes_feature(prior.tostring())
70 | }))
71 | fs[sample_num % file_num].write(example.SerializeToString())
72 | sample_num += 1
73 | if sample_num%1000==0:
74 | print(sample_num)
75 |
76 |
77 | print(sample_num)
78 | for f in fs:
79 | f.close()
80 |
81 | def resize(self,prior):
82 | out = np.zeros((64,64,prior.shape[-1]),dtype=np.uint8)
83 | for i in range(prior.shape[-1]):
84 | out[:,:,i] = misc.imresize(prior[:,:,i],0.5,'bicubic')
85 | return out
86 |
87 | def _parse_one_example(self, example):
88 | features = tf.parse_single_example(
89 | example,
90 | features={
91 | 'hr': tf.FixedLenFeature([], tf.string),
92 | 'prior': tf.FixedLenFeature([], tf.string)
93 | })
94 | gt = features['hr']
95 | gt = tf.decode_raw(gt, tf.uint8)
96 | gt = tf.reshape(gt, [128,128,3])
97 |
98 | prior = features['prior']
99 | prior = tf.decode_raw(prior, tf.uint8)
100 | prior = tf.reshape(prior, [64,64,68+11])
101 |
102 | lr = tf.py_func(lambda x: misc.imresize(x, 1.0 / self.scale, 'bicubic'), [gt], tf.uint8)
103 | bic = tf.py_func(lambda x: misc.imresize(x, self.scale / 1.0, 'bicubic'), [lr], tf.uint8)
104 |
105 | lr = tf.cast(lr,tf.float32)
106 | bic = tf.cast(bic,tf.float32)
107 | gt = tf.cast(gt, tf.float32)
108 | prior = tf.cast(prior, tf.float32)
109 |
110 | gt = tf.reshape(gt, [p, p, 3]) / 255.0
111 | bic = tf.reshape(bic, [p, p, 3]) / 255.0
112 | lr = tf.reshape(lr, [p // self.scale, p // self.scale, 3]) / 255.0
113 | prior = prior/ 255.0
114 | return lr, bic, gt , prior
115 |
116 | def read_tfrecords(self, save_dir='tfrecords1'):
117 | fs_paths = sorted(glob.glob(os.path.join(save_dir, '*.tfrecords')))
118 | if len(fs_paths) == 0:
119 | print('No tfrecords. Should run gen_tfrecords() firstly.')
120 | exit()
121 | dataset = tf.data.TFRecordDataset(fs_paths)
122 | print(self.batch_size)
123 | dataset = dataset.map(self._parse_one_example, self.map_parallel_num).shuffle(self.shuffle_num) \
124 | .prefetch(self.prefetch_num).batch(self.batch_size).repeat()
125 | lr, bic, gt , prior = dataset.make_one_shot_iterator().get_next()
126 | return lr, bic, gt , prior
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 | if __name__=='__main__':
136 | dataLoader = DataLoader()
137 | # dataLoader.gen_tfrecords()
138 | # exit()
139 | lr, bic, gt , prior =dataLoader.read_tfrecords()
140 | sess = tf.Session()
141 | im1,im2,im3,im4=sess.run([lr, bic, gt , prior])
142 | print(im1.shape,im2.shape,im3.shape,im4.shape)
143 | for i in range(16):
144 | a = im2uint8(im1[i])
145 | b=im2uint8(im2[i])
146 | c=im2uint8(im3[i])
147 | d= im2uint8(im4[i][:,:,0])
148 | e = im2uint8(im4[i][:, :, -1])
149 |
150 | a = cv2.cvtColor(a, cv2.COLOR_RGB2BGR)
151 | b = cv2.cvtColor(b,cv2.COLOR_RGB2BGR)
152 | c = cv2.cvtColor(c,cv2.COLOR_RGB2BGR)
153 |
154 | cv2.imwrite(str(i) + '_1.png', a)
155 | cv2.imwrite(str(i)+'_2.png',b)
156 | cv2.imwrite(str(i)+'_3.png',c)
157 | cv2.imwrite(str(i) + '_4.png', d)
158 | cv2.imwrite(str(i) + '_5.png', e)
159 |
--------------------------------------------------------------------------------
/imgs/comp_real.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liang23333/FSRNet-Tensorflow/4509cec8f48480f29fedf798460549f97dc52eb4/imgs/comp_real.jpg
--------------------------------------------------------------------------------
/imgs/comp_real.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liang23333/FSRNet-Tensorflow/4509cec8f48480f29fedf798460549f97dc52eb4/imgs/comp_real.png
--------------------------------------------------------------------------------
/imgs/comp_sota.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liang23333/FSRNet-Tensorflow/4509cec8f48480f29fedf798460549f97dc52eb4/imgs/comp_sota.jpg
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liang23333/FSRNet-Tensorflow/4509cec8f48480f29fedf798460549f97dc52eb4/imgs/teaser.jpg
--------------------------------------------------------------------------------
/imgs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liang23333/FSRNet-Tensorflow/4509cec8f48480f29fedf798460549f97dc52eb4/imgs/teaser.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import time
4 | import random
5 | import datetime
6 | import scipy.misc
7 | from datetime import datetime
8 | from util.util import *
9 | from data_loader import DataLoader
10 | from skimage.measure import compare_psnr
11 | import cv2
12 | class DEBLUR(object):
13 | def __init__(self, args):
14 | self.args = args
15 | self.n_levels = 3
16 | self.scale = 0.5
17 | self.chns = 3 if self.args.model == 'color' else 1 # input / output channels
18 |
19 | # if args.phase == 'train':
20 | self.crop_size = 256
21 |
22 | self.train_dir = os.path.join('./checkpoints', args.model)
23 | if not os.path.exists(self.train_dir):
24 | os.makedirs(self.train_dir)
25 |
26 | self.batch_size = args.batch_size
27 | self.epoch = args.epoch
28 | self.data_size = 9000 // self.batch_size
29 | self.max_steps = int(self.epoch * self.data_size)
30 | self.learning_rate = args.learning_rate
31 |
32 | self.is_training = True
33 |
34 | def generator(self, inputs, reuse=False, scope='g_net'):
35 |
36 | with tf.variable_scope(scope, reuse=reuse):
37 | with tf.variable_scope('coarseSR'):
38 | ### coarse SR Network
39 | x = conv2d(inputs, 'conv1', 64, bn = True, is_training = self.is_training ,activation=True, ksize=3)
40 | x = resblock(x,64,self.is_training,'res1')
41 | x = resblock(x, 64, self.is_training, 'res2')
42 | x = resblock(x, 64, self.is_training, 'res3')
43 | out1 = conv2d(x, 'conv2', 3)
44 | with tf.variable_scope('fineSR_encoder'):
45 | x = conv2d(out1, 'conv1', 64, bn = True, is_training = self.is_training ,activation=True, ksize=3, stride = 2)
46 | for i in range(12):
47 | x = resblock(x, 64, self.is_training,'res'+str(i))
48 | x = conv2d(x, 'conv2', 64, bn = True, is_training = self.is_training ,activation=True, ksize=3)
49 |
50 | with tf.variable_scope('prior'):
51 | y = conv2d(out1, 'conv1', 64, bn=True, is_training=self.is_training, activation=True, ksize=7, stride=2)
52 | for i in range(3):
53 | y = resblock(y,128, self.is_training,'res'+str(i))
54 | y = hour_glass(y,128,4,self.is_training,name='hourglass1')
55 | y = conv2d(y,'conv2',128, bn=True, is_training=self.is_training, activation=True)
56 | y = hour_glass(y,128,4,self.is_training,name='hourglass2')
57 | y1 = conv2d(y,'conv3',68,ksize=1)
58 | y2 = conv2d(y,'conv4',11,ksize=1)
59 | y = tf.concat([y1,y2],axis=-1)
60 |
61 | fuse = tf.concat([x,y],axis=-1)
62 |
63 | with tf.variable_scope('fineSR_decoder'):
64 | x = conv2d(fuse,'conv1',64, bn=True, is_training=self.is_training, activation=True)
65 | x = deconv2d(x,'deconv1',64,bn=True, is_training=self.is_training, activation=True)
66 | for i in range(3):
67 | x = resblock(x, 64, self.is_training, 'res'+str(i))
68 | out = conv2d(x,'out',3)
69 |
70 | return out,out1,y
71 |
72 |
73 | def build_model(self):
74 |
75 | dataLoader = DataLoader(batch_size=14)
76 |
77 | lr, bic, gt , prior = dataLoader.read_tfrecords()
78 | tf.summary.image('bic', im2uint8(bic))
79 | tf.summary.image('gt', im2uint8(gt))
80 |
81 |
82 | # generator
83 | out,out1,y = self.generator(bic, reuse=False, scope='g_net')
84 |
85 | tf.summary.image('final_out', im2uint8(out))
86 | tf.summary.image('coarse_out', im2uint8(out1))
87 | # calculate multi-scale loss
88 | self.loss_total = 0
89 |
90 | self.coarse_loss = tf.reduce_mean((out1-gt)**2)
91 |
92 | self.prior_loss = tf.reduce_mean((y - prior)**2)
93 |
94 | self.final_loss = tf.reduce_mean((out-gt)**2)
95 |
96 |
97 | self.loss_total = self.coarse_loss + self.prior_loss + self.final_loss
98 |
99 | tf.summary.scalar('coarse_loss' , self.coarse_loss)
100 |
101 | tf.summary.scalar('prior_loss', self.prior_loss)
102 |
103 | tf.summary.scalar('final_loss', self.final_loss)
104 |
105 | # losses
106 | tf.summary.scalar('loss_total', self.loss_total)
107 |
108 | # training vars
109 | all_vars = tf.trainable_variables()
110 | self.all_vars = all_vars
111 | self.g_vars = [var for var in all_vars if 'g_net' in var.name]
112 |
113 |
114 | for var in all_vars:
115 | print(var.name)
116 |
117 |
118 | def train(self):
119 | def get_optimizer(loss, global_step=None, var_list=None, is_gradient_clip=False):
120 | train_op = tf.train.RMSPropOptimizer(self.lr)
121 | if is_gradient_clip:
122 | grads_and_vars = train_op.compute_gradients(loss, var_list=var_list)
123 | unchanged_gvs = [(grad, var) for grad, var in grads_and_vars if not 'LSTM' in var.name]
124 | rnn_grad = [grad for grad, var in grads_and_vars if 'LSTM' in var.name]
125 | rnn_var = [var for grad, var in grads_and_vars if 'LSTM' in var.name]
126 | capped_grad, _ = tf.clip_by_global_norm(rnn_grad, clip_norm=3)
127 | capped_gvs = list(zip(capped_grad, rnn_var))
128 | train_op = train_op.apply_gradients(grads_and_vars=capped_gvs + unchanged_gvs, global_step=global_step)
129 | else:
130 |
131 | ### if add bn should update mean and var of bn
132 | train_op = train_op.minimize(loss, global_step, var_list)
133 | return train_op
134 |
135 | global_step = tf.Variable(initial_value=0, dtype=tf.int32, trainable=False)
136 | self.global_step = global_step
137 |
138 | # build model
139 | self.build_model()
140 |
141 | # learning rate decay
142 | self.lr = tf.train.polynomial_decay(2.5e-4, global_step, self.max_steps, end_learning_rate=0.0,
143 | power=0.3)
144 | tf.summary.scalar('learning_rate', self.lr)
145 |
146 | # training operators
147 | train_gnet = get_optimizer(self.loss_total, global_step, self.all_vars)
148 |
149 | # session and thread
150 | gpu_options = tf.GPUOptions(allow_growth=True)
151 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
152 | self.sess = sess
153 | sess.run(tf.global_variables_initializer())
154 | self.saver = tf.train.Saver(max_to_keep=50, keep_checkpoint_every_n_hours=1)
155 | coord = tf.train.Coordinator()
156 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
157 |
158 | # training summary
159 | summary_op = tf.summary.merge_all()
160 | summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph, flush_secs=30)
161 |
162 | for step in xrange(sess.run(global_step), self.max_steps + 1):
163 |
164 | start_time = time.time()
165 |
166 | # update G network
167 | _, loss_total_val, coarse_loss_val, prior_loss_val, final_loss_val = sess.run([train_gnet, self.loss_total, self.coarse_loss ,self.prior_loss ,self.final_loss])
168 |
169 | duration = time.time() - start_time
170 | # print loss_value
171 | assert not np.isnan(loss_total_val), 'Model diverged with loss = NaN'
172 |
173 | if step % 5 == 0:
174 | num_examples_per_step = self.batch_size
175 | examples_per_sec = num_examples_per_step / duration
176 | sec_per_batch = float(duration)
177 |
178 | format_str = ('%s: step %d, loss = (%.5f; %.5f, %.5f, %.5f)(%.1f data/s; %.3f s/bch)')
179 | print(format_str % (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), step, loss_total_val, coarse_loss_val, prior_loss_val, final_loss_val ,
180 | examples_per_sec, sec_per_batch))
181 |
182 | if step % 20 == 0:
183 | # summary_str = sess.run(summary_op, feed_dict={inputs:batch_input, gt:batch_gt})
184 | summary_str = sess.run(summary_op)
185 | summary_writer.add_summary(summary_str, global_step=step)
186 |
187 | # Save the model checkpoint periodically.
188 | if step % 500 == 0 or step == self.max_steps:
189 | checkpoint_path = os.path.join(self.train_dir, 'checkpoints')
190 | self.save(sess, checkpoint_path, step)
191 |
192 | def save(self, sess, checkpoint_dir, step):
193 | model_name = "deblur.model"
194 | if not os.path.exists(checkpoint_dir):
195 | os.makedirs(checkpoint_dir)
196 | self.saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)
197 |
198 | def load(self, sess, checkpoint_dir, step=None):
199 | print(" [*] Reading checkpoints...")
200 | model_name = "deblur.model"
201 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
202 |
203 | if step is not None:
204 | ckpt_name = model_name + '-' + str(step)
205 | self.saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
206 | print(" [*] Reading intermediate checkpoints... Success")
207 | return str(step)
208 | elif ckpt and ckpt.model_checkpoint_path:
209 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
210 | ckpt_iter = ckpt_name.split('-')[1]
211 | self.saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
212 | print(" [*] Reading updated checkpoints... Success")
213 | return ckpt_iter
214 | else:
215 | print(" [*] Reading checkpoints... ERROR")
216 | return False
217 |
218 | def test(self, height, width, input_path, output_path):
219 | self.is_training = True
220 | if not os.path.exists(output_path):
221 | os.makedirs(output_path)
222 | imgsName = sorted(os.listdir(input_path))
223 |
224 | H, W = height, width
225 | inp_chns = 3 if self.args.model == 'color' else 1
226 | self.batch_size = 1 if self.args.model == 'color' else 3
227 | inputs = tf.placeholder(shape=[self.batch_size, H, W, inp_chns], dtype=tf.float32)
228 | img_sr, local_sr, prior = self.generator(inputs, reuse=False)
229 |
230 | sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
231 |
232 | self.saver = tf.train.Saver()
233 |
234 | best_psnr = 0.0
235 | best_step = -1
236 |
237 | for step in range(29500,29500+1,500):
238 | self.load(sess, os.path.join(self.train_dir,'checkpoints'), step=step)
239 | avg_psnr = 0.0
240 | for imgName in imgsName:
241 | blur = scipy.misc.imread(os.path.join(input_path, imgName))
242 |
243 | lr = scipy.misc.imresize(blur,0.125,'bicubic')
244 | bic = scipy.misc.imresize(lr, 8.0, 'bicubic')
245 |
246 | blurPad = np.expand_dims(bic, 0)
247 |
248 | start = time.time()
249 | res = sess.run(img_sr, feed_dict={inputs: blurPad / 255.0})
250 | duration = time.time() - start
251 |
252 | res = im2uint8(res[0, :, :, :])
253 | avg_psnr += compare_psnr(res,blur)
254 | res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
255 | cv2.imwrite(os.path.join(output_path, imgName), res)
256 | avg_psnr /= len(imgsName)
257 | print(step,avg_psnr)
258 | if avg_psnr > best_psnr:
259 | best_psnr = avg_psnr
260 | best_step = step
261 |
262 |
263 | print(best_psnr,best_step)
264 |
265 |
266 | # scipy.misc.imsave(os.path.join(output_path, imgName), res)
267 |
--------------------------------------------------------------------------------
/models/model_bn.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import time
4 | import random
5 | import datetime
6 | import scipy.misc
7 | from datetime import datetime
8 | from util.util_bn import *
9 | from data_loader import DataLoader
10 | from skimage.measure import compare_psnr
11 | import cv2
12 | class DEBLUR(object):
13 | def __init__(self, args):
14 | self.args = args
15 | self.n_levels = 3
16 | self.scale = 0.5
17 | self.chns = 3 if self.args.model == 'color' else 1 # input / output channels
18 |
19 | # if args.phase == 'train':
20 | self.crop_size = 256
21 |
22 | self.train_dir = os.path.join('./checkpoints_bn', args.model)
23 | if not os.path.exists(self.train_dir):
24 | os.makedirs(self.train_dir)
25 |
26 | self.batch_size = args.batch_size
27 | self.epoch = args.epoch
28 | self.data_size = 9000 // self.batch_size
29 | self.max_steps = int(self.epoch * self.data_size)
30 | self.learning_rate = args.learning_rate
31 |
32 | self.is_training = True
33 |
34 |
35 |
36 |
37 | def generator(self, inputs, reuse=False, scope='g_net'):
38 |
39 | with tf.variable_scope(scope, reuse=reuse):
40 | with tf.variable_scope('coarseSR'):
41 | ### coarse SR Network
42 | x = conv2d(inputs, 'conv1', 64, bn = True, is_training = self.is_training ,activation=True, ksize=3)
43 | x = resblock(x,64,self.is_training,'res1')
44 | x = resblock(x, 64, self.is_training, 'res2')
45 | x = resblock(x, 64, self.is_training, 'res3')
46 | out1 = conv2d(x, 'conv2', 3)
47 | with tf.variable_scope('fineSR_encoder'):
48 | x = conv2d(out1, 'conv1', 64, bn = True, is_training = self.is_training ,activation=True, ksize=3, stride = 2)
49 | for i in range(12):
50 | x = resblock(x, 64, self.is_training,'res'+str(i))
51 | x = conv2d(x, 'conv2', 64, bn = True, is_training = self.is_training ,activation=True, ksize=3)
52 |
53 | with tf.variable_scope('prior'):
54 | y = conv2d(out1, 'conv1', 64, bn=True, is_training=self.is_training, activation=True, ksize=7, stride=2)
55 | for i in range(3):
56 | y = resblock(y,128, self.is_training,'res'+str(i))
57 | y = hour_glass(y,128,4,self.is_training,name='hourglass1')
58 | y = conv2d(y,'conv2',128, bn=True, is_training=self.is_training, activation=True)
59 | y = hour_glass(y,128,4,self.is_training,name='hourglass2')
60 | y1 = conv2d(y,'conv3',68,ksize=1)
61 | y2 = conv2d(y,'conv4',11,ksize=1)
62 | y = tf.concat([y1,y2],axis=-1)
63 |
64 | fuse = tf.concat([x,y],axis=-1)
65 |
66 | with tf.variable_scope('fineSR_decoder'):
67 | x = conv2d(fuse,'conv1',64, bn=True, is_training=self.is_training, activation=True)
68 | x = deconv2d(x,'deconv1',64,bn=True, is_training=self.is_training, activation=True)
69 | for i in range(3):
70 | x = resblock(x, 64, self.is_training, 'res'+str(i))
71 | out = conv2d(x,'out',3)
72 |
73 | return out,out1,y
74 |
75 |
76 | def build_model(self):
77 |
78 | dataLoader = DataLoader(batch_size=14)
79 |
80 | lr, bic, gt , prior = dataLoader.read_tfrecords()
81 | tf.summary.image('bic', im2uint8(bic))
82 | tf.summary.image('gt', im2uint8(gt))
83 |
84 |
85 | # generator
86 | out,out1,y = self.generator(bic, reuse=False, scope='g_net')
87 |
88 | tf.summary.image('final_out', im2uint8(out))
89 | tf.summary.image('coarse_out', im2uint8(out1))
90 | # calculate multi-scale loss
91 | self.loss_total = 0
92 |
93 | self.coarse_loss = tf.reduce_mean((out1-gt)**2)
94 |
95 | self.prior_loss = tf.reduce_mean((y - prior)**2)
96 |
97 | self.final_loss = tf.reduce_mean((out-gt)**2)
98 |
99 |
100 | self.loss_total = self.coarse_loss + self.prior_loss + self.final_loss
101 |
102 | tf.summary.scalar('coarse_loss' , self.coarse_loss)
103 |
104 | tf.summary.scalar('prior_loss', self.prior_loss)
105 |
106 | tf.summary.scalar('final_loss', self.final_loss)
107 |
108 | # losses
109 | tf.summary.scalar('loss_total', self.loss_total)
110 |
111 | # training vars
112 | all_vars = tf.trainable_variables()
113 | self.all_vars = all_vars
114 | self.g_vars = [var for var in all_vars if 'g_net' in var.name]
115 |
116 |
117 | for var in all_vars:
118 | print(var.name)
119 |
120 |
121 | def train(self):
122 | def get_optimizer(loss, global_step=None, var_list=None, is_gradient_clip=False):
123 | train_op = tf.train.RMSPropOptimizer(self.lr)
124 | if is_gradient_clip:
125 | grads_and_vars = train_op.compute_gradients(loss, var_list=var_list)
126 | unchanged_gvs = [(grad, var) for grad, var in grads_and_vars if not 'LSTM' in var.name]
127 | rnn_grad = [grad for grad, var in grads_and_vars if 'LSTM' in var.name]
128 | rnn_var = [var for grad, var in grads_and_vars if 'LSTM' in var.name]
129 | capped_grad, _ = tf.clip_by_global_norm(rnn_grad, clip_norm=3)
130 | capped_gvs = list(zip(capped_grad, rnn_var))
131 | train_op = train_op.apply_gradients(grads_and_vars=capped_gvs + unchanged_gvs, global_step=global_step)
132 | else:
133 |
134 | ### if add bn should update mean and var of bn
135 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
136 | with tf.control_dependencies(update_ops):
137 | train_op = train_op.minimize(loss, global_step, var_list)
138 | return train_op
139 |
140 | global_step = tf.Variable(initial_value=0, dtype=tf.int32, trainable=False)
141 | self.global_step = global_step
142 |
143 | # build model
144 | self.build_model()
145 |
146 | # learning rate decay
147 | self.lr = tf.train.polynomial_decay(2.5e-4, global_step, self.max_steps, end_learning_rate=0.0,
148 | power=0.3)
149 | tf.summary.scalar('learning_rate', self.lr)
150 |
151 | # training operators
152 | train_gnet = get_optimizer(self.loss_total, global_step, self.all_vars)
153 |
154 | # session and thread
155 | gpu_options = tf.GPUOptions(allow_growth=True)
156 | sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
157 | self.sess = sess
158 | sess.run(tf.global_variables_initializer())
159 | self.saver = tf.train.Saver(max_to_keep=50, keep_checkpoint_every_n_hours=1)
160 | coord = tf.train.Coordinator()
161 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
162 |
163 | # training summary
164 | summary_op = tf.summary.merge_all()
165 | summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph, flush_secs=30)
166 |
167 | for step in xrange(sess.run(global_step), self.max_steps + 1):
168 |
169 | start_time = time.time()
170 |
171 | # update G network
172 | _, loss_total_val, coarse_loss_val, prior_loss_val, final_loss_val = sess.run([train_gnet, self.loss_total, self.coarse_loss ,self.prior_loss ,self.final_loss])
173 |
174 | duration = time.time() - start_time
175 | # print loss_value
176 | assert not np.isnan(loss_total_val), 'Model diverged with loss = NaN'
177 |
178 | if step % 5 == 0:
179 | num_examples_per_step = self.batch_size
180 | examples_per_sec = num_examples_per_step / duration
181 | sec_per_batch = float(duration)
182 |
183 | format_str = ('%s: step %d, loss = (%.5f; %.5f, %.5f, %.5f)(%.1f data/s; %.3f s/bch)')
184 | print(format_str % (datetime.now().strftime('%Y-%m-%d %H:%M:%S'), step, loss_total_val, coarse_loss_val, prior_loss_val, final_loss_val ,
185 | examples_per_sec, sec_per_batch))
186 |
187 | if step % 20 == 0:
188 | # summary_str = sess.run(summary_op, feed_dict={inputs:batch_input, gt:batch_gt})
189 | summary_str = sess.run(summary_op)
190 | summary_writer.add_summary(summary_str, global_step=step)
191 |
192 | # Save the model checkpoint periodically.
193 | if step % 500 == 0 or step == self.max_steps:
194 | checkpoint_path = os.path.join(self.train_dir, 'checkpoints')
195 | self.save(sess, checkpoint_path, step)
196 |
197 | def save(self, sess, checkpoint_dir, step):
198 | model_name = "deblur.model"
199 | if not os.path.exists(checkpoint_dir):
200 | os.makedirs(checkpoint_dir)
201 | self.saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)
202 |
203 | def load(self, sess, checkpoint_dir, step=None):
204 | print(" [*] Reading checkpoints...")
205 | model_name = "deblur.model"
206 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
207 |
208 | if step is not None:
209 | ckpt_name = model_name + '-' + str(step)
210 | self.saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
211 | print(" [*] Reading intermediate checkpoints... Success")
212 | return str(step)
213 | elif ckpt and ckpt.model_checkpoint_path:
214 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
215 | ckpt_iter = ckpt_name.split('-')[1]
216 | self.saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
217 | print(" [*] Reading updated checkpoints... Success")
218 | return ckpt_iter
219 | else:
220 | print(" [*] Reading checkpoints... ERROR")
221 | return False
222 |
223 | def test(self, height, width, input_path, output_path):
224 | self.is_training = False
225 | if not os.path.exists(output_path):
226 | os.makedirs(output_path)
227 | imgsName = sorted(os.listdir(input_path))
228 |
229 | H, W = height, width
230 | inp_chns = 3 if self.args.model == 'color' else 1
231 | self.batch_size = 1 if self.args.model == 'color' else 3
232 | inputs = tf.placeholder(shape=[self.batch_size, H, W, inp_chns], dtype=tf.float32)
233 | img_sr, local_sr, prior = self.generator(inputs, reuse=False)
234 |
235 | sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
236 |
237 | self.saver = tf.train.Saver()
238 |
239 | best_psnr = 0.0
240 | best_step = -1
241 |
242 | for step in range(37000,37500,500):
243 | self.load(sess, os.path.join(self.train_dir,'checkpoints'), step=step)
244 | avg_psnr = 0.0
245 | for imgName in imgsName:
246 | blur = scipy.misc.imread(os.path.join(input_path, imgName))
247 |
248 | lr = scipy.misc.imresize(blur,0.125,'bicubic')
249 | bic = scipy.misc.imresize(lr, 8.0, 'bicubic')
250 |
251 | blurPad = np.expand_dims(bic, 0)
252 |
253 | start = time.time()
254 | res = sess.run(img_sr, feed_dict={inputs: blurPad / 255.0})
255 | duration = time.time() - start
256 |
257 | res = im2uint8(res[0, :, :, :])
258 | avg_psnr += compare_psnr(res,blur)
259 | res = cv2.cvtColor(res,cv2.COLOR_RGB2BGR)
260 | cv2.imwrite(os.path.join(output_path,imgName),res)
261 | avg_psnr /= len(imgsName)
262 | print(step,avg_psnr)
263 | if avg_psnr > best_psnr:
264 | best_psnr = avg_psnr
265 | best_step = step
266 |
267 |
268 | print(best_psnr,best_step)
269 |
270 |
271 | # scipy.misc.imsave(os.path.join(output_path, imgName), res)
272 |
--------------------------------------------------------------------------------
/run_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import tensorflow as tf
4 | # import models.model_gray as model
5 | # import models.model_color as model
6 | import models.model as model
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='deblur arguments')
11 | parser.add_argument('--phase', type=str, default='test', help='determine whether train or test')
12 | parser.add_argument('--datalist', type=str, default='./datalist_gopro.txt', help='training datalist')
13 | parser.add_argument('--model', type=str, default='color', help='model type: [lstm | gray | color]')
14 | parser.add_argument('--batch_size', help='training batch size', type=int, default=16)
15 | parser.add_argument('--epoch', help='training epoch number', type=int, default=4000)
16 | parser.add_argument('--lr', type=float, default=1e-4, dest='learning_rate', help='initial learning rate')
17 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0', help='use gpu or cpu')
18 | parser.add_argument('--height', type=int, default=720,
19 | help='height for the tensorflow placeholder, should be multiples of 16')
20 | parser.add_argument('--width', type=int, default=1280,
21 | help='width for the tensorflow placeholder, should be multiple of 16 for 3 scales')
22 | parser.add_argument('--input_path', type=str, default='./testing_set',
23 | help='input path for testing images')
24 | parser.add_argument('--output_path', type=str, default='./testing_res',
25 | help='output path for testing images')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main(_):
31 | args = parse_args()
32 |
33 | # set gpu/cpu mode
34 | if int(args.gpu_id) >= 0:
35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
36 | else:
37 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
38 |
39 | # set up deblur models
40 | deblur = model.DEBLUR(args)
41 | if args.phase == 'test':
42 | deblur.test(args.height, args.width, args.input_path, args.output_path)
43 | elif args.phase == 'train':
44 | deblur.train()
45 | else:
46 | print('phase should be set to either test or train')
47 |
48 |
49 | if __name__ == '__main__':
50 | tf.app.run()
--------------------------------------------------------------------------------
/run_model_bn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import tensorflow as tf
4 | # import models.model_gray as model
5 | # import models.model_color as model
6 | import models.model_bn as model
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='deblur arguments')
11 | parser.add_argument('--phase', type=str, default='test', help='determine whether train or test')
12 | parser.add_argument('--datalist', type=str, default='./datalist_gopro.txt', help='training datalist')
13 | parser.add_argument('--model', type=str, default='color', help='model type: [lstm | gray | color]')
14 | parser.add_argument('--batch_size', help='training batch size', type=int, default=16)
15 | parser.add_argument('--epoch', help='training epoch number', type=int, default=4000)
16 | parser.add_argument('--lr', type=float, default=1e-4, dest='learning_rate', help='initial learning rate')
17 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0', help='use gpu or cpu')
18 | parser.add_argument('--height', type=int, default=720,
19 | help='height for the tensorflow placeholder, should be multiples of 16')
20 | parser.add_argument('--width', type=int, default=1280,
21 | help='width for the tensorflow placeholder, should be multiple of 16 for 3 scales')
22 | parser.add_argument('--input_path', type=str, default='./testing_set',
23 | help='input path for testing images')
24 | parser.add_argument('--output_path', type=str, default='./testing_res',
25 | help='output path for testing images')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main(_):
31 | args = parse_args()
32 |
33 | # set gpu/cpu mode
34 | if int(args.gpu_id) >= 0:
35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
36 | else:
37 | os.environ['CUDA_VISIBLE_DEVICES'] = ''
38 |
39 | # set up deblur models
40 | deblur = model.DEBLUR(args)
41 | if args.phase == 'test':
42 | deblur.test(args.height, args.width, args.input_path, args.output_path)
43 | elif args.phase == 'train':
44 | deblur.train()
45 | else:
46 | print('phase should be set to either test or train')
47 |
48 |
49 | if __name__ == '__main__':
50 | tf.app.run()
--------------------------------------------------------------------------------
/testing_res/9099.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liang23333/FSRNet-Tensorflow/4509cec8f48480f29fedf798460549f97dc52eb4/testing_res/9099.png
--------------------------------------------------------------------------------
/testing_set/9099.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liang23333/FSRNet-Tensorflow/4509cec8f48480f29fedf798460549f97dc52eb4/testing_set/9099.png
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | if sys.version_info.major == 3:
6 | xrange = range
7 |
8 |
9 | def im2uint8(x):
10 | if x.__class__ == tf.Tensor:
11 | return tf.cast(tf.clip_by_value(x, 0.0, 1.0) * 255.0, tf.uint8)
12 | else:
13 | t = np.clip(x, 0.0, 1.0) * 255.0
14 | return t.astype(np.uint8)
15 |
16 |
17 | def conv2d(inputs, name, out_channels, bn = False, is_training = False ,activation=False, ksize=3, stride = 1):
18 | with tf.variable_scope(name):
19 | in_channels = inputs.get_shape()[-1]
20 | filter = tf.get_variable('weight', shape=[ksize, ksize, in_channels, out_channels],
21 | initializer=tf.contrib.layers.xavier_initializer())
22 | conv = tf.nn.conv2d(inputs, filter, strides=[1, stride, stride, 1], padding='SAME')
23 | bias = tf.get_variable('bias', shape=[out_channels], initializer=tf.constant_initializer(0.0))
24 | conv = tf.nn.bias_add(conv, bias)
25 |
26 | if activation:
27 | conv = tf.nn.relu(conv)
28 |
29 | tf.add_to_collection('weights', filter)
30 |
31 | return conv
32 |
33 | def deconv2d(inputs, name, out_channels, bn = False, is_training = False ,activation=False, ksize = 4, stride = 2):
34 | with tf.variable_scope(name):
35 | input_shape = inputs.get_shape()
36 | in_channels = input_shape[-1]
37 | input_shape = tf.shape(inputs)
38 | filter = tf.get_variable('weight', shape=[ksize, ksize, out_channels, in_channels],
39 | initializer=tf.contrib.layers.xavier_initializer())
40 | output_shape = [input_shape[0], input_shape[1] * stride, input_shape[2] * stride, out_channels]
41 | deconv = tf.nn.conv2d_transpose(inputs, filter, output_shape, [1, stride, stride, 1])
42 | bias = tf.get_variable('biases', [out_channels], initializer=tf.constant_initializer(0.0))
43 | deconv = tf.nn.bias_add(deconv, bias)
44 |
45 | if activation:
46 | deconv = tf.nn.relu(deconv)
47 |
48 | tf.add_to_collection('weights', filter)
49 | return deconv
50 |
51 |
52 |
53 | def resblock(inputs,out_channels,is_training,name):
54 | with tf.variable_scope(name):
55 | in_channels = inputs.get_shape().as_list()[-1]
56 |
57 | x = conv2d(inputs,'conv1',out_channels,bn=False,is_training=is_training,activation=True)
58 | x = conv2d(x,'conv2',out_channels,bn=False,is_training=is_training,activation=False)
59 |
60 | if in_channels!=out_channels:
61 | inputs = conv2d(inputs,'conv_fuse',out_channels,bn=False,is_training=is_training,activation=True)
62 | x = inputs + x
63 | return x
64 |
65 |
66 | def conv_block(x,output_num,is_training,name):
67 | with tf.variable_scope(name):
68 | inp = x
69 | x1 = tf.nn.relu(inp)
70 | x1 = conv2d(x1,name='conv1',out_channels=output_num//2)
71 |
72 | x2 = tf.nn.relu(x1)
73 | x2 = conv2d(x2,name='conv2',out_channels=output_num//4)
74 |
75 |
76 | x3 = tf.nn.relu(x2)
77 | x3 = conv2d(x3,name='conv3',out_channels=output_num//4)
78 |
79 | x3 = tf.concat([x1,x2,x3],axis=-1)
80 |
81 | input_num = inp.get_shape()[-1]
82 | if input_num != output_num:
83 | inp = tf.nn.relu(inp)
84 | inp = conv2d(inp, name='downsample', out_channels=output_num, ksize=1)
85 | return x3 + inp
86 |
87 |
88 | def hour_glass(x, output_num, depth, is_training, name):
89 | with tf.variable_scope(name + '_%d' % depth):
90 | if depth <= 0:
91 | return x
92 | up1 = x
93 | up1 = conv_block(up1, output_num, is_training, 'cb1')
94 |
95 | low1 = tf.layers.average_pooling2d(x, 2, strides=2)
96 | low1 = conv_block(low1, output_num, is_training, 'cb2')
97 |
98 | if depth > 1:
99 | low2 = hour_glass(low1, output_num, depth-1, is_training, name)
100 | else:
101 | low2 = low1
102 | low2 = conv_block(low2, output_num, is_training, 'cb3')
103 |
104 | low3 = low2
105 | low3 = conv_block(low3, output_num, is_training, 'cb4')
106 |
107 | up2 = tf.image.resize_bilinear(low3, tf.shape(low3)[1:3] * 2)
108 |
109 | return up1 + up2
110 |
111 |
112 |
--------------------------------------------------------------------------------
/util/util_bn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | if sys.version_info.major == 3:
6 | xrange = range
7 |
8 |
9 | def im2uint8(x):
10 | if x.__class__ == tf.Tensor:
11 | return tf.cast(tf.clip_by_value(x, 0.0, 1.0) * 255.0, tf.uint8)
12 | else:
13 | t = np.clip(x, 0.0, 1.0) * 255.0
14 | return t.astype(np.uint8)
15 |
16 |
17 | def conv2d(inputs, name, out_channels, bn = False, is_training = False ,activation=False, ksize=3, stride = 1):
18 | with tf.variable_scope(name):
19 | in_channels = inputs.get_shape()[-1]
20 | filter = tf.get_variable('weight', shape=[ksize, ksize, in_channels, out_channels],
21 | initializer=tf.contrib.layers.xavier_initializer())
22 | conv = tf.nn.conv2d(inputs, filter, strides=[1, stride, stride, 1], padding='SAME')
23 | bias = tf.get_variable('bias', shape=[out_channels], initializer=tf.constant_initializer(0.0))
24 | conv = tf.nn.bias_add(conv, bias)
25 |
26 | if bn:
27 | conv = tf.layers.batch_normalization(conv,training=is_training)
28 | if activation:
29 | conv = tf.nn.relu(conv)
30 |
31 | tf.add_to_collection('weights', filter)
32 |
33 | return conv
34 |
35 | def deconv2d(inputs, name, out_channels, bn = False, is_training = False ,activation=False, ksize = 4, stride = 2):
36 | with tf.variable_scope(name):
37 | input_shape = inputs.get_shape()
38 | in_channels = input_shape[-1]
39 | input_shape = tf.shape(inputs)
40 | filter = tf.get_variable('weight', shape=[ksize, ksize, out_channels, in_channels],
41 | initializer=tf.contrib.layers.xavier_initializer())
42 | output_shape = [input_shape[0], input_shape[1] * stride, input_shape[2] * stride, out_channels]
43 | deconv = tf.nn.conv2d_transpose(inputs, filter, output_shape, [1, stride, stride, 1])
44 | bias = tf.get_variable('biases', [out_channels], initializer=tf.constant_initializer(0.0))
45 | deconv = tf.nn.bias_add(deconv, bias)
46 |
47 | if bn:
48 | deconv = tf.layers.batch_normalization(deconv,training=is_training)
49 | if activation:
50 | deconv = tf.nn.relu(deconv)
51 |
52 | tf.add_to_collection('weights', filter)
53 | return deconv
54 |
55 |
56 |
57 | def resblock(inputs,out_channels,is_training,name):
58 | with tf.variable_scope(name):
59 | in_channels = inputs.get_shape().as_list()[-1]
60 |
61 | x = conv2d(inputs,'conv1',out_channels,bn=True,is_training=is_training,activation=True)
62 | x = conv2d(x,'conv2',out_channels,bn=False,is_training=is_training,activation=False)
63 |
64 | if in_channels!=out_channels:
65 | inputs = conv2d(inputs,'conv_fuse',out_channels,bn=True,is_training=is_training,activation=True)
66 | x = inputs + x
67 | x = tf.layers.batch_normalization(x, training=is_training)
68 | x = tf.nn.relu(x)
69 | return x
70 |
71 |
72 | def conv_block(x,output_num,is_training,name):
73 | with tf.variable_scope(name):
74 | inp = x
75 | x1 = tf.layers.batch_normalization(x,training=is_training)
76 | x1 = tf.nn.relu(x1)
77 | x1 = conv2d(x1,name='conv1',out_channels=output_num//2)
78 |
79 | x2 = tf.layers.batch_normalization(x1,training=is_training)
80 | x2 = tf.nn.relu(x2)
81 | x2 = conv2d(x2,name='conv2',out_channels=output_num//4)
82 |
83 |
84 | x3 = tf.layers.batch_normalization(x2,training=is_training)
85 | x3 = tf.nn.relu(x3)
86 | x3 = conv2d(x3,name='conv3',out_channels=output_num//4)
87 |
88 | x3 = tf.concat([x1,x2,x3],axis=-1)
89 |
90 | input_num = inp.get_shape()[-1]
91 | if input_num != output_num:
92 | inp = tf.layers.batch_normalization(inp,training=is_training)
93 | inp = tf.nn.relu(inp)
94 | inp = conv2d(inp, name='downsample', out_channels=output_num, ksize=1)
95 | return x3 + inp
96 |
97 |
98 | def hour_glass(x, output_num, depth, is_training, name):
99 | with tf.variable_scope(name + '_%d' % depth):
100 | if depth <= 0:
101 | return x
102 | up1 = x
103 | up1 = conv_block(up1, output_num, is_training, 'cb1')
104 |
105 | low1 = tf.layers.average_pooling2d(x, 2, strides=2)
106 | low1 = conv_block(low1, output_num, is_training, 'cb2')
107 |
108 | if depth > 1:
109 | low2 = hour_glass(low1, output_num, depth-1, is_training, name)
110 | else:
111 | low2 = low1
112 | low2 = conv_block(low2, output_num, is_training, 'cb3')
113 |
114 | low3 = low2
115 | low3 = conv_block(low3, output_num, is_training, 'cb4')
116 |
117 | up2 = tf.image.resize_bilinear(low3, tf.shape(low3)[1:3] * 2)
118 |
119 | return up1 + up2
120 |
121 |
122 |
--------------------------------------------------------------------------------