├── src
├── utils
│ ├── __init__.py
│ ├── color_jitter.py
│ ├── loss_utils.py
│ ├── geo_layer_utils.py
│ ├── config.py
│ ├── affine_transformation.py
│ └── curve.py
├── test_FAB.py
├── datagen.py
├── train_FAB.py
└── FAB.py
├── data
└── datasets
│ ├── RWMB
│ └── README.md
│ └── Blurred-300VW
│ └── README.md
├── fig
├── deblur.png
├── effects.png
├── framework.png
└── structure_predictor.png
├── scripts
├── test.sh
└── train.sh
├── LICENSE
└── README.md
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/datasets/RWMB/README.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/datasets/Blurred-300VW/README.md:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/fig/deblur.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/deblur.png
--------------------------------------------------------------------------------
/fig/effects.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/effects.png
--------------------------------------------------------------------------------
/fig/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/framework.png
--------------------------------------------------------------------------------
/fig/structure_predictor.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/structure_predictor.png
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python ./src/test_FAB.py \
4 | --structure_predictor_train_dir ./data/checkpoints/structure_predictor_train_dir/ \
5 | --voxel_flow_train_dir ./data/checkpoints/voxel_flow_train_dir/ \
6 | --resnet_train_dir ./data/checkpoints/resnet_train_dir/ \
7 | --resume_structure_predictor True \
8 | --resume_video_devlur True \
9 | --resume_resnet True \
10 | --resume_all False \
11 | --data_dir ./data/300VW/Images/ \
12 | --img_list ./data/300VW/labels_68pt_256_train_sorted.txt \
13 | --end_2_end_test_dir ../data/test_results/ &
14 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python ./src/train_FAB.py \
4 | --structure_predictor_train_dir ./data/checkpoints/structure_predictor_train_dir/ \
5 | --video_deblur_train_dir ./data/checkpoints/video_deblur_train_dir/ \
6 | --resnet_train_dir ./data/checkpoints/resnet_train_dir/ \
7 | --end_2_end_train_dir ./data/checkpoints/end_2_end_train_dir/ \
8 | --end_2_end_valid_dir ./data/checkpoints/end_2_end_valid_dir/ \
9 | --max_steps 2000000 \
10 | --resume_structure_predictor False \
11 | --resume_video_deblur False \
12 | --resume_resnet False \
13 | --data_dir ./data/300VW/Images/ \
14 | --img_list ./data/300VW/labels_68pt_256_train_sorted.txt \
15 | --data_dir_valid None \
16 | --img_list_valid None \
17 | --training_period train &
18 |
--------------------------------------------------------------------------------
/src/utils/color_jitter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import ImageEnhance
3 |
4 |
5 | transformtypedict=dict(Brightness=ImageEnhance.Brightness,
6 | Contrast=ImageEnhance.Contrast,
7 | Sharpness=ImageEnhance.Sharpness,
8 | Color=ImageEnhance.Color)
9 |
10 | class ImageJitter(object):
11 | def __init__(self, transformdict):
12 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict]
13 |
14 | def __call__(self, img):
15 | out = img
16 | randtensor = np.random.uniform(0, 1, len(self.transforms))
17 |
18 | for i, (transformer, alpha) in enumerate(self.transforms):
19 | r = alpha*(randtensor[i]*2.0 -1.0) + 1
20 | out = transformer(out).enhance(r)
21 |
22 | return out
23 |
--------------------------------------------------------------------------------
/src/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def l1_loss(predictions, targets):
9 | total_elements = (tf.shape(targets)[0] * tf.shape(targets)[1] * tf.shape(targets)[2]
10 | * tf.shape(targets)[3])
11 | total_elements = tf.to_float(total_elements)
12 |
13 | loss = tf.reduce_sum(tf.abs(predictions- targets))
14 | loss = tf.div(loss, total_elements)
15 |
16 | return loss
17 |
18 |
19 | def l2_loss(predictions, targets):
20 | total_elements = (tf.shape(targets)[0] * tf.shape(targets)[1] * tf.shape(targets)[2]
21 | * tf.shape(targets)[3])
22 | total_elements = tf.to_float(total_elements)
23 |
24 | loss = tf.reduce_sum(tf.square(predictions-targets))
25 | loss = tf.div(loss, total_elements)
26 |
27 | return loss
28 |
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 KeqiangSun
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/utils/geo_layer_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def bilinear_interp(im, x, y, name):
9 | with tf.variable_scope(name):
10 | x = tf.reshape(x, [-1])
11 | y = tf.reshape(y, [-1])
12 |
13 | num_batch = tf.shape(im)[0]
14 | _, height, width, channels = im.get_shape().as_list()
15 |
16 | x = tf.to_float(x)
17 | y = tf.to_float(y)
18 |
19 | height_f = tf.cast(height, 'float32')
20 | width_f = tf.cast(width, 'float32')
21 | zero = tf.constant(0, dtype=tf.int32)
22 |
23 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
24 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
25 | x = (x + 1.0) * (width_f - 1.0) / 2.0
26 | y = (y + 1.0) * (height_f - 1.0) / 2.0
27 |
28 | x0 = tf.cast(tf.floor(x), 'int32')
29 | x1 = x0 + 1
30 | y0 = tf.cast(tf.floor(y), 'int32')
31 | y1 = y0 + 1
32 |
33 | x0 = tf.clip_by_value(x0, zero, max_x)
34 | x1 = tf.clip_by_value(x1, zero, max_x)
35 | y0 = tf.clip_by_value(y0, zero, max_y)
36 | y1 = tf.clip_by_value(y1, zero, max_y)
37 |
38 | dim2 = width
39 | dim1 = width * height
40 |
41 | base = tf.range(num_batch) * dim1
42 | base = tf.reshape(base, [-1, 1])
43 | base = tf.tile(base, [1, height * width])
44 | base = tf.reshape(base, [-1])
45 |
46 | base_y0 = base + y0 * dim2
47 | base_y1 = base + y1 * dim2
48 | idx_a = base_y0 + x0
49 | idx_b = base_y1 + x0
50 | idx_c = base_y0 + x1
51 | idx_d = base_y1 + x1
52 |
53 | im_flat = tf.reshape(im, tf.stack([-1, channels]))
54 | im_flat = tf.to_float(im_flat)
55 | pixel_a = tf.gather(im_flat, idx_a)
56 | pixel_b = tf.gather(im_flat, idx_b)
57 | pixel_c = tf.gather(im_flat, idx_c)
58 | pixel_d = tf.gather(im_flat, idx_d)
59 |
60 | x1_f = tf.to_float(x1)
61 | y1_f = tf.to_float(y1)
62 |
63 | wa = tf.expand_dims((x1_f - x) * (y1_f - y), 1)
64 | wb = tf.expand_dims((x1_f - x) * (1.0 - (y1_f - y)), 1)
65 | wc = tf.expand_dims((1.0 - (x1_f - x)) * (y1_f - y), 1)
66 | wd = tf.expand_dims((1.0 - (x1_f - x)) * (1.0 - (y1_f - y)), 1)
67 |
68 | output = tf.add_n([wa*pixel_a, wb*pixel_b, wc*pixel_c, wd*pixel_d])
69 | output = tf.reshape(output, shape=tf.stack([num_batch, height, width, channels]))
70 |
71 | return output
72 |
73 | def meshgrid(height, width):
74 | with tf.variable_scope('meshgrid'):
75 | x_t = tf.matmul(
76 | tf.ones(shape=tf.stack([height,1])),
77 | tf.transpose(
78 | tf.expand_dims(
79 | tf.linspace(-1.0,1.0,width),1),[1,0]))
80 | y_t = tf.matmul(
81 | tf.expand_dims(
82 | tf.linspace(-1.0, 1.0, height), 1),
83 | tf.ones(shape=tf.stack([1, width])))
84 | x_t_flat = tf.reshape(x_t, (1,-1))
85 | y_t_flat = tf.reshape(y_t, (1,-1))
86 | grid_x = tf.reshape(x_t_flat, [1, height, width])
87 | grid_y = tf.reshape(y_t_flat, [1, height, width])
88 |
89 | return grid_x, grid_y
90 |
91 |
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a variable scope aware configuation object for TensorFlow
3 | """
4 | import tensorflow as tf
5 |
6 |
7 | FLAGS = tf.app.flags.FLAGS
8 | class Config:
9 | def __init__(self):
10 | root = self.Scope('')
11 | for k, v in FLAGS.__dict__['__flags'].iteritems():
12 | root[k] = v
13 | self.stack = [ root ]
14 |
15 | def iteritems(self):
16 | return self.to_dict().iteritems()
17 |
18 | def to_dict(self):
19 | self._pop_stale()
20 | out = {}
21 | for i in range(len(self.stack)):
22 | cs = self.stack[-i]
23 | for name in cs:
24 | out[name] = cs[name]
25 | return out
26 |
27 | def _pop_stale(self):
28 | var_scope_name = tf.get_variable_scope().name
29 | top = self.stack[0]
30 | while not top.contains(var_scope_name):
31 | self.stack.pop(0)
32 | top = self.stack[0]
33 |
34 | def __getitem__(self, name):
35 | self._pop_stale()
36 | for i in range(len(self.stack)):
37 | cs = self.stack[i]
38 | if name in cs:
39 | return cs[name]
40 |
41 | raise KeyError(name)
42 |
43 | def set_default(self, name, value):
44 | if not name in self:
45 | self[name] = value
46 |
47 | def __contains__(self, name):
48 | self._pop_stale()
49 | for i in range(len(self.stack)):
50 | cs = self.stack[i]
51 | if name in cs:
52 | return True
53 | return False
54 |
55 | def __setitem__(self, name, value):
56 | self._pop_stale()
57 | top = self.stack[0]
58 | var_scope_name = tf.get_variable_scope().name
59 | assert top.contains(var_scope_name)
60 |
61 | if top.name != var_scope_name:
62 | top = self.Scope(var_scope_name)
63 | self.stack.insert(0, top)
64 |
65 | top[name] = value
66 |
67 | class Scope(dict):
68 | def __init__(self, name):
69 | self.name = name
70 |
71 | def contains(self, var_scope_name):
72 | return var_scope_name.startswith(self.name)
73 |
74 |
75 | if __name__ == '__main__':
76 |
77 | def assert_raises(exception, fn):
78 | try:
79 | fn()
80 | except exception:
81 | pass
82 | else:
83 | assert False, "Expected exception"
84 |
85 | c = Config()
86 |
87 | c['hello'] = 1
88 | assert c['hello'] == 1
89 |
90 | with tf.variable_scope('foo'):
91 | c.set_default("bar", 10)
92 | c['bar'] = 2
93 | assert c['bar'] == 2
94 | assert c['hello'] == 1
95 |
96 | c.set_default("mario", True)
97 |
98 | with tf.variable_scope('meow'):
99 | c['dog'] = 3
100 | assert c['dog'] == 3
101 | assert c['bar'] == 2
102 | assert c['hello'] == 1
103 |
104 | assert c['mario'] == True
105 |
106 | assert_raises(KeyError, lambda: c['dog'])
107 | assert c['bar'] == 2
108 | assert c['hello'] == 1
109 |
--------------------------------------------------------------------------------
/src/utils/affine_transformation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import copy
3 | import cv2
4 | import numpy as np
5 |
6 |
7 | def get_affine_mat(width, height,
8 | max_trans, max_rotate, max_zoom,
9 | min_trans, min_rotate, min_zoom):
10 | rotate = random.uniform(min_rotate, max_rotate)
11 | trans = random.uniform(min_trans, max_trans)
12 | zoom = random.uniform(min_zoom, max_zoom)
13 |
14 | # rotate
15 | transform_matrix = np.zeros((3,3))
16 | center = (width/2.-0.5, height/2.-0.5)
17 | M = cv2.getRotationMatrix2D(center, rotate, 1)
18 | transform_matrix[:2,:] = copy.deepcopy(M)
19 | transform_matrix[2,:] = np.array([0, 0, 1])
20 |
21 | # translate
22 | transform_matrix[0,2] += trans
23 | transform_matrix[1,2] += trans
24 |
25 | # zoom
26 | for i in range(3):
27 | transform_matrix[0,i] *= zoom
28 | transform_matrix[1,i] *= zoom
29 | transform_matrix[0,2] += (1.0 - zoom) * center[0]
30 | transform_matrix[1,2] += (1.0 - zoom) * center[1]
31 |
32 | # random horizontal mirror
33 | do_mirror = False
34 | mirror_rng = random.uniform(0.,1.)
35 | if mirror_rng>0.5:
36 | do_mirror = True
37 |
38 | return transform_matrix,do_mirror
39 |
40 | def AffinePoint(points, affine_mat):
41 | """
42 | Affine a 2d point
43 | """
44 | assert(affine_mat.shape[0] == 2)
45 | assert(affine_mat.shape[1] == 3)
46 | assert(points.shape[1] == 2)
47 | results = np.zeros(points.shape)
48 | for i in range(points.shape[0]):
49 | point_x = points[i,0]
50 | point_y = points[i,1]
51 | results[i,0] = affine_mat[0,0] * point_x + \
52 | affine_mat[0,1] * point_y + \
53 | affine_mat[0,2]
54 | results[i,1] = affine_mat[1,0] * point_x + \
55 | affine_mat[1,1] * point_y + \
56 | affine_mat[1,2]
57 |
58 | return results
59 |
60 | def affine2d(x, matrix, output_img_width, output_img_height,
61 | center=True, is_landmarks=False, do_mirror=False):
62 | assert(len(matrix.shape) == 2)
63 | if is_landmarks:
64 | transform_matrix = matrix[:2,:]
65 | src = x.squeeze()
66 | dst = np.empty((src.shape[0],2), dtype=np.float32)
67 | for i in range(src.shape[0]):
68 | dst[i,:] = AffinePoint(np.expand_dims(src[i,:], axis=0), transform_matrix)
69 | if do_mirror:
70 | results = exchange_landmarks(dst,np.array([0,16,1,15,2,14,3,13,4,12,5,11,6,10,7,9,17,26,18,25,19,24,20,23,21,22,36,45,37,44,38,
71 | 43,39,42,41,46,40,47,31,35,32,34,48,54,49,53,50,52,60,64,61,63,67,65,59,55,58,56]).reshape(-1, 2))
72 | else:
73 | if do_mirror:
74 | matrix[0,0] = -matrix[0,0]
75 | matrix[0,1] = -matrix[0,1]
76 | matrix[0,2] = float(output_img_width)-matrix[0,2]
77 | transform_matrix = matrix[:2,:]
78 | src = x.astype(np.uint8)
79 | dst = cv2.warpAffine(src, transform_matrix,
80 | (output_img_width, output_img_height),
81 | flags=cv2.INTER_LINEAR,
82 | borderMode=cv2.BORDER_CONSTANT,
83 | borderValue=(127,127,127))
84 |
85 | if len(dst.shape) == 2:
86 | dst = np.expand_dims(np.asarray(dst), axis=2)
87 |
88 | return dst
89 |
90 | def exchange_landmarks(input_tf, corr_list):
91 | for i in range(corr_list.shape[0]):
92 | temp = copy.deepcopy(input_tf[corr_list[i][0], :])
93 | input_tf[corr_list[i][0], :] = input_tf[corr_list[i][1], :]
94 | input_tf[corr_list[i][1], :] = temp
95 |
96 | return input_tf
97 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FAB: A Robust Facial Landmark Detection Framework for Motion-Blurred Videos
2 |
3 | [Keqiang Sun](https://keqiangsun.github.io/),
4 | [Wayne Wu](https://wywu.github.io),
5 | [Tinghao Liu](https://github.com/KeqiangSun/FAB),
6 | [Shuo Yang](http://shuoyang1213.me/),
7 | [Quan Wang](https://github.com/KeqiangSun/FAB),
8 | [Qiang Zhou](https://github.com/KeqiangSun/FAB),
9 | [Chen Qian](https://scholar.google.com/citations?user=AerkT0YAAAAJ&hl=en),
10 | and [Zuochang Ye](https://github.com/KeqiangSun/FAB)
11 |
12 | [International Conference on Computer Vision (ICCV), 2019](http://iccv2019.thecvf.com/)
13 |
14 |
15 |
16 |

17 |
18 |
19 | We present a framework named FAB that takes advantage of structure consistency in the temporal dimension for facial landmark detection in motion-blurred videos. A structure predictor is proposed to predict the missing face structural information temporally, which serves as a geometry prior. This allows our framework to work as a virtuous circle. It is also a flexible video-based framework that can incorporate any static image-based methods to provide a performance boost on video datasets. Extensive experiments on Blurred-300VW, the proposed Real-world Motion Blur (RWMB) datasets and 300VW demonstrate the superior performance to the state-of-the-art methods.
20 |
21 | Moreover, we proposed a new benchmark named Real-World Motion Blur (RWMB). It contains videos with obvious motion blur picked from YouTube, which include dancing, boxing, jumping, etc. A detailed description of the system can be found in our [paper](https://keqiangsun.github.io/projects/FAB/FAB.html).
22 |
23 | ## Citation
24 | If you use this code or RWMB dataset for your research, please cite our paper.
25 | ```
26 | @inproceedings{keqiang2019fab,
27 | author = {Sun, Keqiang and Wu, Wayne and Liu, Tinghao and Yang, Shuo and Wang, Quan and Zhou, Qiang and and Ye, Zuochang and Qian, Chen},
28 | title = {FAB: A Robust Facial Landmark Detection Framework for Motion-Blurred Videos},
29 | booktitle = {ICCV},
30 | month = October,
31 | year = {2019}
32 | }
33 | ```
34 |
35 | ## Prerequisites
36 | - Linux
37 | - Python 2
38 | - [TensorFlow](https://www.tensorflow.org/)
39 |
40 | ## Getting Started
41 |
42 | ### Blurred-300VW Dataset Download
43 | [Blurred-300VW](https://keqiangsun.github.io/projects/FAB/Blurred-300VW.html) is a video facial landmark dataset with artifical motion blur, based on [Original 300VW](https://ibug.doc.ic.ac.uk/resources/300-VW/).
44 |
45 | 0. Blurred-300VW [[Google Drive](https://drive.google.com/drive/folders/1aAe1vBoHZ78QlGjBEOup416tHNp4Ztcp?usp=sharing)] [[Baidu Drive]()]
46 | 1. Unzip the package and put them on './data/Blurred-300VW'
47 |
48 | ### Wider Facial Landmark in the Wild (WFLW) Dataset Download
49 | [Real-World Motion Blur(RWMB)](https://keqiangsun.github.io/projects/FAB/RWMB.html) is a newly proposed facial landmark benchmark with read-world motion blur.
50 |
51 | 0. RWMB Testing images [[Google Drive](https://drive.google.com/file/d/1vv7Qppg9R3xlj_O2dmtXZHzEnObOwoDh/view?usp=sharing)] [[Baidu Drive]()]
52 | 1. Unzip the package and put them on './data/RWMB'
53 |
54 |
55 | ### Training FAB on Blurred-300VW
56 |
57 | ```bash
58 | bash ./scripts/train.sh
59 | ```
60 |
61 | ### Testing FAB on Blurred-300VW
62 |
63 | ```bash
64 | bash ./scripts/test.sh
65 | ```
66 |
67 |
68 | ## To Do List
69 | Supported dataset
70 | - [x] [300 Faces In-the-Wild (300W)](https://ibug.doc.ic.ac.uk/resources/300-W/)
71 | - [x] [300 Videos in the Wild(300W)](https://ibug.doc.ic.ac.uk/resources/300-VW/)
72 | - [x] [Blurred 300VW](https://keqiangsun.github.io/projects/FAB/RWMB.html)
73 | - [ ] [Real-World Motion Blur(RWMB)](https://keqiangsun.github.io/projects/FAB/RWMB.html)
74 |
75 |
76 | Supported models
77 | - [ ] [Pretrained Model of Structure Predictor Block]
78 | - [ ] [Pretrained Model of Video Deblur Block]
79 | - [ ] [Pretrained Model of Resnet Block]
80 | - [ ] [Pretrained Model of Final model]
81 |
82 |
83 | ## Questions
84 | Please contact skq719@gmail.com
85 |
--------------------------------------------------------------------------------
/src/test_FAB.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import sys
7 | import tarfile
8 | import cv2
9 | import copy
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | from utils.curve import points_to_heatmap_rectangle_68pt
14 | from six.moves import xrange
15 | from six.moves import urllib
16 | from datagen import DataGenerator
17 | from datagen import ensure_dir
18 | from FAB import FAB
19 |
20 | MOMENTUM = 0.9
21 | POINTS_NUM = 68
22 | IMAGE_SIZE = 256
23 | PIC_CHANNEL = 3
24 | num_input_imgs = 3
25 | NUM_CLASSES = POINTS_NUM*2
26 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
27 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
28 | structure_predictor_net_channel = 64
29 |
30 | FLAGS = tf.app.flags.FLAGS
31 | tf.app.flags.DEFINE_string('structure_predictor_train_dir', '', """Directory where to write train_checkpoints.""")
32 | tf.app.flags.DEFINE_string('video_deblur_train_dir', '', """Directory where to write train_checkpoints.""")
33 | tf.app.flags.DEFINE_string('resnet_train_dir', '', """Directory where to write train_checkpoints.""")
34 | tf.app.flags.DEFINE_string('end_2_end_train_dir', '', """Directory where to write train_checkpoints.""")
35 | tf.app.flags.DEFINE_string('end_2_end_test_dir', '', """Directory where to write test logs.""")
36 | tf.app.flags.DEFINE_string('data_dir', '', """Directory where the dataset stores.""")
37 | tf.app.flags.DEFINE_string('img_list', '', """Directory where the img_list stores.""")
38 |
39 | tf.app.flags.DEFINE_float('learning_rate', 0.0, "learning rate.")
40 | tf.app.flags.DEFINE_integer('batch_size', 1, "batch size")
41 | tf.app.flags.DEFINE_boolean('resume_structure_predictor', False, """Resume from latest saved state.""")
42 | tf.app.flags.DEFINE_boolean('resume_resnet', False, """Resume from latest saved state.""")
43 | tf.app.flags.DEFINE_boolean('resume_video_deblur', False, """Resume from latest saved state.""")
44 | tf.app.flags.DEFINE_boolean('resume_all', False, """Resume from latest saved state.""")
45 | tf.app.flags.DEFINE_boolean('minimal_summaries', False, """Produce fewer summaries to save HD space.""")
46 | tf.app.flags.DEFINE_boolean('use_bn', False, """Use batch normalization. Otherwise use biases.""")
47 |
48 | def resume(sess, do_resume, ckpt_path, key_word):
49 | var = tf.global_variables()
50 | if do_resume:
51 | structure_predictor_latest = tf.train.latest_checkpoint(ckpt_path)
52 | if not structure_predictor_latest:
53 | print ("\n No checkpoint to continue from in ", ckpt_path, '\n')
54 | structure_predictor_var_to_restore = [val for val in var if key_word in val.name]
55 | saver_structure_predictor = tf.train.Saver(structure_predictor_var_to_restore)
56 | saver_structure_predictor.restore(sess, structure_predictor_latest)
57 |
58 | def test(resnet_model, is_training, F, H, F_curr, H_curr, input_images_blur,
59 | input_images_boundary, next_boundary_gt, labels, data_dir, img_list,
60 | dropout_ratio):
61 |
62 | global_step = tf.get_variable('global_step', [],
63 | initializer=tf.constant_initializer(0),
64 | trainable=False)
65 |
66 | init = tf.initialize_all_variables()
67 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
68 | sess.run(init)
69 | val_save_root = os.path.join(FLAGS.end_2_end_test_dir,'visualization')
70 |
71 | ################################ resume part #################################
72 |
73 | # resume weights
74 | resume(sess, FLAGS.resume_structure_predictor, FLAGS.structure_predictor_train_dir, 'voxel_flow_model_')
75 | resume(sess, FLAGS.resume_video_deblur, FLAGS.video_deblur_train_dir, 'video_deblur_model_')
76 | resume(sess, FLAGS.resume_resnet, FLAGS.resnet_train_dir, 'resnet_model_')
77 | resume(sess, FLAGS.resume_all, FLAGS.end_2_end_train_dir, '')
78 |
79 | ##############################################################################
80 |
81 | gt_file_path = os.path.join(FLAGS.end_2_end_test_dir,'gt.txt')
82 | pre_file_path = os.path.join(FLAGS.end_2_end_test_dir,'pre.txt')
83 | ensure_dir(gt_file_path)
84 | ensure_dir(pre_file_path)
85 | gt_file = open(gt_file_path,'w')
86 | pre_file = open(pre_file_path,'w')
87 |
88 | dataset = DataGenerator(data_dir,img_list)
89 | dataset._create_train_table()
90 | dataset._create_sets_for_300VW()
91 | test_gen = dataset._aux_generator(batch_size = FLAGS.batch_size, num_input_imgs = num_input_imgs,
92 | NUM_CLASSES = POINTS_NUM*2, sample_set='test')
93 |
94 | test_break_flag = False
95 | for x in xrange(len(dataset.train_table)-2):
96 |
97 | step = sess.run(global_step)
98 |
99 | if not test_break_flag:
100 | test_line_num, frame_name, input_boundaries, boundary_gt_test, input_images_blur_generated, landmark_gt_test, names, test_break_flag = next(test_gen)
101 |
102 | if (frame_name == '2.jpg') or test_line_num <= 3:
103 | input_images_boundary_init = copy.deepcopy(input_boundaries)
104 | F_init = np.zeros([FLAGS.batch_size, IMAGE_SIZE//2,
105 | IMAGE_SIZE//2, structure_predictor_net_channel//2], dtype=np.float32)
106 |
107 | H_init = np.zeros([1, FLAGS.batch_size, IMAGE_SIZE//2,
108 | IMAGE_SIZE//2, structure_predictor_net_channel], dtype=np.float32)
109 |
110 | feed_dict={
111 | input_images_boundary:input_images_boundary_init,
112 | input_images_blur:input_images_blur_generated,
113 | F:F_init,
114 | H:H_init,
115 | labels:landmark_gt_test,
116 | next_boundary_gt:boundary_gt_test,
117 | dropout_ratio:1.0
118 | }
119 | else:
120 | output_points = o[0]
121 | output_points = np.reshape(output_points,(POINTS_NUM,2))
122 | boundary_from_points = points_to_heatmap_rectangle_68pt(output_points)
123 | boundary_from_points = np.expand_dims(boundary_from_points,axis=0)
124 | boundary_from_points = np.expand_dims(boundary_from_points,axis=3)
125 |
126 | input_images_boundary_init = np.concatenate([input_images_boundary_init[:,:,:,1:2],
127 | boundary_from_points], axis=3)
128 | feed_dict={
129 | input_images_boundary:input_images_boundary_init,
130 | input_images_blur:input_images_blur_generated,
131 | F:o[-2],
132 | H:o[-1],
133 | labels:landmark_gt_test,
134 | next_boundary_gt:boundary_gt_test,
135 | dropout_ratio:1.0
136 | }
137 |
138 | i = [resnet_model.logits, F_curr, H_curr]
139 | o = sess.run(i, feed_dict=feed_dict)
140 | pres = o[0]
141 |
142 | for batch_num,pre in enumerate(pres):
143 | for v in pre:
144 | pre_file.write(str(v*255.0)+' ')
145 | if len(names) > 1:
146 | pre_file.write(names[-1])
147 | else:
148 | pre_file.write(names[batch_num])
149 | pre_file.write('\n')
150 | for batch_num,g in enumerate(landmark_gt_test):
151 | for v in g:
152 | gt_file.write(str(v*255.0)+' ')
153 | if len(names) > 1:
154 | gt_file.write(names[-1])
155 | else:
156 | gt_file.write(names[batch_num])
157 | gt_file.write('\n')
158 |
159 | img = input_images_blur_generated[0,:,:,0:3]*255
160 | points = o[0][0]*255
161 |
162 | for point_num in range(int(points.shape[0]/2)):
163 | cv2.circle(img,(int(round(points[point_num*2])),int(round(points[point_num*2+1]))),1,(55,225,155),2)
164 | val_save_path = os.path.join(val_save_root,str(step)+'.jpg')
165 | ensure_dir(val_save_path)
166 | cv2.imwrite(val_save_path,img)
167 |
168 | global_step = global_step + 1
169 | print('Test done!')
170 |
171 | def main(argv=None):
172 |
173 | resnet_model = FAB()
174 |
175 | is_training = tf.placeholder('bool', [], name='is_training')
176 | input_images_boundary = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 2))
177 | input_images_blur = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, PIC_CHANNEL*3))
178 | next_boundary_gt = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 1))
179 | labels = tf.placeholder(tf.float32,shape=(FLAGS.batch_size,NUM_CLASSES))
180 | dropout_ratio = tf.placeholder(tf.float32)
181 | F = tf.placeholder(tf.float32, [FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel//2])
182 | H = tf.placeholder(tf.float32, [1, FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel])
183 | F_curr, H_curr= \
184 | resnet_model.FAB_inference(input_images_boundary, input_images_blur, F, H, FLAGS.batch_size,
185 | net_channel=structure_predictor_net_channel, num_classes=136, num_blocks=[2, 2, 2, 2],
186 | use_bias=(not FLAGS.use_bn), bottleneck=True, dropout_ratio=1.0)
187 |
188 | test(resnet_model, is_training, F, H, F_curr, H_curr, input_images_blur,
189 | input_images_boundary, next_boundary_gt, labels, FLAGS.data_dir, FLAGS.img_list,
190 | dropout_ratio)
191 |
192 | if __name__ == '__main__':
193 | tf.app.run()
194 |
--------------------------------------------------------------------------------
/src/utils/curve.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import time
4 | import numpy as np
5 | from numpy import linalg as LA
6 |
7 |
8 | def distance(p1, p2):
9 | return math.sqrt((p1[0] - p2[0]) * (p1[0] - p2[0]) + \
10 | (p1[1] - p2[1]) * (p1[1] - p2[1]))
11 |
12 | def curve_interp(src, samples, index):
13 | assert(src.shape[0] > 2)
14 | assert(samples >= 2)
15 |
16 | src_1 = src[0:src.shape[0] - 1, :]
17 | src_2 = src[1:src.shape[0], :]
18 | src_delta = src_1 - src_2
19 | length = np.sqrt(src_delta[:, 0]**2 + src_delta[:, 1]**2)
20 | assert(length.shape[0] == src.shape[0] - 1)
21 |
22 | accu_length = np.zeros((src.shape[0]))
23 | for i in xrange(1, accu_length.shape[0]):
24 | accu_length[i] = accu_length[i - 1] + length[i - 1]
25 | dst = np.zeros((samples, 2))
26 | pre_raw = 0
27 |
28 | step_interp = accu_length[accu_length.shape[0] - 1] / float(samples - 1)
29 | dst[0, :] = src[0, :]
30 | dst[dst.shape[0] - 1, :] = src[src.shape[0] - 1, :]
31 | for i in xrange(1, samples - 1):
32 | covered_interp = step_interp * i
33 | while (covered_interp > accu_length[pre_raw + 1]):
34 | pre_raw += 1
35 | assert(pre_raw < accu_length.shape[0] - 1)
36 | dx = (covered_interp - accu_length[pre_raw]) / length[pre_raw]
37 | dst[i, :] = src[pre_raw, :] * (1.0 - dx) + src[pre_raw + 1, :] * dx
38 |
39 | return dst
40 |
41 | def curve_fitting(points, samples, index):
42 | num_points = points.shape[0]
43 | assert(num_points > 1)
44 | valid_points = [points[0]]
45 | for i in xrange(1, num_points):
46 | if (distance(points[i, :], points[i - 1, :]) > 0.001):
47 | valid_points.append(points[i, :])
48 | assert(len(valid_points) > 1)
49 | valid_points = np.asarray(valid_points)
50 | functions = np.zeros((valid_points.shape[0] - 1, 9))
51 |
52 | if valid_points.shape[0] == 2:
53 | functions[0, 0] = LA.norm(valid_points[0, :] - valid_points[1, :])
54 | functions[0, 1] = valid_points[0, 0]
55 | functions[0, 2] = (valid_points[1, 0] - valid_points[0, 0]) / functions[0, 0]
56 | functions[0, 3] = 0
57 | functions[0, 4] = 0
58 | functions[0, 5] = valid_points[0, 1]
59 | functions[0, 6] = (valid_points[1, 1] - valid_points[0, 1]) / functions[0, 0]
60 | functions[0, 7] = 0
61 | functions[0, 8] = 0
62 | else:
63 | Mx = np.zeros((valid_points.shape[0]))
64 | My = np.zeros((valid_points.shape[0]))
65 | A = np.zeros((valid_points.shape[0] - 2))
66 | B = np.zeros((valid_points.shape[0] - 2))
67 | C = np.zeros((valid_points.shape[0] - 2))
68 | Dx = np.zeros((valid_points.shape[0] - 2))
69 | Dy = np.zeros((valid_points.shape[0] - 2))
70 | for i in xrange(functions.shape[0]):
71 | functions[i, 0] = LA.norm(valid_points[i, :] - valid_points[i + 1, :])
72 | for i in xrange(A.shape[0]):
73 | A[i] = functions[i, 0]
74 | B[i] = 2.0 * (functions[i, 0] + functions[i + 1, 0])
75 | C[i] = functions[i + 1, 0]
76 | Dx[i] = 6.0 * ((valid_points[i + 2, 0] - valid_points[i + 1, 0]) / functions[i + 1, 0] - \
77 | (valid_points[i + 1, 0] - valid_points[i, 0]) / functions[i, 0])
78 |
79 | Dy[i] = 6.0 * ((valid_points[i + 2, 1] - valid_points[i + 1, 1]) / functions[i + 1, 0] - \
80 | (valid_points[i + 1, 1] - valid_points[i, 1]) / functions[i, 0])
81 |
82 | C[0] = C[0] / B[0]
83 | Dx[0] = Dx[0] / B[0]
84 | Dy[0] = Dy[0] / B[0]
85 | for i in xrange(1, A.shape[0]):
86 | tmp = B[i] - A[i] * C[i - 1]
87 | C[i] = C[i] / tmp
88 | Dx[i] = (Dx[i] - A[i] * Dx[i - 1]) / tmp
89 | Dy[i] = (Dy[i] - A[i] * Dy[i - 1]) / tmp
90 | Mx[valid_points.shape[0] - 2] = Dx[valid_points.shape[0] - 3]
91 | My[valid_points.shape[0] - 2] = Dy[valid_points.shape[0] - 3]
92 | for i in xrange(valid_points.shape[0] - 4, -1, -1):
93 | Mx[i + 1] = Dx[i] - C[i] * Mx[i + 2]
94 | My[i + 1] = Dy[i] - C[i] * My[i + 2]
95 | Mx[0] = 0
96 | Mx[valid_points.shape[0] - 1] = 0
97 | My[0] = 0
98 | My[valid_points.shape[0] - 1] = 0
99 |
100 | for i in xrange(functions.shape[0]):
101 | functions[i, 1] = valid_points[i, 0]
102 | functions[i, 2] = (valid_points[i + 1, 0] - valid_points[i, 0]) / functions[i, 0] - \
103 | (2.0 * functions[i, 0] * Mx[i] + functions[i, 0] * Mx[i + 1]) / 6.0
104 | functions[i, 3] = Mx[i] / 2.0
105 | functions[i, 4] = (Mx[i + 1] - Mx[i]) / (6.0 * functions[i, 0])
106 | functions[i, 5] = valid_points[i, 1]
107 | functions[i, 6] = (valid_points[i + 1, 1] - valid_points[i, 1]) / functions[i, 0] - \
108 | (2.0 * functions[i, 0] * My[i] + functions[i, 0] * My[i + 1]) / 6.0
109 | functions[i, 7] = My[i] / 2.0
110 | functions[i, 8] = (My[i + 1] - My[i]) / (6.0 * functions[i, 0])
111 |
112 | samples_per_segment = samples * 1 / functions.shape[0] + 1
113 | rawcurve = np.zeros((functions.shape[0] * samples_per_segment, 2))
114 | for i in xrange(functions.shape[0]):
115 | step = functions[i, 0] / samples_per_segment
116 | for j in xrange(samples_per_segment):
117 | t = step * j
118 | rawcurve[i * samples_per_segment + j, :] = np.asarray([functions[i, 1] + functions[i, 2] * t + functions[i, 3] * t * t + functions[i, 4] * t * t * t,
119 | functions[i, 5] + functions[i, 6] * t + functions[i, 7] * t * t + functions[i, 8] * t * t * t])
120 |
121 | curve_tmp = curve_interp(rawcurve, samples, index)
122 |
123 | return curve_tmp
124 |
125 |
126 | def points_to_heatmap_rectangle_68pt(points,
127 | heatmap_num=13,
128 | heatmap_size=(256, 256),
129 | label_size=(256, 256),
130 | sigma=1):
131 |
132 | for i in range(points.shape[0]):
133 | points[i][0] *= (float(heatmap_size[1]) / float(label_size[1]))
134 | points[i][1] *= (float(heatmap_size[0]) / float(label_size[0]))
135 |
136 | align_on_curve = [0] * heatmap_num
137 | curves = [0] * heatmap_num
138 | align_on_curve[0] = np.zeros((17, 2))
139 | align_on_curve[1] = np.zeros((5, 2))
140 | align_on_curve[2] = np.zeros((5, 2))
141 | align_on_curve[3] = np.zeros((4, 2))
142 | align_on_curve[4] = np.zeros((5, 2))
143 | align_on_curve[5] = np.zeros((4, 2))
144 | align_on_curve[6] = np.zeros((4, 2))
145 | align_on_curve[7] = np.zeros((4, 2))
146 | align_on_curve[8] = np.zeros((4, 2))
147 | align_on_curve[9] = np.zeros((7, 2))
148 | align_on_curve[10] = np.zeros((5, 2))
149 | align_on_curve[11] = np.zeros((5, 2))
150 | align_on_curve[12] = np.zeros((7, 2))
151 |
152 | for i in range(17):
153 | align_on_curve[0][i] = points[i]
154 |
155 | for i in range(5):
156 | align_on_curve[1][i] = points[i + 17]
157 |
158 | for i in range(5):
159 | align_on_curve[2][i] = points[i + 22]
160 |
161 | for i in range(4):
162 | align_on_curve[3][i] = points[i + 27]
163 |
164 | for i in range(5):
165 | align_on_curve[4][i] = points[i + 31]
166 |
167 | align_on_curve[5][0] = points[36]
168 | align_on_curve[5][1] = points[37]
169 | align_on_curve[5][2] = points[38]
170 | align_on_curve[5][3] = points[39]
171 |
172 | align_on_curve[6][0] = points[39]
173 | align_on_curve[6][1] = points[40]
174 | align_on_curve[6][2] = points[41]
175 | align_on_curve[6][3] = points[36]
176 |
177 | align_on_curve[7][0] = points[42]
178 | align_on_curve[7][1] = points[43]
179 | align_on_curve[7][2] = points[44]
180 | align_on_curve[7][3] = points[45]
181 |
182 | align_on_curve[8][0] = points[45]
183 | align_on_curve[8][1] = points[46]
184 | align_on_curve[8][2] = points[47]
185 | align_on_curve[8][3] = points[42]
186 |
187 | for i in range(7):
188 | align_on_curve[9][i] = points[i + 48]
189 |
190 | for i in range(5):
191 | align_on_curve[10][i] = points[i + 60]
192 |
193 | align_on_curve[11][0] = points[60]
194 | align_on_curve[11][1] = points[67]
195 | align_on_curve[11][2] = points[66]
196 | align_on_curve[11][3] = points[65]
197 | align_on_curve[11][4] = points[64]
198 |
199 | align_on_curve[12][0] = points[48]
200 | align_on_curve[12][1] = points[59]
201 | align_on_curve[12][2] = points[58]
202 | align_on_curve[12][3] = points[57]
203 | align_on_curve[12][4] = points[56]
204 | align_on_curve[12][5] = points[55]
205 | align_on_curve[12][6] = points[54]
206 |
207 | heatmap = np.zeros((heatmap_size[0], heatmap_size[1], heatmap_num))
208 | for i in range(heatmap_num):
209 | curve_map = np.full((heatmap_size[0], heatmap_size[1]), 255, dtype=np.uint8)
210 |
211 | valid_points = [align_on_curve[i][0, :]]
212 | for j in range(1, align_on_curve[i].shape[0]):
213 | if (distance(align_on_curve[i][j, :], align_on_curve[i][j - 1, :]) > 0.001):
214 | valid_points.append(align_on_curve[i][j, :])
215 |
216 | if len(valid_points) > 1:
217 | curves[i] = curve_fitting(align_on_curve[i], align_on_curve[i].shape[0] * 10, i)
218 | for j in range(curves[i].shape[0]):
219 | if (int(curves[i][j, 0] + 0.5) >= 0 and int(curves[i][j, 0] + 0.5) < heatmap_size[1] and
220 | int(curves[i][j, 1] + 0.5) >= 0 and int(curves[i][j, 1] + 0.5) < heatmap_size[0]):
221 | curve_map[int(curves[i][j, 1] + 0.5), int(curves[i][j, 0] + 0.5)] = 0
222 |
223 | image_dis = cv2.distanceTransform(
224 | curve_map, cv2.cv.CV_DIST_L2, cv2.cv.CV_DIST_MASK_PRECISE)
225 |
226 | image_dis = image_dis.astype(np.float64)
227 | image_gaussian = (1.0 / (2.0 * np.pi * (sigma**2))) * np.exp(-1.0 * image_dis**2 / (2.0 * sigma**2))
228 | image_gaussian = np.where(image_dis < (3.0 * sigma), image_gaussian, 0)
229 |
230 | maxVal = image_gaussian.max()
231 | minVal = image_gaussian.min()
232 |
233 | if maxVal == minVal:
234 | image_gaussian = 0
235 | else:
236 | image_gaussian = (image_gaussian - minVal) / (maxVal - minVal)
237 |
238 | heatmap[:, :, i] = image_gaussian
239 |
240 | heatmap = np.sum(heatmap, axis=2)
241 |
242 | return heatmap
243 |
--------------------------------------------------------------------------------
/src/datagen.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import cv2
3 | import os
4 | import random
5 | import time
6 | import copy
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import scipy.misc as scm
10 | import tensorflow as tf
11 |
12 | from PIL import Image
13 | from utils import affine_transformation
14 | from utils.color_jitter import ImageJitter
15 | from skimage import transform, util
16 | from utils.curve import points_to_heatmap_rectangle_68pt
17 |
18 | def ensure_dir(file_path):
19 | directory = os.path.dirname(file_path)
20 | if not os.path.exists(directory):
21 | os.makedirs(directory)
22 |
23 | class DataGenerator():
24 |
25 | def __init__(self, img_dir=None, train_list_file=None,
26 | img_dir_valid=None, valid_list_file=None):
27 | self.img_dir = img_dir
28 | self.img_dir_valid = img_dir_valid
29 | self.train_list_file = train_list_file
30 | self.valid_list_file = valid_list_file
31 |
32 | def _create_train_table(self):
33 | self.train_table = []
34 | input_file = open(self.train_list_file, 'r')
35 | for line in input_file.readlines():
36 | self.train_table.append(line)
37 | input_file.close()
38 |
39 | def _randomize(self):
40 | random.shuffle(self.train_table)
41 |
42 | def _create_train_sets_for_300W(self):
43 | self.train_set = []
44 | input_file = open(self.train_list_file, 'r')
45 | for line in input_file.readlines():
46 | self.train_set.append(line)
47 | input_file.close()
48 |
49 | def _create_valid_sets_for_300W(self):
50 | self.valid_set = []
51 | input_file = open(self.valid_list_file, 'r')
52 | for line in input_file.readlines():
53 | self.valid_set.append(line)
54 | input_file.close()
55 |
56 | def _create_sets_for_300VW(self, validation_rate = 0.05):
57 | self.sample = len(self.train_table)
58 | valid_sample = int(self.sample * validation_rate)
59 | self.train_set = self.train_table[:self.sample - valid_sample]
60 | self.valid_set = self.train_table[self.sample - valid_sample:]
61 | self.test_set = self.train_table[:]
62 |
63 | def _aux_generator(self, batch_size = 1, NUM_CLASSES = 136,
64 | num_input_imgs = 3, normalize = True, sample_set = 'train'):
65 | train_line_num = 0
66 | valid_line_num = 0
67 | test_line_num = 0
68 | test_break_flag = False
69 |
70 | while True:
71 | train_img = np.zeros((batch_size, 256,256,3*num_input_imgs), dtype = np.float32)
72 | train_gtmap = np.zeros((batch_size, NUM_CLASSES), dtype = np.float32)
73 | i = 0
74 | names = []
75 | max_lines = 3
76 |
77 | while i < batch_size:
78 | input_boundaries = []
79 |
80 | if sample_set == 'train':
81 | if train_line_num+1 == len(self.train_set) or train_line_num+2 == len(self.train_set) :
82 | train_line_num = 0
83 | elif sample_set == 'valid':
84 | if valid_line_num+1 == len(self.valid_set) or valid_line_num+2 == len(self.valid_set):
85 | valid_line_num = 0
86 | elif sample_set == 'test':
87 | if test_line_num+1 == len(self.test_set):
88 | print('The end of the testing set!')
89 | test_break_flag = True
90 |
91 | for cntr in range(max_lines):
92 | if sample_set == 'train':
93 | line = self.train_set[train_line_num]
94 | train_line_num += 1
95 | elif sample_set == 'valid':
96 | line = self.valid_set[valid_line_num]
97 | valid_line_num += 1
98 | elif sample_set == 'test':
99 | line = self.test_set[test_line_num]
100 | test_line_num += 1
101 |
102 | eles = line.strip().split()
103 | frame_path = eles[-1]
104 | name = frame_path.split('/')[-1]
105 | names.append(name)
106 | gt = np.array(map(float,eles[:-1]))
107 | gt_flatten = np.reshape(gt,(gt.shape[0]/2,2))
108 |
109 | boundary_gt_train = points_to_heatmap_rectangle_68pt(gt_flatten)
110 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=0)
111 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=3)
112 | input_boundaries.append(boundary_gt_train)
113 |
114 | if sample_set == 'train':
115 | if name != '0.jpg' and name != '1.jpg':
116 | break
117 | elif sample_set == 'valid':
118 | if (name != '0.jpg' and name != '1.jpg' and valid_line_num > 2):
119 | break
120 | elif sample_set == 'test':
121 | if (name != '0.jpg' and name != '1.jpg' and test_line_num > 2):
122 | break
123 |
124 | input_boundaries = input_boundaries[:-1]
125 | if len(input_boundaries) > 0:
126 | input_boundaries = np.concatenate(input_boundaries,axis=3)
127 |
128 | path_eles = frame_path.split('/')
129 | name_eles = path_eles[-1].split('.')
130 | frame_num = int(name_eles[0])
131 |
132 | frame_path_2 = os.path.join(path_eles[0],str(frame_num-2)+'.'+name_eles[-1])
133 | input_img_path_2 = os.path.join(self.img_dir, frame_path_2)
134 | img_2 = self.open_img(input_img_path_2)
135 | img_2 = scm.imresize(img_2, (256,256))
136 |
137 | frame_path_1 = os.path.join(path_eles[0],str(frame_num-1)+'.'+name_eles[-1])
138 | input_img_path_1 = os.path.join(self.img_dir, frame_path_1)
139 | img_1 = self.open_img(input_img_path_1)
140 | img_1 = scm.imresize(img_1, (256,256))
141 |
142 | frame_path_0 = os.path.join(path_eles[0],str(frame_num)+'.'+name_eles[-1])
143 | input_img_path_0 = os.path.join(self.img_dir, frame_path_0)
144 | img_0 = self.open_img(input_img_path_0)
145 | img_0 = scm.imresize(img_0, (256,256))
146 |
147 | img = np.concatenate([img_2,img_1,img_0],axis=2)
148 |
149 | if normalize:
150 | train_img[i] = img.astype(np.float32) / 255
151 | train_gtmap[i] = gt.astype(np.float32) /255
152 | else :
153 | train_img[i] = img.astype(np.float32)
154 | train_gtmap[i] = gt.astype(np.float32)
155 |
156 | i = i + 1
157 |
158 | if sample_set == 'train':
159 | yield train_line_num, name, input_boundaries, boundary_gt_train, train_img, train_gtmap
160 | elif sample_set == 'valid':
161 | yield valid_line_num, name, input_boundaries, boundary_gt_train, train_img, train_gtmap
162 | elif sample_set == 'test':
163 | print("name = {}".format(name))
164 | yield test_line_num, name, input_boundaries, boundary_gt_train, train_img, train_gtmap, names, test_break_flag
165 |
166 | def _voxel_flow_generator_(self, batch_size = 1, sample_set = 'train'):
167 |
168 | train_line_num = 0
169 | valid_line_num = 0
170 |
171 | while True:
172 | input_boundaries = np.zeros((batch_size, 256, 256, 2), dtype = np.float32)
173 | boundary_gts_train = np.zeros((batch_size, 256, 256, 1), dtype = np.float32)
174 | i = 0
175 | max_lines = 3
176 |
177 | while i < batch_size:
178 | input_boundary = []
179 |
180 | if sample_set == 'train':
181 | if train_line_num+1 == len(self.train_set) or train_line_num+2 == len(self.train_set) :
182 | train_line_num = 0
183 | line_num = copy.deepcopy(train_line_num)
184 | elif sample_set == 'valid':
185 | if valid_line_num+1 == len(self.valid_set) or valid_line_num+2 == len(self.valid_set):
186 | valid_line_num = 0
187 | line_num = copy.deepcopy(valid_line_num)
188 |
189 | for cntr in range(max_lines):
190 | if sample_set == 'train':
191 | line = self.train_set[line_num]
192 | elif sample_set == 'valid':
193 | line = self.valid_set[line_num]
194 |
195 | line_num += 1
196 | eles = line.strip().split()
197 | frame_path = eles[-1]
198 | gt = np.array(map(float,eles[:-1]))
199 | gt_flatten = np.reshape(gt,(gt.shape[0]/2,2))
200 |
201 | boundary_gt_train = points_to_heatmap_rectangle_68pt(gt_flatten)
202 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=2)
203 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=0)
204 | input_boundary.append(boundary_gt_train[0])
205 |
206 | train_line_num += 1
207 | valid_line_num += 1
208 | input_boundary = input_boundary[:-1]
209 | input_boundaries[i] = np.concatenate(input_boundary,axis=2)
210 | boundary_gts_train[i] = boundary_gt_train[0]
211 |
212 | i = i + 1
213 |
214 | if sample_set == 'train':
215 | yield input_boundaries, boundary_gts_train
216 | elif sample_set == 'valid':
217 | yield input_boundaries, boundary_gts_train
218 |
219 | def _video_deblur_generator_(self, batch_size = 1,normalize = True,
220 | num_input_imgs = 3,sample_set='train'):
221 |
222 | train_line_num = 0
223 | valid_line_num = 0
224 |
225 | while True:
226 | train_img = np.zeros((batch_size, 256, 256, 3*num_input_imgs), dtype = np.float32)
227 | i = 0
228 | max_lines = 3
229 |
230 | while i < batch_size:
231 | input_images = []
232 |
233 | if sample_set == 'train':
234 | if train_line_num+1 == len(self.train_set) or train_line_num+2 == len(self.train_set) :
235 | train_line_num = 0
236 | line_num = copy.deepcopy(train_line_num)
237 | elif sample_set == 'valid':
238 | if valid_line_num+1 == len(self.valid_set) or valid_line_num+2 == len(self.valid_set):
239 | valid_line_num = 0
240 | line_num = copy.deepcopy(valid_line_num)
241 |
242 | for cntr in range(max_lines):
243 | if sample_set == 'train':
244 | line = self.train_set[line_num]
245 | elif sample_set == 'valid':
246 | line = self.valid_set[line_num]
247 | line_num += 1
248 |
249 | eles = line.strip().split()
250 | frame_path = eles[-1]
251 | input_img_path = os.path.join(self.img_dir, frame_path)
252 | name = frame_path.split('/')[-1]
253 |
254 | img = self.open_img(input_img_path)
255 | img = scm.imresize(img, (256,256))
256 |
257 | if normalize:
258 | input_images.append(img.astype(np.float32) / 255)
259 | else :
260 | input_images.append(img.astype(np.float32))
261 |
262 | train_line_num += 1
263 | valid_line_num += 1
264 | train_img[i] = np.concatenate(input_images,axis=2)
265 |
266 | i = i + 1
267 |
268 | if sample_set == 'train':
269 | yield train_line_num, name, train_img
270 | elif sample_set == 'valid':
271 | yield valid_line_num, name, train_img
272 |
273 | def _resnet_generator(self, batch_size = 16, NUM_CLASSES = 136,
274 | normalize = True, sample_set = 'train'):
275 |
276 | while True:
277 | train_img = np.zeros((batch_size, 256,256,3), dtype = np.float32)
278 | train_gtmap = np.zeros((batch_size, NUM_CLASSES), dtype = np.float32)
279 | i = 0
280 |
281 | while i < batch_size:
282 | if sample_set == 'train':
283 | line = random.choice(self.train_set)
284 | elif sample_set == 'valid':
285 | line = random.choice(self.valid_set)
286 |
287 | eles = line.strip().split()
288 | name = eles[-1]
289 | if sample_set == 'train':
290 | input_img_path = os.path.join(self.img_dir, name)
291 | elif sample_set == 'valid':
292 | input_img_path = os.path.join(self.img_dir_valid, name)
293 |
294 | img = self.open_img(input_img_path)
295 |
296 | if sample_set == 'train':
297 | gt = np.array(list(map(float, eles[:-1])))
298 | gt = gt.reshape(-1, 2)
299 |
300 | transform_matrix, do_mirror = affine_transformation.get_affine_mat(
301 | width=256, height=256,
302 | max_trans=40, max_rotate=30, max_zoom=1.1,
303 | min_trans=-40, min_rotate=-30, min_zoom=0.9)
304 |
305 | img = affine_transformation.affine2d(img, transform_matrix, output_img_width=256,
306 | output_img_height=256, center=True,
307 | is_landmarks=False, do_mirror=do_mirror)
308 | gt = affine_transformation.affine2d(gt, transform_matrix, output_img_width=256,
309 | output_img_height=256, center=True,
310 | is_landmarks=True, do_mirror=do_mirror)
311 |
312 | transformdict = {'Brightness':0.5025, 'Contrast':0.5136,
313 | 'Sharpness':0.5568, 'Color':0.5203}
314 | image_jitter = ImageJitter(transformdict)
315 | img = Image.fromarray(img)
316 | img = image_jitter(img)
317 | img = np.array(img)
318 |
319 | img = util.random_noise(img, mode='gaussian')
320 | img = (img*255).astype(np.uint8)
321 | gt = gt.reshape(1, -1).squeeze()
322 |
323 | elif sample_set == 'valid':
324 | gt = np.array(map(float,eles[:-1]))
325 |
326 | if normalize:
327 | train_img[i] = img.astype(np.float32) / 255
328 | train_gtmap[i] = gt.astype(np.float32) /255
329 | else:
330 | train_img[i] = img.astype(np.float32)
331 | train_gtmap[i] = gt.astype(np.float32)
332 |
333 | i = i + 1
334 |
335 | yield train_img, train_gtmap
336 |
337 | def open_img(self, img_path, color = 'RGB'):
338 | img = cv2.imread(img_path)
339 | if color == 'RGB':
340 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
341 | return img
342 | elif color == 'BGR':
343 | return img
344 | elif color == 'GRAY':
345 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
346 | else:
347 | print('Color mode supported: RGB/BGR. If you need another mode do it yourself :p')
348 |
--------------------------------------------------------------------------------
/src/train_FAB.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import sys
7 | import tarfile
8 | import cv2
9 | import time
10 | import copy
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from utils.curve import points_to_heatmap_rectangle_68pt
15 | from six.moves import xrange
16 | from six.moves import urllib
17 | from datagen import DataGenerator
18 | from datagen import ensure_dir
19 | from FAB import FAB
20 |
21 | MOMENTUM = 0.9
22 | POINTS_NUM = 68
23 | IMAGE_SIZE = 256
24 | PIC_CHANNEL = 3
25 | num_input_imgs = 3
26 | NUM_CLASSES = POINTS_NUM*2
27 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
28 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
29 | structure_predictor_net_channel = 64
30 |
31 | FLAGS = tf.app.flags.FLAGS
32 | # address
33 | tf.app.flags.DEFINE_string('structure_predictor_train_dir', '', """Directory where to write train_checkpoints.""")
34 | tf.app.flags.DEFINE_string('video_deblur_train_dir', '', """Directory where to write train_checkpoints.""")
35 | tf.app.flags.DEFINE_string('resnet_train_dir', '', """Directory where to write train_checkpoints.""")
36 | tf.app.flags.DEFINE_string('end_2_end_train_dir', '', """Directory where to write train_checkpoints.""")
37 | tf.app.flags.DEFINE_string('end_2_end_valid_dir', '', """Directory where to write valid logs.""")
38 | tf.app.flags.DEFINE_string('data_dir', '', """Directory where the dataset stores.""")
39 | tf.app.flags.DEFINE_string('img_list', '', """Directory where the img_list stores.""")
40 | tf.app.flags.DEFINE_string('data_dir_valid', '', """Directory where the valid image stores. Only used for pretraining on 300W datasets.""")
41 | tf.app.flags.DEFINE_string('img_list_valid', '', """Directory where the valid image_list stores. Only used for pretraining on 300W datasets.""")
42 | # parameters
43 | tf.app.flags.DEFINE_float('learning_rate', 0.00003, "learning rate.")
44 | tf.app.flags.DEFINE_integer('batch_size', 1, "batch size")
45 | tf.app.flags.DEFINE_integer('max_steps', 2000000, "max steps")
46 | tf.app.flags.DEFINE_boolean('resume_structure_predictor', True, """Resume from latest saved state.""")
47 | tf.app.flags.DEFINE_boolean('resume_resnet', True, """Resume from latest saved state.""")
48 | tf.app.flags.DEFINE_boolean('resume_video_deblur', True, """Resume from latest saved state.""")
49 | tf.app.flags.DEFINE_boolean('resume_all', False, """Resume from latest saved state.""")
50 | tf.app.flags.DEFINE_boolean('minimal_summaries', False, """Produce fewer summaries to save HD space.""")
51 | tf.app.flags.DEFINE_string('training_period', 'pretrain', """Choose the training period: pretrain/train.""")
52 | tf.app.flags.DEFINE_boolean('use_bn', False, """Use batch normalization. Otherwise use biases.""")
53 |
54 | def resume(sess, do_resume, ckpt_path, key_word):
55 | var = tf.global_variables()
56 | if do_resume:
57 | structure_predictor_latest = tf.train.latest_checkpoint(ckpt_path)
58 | if not structure_predictor_latest:
59 | print ("\n No checkpoint to continue from in ", ckpt_path, '\n')
60 | structure_predictor_var_to_restore = [val for val in var if key_word in val.name]
61 | saver_structure_predictor = tf.train.Saver(structure_predictor_var_to_restore)
62 | saver_structure_predictor.restore(sess, structure_predictor_latest)
63 |
64 | def train(resnet_model, is_training, F, H, F_curr, H_curr,
65 | input_images_blur, input_images_boundary, next_boundary_gt, labels,
66 | data_dir, data_dir_valid, img_list, img_list_valid,
67 | dropout_ratio):
68 |
69 | global_step = tf.get_variable('global_step', [],
70 | initializer=tf.constant_initializer(0),
71 | trainable=False)
72 | val_step = tf.get_variable('val_step', [],
73 | initializer=tf.constant_initializer(0),
74 | trainable=False)
75 |
76 | # define the losses.
77 | lambda_ = 1e-5
78 |
79 | loss_1 = resnet_model.l2_loss_(resnet_model.logits, labels)
80 | loss_2 = resnet_model.l2_loss_(resnet_model.next_frame,next_boundary_gt)
81 | loss_3 = resnet_model.l2_loss_(input_images_blur[:,:,:,-3:],resnet_model.video_deblur_output)
82 | loss_ = loss_1+loss_2+loss_3+tf.reduce_sum(tf.square(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))*lambda_
83 |
84 | ema = tf.train.ExponentialMovingAverage(resnet_model.MOVING_AVERAGE_DECAY, global_step)
85 | tf.add_to_collection(resnet_model.UPDATE_OPS_COLLECTION, ema.apply([loss_]))
86 | tf.summary.scalar('loss_avg', ema.average(loss_))
87 |
88 | ema = tf.train.ExponentialMovingAverage(0.9, val_step)
89 | val_op = tf.group(val_step.assign_add(1), ema.apply([loss_]))
90 | tf.summary.scalar('loss_valid', ema.average(loss_))
91 |
92 | tf.summary.scalar('learning_rate', FLAGS.learning_rate)
93 |
94 | # define the optimizer and back propagate.
95 | opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
96 | grads = opt.compute_gradients(loss_)
97 | for grad, var in grads:
98 | if grad is not None and not FLAGS.minimal_summaries:
99 | tf.summary.histogram(var.op.name + '/gradients', grad)
100 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
101 |
102 | batchnorm_updates = tf.get_collection(resnet_model.UPDATE_OPS_COLLECTION)
103 | batchnorm_updates_op = tf.group(*batchnorm_updates)
104 | train_op = tf.group(apply_gradient_op, batchnorm_updates_op)
105 |
106 | saver_all = tf.train.Saver(tf.all_variables())
107 |
108 | summary_op = tf.summary.merge_all()
109 |
110 | # initialize all variables
111 | init = tf.initialize_all_variables()
112 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
113 | sess.run(init)
114 |
115 | summary_writer = tf.summary.FileWriter(FLAGS.end_2_end_train_dir, sess.graph)
116 | val_summary_writer = tf.summary.FileWriter(FLAGS.end_2_end_valid_dir)
117 | val_save_root = os.path.join(FLAGS.end_2_end_valid_dir,'visualization')
118 | compare_save_root = os.path.join(FLAGS.end_2_end_valid_dir,'deblur_compare')
119 |
120 | # resume weights
121 | resume(sess, FLAGS.resume_structure_predictor, FLAGS.structure_predictor_train_dir, 'voxel_flow_model_')
122 | resume(sess, FLAGS.resume_video_deblur, FLAGS.video_deblur_train_dir, 'video_deblur_model_')
123 | resume(sess, FLAGS.resume_resnet, FLAGS.resnet_train_dir, 'resnet_model_')
124 | resume(sess, FLAGS.resume_all, FLAGS.end_2_end_train_dir, '')
125 |
126 | # create data generator
127 | if FLAGS.training_period == 'pretrain':
128 | dataset = DataGenerator(data_dir, img_list, data_dir_valid, img_list_valid)
129 | dataset._create_train_sets_for_300W()
130 | dataset._create_valid_sets_for_300W()
131 | elif FLAGS.training_period == 'train':
132 | dataset = DataGenerator(data_dir,img_list)
133 | dataset._create_train_table()
134 | dataset._create_sets_for_300VW()
135 | else:
136 | raise NameError("No such training_period!")
137 | train_gen = dataset._aux_generator(batch_size = FLAGS.batch_size,
138 | num_input_imgs = num_input_imgs,
139 | NUM_CLASSES = POINTS_NUM*2,
140 | sample_set='train')
141 | valid_gen = dataset._aux_generator(batch_size = FLAGS.batch_size,
142 | num_input_imgs = num_input_imgs,
143 | NUM_CLASSES = POINTS_NUM*2,
144 | sample_set='valid')
145 |
146 | # main training process.
147 | for x in xrange(FLAGS.max_steps + 1):
148 |
149 | start_time = time.time()
150 | step = sess.run(global_step)
151 | i = [train_op, loss_]
152 | write_summary = step > 1 and not (step % 100)
153 | if write_summary:
154 | i.append(summary_op)
155 | i.append(resnet_model.logits)
156 | i.append(F_curr)
157 | i.append(H_curr)
158 |
159 | train_line_num, frame_name, input_boundaries, boundary_gt_train, input_images_blur_generated, landmark_gt_train = next(train_gen)
160 |
161 | if (frame_name == '2.jpg'):
162 | input_images_boundary_init = copy.deepcopy(input_boundaries)
163 | F_init = np.zeros([FLAGS.batch_size,
164 | IMAGE_SIZE//2,
165 | IMAGE_SIZE//2,
166 | structure_predictor_net_channel//2], dtype=np.float32)
167 |
168 | H_init = np.zeros([1,
169 | FLAGS.batch_size,
170 | IMAGE_SIZE//2,
171 | IMAGE_SIZE//2,
172 | structure_predictor_net_channel], dtype=np.float32)
173 | feed_dict={
174 | input_images_boundary:input_images_boundary_init,
175 | input_images_blur:input_images_blur_generated,
176 | F:F_init,
177 | H:H_init,
178 | labels:landmark_gt_train,
179 | next_boundary_gt:boundary_gt_train,
180 | dropout_ratio:0.5
181 | }
182 | else:
183 | output_points = o[-3]
184 | output_points = np.reshape(output_points,(POINTS_NUM,2))
185 |
186 | boundary_from_points = points_to_heatmap_rectangle_68pt(output_points)
187 | boundary_from_points = np.expand_dims(boundary_from_points,axis=0)
188 | boundary_from_points = np.expand_dims(boundary_from_points,axis=3)
189 | input_images_boundary_init = np.concatenate([input_images_boundary_init[:,:,:,1:2],
190 | boundary_from_points], axis=3)
191 | feed_dict={
192 | input_images_boundary:input_images_boundary_init,
193 | input_images_blur:input_images_blur_generated,
194 | F:o[-2],
195 | H:o[-1],
196 | labels:landmark_gt_train,
197 | next_boundary_gt:boundary_gt_train,
198 | dropout_ratio:0.5
199 | }
200 |
201 | o = sess.run(i,feed_dict=feed_dict)
202 | loss_value = o[1]
203 | duration = time.time() - start_time
204 | assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
205 |
206 | if step > 1 and step % 300 == 0:
207 | examples_per_sec = FLAGS.batch_size / float(duration)
208 | format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
209 | 'sec/batch)')
210 | print(format_str % (step, loss_value, examples_per_sec, duration))
211 |
212 | if write_summary:
213 | summary_str = o[2]
214 | summary_writer.add_summary(summary_str, step)
215 |
216 | if step > 1 and step % 300 == 0:
217 | checkpoint_path = os.path.join(FLAGS.end_2_end_train_dir, 'model.ckpt')
218 | ensure_dir(checkpoint_path)
219 | saver_all.save(sess, checkpoint_path, global_step=global_step)
220 |
221 | # Run validation periodically
222 | if step > 1 and step % 300 == 0:
223 | valid_line_num, frame_name, input_boundaries, boundary_gt_valid, input_images_blur_generated, landmark_gt_valid = next(valid_gen)
224 |
225 | if (frame_name == '2.jpg') or valid_line_num <= 3:
226 | input_images_boundary_init = copy.deepcopy(input_boundaries)
227 | F_init = np.zeros([FLAGS.batch_size,
228 | IMAGE_SIZE//2,
229 | IMAGE_SIZE//2,
230 | structure_predictor_net_channel//2], dtype=np.float32)
231 |
232 | H_init = np.zeros([1, FLAGS.batch_size,
233 | IMAGE_SIZE//2,
234 | IMAGE_SIZE//2,
235 | structure_predictor_net_channel], dtype=np.float32)
236 |
237 | feed_dict={input_images_boundary:input_images_boundary_init,
238 | input_images_blur:input_images_blur_generated,
239 | F:F_init,
240 | H:H_init,
241 | labels:landmark_gt_valid,
242 | next_boundary_gt:boundary_gt_valid,
243 | dropout_ratio:1.0
244 | }
245 | else:
246 | output_points = o_valid[-3]
247 | output_points = np.reshape(output_points,(POINTS_NUM,2))
248 | boundary_from_points = points_to_heatmap_rectangle_68pt(output_points)
249 | boundary_from_points = np.expand_dims(boundary_from_points,axis=0)
250 | boundary_from_points = np.expand_dims(boundary_from_points,axis=3)
251 |
252 | input_images_boundary_init = np.concatenate([input_images_boundary_init[:,:,:,1:2],
253 | boundary_from_points], axis=3)
254 | feed_dict={
255 | input_images_boundary:input_images_boundary_init,
256 | input_images_blur:input_images_blur_generated,
257 | F:o_valid[-2],
258 | H:o_valid[-1],
259 | labels:landmark_gt_valid,
260 | next_boundary_gt:boundary_gt_valid,
261 | dropout_ratio:1.0
262 | }
263 | i_valid = [loss_,resnet_model.logits,F_curr,H_curr]
264 | o_valid = sess.run(i_valid,feed_dict=feed_dict)
265 | print('Validation top1 error %.2f' % o_valid[0])
266 | if write_summary:
267 | val_summary_writer.add_summary(summary_str, step)
268 | img_video_deblur_output = sess.run(resnet_model.video_deblur_output,feed_dict=feed_dict)[0]*255
269 | img = input_images_blur_generated[0,:,:,0:3]*255
270 | compare_img = np.concatenate([img,img_video_deblur_output],axis=1)
271 | points = o_valid[1][0]*255
272 |
273 | for point_num in range(int(points.shape[0]/2)):
274 | cv2.circle(img,(int(round(points[point_num*2])),int(round(points[point_num*2+1]))),1,(55,225,155),2)
275 | val_save_path = os.path.join(val_save_root,str(step)+'.jpg')
276 | compare_save_path = os.path.join(compare_save_root,str(step)+'.jpg')
277 | ensure_dir(val_save_path)
278 | ensure_dir(compare_save_path)
279 | cv2.imwrite(val_save_path,img)
280 | cv2.imwrite(compare_save_path,compare_img)
281 |
282 | def main(argv=None):
283 | resnet_model = FAB(structure_predictor_is_train=False,
284 | deblur_is_train=True,
285 | resnet_is_train=False)
286 |
287 | is_training = tf.placeholder('bool', [], name='is_training')
288 |
289 | input_images_boundary = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 2))
290 | input_images_blur = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, PIC_CHANNEL*3))
291 | next_boundary_gt = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 1))
292 | labels = tf.placeholder(tf.float32,shape=(FLAGS.batch_size,NUM_CLASSES))
293 | dropout_ratio = tf.placeholder(tf.float32)
294 | F = tf.placeholder(tf.float32, [FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel//2])
295 | H = tf.placeholder(tf.float32, [1, FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel])
296 |
297 | F_curr, H_curr = resnet_model.FAB_inference(input_images_boundary, input_images_blur, F, H, FLAGS.batch_size,
298 | net_channel=structure_predictor_net_channel, num_classes=136,
299 | num_blocks=[2, 2, 2, 2], use_bias=(not FLAGS.use_bn),
300 | bottleneck=True,dropout_ratio=1.0)
301 |
302 | train(resnet_model, is_training, F, H, F_curr, H_curr,
303 | input_images_blur, input_images_boundary, next_boundary_gt, labels,
304 | FLAGS.data_dir, FLAGS.data_dir_valid, FLAGS.img_list, FLAGS.img_list_valid,
305 | dropout_ratio)
306 |
307 | if __name__ == '__main__':
308 | tf.app.run()
309 |
--------------------------------------------------------------------------------
/src/FAB.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import tensorflow.contrib.slim as slim
7 | import datetime
8 | import numpy as np
9 | import os
10 | import time
11 | import math
12 | import skimage.io
13 | import skimage.transform
14 |
15 | from utils.loss_utils import l2_loss
16 | from utils.geo_layer_utils import bilinear_interp
17 | from utils.geo_layer_utils import meshgrid
18 | from tensorflow.python.ops import control_flow_ops
19 | from tensorflow.python.training import moving_averages
20 | from utils.config import Config
21 |
22 |
23 | class FAB(object):
24 | def __init__(self, structure_predictor_is_train=True, deblur_is_train=True,
25 | resnet_is_train=True, is_training=True,
26 | MOVING_AVERAGE_DECAY=0.9997, BN_EPSILON=0.001,
27 | CONV_WEIGHT_DECAY=0.0005, CONV_WEIGHT_STDDEV=0.1,
28 | FC_WEIGHT_DECAY=0.0005, FC_WEIGHT_STDDEV=0.01,
29 | RESNET_VARIABLES='RESNET_VARIABLES',
30 | UPDATE_OPS_COLLECTION='resnet_update_ops',
31 | IMAGENET_MEAN_BGR=[103.062623801, 115.902882574, 123.151630838, ],
32 | input_size = 224):
33 |
34 | self.structure_predictor_is_train = structure_predictor_is_train
35 | self.deblur_is_train = deblur_is_train
36 | self.resnet_is_train = resnet_is_train
37 |
38 | self.MOVING_AVERAGE_DECAY = MOVING_AVERAGE_DECAY
39 | self.BN_DECAY = self.MOVING_AVERAGE_DECAY
40 | self.BN_EPSILON = BN_EPSILON
41 | self.CONV_WEIGHT_DECAY = CONV_WEIGHT_DECAY
42 | self.CONV_WEIGHT_STDDEV = CONV_WEIGHT_STDDEV
43 | self.FC_WEIGHT_DECAY = FC_WEIGHT_DECAY
44 | self.FC_WEIGHT_STDDEV = FC_WEIGHT_STDDEV
45 | self.RESNET_VARIABLES = RESNET_VARIABLES
46 | self.UPDATE_OPS_COLLECTION = UPDATE_OPS_COLLECTION
47 | self.IMAGENET_MEAN_BGR = IMAGENET_MEAN_BGR
48 | self.input_size = input_size
49 |
50 | ### loss function ###
51 | def l1_loss_(self, logits, labels):
52 | logits = tf.cast(logits,tf.float32)
53 | labels = tf.cast(labels,tf.float32)
54 | losses = tf.reduce_sum(tf.abs(tf.subtract(logits,labels)), axis=1)
55 | losses_mean = tf.reduce_mean(losses)
56 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
57 | loss_ = tf.add_n([losses_mean] + regularization_losses)
58 |
59 | return loss_
60 |
61 | def l2_loss_(self, logits, labels):
62 | logits = tf.cast(logits,tf.float32)
63 | labels = tf.cast(labels,tf.float32)
64 | losses = tf.nn.l2_loss(tf.subtract(logits,labels))
65 | losses_mean = tf.reduce_mean(losses)
66 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
67 | loss_ = tf.add_n([losses_mean] + regularization_losses)
68 |
69 | return loss_
70 |
71 | def wing_loss(self, logits, labels, w=10.0, epsilon=2.0):
72 | logits = tf.cast(logits,tf.float32)
73 | labels = tf.cast(labels,tf.float32)
74 | x = tf.subtract(logits,labels)
75 | C = w * (1.0 - math.log(1.0 + w/epsilon))
76 | absolute_x = tf.abs(x)
77 | losses = tf.where(tf.greater(w, absolute_x),
78 | w * tf.log(1.0 + absolute_x/epsilon),
79 | absolute_x - C)
80 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
81 | loss_ = tf.add_n([losses] + regularization_losses)
82 |
83 | return loss_
84 |
85 | def calculate_NME(self, logits, labels):
86 | logits = tf.cast(logits,tf.float32)
87 | labels = tf.cast(labels,tf.float32)
88 |
89 | subtract_square_distance = tf.square(tf.subtract(logits, labels))
90 | mean_distance = tf.reduce_mean([tf.sqrt(tf.add(subtract_square_distance[:, column],
91 | subtract_square_distance[:, column+1])) for column in range(0, 136, 2)], axis=0)
92 |
93 | outer_eye_x = tf.square(tf.subtract(labels[:, 72], labels[:, 90]))
94 | outer_eye_y = tf.square(tf.subtract(labels[:, 73], labels[:, 91]))
95 | inter_ocular_distance = tf.sqrt(tf.add(outer_eye_x, outer_eye_y))
96 |
97 | normalized_mean_error = tf.divide(mean_distance, inter_ocular_distance,
98 | name='normalized_mean_error')
99 | loss_ = tf.reduce_mean(normalized_mean_error)
100 |
101 | return loss_
102 |
103 | ### structure predictor model ###
104 | def structure_predictor_inference(self,input_images_boundary,batch_size):
105 | with tf.variable_scope('structure_predictor_model_'):
106 | with slim.arg_scope([slim.conv2d],
107 | activation_fn=tf.nn.relu,
108 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
109 | weights_regularizer=slim.l2_regularizer(0.0001)):
110 |
111 | batch_norm_params = {'decay': 0.9997,
112 | 'epsilon': 0.0001,
113 | 'is_training': self.structure_predictor_is_train}
114 |
115 | with slim.arg_scope([slim.batch_norm],
116 | is_training = self.structure_predictor_is_train,
117 | updates_collections=None):
118 | with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
119 | normalizer_params=batch_norm_params):
120 | net = slim.conv2d(input_images_boundary, 64, [5, 5], stride=1, scope='conv1')
121 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
122 | net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2')
123 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
124 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3')
125 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
126 | net = tf.image.resize_bilinear(net, [64,64])
127 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4')
128 | net = tf.image.resize_bilinear(net, [128,128])
129 | net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5')
130 | net = tf.image.resize_bilinear(net, [256,256])
131 | net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6')
132 |
133 | net = slim.conv2d(net, 3, [5, 5], stride=1, activation_fn=tf.tanh,
134 | normalizer_fn=None, scope='conv7')
135 | flow = net[:, :, :, 0:2]
136 | mask = tf.expand_dims(net[:, :, :, 2], 3)
137 |
138 | grid_x, grid_y = meshgrid(256, 256)
139 | grid_x = tf.tile(grid_x, [batch_size, 1, 1])
140 | grid_y = tf.tile(grid_y, [batch_size, 1, 1])
141 |
142 | coor_x_1 = grid_x + flow[:, :, :, 0]*2
143 | coor_y_1 = grid_y + flow[:, :, :, 1]*2
144 | coor_x_2 = grid_x + flow[:, :, :, 0]
145 | coor_y_2 = grid_y + flow[:, :, :, 1]
146 |
147 | output_1 = bilinear_interp(input_images_boundary[:, :, :, 0:1],
148 | coor_x_1, coor_y_1, 'extrapolate')
149 | output_2 = bilinear_interp(input_images_boundary[:, :, :, 1:2],
150 | coor_x_2, coor_y_2, 'extrapolate')
151 |
152 | mask = 0.33 * (1.0 + mask)
153 | mask = tf.tile(mask, [1, 1, 1, 3])
154 | next_frame = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2)
155 |
156 | return next_frame
157 |
158 | ### video deblur function ###
159 | def get_shape(self, x, i):
160 | return x.get_shape().as_list()[i]
161 |
162 | def weight_variable(self, shape, stddev=0.02, name = 'weight'):
163 | w = tf.get_variable(name, shape,
164 | initializer=tf.random_normal_initializer(stddev=stddev),
165 | trainable=self.deblur_is_train)
166 | return w
167 |
168 | def bias_variable(self, shape, name):
169 | b = tf.get_variable(name, initializer = tf.zeros(shape),
170 | trainable= self.deblur_is_train)
171 | return b
172 |
173 | def conv2d(self, x, W, stride = 1):
174 | return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='SAME')
175 |
176 | def conv2d_transpose(self, x, w, output_shape, stride = 2):
177 | return tf.nn.conv2d_transpose(x, w, output_shape=output_shape,
178 | strides=[1, stride, stride, 1], padding='SAME')
179 |
180 | def bn(self, x):
181 | net = x
182 | out_channels = self.get_shape(net, 3)
183 | mean, var = tf.nn.moments(net, axes=[0,1,2])
184 | beta = self.bias_variable([out_channels], name="beta")
185 | gamma = self.weight_variable([out_channels], name="gamma")
186 | net = tf.nn.batch_normalization(net, mean, var, beta, gamma, 0.001)
187 | return net
188 |
189 | def conv_bn(self, x, filter_shape):
190 | net = x
191 | net = tf.nn.conv2d(net, self.weight_variable(filter_shape, name = "weight"),
192 | strides=[1, 1, 1, 1], padding="SAME")
193 | out_channels = filter_shape[3]
194 | mean, var = tf.nn.moments(net, axes=[0,1,2])
195 | beta = self.bias_variable([out_channels], name="beta")
196 | gamma = self.weight_variable([out_channels], name="gamma")
197 | net = tf.nn.batch_normalization(net, mean, var, beta, gamma, 0.001)
198 | return net
199 |
200 | def resnet_block(self, x, out_channel, filter_size = 3):
201 | x_channel = x.get_shape().as_list()[3]
202 | with tf.variable_scope("conv_bn_relu"):
203 | net = self.conv_bn(x, filter_shape=[filter_size,
204 | filter_size,
205 | out_channel,
206 | out_channel])
207 | net = tf.nn.relu(net)
208 | with tf.variable_scope("conv_bn"):
209 | net = self.conv_bn(net, filter_shape=[filter_size,
210 | filter_size,
211 | out_channel,
212 | out_channel])
213 | net = net + x
214 | tf.nn.relu(net)
215 | return net
216 |
217 | def dynamic_fusion(self, x, h, filter_size = 5):
218 | n_channel = self.get_shape(x, 3)
219 | t = tf.concat([x, h], 3)
220 | similarity = tf.nn.conv2d(t, self.weight_variable([filter_size,
221 | filter_size,
222 | n_channel*2,
223 | n_channel],
224 | name = "wt"),
225 | strides=[1, 1, 1, 1],
226 | padding='VALID')
227 | epsilon = self.bias_variable([1], name = 'bias_epsilon')
228 | alpha = 2*tf.abs(tf.sigmoid(similarity) - 0.5) + epsilon
229 | alpha = tf.clip_by_value(alpha, 0, 1)
230 | hflt_filter_size = filter_size // 2
231 | alpha = tf.pad(alpha-1, [[0, 0],
232 | [hflt_filter_size, hflt_filter_size],
233 | [hflt_filter_size, hflt_filter_size],
234 | [0, 0]], "CONSTANT") + 1
235 | y = alpha*x + (1-alpha)*h
236 | return y, alpha
237 |
238 | def video_deblur_inference(self, X, F, H, net_channel = 64):
239 | with tf.variable_scope('video_deblur_model_'):
240 | H_curr = []
241 | with tf.variable_scope("encoding"):
242 | with tf.variable_scope("conv1"):
243 | filter_size = 5
244 | net_X = self.conv2d(X, self.weight_variable([filter_size,
245 | filter_size,
246 | self.get_shape(X, 3),
247 | net_channel]))
248 | net_X = tf.nn.relu(net_X)
249 | with tf.variable_scope("conv2"):
250 | filter_size = 3
251 | net_X = self.conv2d(net_X, self.weight_variable([filter_size,
252 | filter_size,
253 | self.get_shape(net_X, 3),
254 | net_channel//2]),
255 | stride = 2)
256 | net_X = tf.nn.relu(net_X)
257 | net = tf.concat([net_X, F], 3)
258 | f0 = net
259 | filter_size = 3
260 | num_resnet_layers = 8
261 | for i in range (num_resnet_layers):
262 | with tf.variable_scope('resnet_block%d' % (i+1)):
263 | net = self.resnet_block(net, net_channel)
264 | if i == 3:
265 | (net, alpha) = self.dynamic_fusion(net, H[0])
266 | h = tf.expand_dims(net, axis=0)
267 | H_curr = h
268 | with tf.variable_scope("feat_out"):
269 | F = self.conv2d(net, self.weight_variable([filter_size,
270 | filter_size,
271 | self.get_shape(net, 3),
272 | net_channel//2],
273 | name = 'conv_F'))
274 | F = tf.nn.relu(F)
275 | with tf.variable_scope("img_out"):
276 | filter_size = 4
277 | shape = [self.get_shape(X, 0),
278 | self.get_shape(X, 1),
279 | self.get_shape(X, 2),
280 | net_channel]
281 | Y = self.conv2d_transpose(net, self.weight_variable([filter_size,
282 | filter_size,
283 | net_channel,
284 | net_channel],
285 | name = "deconv"),
286 | shape,
287 | stride = 2)
288 | Y = tf.nn.relu(Y)
289 | filter_size = 3
290 | Y = self.conv2d(Y, self.weight_variable([filter_size,
291 | filter_size,
292 | self.get_shape(Y, 3),
293 | 3],
294 | name = 'conv'))
295 | return Y, F, H_curr
296 |
297 | ### resnet inference ###
298 | def resnet_inference(self,
299 | input_images_blur,
300 | batch_size,
301 | num_classes=136,
302 | num_blocks=[2, 2, 2, 2],
303 | use_bias=False,
304 | bottleneck=True,
305 | dropout_ratio=1.0):
306 | ####resnet_model####
307 | with tf.variable_scope('resnet_model_'):
308 | c = Config()
309 | c['bottleneck'] = bottleneck
310 | c['is_training'] = tf.convert_to_tensor(self.resnet_is_train,
311 | dtype='bool',
312 | name='is_training')
313 | c['ksize'] = 3
314 | c['stride'] = 1
315 | c['use_bias'] = use_bias
316 | c['fc_units_out'] = num_classes
317 | c['num_blocks'] = num_blocks
318 | c['stack_stride'] = 2
319 |
320 | with tf.variable_scope('scale1'):
321 | c['conv_filters_out'] = 16
322 | c['ksize'] = 7
323 | c['stride'] = 2
324 | x = self.conv(input_images_blur, c)
325 | x = self.resnet_bn(x, c)
326 | x = self.activation(x)
327 |
328 | with tf.variable_scope('scale1_pool'):
329 | x = self._max_pool(x, ksize=3, stride=2)
330 | x = self.resnet_bn(x, c)
331 | x = self.activation(x)
332 |
333 | with tf.variable_scope('scale2'):
334 | x = self._max_pool(x, ksize=3, stride=2)
335 | c['num_blocks'] = num_blocks[0]
336 | c['stack_stride'] = 1
337 | c['block_filters_internal'] = 8
338 | x = self.stack(x, c)
339 |
340 | with tf.variable_scope('scale3'):
341 | c['num_blocks'] = num_blocks[1]
342 | c['block_filters_internal'] = 16
343 | assert c['stack_stride'] == 2
344 | x = self.stack(x, c)
345 |
346 | with tf.variable_scope('scale4'):
347 | c['num_blocks'] = num_blocks[2]
348 | c['block_filters_internal'] = 32
349 | x = self.stack(x, c)
350 |
351 | with tf.variable_scope('scale5'):
352 | c['num_blocks'] = num_blocks[3]
353 | c['block_filters_internal'] = 64
354 | x = self.stack(x, c)
355 |
356 | x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool")
357 |
358 | if num_classes != None:
359 | with tf.variable_scope('fc1'):
360 | c['fc_units_out'] = 256
361 | x = self.fc(x, c)
362 |
363 | with tf.variable_scope('dropout1'):
364 | x = tf.nn.dropout(x, dropout_ratio)
365 |
366 | with tf.variable_scope('fc2'):
367 | c['fc_units_out'] = 256
368 | x = self.fc(x, c)
369 |
370 | with tf.variable_scope('dropout2'):
371 | x = tf.nn.dropout(x, dropout_ratio)
372 |
373 | with tf.variable_scope('fc3'):
374 | c['fc_units_out'] = 136
375 | landmark_localization = self.fc(x, c)
376 |
377 | return landmark_localization
378 |
379 | def stack(self, x, c):
380 | for n in range(c['num_blocks']):
381 | s = c['stack_stride'] if n == 0 else 1
382 | c['block_stride'] = s
383 | with tf.variable_scope('block%d' % (n + 1)):
384 | x = self.block(x, c, n)
385 | return x
386 |
387 | def block(self, x, c, n):
388 | filters_in = x.get_shape()[-1]
389 | m = 4 if c['bottleneck'] else 1
390 | filters_out = m * c['block_filters_internal']
391 | c['conv_filters_out'] = c['block_filters_internal']
392 |
393 | shortcut = x
394 |
395 | if c['bottleneck']:
396 | if n == 1:
397 | with tf.variable_scope('pre_activation'):
398 | x = self.resnet_bn(x, c)
399 | x = self.activation(x)
400 |
401 | with tf.variable_scope('a'):
402 | c['ksize'] = 1
403 | c['stride'] = c['block_stride']
404 | x = self.conv(x, c)
405 | x = self.resnet_bn(x, c)
406 | x = self.activation(x)
407 |
408 | with tf.variable_scope('b'):
409 | x = self.conv(x, c)
410 | x = self.resnet_bn(x, c)
411 | x = self.activation(x)
412 |
413 | with tf.variable_scope('c'):
414 | c['conv_filters_out'] = filters_out
415 | c['ksize'] = 1
416 | assert c['stride'] == 1
417 | x = self.conv(x, c)
418 | else:
419 | with tf.variable_scope('A'):
420 | c['stride'] = c['block_stride']
421 | assert c['ksize'] == 3
422 | x = self.conv(x, c)
423 | x = self.resnet_bn(x, c)
424 | x = self.activation(x)
425 |
426 | with tf.variable_scope('B'):
427 | c['conv_filters_out'] = filters_out
428 | assert c['ksize'] == 3
429 | assert c['stride'] == 1
430 | x = self.conv(x, c)
431 | x = self.resnet_bn(x, c)
432 |
433 | with tf.variable_scope('shortcut'):
434 | if filters_out != filters_in or c['block_stride'] != 1:
435 | c['ksize'] = 1
436 | c['stride'] = c['block_stride']
437 | c['conv_filters_out'] = filters_out
438 | shortcut = self.conv(shortcut, c)
439 |
440 | if n == 0:
441 | return x + shortcut
442 | elif n == 1:
443 | x = self.resnet_bn(x+shortcut, c)
444 | return self.activation(x)
445 |
446 | def resnet_bn(self, x, c):
447 | x_shape = x.get_shape()
448 | params_shape = x_shape[-1:]
449 |
450 | if c['use_bias']:
451 | bias = self._get_variable('bias',
452 | params_shape,
453 | initializer=tf.zeros_initializer)
454 | return x + bias
455 |
456 | axis = list(range(len(x_shape) - 1))
457 | beta = self._get_variable('beta',
458 | params_shape,
459 | initializer=tf.zeros_initializer)
460 | gamma = self._get_variable('gamma',
461 | params_shape,
462 | initializer=tf.ones_initializer)
463 |
464 | moving_mean = self._get_variable('moving_mean',
465 | params_shape,
466 | initializer=tf.zeros_initializer,
467 | trainable=False)
468 | moving_variance = self._get_variable('moving_variance',
469 | params_shape,
470 | initializer=tf.ones_initializer,
471 | trainable=False)
472 |
473 | # These ops will only be preformed when training.
474 | mean, variance = tf.nn.moments(x, axis)
475 | update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, self.BN_DECAY)
476 | update_moving_variance = moving_averages.assign_moving_average(
477 | moving_variance, variance, self.BN_DECAY)
478 |
479 | tf.add_to_collection(self.UPDATE_OPS_COLLECTION, update_moving_mean)
480 | tf.add_to_collection(self.UPDATE_OPS_COLLECTION, update_moving_variance)
481 |
482 | mean, variance = control_flow_ops.cond(
483 | c['is_training'], lambda: (mean, variance),
484 | lambda: (moving_mean, moving_variance))
485 |
486 | x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, self.BN_EPSILON)
487 |
488 | return x
489 |
490 | def activation(self, x):
491 | alphas = tf.get_variable('alpha', x.get_shape()[-1],
492 | initializer=tf.constant_initializer(0.25),
493 | dtype=tf.float32)
494 | pos = tf.nn.relu(x)
495 | neg = alphas * (x - abs(x)) * 0.5
496 |
497 | return pos + neg
498 |
499 | def fc(self, x, c):
500 | num_units_in = x.get_shape()[1]
501 | num_units_out = c['fc_units_out']
502 | weights_initializer = tf.truncated_normal_initializer(
503 | stddev=self.FC_WEIGHT_STDDEV)
504 | weights = self._get_variable('weights',
505 | shape=[num_units_in, num_units_out],
506 | initializer=weights_initializer,
507 | weight_decay=self.FC_WEIGHT_STDDEV)
508 | biases = self._get_variable('biases',
509 | shape=[num_units_out],
510 | initializer=tf.zeros_initializer)
511 | x = tf.nn.xw_plus_b(x, weights, biases)
512 |
513 | return x
514 |
515 | def stack_fc(self, x, c):
516 | num_units_in = x.get_shape()[1]
517 |
518 | weights_initializer = tf.truncated_normal_initializer(
519 | stddev=self.FC_WEIGHT_STDDEV)
520 |
521 | weights = self._get_variable('weights',
522 | shape=[num_units_in, 256],
523 | initializer=weights_initializer,
524 | weight_decay=self.FC_WEIGHT_STDDEV)
525 | biases = self._get_variable('biases',
526 | shape=[256],
527 | initializer=tf.zeros_initializer)
528 | x = tf.nn.xw_plus_b(x, weights, biases)
529 |
530 | weights_2 = self._get_variable('weights_2',
531 | shape=[256, 256],
532 | initializer=weights_initializer,
533 | weight_decay=self.FC_WEIGHT_STDDEV)
534 | biases_2 = self._get_variable('biases_2',
535 | shape=[256],
536 | initializer=tf.zeros_initializer)
537 | x = tf.nn.xw_plus_b(x, weights_2, biases_2)
538 |
539 | num_units_out = c['fc_units_out']
540 |
541 | weights_3 = self._get_variable('weights_3',
542 | shape=[256, num_units_out],
543 | initializer=weights_initializer,
544 | weight_decay=self.FC_WEIGHT_STDDEV)
545 | biases_3 = self._get_variable('biases_3',
546 | shape=[num_units_out],
547 | initializer=tf.zeros_initializer)
548 | x = tf.nn.xw_plus_b(x, weights_3, biases_3)
549 |
550 | return x
551 |
552 | def _get_variable(self, name,
553 | shape,
554 | initializer,
555 | weight_decay=0.0,
556 | dtype='float',
557 | trainable=True):
558 | if weight_decay > 0:
559 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
560 | else:
561 | regularizer = None
562 | collections = [tf.GraphKeys.VARIABLES, self.RESNET_VARIABLES]
563 |
564 | return tf.get_variable(name,
565 | shape=shape,
566 | initializer=initializer,
567 | dtype=dtype,
568 | regularizer=regularizer,
569 | collections=collections,
570 | trainable=trainable)
571 |
572 | def conv(self, x, c):
573 | ksize = c['ksize']
574 | stride = c['stride']
575 | filters_out = c['conv_filters_out']
576 |
577 | filters_in = x.get_shape()[-1]
578 | shape = [ksize, ksize, filters_in, filters_out]
579 | initializer = tf.truncated_normal_initializer(stddev=self.CONV_WEIGHT_STDDEV)
580 | weights = self._get_variable('weights',
581 | shape=shape,
582 | dtype='float',
583 | initializer=initializer,
584 | weight_decay=self.CONV_WEIGHT_DECAY)
585 |
586 | return tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding='SAME')
587 |
588 | def _max_pool(self, x, ksize=3, stride=2):
589 | return tf.nn.max_pool(x,
590 | ksize=[1, ksize, ksize, 1],
591 | strides=[1, stride, stride, 1],
592 | padding='SAME')
593 |
594 | ### FAB model ###
595 | def FAB_inference(self,
596 | input_images_boundary,
597 | input_images_blur,
598 | F,H,
599 | batch_size,
600 | net_channel=64,
601 | num_classes=136,
602 | num_blocks=[2, 2, 2, 2],
603 | use_bias=False,
604 | bottleneck=True,
605 | dropout_ratio=1.0):
606 |
607 | ####structure_predictor_model####
608 | with tf.variable_scope('structure_predictor_model_'):
609 | with slim.arg_scope([slim.conv2d],
610 | activation_fn=tf.nn.relu,
611 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
612 | weights_regularizer=slim.l2_regularizer(0.0001)):
613 |
614 | batch_norm_params = {
615 | 'decay': 0.9997,
616 | 'epsilon': 0.001,
617 | 'is_training': self.structure_predictor_is_train,
618 | }
619 | with slim.arg_scope([slim.batch_norm],
620 | is_training=self.structure_predictor_is_train,
621 | updates_collections=None):
622 | with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm,
623 | normalizer_params=batch_norm_params):
624 | net = slim.conv2d(input_images_boundary, 64, [5, 5], stride=1, scope='conv1')
625 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
626 | net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2')
627 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
628 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3')
629 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
630 | net = tf.image.resize_bilinear(net, [64,64])
631 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4')
632 | net = tf.image.resize_bilinear(net, [128,128])
633 | net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5')
634 | net = tf.image.resize_bilinear(net, [256,256])
635 | net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6')
636 | net = slim.conv2d(net, 3, [5, 5], stride=1,
637 | activation_fn=tf.tanh, normalizer_fn=None, scope='conv7')
638 | flow = net[:, :, :, 0:2]
639 | mask = tf.expand_dims(net[:, :, :, 2], 3)
640 |
641 | grid_x, grid_y = meshgrid(256, 256)
642 | grid_x = tf.tile(grid_x, [batch_size, 1, 1])
643 | grid_y = tf.tile(grid_y, [batch_size, 1, 1])
644 |
645 | coor_x_1 = grid_x + flow[:, :, :, 0]*2
646 | coor_y_1 = grid_y + flow[:, :, :, 1]*2
647 | coor_x_2 = grid_x + flow[:, :, :, 0]
648 | coor_y_2 = grid_y + flow[:, :, :, 1]
649 |
650 | output_1 = bilinear_interp(input_images_boundary[:, :, :, 0:1],
651 | coor_x_1, coor_y_1, 'extrapolate')
652 | output_2 = bilinear_interp(input_images_boundary[:, :, :, 1:2],
653 | coor_x_2, coor_y_2, 'extrapolate')
654 |
655 | mask = 0.5 * (1.0 + mask)
656 | mask = tf.tile(mask, [1, 1, 1, 3])
657 | self.next_frame = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2)
658 | self.structure_predictor_output = tf.concat([self.next_frame,input_images_blur],3)
659 |
660 | ####video_deblur_model####
661 | with tf.variable_scope('video_deblur_model_'):
662 | H_curr = []
663 | with tf.variable_scope("encoding"):
664 |
665 | with tf.variable_scope("conv1"):
666 | filter_size = 5
667 | net_X = self.conv2d(self.structure_predictor_output, self.weight_variable([filter_size,
668 | filter_size,
669 | self.get_shape(self.structure_predictor_output, 3),
670 | net_channel]))
671 | net_X = tf.nn.relu(net_X)
672 |
673 | with tf.variable_scope("conv2"):
674 | filter_size = 3
675 | net_X = self.conv2d(net_X, self.weight_variable([filter_size,
676 | filter_size,
677 | self.get_shape(net_X, 3),
678 | net_channel//2]),
679 | stride = 2)
680 | net_X = tf.nn.relu(net_X)
681 |
682 | net = tf.concat([net_X, F], 3)
683 | f0 = net
684 | filter_size = 3
685 | num_resnet_layers = 8
686 |
687 | for i in range (num_resnet_layers):
688 | with tf.variable_scope('resnet_block%d' % (i+1)):
689 | net = self.resnet_block(net, net_channel)
690 |
691 | if i == 3:
692 | (net, alpha) = self.dynamic_fusion(net, H[0])
693 | h = tf.expand_dims(net, axis=0)
694 | H_curr = h
695 |
696 | with tf.variable_scope("feat_out"):
697 | F = self.conv2d(net, self.weight_variable([filter_size,
698 | filter_size,
699 | self.get_shape(net, 3),
700 | net_channel//2],
701 | name = 'conv_F'))
702 | F = tf.nn.relu(F)
703 |
704 | with tf.variable_scope("img_out"):
705 | filter_size = 4
706 | shape = [self.get_shape(self.structure_predictor_output, 0),
707 | self.get_shape(self.structure_predictor_output, 1),
708 | self.get_shape(self.structure_predictor_output, 2),
709 | net_channel]
710 | Y = self.conv2d_transpose(net, self.weight_variable([filter_size,
711 | filter_size,
712 | net_channel,
713 | net_channel],
714 | name = "deconv"),
715 | shape,
716 | stride = 2)
717 | Y = tf.nn.relu(Y)
718 | filter_size = 3
719 | self.video_deblur_output = self.conv2d(Y, self.weight_variable([filter_size,
720 | filter_size,
721 | self.get_shape(Y, 3),
722 | 3],
723 | name = 'conv'))
724 |
725 | ####resnet_model####
726 | with tf.variable_scope('resnet_model_'):
727 | c = Config()
728 | c['bottleneck'] = bottleneck
729 | c['is_training'] = tf.convert_to_tensor(self.resnet_is_train,
730 | dtype='bool',
731 | name='is_training')
732 | c['ksize'] = 3
733 | c['stride'] = 1
734 | c['use_bias'] = use_bias
735 | c['fc_units_out'] = num_classes
736 | c['num_blocks'] = num_blocks
737 | c['stack_stride'] = 2
738 |
739 | with tf.variable_scope('scale1'):
740 | c['conv_filters_out'] = 16
741 | c['ksize'] = 7
742 | c['stride'] = 2
743 | x = self.conv(self.video_deblur_output, c)
744 | x = self.resnet_bn(x, c)
745 | x = self.activation(x)
746 |
747 | with tf.variable_scope('scale1_pool'):
748 | x = self._max_pool(x, ksize=3, stride=2)
749 | x = self.resnet_bn(x, c)
750 | x = self.activation(x)
751 |
752 | with tf.variable_scope('scale2'):
753 | x = self._max_pool(x, ksize=3, stride=2)
754 | c['num_blocks'] = num_blocks[0]
755 | c['stack_stride'] = 1
756 | c['block_filters_internal'] = 8
757 | x = self.stack(x, c)
758 |
759 | with tf.variable_scope('scale3'):
760 | c['num_blocks'] = num_blocks[1]
761 | c['block_filters_internal'] = 16
762 | assert c['stack_stride'] == 2
763 | x = self.stack(x, c)
764 |
765 | with tf.variable_scope('scale4'):
766 | c['num_blocks'] = num_blocks[2]
767 | c['block_filters_internal'] = 32
768 | x = self.stack(x, c)
769 |
770 | with tf.variable_scope('scale5'):
771 | c['num_blocks'] = num_blocks[3]
772 | c['block_filters_internal'] = 64
773 | x = self.stack(x, c)
774 |
775 | x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool")
776 |
777 | if num_classes != None:
778 | with tf.variable_scope('fc1'):
779 | c['fc_units_out'] = 256
780 | x = self.fc(x, c)
781 |
782 | with tf.variable_scope('dropout1'):
783 | x = tf.nn.dropout(x, dropout_ratio)
784 |
785 | with tf.variable_scope('fc2'):
786 | c['fc_units_out'] = 256
787 | x = self.fc(x, c)
788 |
789 | with tf.variable_scope('dropout2'):
790 | x = tf.nn.dropout(x, dropout_ratio)
791 |
792 | with tf.variable_scope('fc3'):
793 | c['fc_units_out'] = 136
794 | self.logits = self.fc(x, c)
795 |
796 | return F, H_curr
797 |
--------------------------------------------------------------------------------