├── src ├── utils │ ├── __init__.py │ ├── color_jitter.py │ ├── loss_utils.py │ ├── geo_layer_utils.py │ ├── config.py │ ├── affine_transformation.py │ └── curve.py ├── test_FAB.py ├── datagen.py ├── train_FAB.py └── FAB.py ├── data └── datasets │ ├── RWMB │ └── README.md │ └── Blurred-300VW │ └── README.md ├── fig ├── deblur.png ├── effects.png ├── framework.png └── structure_predictor.png ├── scripts ├── test.sh └── train.sh ├── LICENSE └── README.md /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/datasets/RWMB/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/datasets/Blurred-300VW/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fig/deblur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/deblur.png -------------------------------------------------------------------------------- /fig/effects.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/effects.png -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/framework.png -------------------------------------------------------------------------------- /fig/structure_predictor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KeqiangSun/FAB/HEAD/fig/structure_predictor.png -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python ./src/test_FAB.py \ 4 | --structure_predictor_train_dir ./data/checkpoints/structure_predictor_train_dir/ \ 5 | --voxel_flow_train_dir ./data/checkpoints/voxel_flow_train_dir/ \ 6 | --resnet_train_dir ./data/checkpoints/resnet_train_dir/ \ 7 | --resume_structure_predictor True \ 8 | --resume_video_devlur True \ 9 | --resume_resnet True \ 10 | --resume_all False \ 11 | --data_dir ./data/300VW/Images/ \ 12 | --img_list ./data/300VW/labels_68pt_256_train_sorted.txt \ 13 | --end_2_end_test_dir ../data/test_results/ & 14 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python ./src/train_FAB.py \ 4 | --structure_predictor_train_dir ./data/checkpoints/structure_predictor_train_dir/ \ 5 | --video_deblur_train_dir ./data/checkpoints/video_deblur_train_dir/ \ 6 | --resnet_train_dir ./data/checkpoints/resnet_train_dir/ \ 7 | --end_2_end_train_dir ./data/checkpoints/end_2_end_train_dir/ \ 8 | --end_2_end_valid_dir ./data/checkpoints/end_2_end_valid_dir/ \ 9 | --max_steps 2000000 \ 10 | --resume_structure_predictor False \ 11 | --resume_video_deblur False \ 12 | --resume_resnet False \ 13 | --data_dir ./data/300VW/Images/ \ 14 | --img_list ./data/300VW/labels_68pt_256_train_sorted.txt \ 15 | --data_dir_valid None \ 16 | --img_list_valid None \ 17 | --training_period train & 18 | -------------------------------------------------------------------------------- /src/utils/color_jitter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageEnhance 3 | 4 | 5 | transformtypedict=dict(Brightness=ImageEnhance.Brightness, 6 | Contrast=ImageEnhance.Contrast, 7 | Sharpness=ImageEnhance.Sharpness, 8 | Color=ImageEnhance.Color) 9 | 10 | class ImageJitter(object): 11 | def __init__(self, transformdict): 12 | self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict] 13 | 14 | def __call__(self, img): 15 | out = img 16 | randtensor = np.random.uniform(0, 1, len(self.transforms)) 17 | 18 | for i, (transformer, alpha) in enumerate(self.transforms): 19 | r = alpha*(randtensor[i]*2.0 -1.0) + 1 20 | out = transformer(out).enhance(r) 21 | 22 | return out 23 | -------------------------------------------------------------------------------- /src/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def l1_loss(predictions, targets): 9 | total_elements = (tf.shape(targets)[0] * tf.shape(targets)[1] * tf.shape(targets)[2] 10 | * tf.shape(targets)[3]) 11 | total_elements = tf.to_float(total_elements) 12 | 13 | loss = tf.reduce_sum(tf.abs(predictions- targets)) 14 | loss = tf.div(loss, total_elements) 15 | 16 | return loss 17 | 18 | 19 | def l2_loss(predictions, targets): 20 | total_elements = (tf.shape(targets)[0] * tf.shape(targets)[1] * tf.shape(targets)[2] 21 | * tf.shape(targets)[3]) 22 | total_elements = tf.to_float(total_elements) 23 | 24 | loss = tf.reduce_sum(tf.square(predictions-targets)) 25 | loss = tf.div(loss, total_elements) 26 | 27 | return loss 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 KeqiangSun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/utils/geo_layer_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def bilinear_interp(im, x, y, name): 9 | with tf.variable_scope(name): 10 | x = tf.reshape(x, [-1]) 11 | y = tf.reshape(y, [-1]) 12 | 13 | num_batch = tf.shape(im)[0] 14 | _, height, width, channels = im.get_shape().as_list() 15 | 16 | x = tf.to_float(x) 17 | y = tf.to_float(y) 18 | 19 | height_f = tf.cast(height, 'float32') 20 | width_f = tf.cast(width, 'float32') 21 | zero = tf.constant(0, dtype=tf.int32) 22 | 23 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') 24 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') 25 | x = (x + 1.0) * (width_f - 1.0) / 2.0 26 | y = (y + 1.0) * (height_f - 1.0) / 2.0 27 | 28 | x0 = tf.cast(tf.floor(x), 'int32') 29 | x1 = x0 + 1 30 | y0 = tf.cast(tf.floor(y), 'int32') 31 | y1 = y0 + 1 32 | 33 | x0 = tf.clip_by_value(x0, zero, max_x) 34 | x1 = tf.clip_by_value(x1, zero, max_x) 35 | y0 = tf.clip_by_value(y0, zero, max_y) 36 | y1 = tf.clip_by_value(y1, zero, max_y) 37 | 38 | dim2 = width 39 | dim1 = width * height 40 | 41 | base = tf.range(num_batch) * dim1 42 | base = tf.reshape(base, [-1, 1]) 43 | base = tf.tile(base, [1, height * width]) 44 | base = tf.reshape(base, [-1]) 45 | 46 | base_y0 = base + y0 * dim2 47 | base_y1 = base + y1 * dim2 48 | idx_a = base_y0 + x0 49 | idx_b = base_y1 + x0 50 | idx_c = base_y0 + x1 51 | idx_d = base_y1 + x1 52 | 53 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 54 | im_flat = tf.to_float(im_flat) 55 | pixel_a = tf.gather(im_flat, idx_a) 56 | pixel_b = tf.gather(im_flat, idx_b) 57 | pixel_c = tf.gather(im_flat, idx_c) 58 | pixel_d = tf.gather(im_flat, idx_d) 59 | 60 | x1_f = tf.to_float(x1) 61 | y1_f = tf.to_float(y1) 62 | 63 | wa = tf.expand_dims((x1_f - x) * (y1_f - y), 1) 64 | wb = tf.expand_dims((x1_f - x) * (1.0 - (y1_f - y)), 1) 65 | wc = tf.expand_dims((1.0 - (x1_f - x)) * (y1_f - y), 1) 66 | wd = tf.expand_dims((1.0 - (x1_f - x)) * (1.0 - (y1_f - y)), 1) 67 | 68 | output = tf.add_n([wa*pixel_a, wb*pixel_b, wc*pixel_c, wd*pixel_d]) 69 | output = tf.reshape(output, shape=tf.stack([num_batch, height, width, channels])) 70 | 71 | return output 72 | 73 | def meshgrid(height, width): 74 | with tf.variable_scope('meshgrid'): 75 | x_t = tf.matmul( 76 | tf.ones(shape=tf.stack([height,1])), 77 | tf.transpose( 78 | tf.expand_dims( 79 | tf.linspace(-1.0,1.0,width),1),[1,0])) 80 | y_t = tf.matmul( 81 | tf.expand_dims( 82 | tf.linspace(-1.0, 1.0, height), 1), 83 | tf.ones(shape=tf.stack([1, width]))) 84 | x_t_flat = tf.reshape(x_t, (1,-1)) 85 | y_t_flat = tf.reshape(y_t, (1,-1)) 86 | grid_x = tf.reshape(x_t_flat, [1, height, width]) 87 | grid_y = tf.reshape(y_t_flat, [1, height, width]) 88 | 89 | return grid_x, grid_y 90 | 91 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a variable scope aware configuation object for TensorFlow 3 | """ 4 | import tensorflow as tf 5 | 6 | 7 | FLAGS = tf.app.flags.FLAGS 8 | class Config: 9 | def __init__(self): 10 | root = self.Scope('') 11 | for k, v in FLAGS.__dict__['__flags'].iteritems(): 12 | root[k] = v 13 | self.stack = [ root ] 14 | 15 | def iteritems(self): 16 | return self.to_dict().iteritems() 17 | 18 | def to_dict(self): 19 | self._pop_stale() 20 | out = {} 21 | for i in range(len(self.stack)): 22 | cs = self.stack[-i] 23 | for name in cs: 24 | out[name] = cs[name] 25 | return out 26 | 27 | def _pop_stale(self): 28 | var_scope_name = tf.get_variable_scope().name 29 | top = self.stack[0] 30 | while not top.contains(var_scope_name): 31 | self.stack.pop(0) 32 | top = self.stack[0] 33 | 34 | def __getitem__(self, name): 35 | self._pop_stale() 36 | for i in range(len(self.stack)): 37 | cs = self.stack[i] 38 | if name in cs: 39 | return cs[name] 40 | 41 | raise KeyError(name) 42 | 43 | def set_default(self, name, value): 44 | if not name in self: 45 | self[name] = value 46 | 47 | def __contains__(self, name): 48 | self._pop_stale() 49 | for i in range(len(self.stack)): 50 | cs = self.stack[i] 51 | if name in cs: 52 | return True 53 | return False 54 | 55 | def __setitem__(self, name, value): 56 | self._pop_stale() 57 | top = self.stack[0] 58 | var_scope_name = tf.get_variable_scope().name 59 | assert top.contains(var_scope_name) 60 | 61 | if top.name != var_scope_name: 62 | top = self.Scope(var_scope_name) 63 | self.stack.insert(0, top) 64 | 65 | top[name] = value 66 | 67 | class Scope(dict): 68 | def __init__(self, name): 69 | self.name = name 70 | 71 | def contains(self, var_scope_name): 72 | return var_scope_name.startswith(self.name) 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | def assert_raises(exception, fn): 78 | try: 79 | fn() 80 | except exception: 81 | pass 82 | else: 83 | assert False, "Expected exception" 84 | 85 | c = Config() 86 | 87 | c['hello'] = 1 88 | assert c['hello'] == 1 89 | 90 | with tf.variable_scope('foo'): 91 | c.set_default("bar", 10) 92 | c['bar'] = 2 93 | assert c['bar'] == 2 94 | assert c['hello'] == 1 95 | 96 | c.set_default("mario", True) 97 | 98 | with tf.variable_scope('meow'): 99 | c['dog'] = 3 100 | assert c['dog'] == 3 101 | assert c['bar'] == 2 102 | assert c['hello'] == 1 103 | 104 | assert c['mario'] == True 105 | 106 | assert_raises(KeyError, lambda: c['dog']) 107 | assert c['bar'] == 2 108 | assert c['hello'] == 1 109 | -------------------------------------------------------------------------------- /src/utils/affine_transformation.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def get_affine_mat(width, height, 8 | max_trans, max_rotate, max_zoom, 9 | min_trans, min_rotate, min_zoom): 10 | rotate = random.uniform(min_rotate, max_rotate) 11 | trans = random.uniform(min_trans, max_trans) 12 | zoom = random.uniform(min_zoom, max_zoom) 13 | 14 | # rotate 15 | transform_matrix = np.zeros((3,3)) 16 | center = (width/2.-0.5, height/2.-0.5) 17 | M = cv2.getRotationMatrix2D(center, rotate, 1) 18 | transform_matrix[:2,:] = copy.deepcopy(M) 19 | transform_matrix[2,:] = np.array([0, 0, 1]) 20 | 21 | # translate 22 | transform_matrix[0,2] += trans 23 | transform_matrix[1,2] += trans 24 | 25 | # zoom 26 | for i in range(3): 27 | transform_matrix[0,i] *= zoom 28 | transform_matrix[1,i] *= zoom 29 | transform_matrix[0,2] += (1.0 - zoom) * center[0] 30 | transform_matrix[1,2] += (1.0 - zoom) * center[1] 31 | 32 | # random horizontal mirror 33 | do_mirror = False 34 | mirror_rng = random.uniform(0.,1.) 35 | if mirror_rng>0.5: 36 | do_mirror = True 37 | 38 | return transform_matrix,do_mirror 39 | 40 | def AffinePoint(points, affine_mat): 41 | """ 42 | Affine a 2d point 43 | """ 44 | assert(affine_mat.shape[0] == 2) 45 | assert(affine_mat.shape[1] == 3) 46 | assert(points.shape[1] == 2) 47 | results = np.zeros(points.shape) 48 | for i in range(points.shape[0]): 49 | point_x = points[i,0] 50 | point_y = points[i,1] 51 | results[i,0] = affine_mat[0,0] * point_x + \ 52 | affine_mat[0,1] * point_y + \ 53 | affine_mat[0,2] 54 | results[i,1] = affine_mat[1,0] * point_x + \ 55 | affine_mat[1,1] * point_y + \ 56 | affine_mat[1,2] 57 | 58 | return results 59 | 60 | def affine2d(x, matrix, output_img_width, output_img_height, 61 | center=True, is_landmarks=False, do_mirror=False): 62 | assert(len(matrix.shape) == 2) 63 | if is_landmarks: 64 | transform_matrix = matrix[:2,:] 65 | src = x.squeeze() 66 | dst = np.empty((src.shape[0],2), dtype=np.float32) 67 | for i in range(src.shape[0]): 68 | dst[i,:] = AffinePoint(np.expand_dims(src[i,:], axis=0), transform_matrix) 69 | if do_mirror: 70 | results = exchange_landmarks(dst,np.array([0,16,1,15,2,14,3,13,4,12,5,11,6,10,7,9,17,26,18,25,19,24,20,23,21,22,36,45,37,44,38, 71 | 43,39,42,41,46,40,47,31,35,32,34,48,54,49,53,50,52,60,64,61,63,67,65,59,55,58,56]).reshape(-1, 2)) 72 | else: 73 | if do_mirror: 74 | matrix[0,0] = -matrix[0,0] 75 | matrix[0,1] = -matrix[0,1] 76 | matrix[0,2] = float(output_img_width)-matrix[0,2] 77 | transform_matrix = matrix[:2,:] 78 | src = x.astype(np.uint8) 79 | dst = cv2.warpAffine(src, transform_matrix, 80 | (output_img_width, output_img_height), 81 | flags=cv2.INTER_LINEAR, 82 | borderMode=cv2.BORDER_CONSTANT, 83 | borderValue=(127,127,127)) 84 | 85 | if len(dst.shape) == 2: 86 | dst = np.expand_dims(np.asarray(dst), axis=2) 87 | 88 | return dst 89 | 90 | def exchange_landmarks(input_tf, corr_list): 91 | for i in range(corr_list.shape[0]): 92 | temp = copy.deepcopy(input_tf[corr_list[i][0], :]) 93 | input_tf[corr_list[i][0], :] = input_tf[corr_list[i][1], :] 94 | input_tf[corr_list[i][1], :] = temp 95 | 96 | return input_tf 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FAB: A Robust Facial Landmark Detection Framework for Motion-Blurred Videos 2 | 3 | [Keqiang Sun](https://keqiangsun.github.io/), 4 | [Wayne Wu](https://wywu.github.io), 5 | [Tinghao Liu](https://github.com/KeqiangSun/FAB), 6 | [Shuo Yang](http://shuoyang1213.me/), 7 | [Quan Wang](https://github.com/KeqiangSun/FAB), 8 | [Qiang Zhou](https://github.com/KeqiangSun/FAB), 9 | [Chen Qian](https://scholar.google.com/citations?user=AerkT0YAAAAJ&hl=en), 10 | and [Zuochang Ye](https://github.com/KeqiangSun/FAB) 11 | 12 | [International Conference on Computer Vision (ICCV), 2019](http://iccv2019.thecvf.com/) 13 | 14 | 15 |
16 | 17 |
18 | 19 | We present a framework named FAB that takes advantage of structure consistency in the temporal dimension for facial landmark detection in motion-blurred videos. A structure predictor is proposed to predict the missing face structural information temporally, which serves as a geometry prior. This allows our framework to work as a virtuous circle. It is also a flexible video-based framework that can incorporate any static image-based methods to provide a performance boost on video datasets. Extensive experiments on Blurred-300VW, the proposed Real-world Motion Blur (RWMB) datasets and 300VW demonstrate the superior performance to the state-of-the-art methods. 20 | 21 | Moreover, we proposed a new benchmark named Real-World Motion Blur (RWMB). It contains videos with obvious motion blur picked from YouTube, which include dancing, boxing, jumping, etc. A detailed description of the system can be found in our [paper](https://keqiangsun.github.io/projects/FAB/FAB.html). 22 | 23 | ## Citation 24 | If you use this code or RWMB dataset for your research, please cite our paper. 25 | ``` 26 | @inproceedings{keqiang2019fab, 27 | author = {Sun, Keqiang and Wu, Wayne and Liu, Tinghao and Yang, Shuo and Wang, Quan and Zhou, Qiang and and Ye, Zuochang and Qian, Chen}, 28 | title = {FAB: A Robust Facial Landmark Detection Framework for Motion-Blurred Videos}, 29 | booktitle = {ICCV}, 30 | month = October, 31 | year = {2019} 32 | } 33 | ``` 34 | 35 | ## Prerequisites 36 | - Linux 37 | - Python 2 38 | - [TensorFlow](https://www.tensorflow.org/) 39 | 40 | ## Getting Started 41 | 42 | ### Blurred-300VW Dataset Download 43 | [Blurred-300VW](https://keqiangsun.github.io/projects/FAB/Blurred-300VW.html) is a video facial landmark dataset with artifical motion blur, based on [Original 300VW](https://ibug.doc.ic.ac.uk/resources/300-VW/). 44 | 45 | 0. Blurred-300VW [[Google Drive](https://drive.google.com/drive/folders/1aAe1vBoHZ78QlGjBEOup416tHNp4Ztcp?usp=sharing)] [[Baidu Drive]()] 46 | 1. Unzip the package and put them on './data/Blurred-300VW' 47 | 48 | ### Wider Facial Landmark in the Wild (WFLW) Dataset Download 49 | [Real-World Motion Blur(RWMB)](https://keqiangsun.github.io/projects/FAB/RWMB.html) is a newly proposed facial landmark benchmark with read-world motion blur. 50 | 51 | 0. RWMB Testing images [[Google Drive](https://drive.google.com/file/d/1vv7Qppg9R3xlj_O2dmtXZHzEnObOwoDh/view?usp=sharing)] [[Baidu Drive]()] 52 | 1. Unzip the package and put them on './data/RWMB' 53 | 54 | 55 | ### Training FAB on Blurred-300VW 56 | 57 | ```bash 58 | bash ./scripts/train.sh 59 | ``` 60 | 61 | ### Testing FAB on Blurred-300VW 62 | 63 | ```bash 64 | bash ./scripts/test.sh 65 | ``` 66 | 67 | 68 | ## To Do List 69 | Supported dataset 70 | - [x] [300 Faces In-the-Wild (300W)](https://ibug.doc.ic.ac.uk/resources/300-W/) 71 | - [x] [300 Videos in the Wild(300W)](https://ibug.doc.ic.ac.uk/resources/300-VW/) 72 | - [x] [Blurred 300VW](https://keqiangsun.github.io/projects/FAB/RWMB.html) 73 | - [ ] [Real-World Motion Blur(RWMB)](https://keqiangsun.github.io/projects/FAB/RWMB.html) 74 | 75 | 76 | Supported models 77 | - [ ] [Pretrained Model of Structure Predictor Block] 78 | - [ ] [Pretrained Model of Video Deblur Block] 79 | - [ ] [Pretrained Model of Resnet Block] 80 | - [ ] [Pretrained Model of Final model] 81 | 82 | 83 | ## Questions 84 | Please contact skq719@gmail.com 85 | -------------------------------------------------------------------------------- /src/test_FAB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import tarfile 8 | import cv2 9 | import copy 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from utils.curve import points_to_heatmap_rectangle_68pt 14 | from six.moves import xrange 15 | from six.moves import urllib 16 | from datagen import DataGenerator 17 | from datagen import ensure_dir 18 | from FAB import FAB 19 | 20 | MOMENTUM = 0.9 21 | POINTS_NUM = 68 22 | IMAGE_SIZE = 256 23 | PIC_CHANNEL = 3 24 | num_input_imgs = 3 25 | NUM_CLASSES = POINTS_NUM*2 26 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 27 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 28 | structure_predictor_net_channel = 64 29 | 30 | FLAGS = tf.app.flags.FLAGS 31 | tf.app.flags.DEFINE_string('structure_predictor_train_dir', '', """Directory where to write train_checkpoints.""") 32 | tf.app.flags.DEFINE_string('video_deblur_train_dir', '', """Directory where to write train_checkpoints.""") 33 | tf.app.flags.DEFINE_string('resnet_train_dir', '', """Directory where to write train_checkpoints.""") 34 | tf.app.flags.DEFINE_string('end_2_end_train_dir', '', """Directory where to write train_checkpoints.""") 35 | tf.app.flags.DEFINE_string('end_2_end_test_dir', '', """Directory where to write test logs.""") 36 | tf.app.flags.DEFINE_string('data_dir', '', """Directory where the dataset stores.""") 37 | tf.app.flags.DEFINE_string('img_list', '', """Directory where the img_list stores.""") 38 | 39 | tf.app.flags.DEFINE_float('learning_rate', 0.0, "learning rate.") 40 | tf.app.flags.DEFINE_integer('batch_size', 1, "batch size") 41 | tf.app.flags.DEFINE_boolean('resume_structure_predictor', False, """Resume from latest saved state.""") 42 | tf.app.flags.DEFINE_boolean('resume_resnet', False, """Resume from latest saved state.""") 43 | tf.app.flags.DEFINE_boolean('resume_video_deblur', False, """Resume from latest saved state.""") 44 | tf.app.flags.DEFINE_boolean('resume_all', False, """Resume from latest saved state.""") 45 | tf.app.flags.DEFINE_boolean('minimal_summaries', False, """Produce fewer summaries to save HD space.""") 46 | tf.app.flags.DEFINE_boolean('use_bn', False, """Use batch normalization. Otherwise use biases.""") 47 | 48 | def resume(sess, do_resume, ckpt_path, key_word): 49 | var = tf.global_variables() 50 | if do_resume: 51 | structure_predictor_latest = tf.train.latest_checkpoint(ckpt_path) 52 | if not structure_predictor_latest: 53 | print ("\n No checkpoint to continue from in ", ckpt_path, '\n') 54 | structure_predictor_var_to_restore = [val for val in var if key_word in val.name] 55 | saver_structure_predictor = tf.train.Saver(structure_predictor_var_to_restore) 56 | saver_structure_predictor.restore(sess, structure_predictor_latest) 57 | 58 | def test(resnet_model, is_training, F, H, F_curr, H_curr, input_images_blur, 59 | input_images_boundary, next_boundary_gt, labels, data_dir, img_list, 60 | dropout_ratio): 61 | 62 | global_step = tf.get_variable('global_step', [], 63 | initializer=tf.constant_initializer(0), 64 | trainable=False) 65 | 66 | init = tf.initialize_all_variables() 67 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) 68 | sess.run(init) 69 | val_save_root = os.path.join(FLAGS.end_2_end_test_dir,'visualization') 70 | 71 | ################################ resume part ################################# 72 | 73 | # resume weights 74 | resume(sess, FLAGS.resume_structure_predictor, FLAGS.structure_predictor_train_dir, 'voxel_flow_model_') 75 | resume(sess, FLAGS.resume_video_deblur, FLAGS.video_deblur_train_dir, 'video_deblur_model_') 76 | resume(sess, FLAGS.resume_resnet, FLAGS.resnet_train_dir, 'resnet_model_') 77 | resume(sess, FLAGS.resume_all, FLAGS.end_2_end_train_dir, '') 78 | 79 | ############################################################################## 80 | 81 | gt_file_path = os.path.join(FLAGS.end_2_end_test_dir,'gt.txt') 82 | pre_file_path = os.path.join(FLAGS.end_2_end_test_dir,'pre.txt') 83 | ensure_dir(gt_file_path) 84 | ensure_dir(pre_file_path) 85 | gt_file = open(gt_file_path,'w') 86 | pre_file = open(pre_file_path,'w') 87 | 88 | dataset = DataGenerator(data_dir,img_list) 89 | dataset._create_train_table() 90 | dataset._create_sets_for_300VW() 91 | test_gen = dataset._aux_generator(batch_size = FLAGS.batch_size, num_input_imgs = num_input_imgs, 92 | NUM_CLASSES = POINTS_NUM*2, sample_set='test') 93 | 94 | test_break_flag = False 95 | for x in xrange(len(dataset.train_table)-2): 96 | 97 | step = sess.run(global_step) 98 | 99 | if not test_break_flag: 100 | test_line_num, frame_name, input_boundaries, boundary_gt_test, input_images_blur_generated, landmark_gt_test, names, test_break_flag = next(test_gen) 101 | 102 | if (frame_name == '2.jpg') or test_line_num <= 3: 103 | input_images_boundary_init = copy.deepcopy(input_boundaries) 104 | F_init = np.zeros([FLAGS.batch_size, IMAGE_SIZE//2, 105 | IMAGE_SIZE//2, structure_predictor_net_channel//2], dtype=np.float32) 106 | 107 | H_init = np.zeros([1, FLAGS.batch_size, IMAGE_SIZE//2, 108 | IMAGE_SIZE//2, structure_predictor_net_channel], dtype=np.float32) 109 | 110 | feed_dict={ 111 | input_images_boundary:input_images_boundary_init, 112 | input_images_blur:input_images_blur_generated, 113 | F:F_init, 114 | H:H_init, 115 | labels:landmark_gt_test, 116 | next_boundary_gt:boundary_gt_test, 117 | dropout_ratio:1.0 118 | } 119 | else: 120 | output_points = o[0] 121 | output_points = np.reshape(output_points,(POINTS_NUM,2)) 122 | boundary_from_points = points_to_heatmap_rectangle_68pt(output_points) 123 | boundary_from_points = np.expand_dims(boundary_from_points,axis=0) 124 | boundary_from_points = np.expand_dims(boundary_from_points,axis=3) 125 | 126 | input_images_boundary_init = np.concatenate([input_images_boundary_init[:,:,:,1:2], 127 | boundary_from_points], axis=3) 128 | feed_dict={ 129 | input_images_boundary:input_images_boundary_init, 130 | input_images_blur:input_images_blur_generated, 131 | F:o[-2], 132 | H:o[-1], 133 | labels:landmark_gt_test, 134 | next_boundary_gt:boundary_gt_test, 135 | dropout_ratio:1.0 136 | } 137 | 138 | i = [resnet_model.logits, F_curr, H_curr] 139 | o = sess.run(i, feed_dict=feed_dict) 140 | pres = o[0] 141 | 142 | for batch_num,pre in enumerate(pres): 143 | for v in pre: 144 | pre_file.write(str(v*255.0)+' ') 145 | if len(names) > 1: 146 | pre_file.write(names[-1]) 147 | else: 148 | pre_file.write(names[batch_num]) 149 | pre_file.write('\n') 150 | for batch_num,g in enumerate(landmark_gt_test): 151 | for v in g: 152 | gt_file.write(str(v*255.0)+' ') 153 | if len(names) > 1: 154 | gt_file.write(names[-1]) 155 | else: 156 | gt_file.write(names[batch_num]) 157 | gt_file.write('\n') 158 | 159 | img = input_images_blur_generated[0,:,:,0:3]*255 160 | points = o[0][0]*255 161 | 162 | for point_num in range(int(points.shape[0]/2)): 163 | cv2.circle(img,(int(round(points[point_num*2])),int(round(points[point_num*2+1]))),1,(55,225,155),2) 164 | val_save_path = os.path.join(val_save_root,str(step)+'.jpg') 165 | ensure_dir(val_save_path) 166 | cv2.imwrite(val_save_path,img) 167 | 168 | global_step = global_step + 1 169 | print('Test done!') 170 | 171 | def main(argv=None): 172 | 173 | resnet_model = FAB() 174 | 175 | is_training = tf.placeholder('bool', [], name='is_training') 176 | input_images_boundary = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 2)) 177 | input_images_blur = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, PIC_CHANNEL*3)) 178 | next_boundary_gt = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 1)) 179 | labels = tf.placeholder(tf.float32,shape=(FLAGS.batch_size,NUM_CLASSES)) 180 | dropout_ratio = tf.placeholder(tf.float32) 181 | F = tf.placeholder(tf.float32, [FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel//2]) 182 | H = tf.placeholder(tf.float32, [1, FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel]) 183 | F_curr, H_curr= \ 184 | resnet_model.FAB_inference(input_images_boundary, input_images_blur, F, H, FLAGS.batch_size, 185 | net_channel=structure_predictor_net_channel, num_classes=136, num_blocks=[2, 2, 2, 2], 186 | use_bias=(not FLAGS.use_bn), bottleneck=True, dropout_ratio=1.0) 187 | 188 | test(resnet_model, is_training, F, H, F_curr, H_curr, input_images_blur, 189 | input_images_boundary, next_boundary_gt, labels, FLAGS.data_dir, FLAGS.img_list, 190 | dropout_ratio) 191 | 192 | if __name__ == '__main__': 193 | tf.app.run() 194 | -------------------------------------------------------------------------------- /src/utils/curve.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import time 4 | import numpy as np 5 | from numpy import linalg as LA 6 | 7 | 8 | def distance(p1, p2): 9 | return math.sqrt((p1[0] - p2[0]) * (p1[0] - p2[0]) + \ 10 | (p1[1] - p2[1]) * (p1[1] - p2[1])) 11 | 12 | def curve_interp(src, samples, index): 13 | assert(src.shape[0] > 2) 14 | assert(samples >= 2) 15 | 16 | src_1 = src[0:src.shape[0] - 1, :] 17 | src_2 = src[1:src.shape[0], :] 18 | src_delta = src_1 - src_2 19 | length = np.sqrt(src_delta[:, 0]**2 + src_delta[:, 1]**2) 20 | assert(length.shape[0] == src.shape[0] - 1) 21 | 22 | accu_length = np.zeros((src.shape[0])) 23 | for i in xrange(1, accu_length.shape[0]): 24 | accu_length[i] = accu_length[i - 1] + length[i - 1] 25 | dst = np.zeros((samples, 2)) 26 | pre_raw = 0 27 | 28 | step_interp = accu_length[accu_length.shape[0] - 1] / float(samples - 1) 29 | dst[0, :] = src[0, :] 30 | dst[dst.shape[0] - 1, :] = src[src.shape[0] - 1, :] 31 | for i in xrange(1, samples - 1): 32 | covered_interp = step_interp * i 33 | while (covered_interp > accu_length[pre_raw + 1]): 34 | pre_raw += 1 35 | assert(pre_raw < accu_length.shape[0] - 1) 36 | dx = (covered_interp - accu_length[pre_raw]) / length[pre_raw] 37 | dst[i, :] = src[pre_raw, :] * (1.0 - dx) + src[pre_raw + 1, :] * dx 38 | 39 | return dst 40 | 41 | def curve_fitting(points, samples, index): 42 | num_points = points.shape[0] 43 | assert(num_points > 1) 44 | valid_points = [points[0]] 45 | for i in xrange(1, num_points): 46 | if (distance(points[i, :], points[i - 1, :]) > 0.001): 47 | valid_points.append(points[i, :]) 48 | assert(len(valid_points) > 1) 49 | valid_points = np.asarray(valid_points) 50 | functions = np.zeros((valid_points.shape[0] - 1, 9)) 51 | 52 | if valid_points.shape[0] == 2: 53 | functions[0, 0] = LA.norm(valid_points[0, :] - valid_points[1, :]) 54 | functions[0, 1] = valid_points[0, 0] 55 | functions[0, 2] = (valid_points[1, 0] - valid_points[0, 0]) / functions[0, 0] 56 | functions[0, 3] = 0 57 | functions[0, 4] = 0 58 | functions[0, 5] = valid_points[0, 1] 59 | functions[0, 6] = (valid_points[1, 1] - valid_points[0, 1]) / functions[0, 0] 60 | functions[0, 7] = 0 61 | functions[0, 8] = 0 62 | else: 63 | Mx = np.zeros((valid_points.shape[0])) 64 | My = np.zeros((valid_points.shape[0])) 65 | A = np.zeros((valid_points.shape[0] - 2)) 66 | B = np.zeros((valid_points.shape[0] - 2)) 67 | C = np.zeros((valid_points.shape[0] - 2)) 68 | Dx = np.zeros((valid_points.shape[0] - 2)) 69 | Dy = np.zeros((valid_points.shape[0] - 2)) 70 | for i in xrange(functions.shape[0]): 71 | functions[i, 0] = LA.norm(valid_points[i, :] - valid_points[i + 1, :]) 72 | for i in xrange(A.shape[0]): 73 | A[i] = functions[i, 0] 74 | B[i] = 2.0 * (functions[i, 0] + functions[i + 1, 0]) 75 | C[i] = functions[i + 1, 0] 76 | Dx[i] = 6.0 * ((valid_points[i + 2, 0] - valid_points[i + 1, 0]) / functions[i + 1, 0] - \ 77 | (valid_points[i + 1, 0] - valid_points[i, 0]) / functions[i, 0]) 78 | 79 | Dy[i] = 6.0 * ((valid_points[i + 2, 1] - valid_points[i + 1, 1]) / functions[i + 1, 0] - \ 80 | (valid_points[i + 1, 1] - valid_points[i, 1]) / functions[i, 0]) 81 | 82 | C[0] = C[0] / B[0] 83 | Dx[0] = Dx[0] / B[0] 84 | Dy[0] = Dy[0] / B[0] 85 | for i in xrange(1, A.shape[0]): 86 | tmp = B[i] - A[i] * C[i - 1] 87 | C[i] = C[i] / tmp 88 | Dx[i] = (Dx[i] - A[i] * Dx[i - 1]) / tmp 89 | Dy[i] = (Dy[i] - A[i] * Dy[i - 1]) / tmp 90 | Mx[valid_points.shape[0] - 2] = Dx[valid_points.shape[0] - 3] 91 | My[valid_points.shape[0] - 2] = Dy[valid_points.shape[0] - 3] 92 | for i in xrange(valid_points.shape[0] - 4, -1, -1): 93 | Mx[i + 1] = Dx[i] - C[i] * Mx[i + 2] 94 | My[i + 1] = Dy[i] - C[i] * My[i + 2] 95 | Mx[0] = 0 96 | Mx[valid_points.shape[0] - 1] = 0 97 | My[0] = 0 98 | My[valid_points.shape[0] - 1] = 0 99 | 100 | for i in xrange(functions.shape[0]): 101 | functions[i, 1] = valid_points[i, 0] 102 | functions[i, 2] = (valid_points[i + 1, 0] - valid_points[i, 0]) / functions[i, 0] - \ 103 | (2.0 * functions[i, 0] * Mx[i] + functions[i, 0] * Mx[i + 1]) / 6.0 104 | functions[i, 3] = Mx[i] / 2.0 105 | functions[i, 4] = (Mx[i + 1] - Mx[i]) / (6.0 * functions[i, 0]) 106 | functions[i, 5] = valid_points[i, 1] 107 | functions[i, 6] = (valid_points[i + 1, 1] - valid_points[i, 1]) / functions[i, 0] - \ 108 | (2.0 * functions[i, 0] * My[i] + functions[i, 0] * My[i + 1]) / 6.0 109 | functions[i, 7] = My[i] / 2.0 110 | functions[i, 8] = (My[i + 1] - My[i]) / (6.0 * functions[i, 0]) 111 | 112 | samples_per_segment = samples * 1 / functions.shape[0] + 1 113 | rawcurve = np.zeros((functions.shape[0] * samples_per_segment, 2)) 114 | for i in xrange(functions.shape[0]): 115 | step = functions[i, 0] / samples_per_segment 116 | for j in xrange(samples_per_segment): 117 | t = step * j 118 | rawcurve[i * samples_per_segment + j, :] = np.asarray([functions[i, 1] + functions[i, 2] * t + functions[i, 3] * t * t + functions[i, 4] * t * t * t, 119 | functions[i, 5] + functions[i, 6] * t + functions[i, 7] * t * t + functions[i, 8] * t * t * t]) 120 | 121 | curve_tmp = curve_interp(rawcurve, samples, index) 122 | 123 | return curve_tmp 124 | 125 | 126 | def points_to_heatmap_rectangle_68pt(points, 127 | heatmap_num=13, 128 | heatmap_size=(256, 256), 129 | label_size=(256, 256), 130 | sigma=1): 131 | 132 | for i in range(points.shape[0]): 133 | points[i][0] *= (float(heatmap_size[1]) / float(label_size[1])) 134 | points[i][1] *= (float(heatmap_size[0]) / float(label_size[0])) 135 | 136 | align_on_curve = [0] * heatmap_num 137 | curves = [0] * heatmap_num 138 | align_on_curve[0] = np.zeros((17, 2)) 139 | align_on_curve[1] = np.zeros((5, 2)) 140 | align_on_curve[2] = np.zeros((5, 2)) 141 | align_on_curve[3] = np.zeros((4, 2)) 142 | align_on_curve[4] = np.zeros((5, 2)) 143 | align_on_curve[5] = np.zeros((4, 2)) 144 | align_on_curve[6] = np.zeros((4, 2)) 145 | align_on_curve[7] = np.zeros((4, 2)) 146 | align_on_curve[8] = np.zeros((4, 2)) 147 | align_on_curve[9] = np.zeros((7, 2)) 148 | align_on_curve[10] = np.zeros((5, 2)) 149 | align_on_curve[11] = np.zeros((5, 2)) 150 | align_on_curve[12] = np.zeros((7, 2)) 151 | 152 | for i in range(17): 153 | align_on_curve[0][i] = points[i] 154 | 155 | for i in range(5): 156 | align_on_curve[1][i] = points[i + 17] 157 | 158 | for i in range(5): 159 | align_on_curve[2][i] = points[i + 22] 160 | 161 | for i in range(4): 162 | align_on_curve[3][i] = points[i + 27] 163 | 164 | for i in range(5): 165 | align_on_curve[4][i] = points[i + 31] 166 | 167 | align_on_curve[5][0] = points[36] 168 | align_on_curve[5][1] = points[37] 169 | align_on_curve[5][2] = points[38] 170 | align_on_curve[5][3] = points[39] 171 | 172 | align_on_curve[6][0] = points[39] 173 | align_on_curve[6][1] = points[40] 174 | align_on_curve[6][2] = points[41] 175 | align_on_curve[6][3] = points[36] 176 | 177 | align_on_curve[7][0] = points[42] 178 | align_on_curve[7][1] = points[43] 179 | align_on_curve[7][2] = points[44] 180 | align_on_curve[7][3] = points[45] 181 | 182 | align_on_curve[8][0] = points[45] 183 | align_on_curve[8][1] = points[46] 184 | align_on_curve[8][2] = points[47] 185 | align_on_curve[8][3] = points[42] 186 | 187 | for i in range(7): 188 | align_on_curve[9][i] = points[i + 48] 189 | 190 | for i in range(5): 191 | align_on_curve[10][i] = points[i + 60] 192 | 193 | align_on_curve[11][0] = points[60] 194 | align_on_curve[11][1] = points[67] 195 | align_on_curve[11][2] = points[66] 196 | align_on_curve[11][3] = points[65] 197 | align_on_curve[11][4] = points[64] 198 | 199 | align_on_curve[12][0] = points[48] 200 | align_on_curve[12][1] = points[59] 201 | align_on_curve[12][2] = points[58] 202 | align_on_curve[12][3] = points[57] 203 | align_on_curve[12][4] = points[56] 204 | align_on_curve[12][5] = points[55] 205 | align_on_curve[12][6] = points[54] 206 | 207 | heatmap = np.zeros((heatmap_size[0], heatmap_size[1], heatmap_num)) 208 | for i in range(heatmap_num): 209 | curve_map = np.full((heatmap_size[0], heatmap_size[1]), 255, dtype=np.uint8) 210 | 211 | valid_points = [align_on_curve[i][0, :]] 212 | for j in range(1, align_on_curve[i].shape[0]): 213 | if (distance(align_on_curve[i][j, :], align_on_curve[i][j - 1, :]) > 0.001): 214 | valid_points.append(align_on_curve[i][j, :]) 215 | 216 | if len(valid_points) > 1: 217 | curves[i] = curve_fitting(align_on_curve[i], align_on_curve[i].shape[0] * 10, i) 218 | for j in range(curves[i].shape[0]): 219 | if (int(curves[i][j, 0] + 0.5) >= 0 and int(curves[i][j, 0] + 0.5) < heatmap_size[1] and 220 | int(curves[i][j, 1] + 0.5) >= 0 and int(curves[i][j, 1] + 0.5) < heatmap_size[0]): 221 | curve_map[int(curves[i][j, 1] + 0.5), int(curves[i][j, 0] + 0.5)] = 0 222 | 223 | image_dis = cv2.distanceTransform( 224 | curve_map, cv2.cv.CV_DIST_L2, cv2.cv.CV_DIST_MASK_PRECISE) 225 | 226 | image_dis = image_dis.astype(np.float64) 227 | image_gaussian = (1.0 / (2.0 * np.pi * (sigma**2))) * np.exp(-1.0 * image_dis**2 / (2.0 * sigma**2)) 228 | image_gaussian = np.where(image_dis < (3.0 * sigma), image_gaussian, 0) 229 | 230 | maxVal = image_gaussian.max() 231 | minVal = image_gaussian.min() 232 | 233 | if maxVal == minVal: 234 | image_gaussian = 0 235 | else: 236 | image_gaussian = (image_gaussian - minVal) / (maxVal - minVal) 237 | 238 | heatmap[:, :, i] = image_gaussian 239 | 240 | heatmap = np.sum(heatmap, axis=2) 241 | 242 | return heatmap 243 | -------------------------------------------------------------------------------- /src/datagen.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import cv2 3 | import os 4 | import random 5 | import time 6 | import copy 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import scipy.misc as scm 10 | import tensorflow as tf 11 | 12 | from PIL import Image 13 | from utils import affine_transformation 14 | from utils.color_jitter import ImageJitter 15 | from skimage import transform, util 16 | from utils.curve import points_to_heatmap_rectangle_68pt 17 | 18 | def ensure_dir(file_path): 19 | directory = os.path.dirname(file_path) 20 | if not os.path.exists(directory): 21 | os.makedirs(directory) 22 | 23 | class DataGenerator(): 24 | 25 | def __init__(self, img_dir=None, train_list_file=None, 26 | img_dir_valid=None, valid_list_file=None): 27 | self.img_dir = img_dir 28 | self.img_dir_valid = img_dir_valid 29 | self.train_list_file = train_list_file 30 | self.valid_list_file = valid_list_file 31 | 32 | def _create_train_table(self): 33 | self.train_table = [] 34 | input_file = open(self.train_list_file, 'r') 35 | for line in input_file.readlines(): 36 | self.train_table.append(line) 37 | input_file.close() 38 | 39 | def _randomize(self): 40 | random.shuffle(self.train_table) 41 | 42 | def _create_train_sets_for_300W(self): 43 | self.train_set = [] 44 | input_file = open(self.train_list_file, 'r') 45 | for line in input_file.readlines(): 46 | self.train_set.append(line) 47 | input_file.close() 48 | 49 | def _create_valid_sets_for_300W(self): 50 | self.valid_set = [] 51 | input_file = open(self.valid_list_file, 'r') 52 | for line in input_file.readlines(): 53 | self.valid_set.append(line) 54 | input_file.close() 55 | 56 | def _create_sets_for_300VW(self, validation_rate = 0.05): 57 | self.sample = len(self.train_table) 58 | valid_sample = int(self.sample * validation_rate) 59 | self.train_set = self.train_table[:self.sample - valid_sample] 60 | self.valid_set = self.train_table[self.sample - valid_sample:] 61 | self.test_set = self.train_table[:] 62 | 63 | def _aux_generator(self, batch_size = 1, NUM_CLASSES = 136, 64 | num_input_imgs = 3, normalize = True, sample_set = 'train'): 65 | train_line_num = 0 66 | valid_line_num = 0 67 | test_line_num = 0 68 | test_break_flag = False 69 | 70 | while True: 71 | train_img = np.zeros((batch_size, 256,256,3*num_input_imgs), dtype = np.float32) 72 | train_gtmap = np.zeros((batch_size, NUM_CLASSES), dtype = np.float32) 73 | i = 0 74 | names = [] 75 | max_lines = 3 76 | 77 | while i < batch_size: 78 | input_boundaries = [] 79 | 80 | if sample_set == 'train': 81 | if train_line_num+1 == len(self.train_set) or train_line_num+2 == len(self.train_set) : 82 | train_line_num = 0 83 | elif sample_set == 'valid': 84 | if valid_line_num+1 == len(self.valid_set) or valid_line_num+2 == len(self.valid_set): 85 | valid_line_num = 0 86 | elif sample_set == 'test': 87 | if test_line_num+1 == len(self.test_set): 88 | print('The end of the testing set!') 89 | test_break_flag = True 90 | 91 | for cntr in range(max_lines): 92 | if sample_set == 'train': 93 | line = self.train_set[train_line_num] 94 | train_line_num += 1 95 | elif sample_set == 'valid': 96 | line = self.valid_set[valid_line_num] 97 | valid_line_num += 1 98 | elif sample_set == 'test': 99 | line = self.test_set[test_line_num] 100 | test_line_num += 1 101 | 102 | eles = line.strip().split() 103 | frame_path = eles[-1] 104 | name = frame_path.split('/')[-1] 105 | names.append(name) 106 | gt = np.array(map(float,eles[:-1])) 107 | gt_flatten = np.reshape(gt,(gt.shape[0]/2,2)) 108 | 109 | boundary_gt_train = points_to_heatmap_rectangle_68pt(gt_flatten) 110 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=0) 111 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=3) 112 | input_boundaries.append(boundary_gt_train) 113 | 114 | if sample_set == 'train': 115 | if name != '0.jpg' and name != '1.jpg': 116 | break 117 | elif sample_set == 'valid': 118 | if (name != '0.jpg' and name != '1.jpg' and valid_line_num > 2): 119 | break 120 | elif sample_set == 'test': 121 | if (name != '0.jpg' and name != '1.jpg' and test_line_num > 2): 122 | break 123 | 124 | input_boundaries = input_boundaries[:-1] 125 | if len(input_boundaries) > 0: 126 | input_boundaries = np.concatenate(input_boundaries,axis=3) 127 | 128 | path_eles = frame_path.split('/') 129 | name_eles = path_eles[-1].split('.') 130 | frame_num = int(name_eles[0]) 131 | 132 | frame_path_2 = os.path.join(path_eles[0],str(frame_num-2)+'.'+name_eles[-1]) 133 | input_img_path_2 = os.path.join(self.img_dir, frame_path_2) 134 | img_2 = self.open_img(input_img_path_2) 135 | img_2 = scm.imresize(img_2, (256,256)) 136 | 137 | frame_path_1 = os.path.join(path_eles[0],str(frame_num-1)+'.'+name_eles[-1]) 138 | input_img_path_1 = os.path.join(self.img_dir, frame_path_1) 139 | img_1 = self.open_img(input_img_path_1) 140 | img_1 = scm.imresize(img_1, (256,256)) 141 | 142 | frame_path_0 = os.path.join(path_eles[0],str(frame_num)+'.'+name_eles[-1]) 143 | input_img_path_0 = os.path.join(self.img_dir, frame_path_0) 144 | img_0 = self.open_img(input_img_path_0) 145 | img_0 = scm.imresize(img_0, (256,256)) 146 | 147 | img = np.concatenate([img_2,img_1,img_0],axis=2) 148 | 149 | if normalize: 150 | train_img[i] = img.astype(np.float32) / 255 151 | train_gtmap[i] = gt.astype(np.float32) /255 152 | else : 153 | train_img[i] = img.astype(np.float32) 154 | train_gtmap[i] = gt.astype(np.float32) 155 | 156 | i = i + 1 157 | 158 | if sample_set == 'train': 159 | yield train_line_num, name, input_boundaries, boundary_gt_train, train_img, train_gtmap 160 | elif sample_set == 'valid': 161 | yield valid_line_num, name, input_boundaries, boundary_gt_train, train_img, train_gtmap 162 | elif sample_set == 'test': 163 | print("name = {}".format(name)) 164 | yield test_line_num, name, input_boundaries, boundary_gt_train, train_img, train_gtmap, names, test_break_flag 165 | 166 | def _voxel_flow_generator_(self, batch_size = 1, sample_set = 'train'): 167 | 168 | train_line_num = 0 169 | valid_line_num = 0 170 | 171 | while True: 172 | input_boundaries = np.zeros((batch_size, 256, 256, 2), dtype = np.float32) 173 | boundary_gts_train = np.zeros((batch_size, 256, 256, 1), dtype = np.float32) 174 | i = 0 175 | max_lines = 3 176 | 177 | while i < batch_size: 178 | input_boundary = [] 179 | 180 | if sample_set == 'train': 181 | if train_line_num+1 == len(self.train_set) or train_line_num+2 == len(self.train_set) : 182 | train_line_num = 0 183 | line_num = copy.deepcopy(train_line_num) 184 | elif sample_set == 'valid': 185 | if valid_line_num+1 == len(self.valid_set) or valid_line_num+2 == len(self.valid_set): 186 | valid_line_num = 0 187 | line_num = copy.deepcopy(valid_line_num) 188 | 189 | for cntr in range(max_lines): 190 | if sample_set == 'train': 191 | line = self.train_set[line_num] 192 | elif sample_set == 'valid': 193 | line = self.valid_set[line_num] 194 | 195 | line_num += 1 196 | eles = line.strip().split() 197 | frame_path = eles[-1] 198 | gt = np.array(map(float,eles[:-1])) 199 | gt_flatten = np.reshape(gt,(gt.shape[0]/2,2)) 200 | 201 | boundary_gt_train = points_to_heatmap_rectangle_68pt(gt_flatten) 202 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=2) 203 | boundary_gt_train = np.expand_dims(boundary_gt_train,axis=0) 204 | input_boundary.append(boundary_gt_train[0]) 205 | 206 | train_line_num += 1 207 | valid_line_num += 1 208 | input_boundary = input_boundary[:-1] 209 | input_boundaries[i] = np.concatenate(input_boundary,axis=2) 210 | boundary_gts_train[i] = boundary_gt_train[0] 211 | 212 | i = i + 1 213 | 214 | if sample_set == 'train': 215 | yield input_boundaries, boundary_gts_train 216 | elif sample_set == 'valid': 217 | yield input_boundaries, boundary_gts_train 218 | 219 | def _video_deblur_generator_(self, batch_size = 1,normalize = True, 220 | num_input_imgs = 3,sample_set='train'): 221 | 222 | train_line_num = 0 223 | valid_line_num = 0 224 | 225 | while True: 226 | train_img = np.zeros((batch_size, 256, 256, 3*num_input_imgs), dtype = np.float32) 227 | i = 0 228 | max_lines = 3 229 | 230 | while i < batch_size: 231 | input_images = [] 232 | 233 | if sample_set == 'train': 234 | if train_line_num+1 == len(self.train_set) or train_line_num+2 == len(self.train_set) : 235 | train_line_num = 0 236 | line_num = copy.deepcopy(train_line_num) 237 | elif sample_set == 'valid': 238 | if valid_line_num+1 == len(self.valid_set) or valid_line_num+2 == len(self.valid_set): 239 | valid_line_num = 0 240 | line_num = copy.deepcopy(valid_line_num) 241 | 242 | for cntr in range(max_lines): 243 | if sample_set == 'train': 244 | line = self.train_set[line_num] 245 | elif sample_set == 'valid': 246 | line = self.valid_set[line_num] 247 | line_num += 1 248 | 249 | eles = line.strip().split() 250 | frame_path = eles[-1] 251 | input_img_path = os.path.join(self.img_dir, frame_path) 252 | name = frame_path.split('/')[-1] 253 | 254 | img = self.open_img(input_img_path) 255 | img = scm.imresize(img, (256,256)) 256 | 257 | if normalize: 258 | input_images.append(img.astype(np.float32) / 255) 259 | else : 260 | input_images.append(img.astype(np.float32)) 261 | 262 | train_line_num += 1 263 | valid_line_num += 1 264 | train_img[i] = np.concatenate(input_images,axis=2) 265 | 266 | i = i + 1 267 | 268 | if sample_set == 'train': 269 | yield train_line_num, name, train_img 270 | elif sample_set == 'valid': 271 | yield valid_line_num, name, train_img 272 | 273 | def _resnet_generator(self, batch_size = 16, NUM_CLASSES = 136, 274 | normalize = True, sample_set = 'train'): 275 | 276 | while True: 277 | train_img = np.zeros((batch_size, 256,256,3), dtype = np.float32) 278 | train_gtmap = np.zeros((batch_size, NUM_CLASSES), dtype = np.float32) 279 | i = 0 280 | 281 | while i < batch_size: 282 | if sample_set == 'train': 283 | line = random.choice(self.train_set) 284 | elif sample_set == 'valid': 285 | line = random.choice(self.valid_set) 286 | 287 | eles = line.strip().split() 288 | name = eles[-1] 289 | if sample_set == 'train': 290 | input_img_path = os.path.join(self.img_dir, name) 291 | elif sample_set == 'valid': 292 | input_img_path = os.path.join(self.img_dir_valid, name) 293 | 294 | img = self.open_img(input_img_path) 295 | 296 | if sample_set == 'train': 297 | gt = np.array(list(map(float, eles[:-1]))) 298 | gt = gt.reshape(-1, 2) 299 | 300 | transform_matrix, do_mirror = affine_transformation.get_affine_mat( 301 | width=256, height=256, 302 | max_trans=40, max_rotate=30, max_zoom=1.1, 303 | min_trans=-40, min_rotate=-30, min_zoom=0.9) 304 | 305 | img = affine_transformation.affine2d(img, transform_matrix, output_img_width=256, 306 | output_img_height=256, center=True, 307 | is_landmarks=False, do_mirror=do_mirror) 308 | gt = affine_transformation.affine2d(gt, transform_matrix, output_img_width=256, 309 | output_img_height=256, center=True, 310 | is_landmarks=True, do_mirror=do_mirror) 311 | 312 | transformdict = {'Brightness':0.5025, 'Contrast':0.5136, 313 | 'Sharpness':0.5568, 'Color':0.5203} 314 | image_jitter = ImageJitter(transformdict) 315 | img = Image.fromarray(img) 316 | img = image_jitter(img) 317 | img = np.array(img) 318 | 319 | img = util.random_noise(img, mode='gaussian') 320 | img = (img*255).astype(np.uint8) 321 | gt = gt.reshape(1, -1).squeeze() 322 | 323 | elif sample_set == 'valid': 324 | gt = np.array(map(float,eles[:-1])) 325 | 326 | if normalize: 327 | train_img[i] = img.astype(np.float32) / 255 328 | train_gtmap[i] = gt.astype(np.float32) /255 329 | else: 330 | train_img[i] = img.astype(np.float32) 331 | train_gtmap[i] = gt.astype(np.float32) 332 | 333 | i = i + 1 334 | 335 | yield train_img, train_gtmap 336 | 337 | def open_img(self, img_path, color = 'RGB'): 338 | img = cv2.imread(img_path) 339 | if color == 'RGB': 340 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 341 | return img 342 | elif color == 'BGR': 343 | return img 344 | elif color == 'GRAY': 345 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 346 | else: 347 | print('Color mode supported: RGB/BGR. If you need another mode do it yourself :p') 348 | -------------------------------------------------------------------------------- /src/train_FAB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import tarfile 8 | import cv2 9 | import time 10 | import copy 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from utils.curve import points_to_heatmap_rectangle_68pt 15 | from six.moves import xrange 16 | from six.moves import urllib 17 | from datagen import DataGenerator 18 | from datagen import ensure_dir 19 | from FAB import FAB 20 | 21 | MOMENTUM = 0.9 22 | POINTS_NUM = 68 23 | IMAGE_SIZE = 256 24 | PIC_CHANNEL = 3 25 | num_input_imgs = 3 26 | NUM_CLASSES = POINTS_NUM*2 27 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 28 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 29 | structure_predictor_net_channel = 64 30 | 31 | FLAGS = tf.app.flags.FLAGS 32 | # address 33 | tf.app.flags.DEFINE_string('structure_predictor_train_dir', '', """Directory where to write train_checkpoints.""") 34 | tf.app.flags.DEFINE_string('video_deblur_train_dir', '', """Directory where to write train_checkpoints.""") 35 | tf.app.flags.DEFINE_string('resnet_train_dir', '', """Directory where to write train_checkpoints.""") 36 | tf.app.flags.DEFINE_string('end_2_end_train_dir', '', """Directory where to write train_checkpoints.""") 37 | tf.app.flags.DEFINE_string('end_2_end_valid_dir', '', """Directory where to write valid logs.""") 38 | tf.app.flags.DEFINE_string('data_dir', '', """Directory where the dataset stores.""") 39 | tf.app.flags.DEFINE_string('img_list', '', """Directory where the img_list stores.""") 40 | tf.app.flags.DEFINE_string('data_dir_valid', '', """Directory where the valid image stores. Only used for pretraining on 300W datasets.""") 41 | tf.app.flags.DEFINE_string('img_list_valid', '', """Directory where the valid image_list stores. Only used for pretraining on 300W datasets.""") 42 | # parameters 43 | tf.app.flags.DEFINE_float('learning_rate', 0.00003, "learning rate.") 44 | tf.app.flags.DEFINE_integer('batch_size', 1, "batch size") 45 | tf.app.flags.DEFINE_integer('max_steps', 2000000, "max steps") 46 | tf.app.flags.DEFINE_boolean('resume_structure_predictor', True, """Resume from latest saved state.""") 47 | tf.app.flags.DEFINE_boolean('resume_resnet', True, """Resume from latest saved state.""") 48 | tf.app.flags.DEFINE_boolean('resume_video_deblur', True, """Resume from latest saved state.""") 49 | tf.app.flags.DEFINE_boolean('resume_all', False, """Resume from latest saved state.""") 50 | tf.app.flags.DEFINE_boolean('minimal_summaries', False, """Produce fewer summaries to save HD space.""") 51 | tf.app.flags.DEFINE_string('training_period', 'pretrain', """Choose the training period: pretrain/train.""") 52 | tf.app.flags.DEFINE_boolean('use_bn', False, """Use batch normalization. Otherwise use biases.""") 53 | 54 | def resume(sess, do_resume, ckpt_path, key_word): 55 | var = tf.global_variables() 56 | if do_resume: 57 | structure_predictor_latest = tf.train.latest_checkpoint(ckpt_path) 58 | if not structure_predictor_latest: 59 | print ("\n No checkpoint to continue from in ", ckpt_path, '\n') 60 | structure_predictor_var_to_restore = [val for val in var if key_word in val.name] 61 | saver_structure_predictor = tf.train.Saver(structure_predictor_var_to_restore) 62 | saver_structure_predictor.restore(sess, structure_predictor_latest) 63 | 64 | def train(resnet_model, is_training, F, H, F_curr, H_curr, 65 | input_images_blur, input_images_boundary, next_boundary_gt, labels, 66 | data_dir, data_dir_valid, img_list, img_list_valid, 67 | dropout_ratio): 68 | 69 | global_step = tf.get_variable('global_step', [], 70 | initializer=tf.constant_initializer(0), 71 | trainable=False) 72 | val_step = tf.get_variable('val_step', [], 73 | initializer=tf.constant_initializer(0), 74 | trainable=False) 75 | 76 | # define the losses. 77 | lambda_ = 1e-5 78 | 79 | loss_1 = resnet_model.l2_loss_(resnet_model.logits, labels) 80 | loss_2 = resnet_model.l2_loss_(resnet_model.next_frame,next_boundary_gt) 81 | loss_3 = resnet_model.l2_loss_(input_images_blur[:,:,:,-3:],resnet_model.video_deblur_output) 82 | loss_ = loss_1+loss_2+loss_3+tf.reduce_sum(tf.square(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))*lambda_ 83 | 84 | ema = tf.train.ExponentialMovingAverage(resnet_model.MOVING_AVERAGE_DECAY, global_step) 85 | tf.add_to_collection(resnet_model.UPDATE_OPS_COLLECTION, ema.apply([loss_])) 86 | tf.summary.scalar('loss_avg', ema.average(loss_)) 87 | 88 | ema = tf.train.ExponentialMovingAverage(0.9, val_step) 89 | val_op = tf.group(val_step.assign_add(1), ema.apply([loss_])) 90 | tf.summary.scalar('loss_valid', ema.average(loss_)) 91 | 92 | tf.summary.scalar('learning_rate', FLAGS.learning_rate) 93 | 94 | # define the optimizer and back propagate. 95 | opt = tf.train.AdamOptimizer(FLAGS.learning_rate) 96 | grads = opt.compute_gradients(loss_) 97 | for grad, var in grads: 98 | if grad is not None and not FLAGS.minimal_summaries: 99 | tf.summary.histogram(var.op.name + '/gradients', grad) 100 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 101 | 102 | batchnorm_updates = tf.get_collection(resnet_model.UPDATE_OPS_COLLECTION) 103 | batchnorm_updates_op = tf.group(*batchnorm_updates) 104 | train_op = tf.group(apply_gradient_op, batchnorm_updates_op) 105 | 106 | saver_all = tf.train.Saver(tf.all_variables()) 107 | 108 | summary_op = tf.summary.merge_all() 109 | 110 | # initialize all variables 111 | init = tf.initialize_all_variables() 112 | sess = tf.Session(config=tf.ConfigProto(log_device_placement=False)) 113 | sess.run(init) 114 | 115 | summary_writer = tf.summary.FileWriter(FLAGS.end_2_end_train_dir, sess.graph) 116 | val_summary_writer = tf.summary.FileWriter(FLAGS.end_2_end_valid_dir) 117 | val_save_root = os.path.join(FLAGS.end_2_end_valid_dir,'visualization') 118 | compare_save_root = os.path.join(FLAGS.end_2_end_valid_dir,'deblur_compare') 119 | 120 | # resume weights 121 | resume(sess, FLAGS.resume_structure_predictor, FLAGS.structure_predictor_train_dir, 'voxel_flow_model_') 122 | resume(sess, FLAGS.resume_video_deblur, FLAGS.video_deblur_train_dir, 'video_deblur_model_') 123 | resume(sess, FLAGS.resume_resnet, FLAGS.resnet_train_dir, 'resnet_model_') 124 | resume(sess, FLAGS.resume_all, FLAGS.end_2_end_train_dir, '') 125 | 126 | # create data generator 127 | if FLAGS.training_period == 'pretrain': 128 | dataset = DataGenerator(data_dir, img_list, data_dir_valid, img_list_valid) 129 | dataset._create_train_sets_for_300W() 130 | dataset._create_valid_sets_for_300W() 131 | elif FLAGS.training_period == 'train': 132 | dataset = DataGenerator(data_dir,img_list) 133 | dataset._create_train_table() 134 | dataset._create_sets_for_300VW() 135 | else: 136 | raise NameError("No such training_period!") 137 | train_gen = dataset._aux_generator(batch_size = FLAGS.batch_size, 138 | num_input_imgs = num_input_imgs, 139 | NUM_CLASSES = POINTS_NUM*2, 140 | sample_set='train') 141 | valid_gen = dataset._aux_generator(batch_size = FLAGS.batch_size, 142 | num_input_imgs = num_input_imgs, 143 | NUM_CLASSES = POINTS_NUM*2, 144 | sample_set='valid') 145 | 146 | # main training process. 147 | for x in xrange(FLAGS.max_steps + 1): 148 | 149 | start_time = time.time() 150 | step = sess.run(global_step) 151 | i = [train_op, loss_] 152 | write_summary = step > 1 and not (step % 100) 153 | if write_summary: 154 | i.append(summary_op) 155 | i.append(resnet_model.logits) 156 | i.append(F_curr) 157 | i.append(H_curr) 158 | 159 | train_line_num, frame_name, input_boundaries, boundary_gt_train, input_images_blur_generated, landmark_gt_train = next(train_gen) 160 | 161 | if (frame_name == '2.jpg'): 162 | input_images_boundary_init = copy.deepcopy(input_boundaries) 163 | F_init = np.zeros([FLAGS.batch_size, 164 | IMAGE_SIZE//2, 165 | IMAGE_SIZE//2, 166 | structure_predictor_net_channel//2], dtype=np.float32) 167 | 168 | H_init = np.zeros([1, 169 | FLAGS.batch_size, 170 | IMAGE_SIZE//2, 171 | IMAGE_SIZE//2, 172 | structure_predictor_net_channel], dtype=np.float32) 173 | feed_dict={ 174 | input_images_boundary:input_images_boundary_init, 175 | input_images_blur:input_images_blur_generated, 176 | F:F_init, 177 | H:H_init, 178 | labels:landmark_gt_train, 179 | next_boundary_gt:boundary_gt_train, 180 | dropout_ratio:0.5 181 | } 182 | else: 183 | output_points = o[-3] 184 | output_points = np.reshape(output_points,(POINTS_NUM,2)) 185 | 186 | boundary_from_points = points_to_heatmap_rectangle_68pt(output_points) 187 | boundary_from_points = np.expand_dims(boundary_from_points,axis=0) 188 | boundary_from_points = np.expand_dims(boundary_from_points,axis=3) 189 | input_images_boundary_init = np.concatenate([input_images_boundary_init[:,:,:,1:2], 190 | boundary_from_points], axis=3) 191 | feed_dict={ 192 | input_images_boundary:input_images_boundary_init, 193 | input_images_blur:input_images_blur_generated, 194 | F:o[-2], 195 | H:o[-1], 196 | labels:landmark_gt_train, 197 | next_boundary_gt:boundary_gt_train, 198 | dropout_ratio:0.5 199 | } 200 | 201 | o = sess.run(i,feed_dict=feed_dict) 202 | loss_value = o[1] 203 | duration = time.time() - start_time 204 | assert not np.isnan(loss_value), 'Model diverged with loss = NaN' 205 | 206 | if step > 1 and step % 300 == 0: 207 | examples_per_sec = FLAGS.batch_size / float(duration) 208 | format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f ' 209 | 'sec/batch)') 210 | print(format_str % (step, loss_value, examples_per_sec, duration)) 211 | 212 | if write_summary: 213 | summary_str = o[2] 214 | summary_writer.add_summary(summary_str, step) 215 | 216 | if step > 1 and step % 300 == 0: 217 | checkpoint_path = os.path.join(FLAGS.end_2_end_train_dir, 'model.ckpt') 218 | ensure_dir(checkpoint_path) 219 | saver_all.save(sess, checkpoint_path, global_step=global_step) 220 | 221 | # Run validation periodically 222 | if step > 1 and step % 300 == 0: 223 | valid_line_num, frame_name, input_boundaries, boundary_gt_valid, input_images_blur_generated, landmark_gt_valid = next(valid_gen) 224 | 225 | if (frame_name == '2.jpg') or valid_line_num <= 3: 226 | input_images_boundary_init = copy.deepcopy(input_boundaries) 227 | F_init = np.zeros([FLAGS.batch_size, 228 | IMAGE_SIZE//2, 229 | IMAGE_SIZE//2, 230 | structure_predictor_net_channel//2], dtype=np.float32) 231 | 232 | H_init = np.zeros([1, FLAGS.batch_size, 233 | IMAGE_SIZE//2, 234 | IMAGE_SIZE//2, 235 | structure_predictor_net_channel], dtype=np.float32) 236 | 237 | feed_dict={input_images_boundary:input_images_boundary_init, 238 | input_images_blur:input_images_blur_generated, 239 | F:F_init, 240 | H:H_init, 241 | labels:landmark_gt_valid, 242 | next_boundary_gt:boundary_gt_valid, 243 | dropout_ratio:1.0 244 | } 245 | else: 246 | output_points = o_valid[-3] 247 | output_points = np.reshape(output_points,(POINTS_NUM,2)) 248 | boundary_from_points = points_to_heatmap_rectangle_68pt(output_points) 249 | boundary_from_points = np.expand_dims(boundary_from_points,axis=0) 250 | boundary_from_points = np.expand_dims(boundary_from_points,axis=3) 251 | 252 | input_images_boundary_init = np.concatenate([input_images_boundary_init[:,:,:,1:2], 253 | boundary_from_points], axis=3) 254 | feed_dict={ 255 | input_images_boundary:input_images_boundary_init, 256 | input_images_blur:input_images_blur_generated, 257 | F:o_valid[-2], 258 | H:o_valid[-1], 259 | labels:landmark_gt_valid, 260 | next_boundary_gt:boundary_gt_valid, 261 | dropout_ratio:1.0 262 | } 263 | i_valid = [loss_,resnet_model.logits,F_curr,H_curr] 264 | o_valid = sess.run(i_valid,feed_dict=feed_dict) 265 | print('Validation top1 error %.2f' % o_valid[0]) 266 | if write_summary: 267 | val_summary_writer.add_summary(summary_str, step) 268 | img_video_deblur_output = sess.run(resnet_model.video_deblur_output,feed_dict=feed_dict)[0]*255 269 | img = input_images_blur_generated[0,:,:,0:3]*255 270 | compare_img = np.concatenate([img,img_video_deblur_output],axis=1) 271 | points = o_valid[1][0]*255 272 | 273 | for point_num in range(int(points.shape[0]/2)): 274 | cv2.circle(img,(int(round(points[point_num*2])),int(round(points[point_num*2+1]))),1,(55,225,155),2) 275 | val_save_path = os.path.join(val_save_root,str(step)+'.jpg') 276 | compare_save_path = os.path.join(compare_save_root,str(step)+'.jpg') 277 | ensure_dir(val_save_path) 278 | ensure_dir(compare_save_path) 279 | cv2.imwrite(val_save_path,img) 280 | cv2.imwrite(compare_save_path,compare_img) 281 | 282 | def main(argv=None): 283 | resnet_model = FAB(structure_predictor_is_train=False, 284 | deblur_is_train=True, 285 | resnet_is_train=False) 286 | 287 | is_training = tf.placeholder('bool', [], name='is_training') 288 | 289 | input_images_boundary = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 2)) 290 | input_images_blur = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, PIC_CHANNEL*3)) 291 | next_boundary_gt = tf.placeholder(tf.float32,shape=(FLAGS.batch_size, IMAGE_SIZE, IMAGE_SIZE, 1)) 292 | labels = tf.placeholder(tf.float32,shape=(FLAGS.batch_size,NUM_CLASSES)) 293 | dropout_ratio = tf.placeholder(tf.float32) 294 | F = tf.placeholder(tf.float32, [FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel//2]) 295 | H = tf.placeholder(tf.float32, [1, FLAGS.batch_size, IMAGE_SIZE//2, IMAGE_SIZE//2, structure_predictor_net_channel]) 296 | 297 | F_curr, H_curr = resnet_model.FAB_inference(input_images_boundary, input_images_blur, F, H, FLAGS.batch_size, 298 | net_channel=structure_predictor_net_channel, num_classes=136, 299 | num_blocks=[2, 2, 2, 2], use_bias=(not FLAGS.use_bn), 300 | bottleneck=True,dropout_ratio=1.0) 301 | 302 | train(resnet_model, is_training, F, H, F_curr, H_curr, 303 | input_images_blur, input_images_boundary, next_boundary_gt, labels, 304 | FLAGS.data_dir, FLAGS.data_dir_valid, FLAGS.img_list, FLAGS.img_list_valid, 305 | dropout_ratio) 306 | 307 | if __name__ == '__main__': 308 | tf.app.run() 309 | -------------------------------------------------------------------------------- /src/FAB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | import datetime 8 | import numpy as np 9 | import os 10 | import time 11 | import math 12 | import skimage.io 13 | import skimage.transform 14 | 15 | from utils.loss_utils import l2_loss 16 | from utils.geo_layer_utils import bilinear_interp 17 | from utils.geo_layer_utils import meshgrid 18 | from tensorflow.python.ops import control_flow_ops 19 | from tensorflow.python.training import moving_averages 20 | from utils.config import Config 21 | 22 | 23 | class FAB(object): 24 | def __init__(self, structure_predictor_is_train=True, deblur_is_train=True, 25 | resnet_is_train=True, is_training=True, 26 | MOVING_AVERAGE_DECAY=0.9997, BN_EPSILON=0.001, 27 | CONV_WEIGHT_DECAY=0.0005, CONV_WEIGHT_STDDEV=0.1, 28 | FC_WEIGHT_DECAY=0.0005, FC_WEIGHT_STDDEV=0.01, 29 | RESNET_VARIABLES='RESNET_VARIABLES', 30 | UPDATE_OPS_COLLECTION='resnet_update_ops', 31 | IMAGENET_MEAN_BGR=[103.062623801, 115.902882574, 123.151630838, ], 32 | input_size = 224): 33 | 34 | self.structure_predictor_is_train = structure_predictor_is_train 35 | self.deblur_is_train = deblur_is_train 36 | self.resnet_is_train = resnet_is_train 37 | 38 | self.MOVING_AVERAGE_DECAY = MOVING_AVERAGE_DECAY 39 | self.BN_DECAY = self.MOVING_AVERAGE_DECAY 40 | self.BN_EPSILON = BN_EPSILON 41 | self.CONV_WEIGHT_DECAY = CONV_WEIGHT_DECAY 42 | self.CONV_WEIGHT_STDDEV = CONV_WEIGHT_STDDEV 43 | self.FC_WEIGHT_DECAY = FC_WEIGHT_DECAY 44 | self.FC_WEIGHT_STDDEV = FC_WEIGHT_STDDEV 45 | self.RESNET_VARIABLES = RESNET_VARIABLES 46 | self.UPDATE_OPS_COLLECTION = UPDATE_OPS_COLLECTION 47 | self.IMAGENET_MEAN_BGR = IMAGENET_MEAN_BGR 48 | self.input_size = input_size 49 | 50 | ### loss function ### 51 | def l1_loss_(self, logits, labels): 52 | logits = tf.cast(logits,tf.float32) 53 | labels = tf.cast(labels,tf.float32) 54 | losses = tf.reduce_sum(tf.abs(tf.subtract(logits,labels)), axis=1) 55 | losses_mean = tf.reduce_mean(losses) 56 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 57 | loss_ = tf.add_n([losses_mean] + regularization_losses) 58 | 59 | return loss_ 60 | 61 | def l2_loss_(self, logits, labels): 62 | logits = tf.cast(logits,tf.float32) 63 | labels = tf.cast(labels,tf.float32) 64 | losses = tf.nn.l2_loss(tf.subtract(logits,labels)) 65 | losses_mean = tf.reduce_mean(losses) 66 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 67 | loss_ = tf.add_n([losses_mean] + regularization_losses) 68 | 69 | return loss_ 70 | 71 | def wing_loss(self, logits, labels, w=10.0, epsilon=2.0): 72 | logits = tf.cast(logits,tf.float32) 73 | labels = tf.cast(labels,tf.float32) 74 | x = tf.subtract(logits,labels) 75 | C = w * (1.0 - math.log(1.0 + w/epsilon)) 76 | absolute_x = tf.abs(x) 77 | losses = tf.where(tf.greater(w, absolute_x), 78 | w * tf.log(1.0 + absolute_x/epsilon), 79 | absolute_x - C) 80 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 81 | loss_ = tf.add_n([losses] + regularization_losses) 82 | 83 | return loss_ 84 | 85 | def calculate_NME(self, logits, labels): 86 | logits = tf.cast(logits,tf.float32) 87 | labels = tf.cast(labels,tf.float32) 88 | 89 | subtract_square_distance = tf.square(tf.subtract(logits, labels)) 90 | mean_distance = tf.reduce_mean([tf.sqrt(tf.add(subtract_square_distance[:, column], 91 | subtract_square_distance[:, column+1])) for column in range(0, 136, 2)], axis=0) 92 | 93 | outer_eye_x = tf.square(tf.subtract(labels[:, 72], labels[:, 90])) 94 | outer_eye_y = tf.square(tf.subtract(labels[:, 73], labels[:, 91])) 95 | inter_ocular_distance = tf.sqrt(tf.add(outer_eye_x, outer_eye_y)) 96 | 97 | normalized_mean_error = tf.divide(mean_distance, inter_ocular_distance, 98 | name='normalized_mean_error') 99 | loss_ = tf.reduce_mean(normalized_mean_error) 100 | 101 | return loss_ 102 | 103 | ### structure predictor model ### 104 | def structure_predictor_inference(self,input_images_boundary,batch_size): 105 | with tf.variable_scope('structure_predictor_model_'): 106 | with slim.arg_scope([slim.conv2d], 107 | activation_fn=tf.nn.relu, 108 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 109 | weights_regularizer=slim.l2_regularizer(0.0001)): 110 | 111 | batch_norm_params = {'decay': 0.9997, 112 | 'epsilon': 0.0001, 113 | 'is_training': self.structure_predictor_is_train} 114 | 115 | with slim.arg_scope([slim.batch_norm], 116 | is_training = self.structure_predictor_is_train, 117 | updates_collections=None): 118 | with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm, 119 | normalizer_params=batch_norm_params): 120 | net = slim.conv2d(input_images_boundary, 64, [5, 5], stride=1, scope='conv1') 121 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 122 | net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2') 123 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 124 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3') 125 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 126 | net = tf.image.resize_bilinear(net, [64,64]) 127 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4') 128 | net = tf.image.resize_bilinear(net, [128,128]) 129 | net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5') 130 | net = tf.image.resize_bilinear(net, [256,256]) 131 | net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6') 132 | 133 | net = slim.conv2d(net, 3, [5, 5], stride=1, activation_fn=tf.tanh, 134 | normalizer_fn=None, scope='conv7') 135 | flow = net[:, :, :, 0:2] 136 | mask = tf.expand_dims(net[:, :, :, 2], 3) 137 | 138 | grid_x, grid_y = meshgrid(256, 256) 139 | grid_x = tf.tile(grid_x, [batch_size, 1, 1]) 140 | grid_y = tf.tile(grid_y, [batch_size, 1, 1]) 141 | 142 | coor_x_1 = grid_x + flow[:, :, :, 0]*2 143 | coor_y_1 = grid_y + flow[:, :, :, 1]*2 144 | coor_x_2 = grid_x + flow[:, :, :, 0] 145 | coor_y_2 = grid_y + flow[:, :, :, 1] 146 | 147 | output_1 = bilinear_interp(input_images_boundary[:, :, :, 0:1], 148 | coor_x_1, coor_y_1, 'extrapolate') 149 | output_2 = bilinear_interp(input_images_boundary[:, :, :, 1:2], 150 | coor_x_2, coor_y_2, 'extrapolate') 151 | 152 | mask = 0.33 * (1.0 + mask) 153 | mask = tf.tile(mask, [1, 1, 1, 3]) 154 | next_frame = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2) 155 | 156 | return next_frame 157 | 158 | ### video deblur function ### 159 | def get_shape(self, x, i): 160 | return x.get_shape().as_list()[i] 161 | 162 | def weight_variable(self, shape, stddev=0.02, name = 'weight'): 163 | w = tf.get_variable(name, shape, 164 | initializer=tf.random_normal_initializer(stddev=stddev), 165 | trainable=self.deblur_is_train) 166 | return w 167 | 168 | def bias_variable(self, shape, name): 169 | b = tf.get_variable(name, initializer = tf.zeros(shape), 170 | trainable= self.deblur_is_train) 171 | return b 172 | 173 | def conv2d(self, x, W, stride = 1): 174 | return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='SAME') 175 | 176 | def conv2d_transpose(self, x, w, output_shape, stride = 2): 177 | return tf.nn.conv2d_transpose(x, w, output_shape=output_shape, 178 | strides=[1, stride, stride, 1], padding='SAME') 179 | 180 | def bn(self, x): 181 | net = x 182 | out_channels = self.get_shape(net, 3) 183 | mean, var = tf.nn.moments(net, axes=[0,1,2]) 184 | beta = self.bias_variable([out_channels], name="beta") 185 | gamma = self.weight_variable([out_channels], name="gamma") 186 | net = tf.nn.batch_normalization(net, mean, var, beta, gamma, 0.001) 187 | return net 188 | 189 | def conv_bn(self, x, filter_shape): 190 | net = x 191 | net = tf.nn.conv2d(net, self.weight_variable(filter_shape, name = "weight"), 192 | strides=[1, 1, 1, 1], padding="SAME") 193 | out_channels = filter_shape[3] 194 | mean, var = tf.nn.moments(net, axes=[0,1,2]) 195 | beta = self.bias_variable([out_channels], name="beta") 196 | gamma = self.weight_variable([out_channels], name="gamma") 197 | net = tf.nn.batch_normalization(net, mean, var, beta, gamma, 0.001) 198 | return net 199 | 200 | def resnet_block(self, x, out_channel, filter_size = 3): 201 | x_channel = x.get_shape().as_list()[3] 202 | with tf.variable_scope("conv_bn_relu"): 203 | net = self.conv_bn(x, filter_shape=[filter_size, 204 | filter_size, 205 | out_channel, 206 | out_channel]) 207 | net = tf.nn.relu(net) 208 | with tf.variable_scope("conv_bn"): 209 | net = self.conv_bn(net, filter_shape=[filter_size, 210 | filter_size, 211 | out_channel, 212 | out_channel]) 213 | net = net + x 214 | tf.nn.relu(net) 215 | return net 216 | 217 | def dynamic_fusion(self, x, h, filter_size = 5): 218 | n_channel = self.get_shape(x, 3) 219 | t = tf.concat([x, h], 3) 220 | similarity = tf.nn.conv2d(t, self.weight_variable([filter_size, 221 | filter_size, 222 | n_channel*2, 223 | n_channel], 224 | name = "wt"), 225 | strides=[1, 1, 1, 1], 226 | padding='VALID') 227 | epsilon = self.bias_variable([1], name = 'bias_epsilon') 228 | alpha = 2*tf.abs(tf.sigmoid(similarity) - 0.5) + epsilon 229 | alpha = tf.clip_by_value(alpha, 0, 1) 230 | hflt_filter_size = filter_size // 2 231 | alpha = tf.pad(alpha-1, [[0, 0], 232 | [hflt_filter_size, hflt_filter_size], 233 | [hflt_filter_size, hflt_filter_size], 234 | [0, 0]], "CONSTANT") + 1 235 | y = alpha*x + (1-alpha)*h 236 | return y, alpha 237 | 238 | def video_deblur_inference(self, X, F, H, net_channel = 64): 239 | with tf.variable_scope('video_deblur_model_'): 240 | H_curr = [] 241 | with tf.variable_scope("encoding"): 242 | with tf.variable_scope("conv1"): 243 | filter_size = 5 244 | net_X = self.conv2d(X, self.weight_variable([filter_size, 245 | filter_size, 246 | self.get_shape(X, 3), 247 | net_channel])) 248 | net_X = tf.nn.relu(net_X) 249 | with tf.variable_scope("conv2"): 250 | filter_size = 3 251 | net_X = self.conv2d(net_X, self.weight_variable([filter_size, 252 | filter_size, 253 | self.get_shape(net_X, 3), 254 | net_channel//2]), 255 | stride = 2) 256 | net_X = tf.nn.relu(net_X) 257 | net = tf.concat([net_X, F], 3) 258 | f0 = net 259 | filter_size = 3 260 | num_resnet_layers = 8 261 | for i in range (num_resnet_layers): 262 | with tf.variable_scope('resnet_block%d' % (i+1)): 263 | net = self.resnet_block(net, net_channel) 264 | if i == 3: 265 | (net, alpha) = self.dynamic_fusion(net, H[0]) 266 | h = tf.expand_dims(net, axis=0) 267 | H_curr = h 268 | with tf.variable_scope("feat_out"): 269 | F = self.conv2d(net, self.weight_variable([filter_size, 270 | filter_size, 271 | self.get_shape(net, 3), 272 | net_channel//2], 273 | name = 'conv_F')) 274 | F = tf.nn.relu(F) 275 | with tf.variable_scope("img_out"): 276 | filter_size = 4 277 | shape = [self.get_shape(X, 0), 278 | self.get_shape(X, 1), 279 | self.get_shape(X, 2), 280 | net_channel] 281 | Y = self.conv2d_transpose(net, self.weight_variable([filter_size, 282 | filter_size, 283 | net_channel, 284 | net_channel], 285 | name = "deconv"), 286 | shape, 287 | stride = 2) 288 | Y = tf.nn.relu(Y) 289 | filter_size = 3 290 | Y = self.conv2d(Y, self.weight_variable([filter_size, 291 | filter_size, 292 | self.get_shape(Y, 3), 293 | 3], 294 | name = 'conv')) 295 | return Y, F, H_curr 296 | 297 | ### resnet inference ### 298 | def resnet_inference(self, 299 | input_images_blur, 300 | batch_size, 301 | num_classes=136, 302 | num_blocks=[2, 2, 2, 2], 303 | use_bias=False, 304 | bottleneck=True, 305 | dropout_ratio=1.0): 306 | ####resnet_model#### 307 | with tf.variable_scope('resnet_model_'): 308 | c = Config() 309 | c['bottleneck'] = bottleneck 310 | c['is_training'] = tf.convert_to_tensor(self.resnet_is_train, 311 | dtype='bool', 312 | name='is_training') 313 | c['ksize'] = 3 314 | c['stride'] = 1 315 | c['use_bias'] = use_bias 316 | c['fc_units_out'] = num_classes 317 | c['num_blocks'] = num_blocks 318 | c['stack_stride'] = 2 319 | 320 | with tf.variable_scope('scale1'): 321 | c['conv_filters_out'] = 16 322 | c['ksize'] = 7 323 | c['stride'] = 2 324 | x = self.conv(input_images_blur, c) 325 | x = self.resnet_bn(x, c) 326 | x = self.activation(x) 327 | 328 | with tf.variable_scope('scale1_pool'): 329 | x = self._max_pool(x, ksize=3, stride=2) 330 | x = self.resnet_bn(x, c) 331 | x = self.activation(x) 332 | 333 | with tf.variable_scope('scale2'): 334 | x = self._max_pool(x, ksize=3, stride=2) 335 | c['num_blocks'] = num_blocks[0] 336 | c['stack_stride'] = 1 337 | c['block_filters_internal'] = 8 338 | x = self.stack(x, c) 339 | 340 | with tf.variable_scope('scale3'): 341 | c['num_blocks'] = num_blocks[1] 342 | c['block_filters_internal'] = 16 343 | assert c['stack_stride'] == 2 344 | x = self.stack(x, c) 345 | 346 | with tf.variable_scope('scale4'): 347 | c['num_blocks'] = num_blocks[2] 348 | c['block_filters_internal'] = 32 349 | x = self.stack(x, c) 350 | 351 | with tf.variable_scope('scale5'): 352 | c['num_blocks'] = num_blocks[3] 353 | c['block_filters_internal'] = 64 354 | x = self.stack(x, c) 355 | 356 | x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool") 357 | 358 | if num_classes != None: 359 | with tf.variable_scope('fc1'): 360 | c['fc_units_out'] = 256 361 | x = self.fc(x, c) 362 | 363 | with tf.variable_scope('dropout1'): 364 | x = tf.nn.dropout(x, dropout_ratio) 365 | 366 | with tf.variable_scope('fc2'): 367 | c['fc_units_out'] = 256 368 | x = self.fc(x, c) 369 | 370 | with tf.variable_scope('dropout2'): 371 | x = tf.nn.dropout(x, dropout_ratio) 372 | 373 | with tf.variable_scope('fc3'): 374 | c['fc_units_out'] = 136 375 | landmark_localization = self.fc(x, c) 376 | 377 | return landmark_localization 378 | 379 | def stack(self, x, c): 380 | for n in range(c['num_blocks']): 381 | s = c['stack_stride'] if n == 0 else 1 382 | c['block_stride'] = s 383 | with tf.variable_scope('block%d' % (n + 1)): 384 | x = self.block(x, c, n) 385 | return x 386 | 387 | def block(self, x, c, n): 388 | filters_in = x.get_shape()[-1] 389 | m = 4 if c['bottleneck'] else 1 390 | filters_out = m * c['block_filters_internal'] 391 | c['conv_filters_out'] = c['block_filters_internal'] 392 | 393 | shortcut = x 394 | 395 | if c['bottleneck']: 396 | if n == 1: 397 | with tf.variable_scope('pre_activation'): 398 | x = self.resnet_bn(x, c) 399 | x = self.activation(x) 400 | 401 | with tf.variable_scope('a'): 402 | c['ksize'] = 1 403 | c['stride'] = c['block_stride'] 404 | x = self.conv(x, c) 405 | x = self.resnet_bn(x, c) 406 | x = self.activation(x) 407 | 408 | with tf.variable_scope('b'): 409 | x = self.conv(x, c) 410 | x = self.resnet_bn(x, c) 411 | x = self.activation(x) 412 | 413 | with tf.variable_scope('c'): 414 | c['conv_filters_out'] = filters_out 415 | c['ksize'] = 1 416 | assert c['stride'] == 1 417 | x = self.conv(x, c) 418 | else: 419 | with tf.variable_scope('A'): 420 | c['stride'] = c['block_stride'] 421 | assert c['ksize'] == 3 422 | x = self.conv(x, c) 423 | x = self.resnet_bn(x, c) 424 | x = self.activation(x) 425 | 426 | with tf.variable_scope('B'): 427 | c['conv_filters_out'] = filters_out 428 | assert c['ksize'] == 3 429 | assert c['stride'] == 1 430 | x = self.conv(x, c) 431 | x = self.resnet_bn(x, c) 432 | 433 | with tf.variable_scope('shortcut'): 434 | if filters_out != filters_in or c['block_stride'] != 1: 435 | c['ksize'] = 1 436 | c['stride'] = c['block_stride'] 437 | c['conv_filters_out'] = filters_out 438 | shortcut = self.conv(shortcut, c) 439 | 440 | if n == 0: 441 | return x + shortcut 442 | elif n == 1: 443 | x = self.resnet_bn(x+shortcut, c) 444 | return self.activation(x) 445 | 446 | def resnet_bn(self, x, c): 447 | x_shape = x.get_shape() 448 | params_shape = x_shape[-1:] 449 | 450 | if c['use_bias']: 451 | bias = self._get_variable('bias', 452 | params_shape, 453 | initializer=tf.zeros_initializer) 454 | return x + bias 455 | 456 | axis = list(range(len(x_shape) - 1)) 457 | beta = self._get_variable('beta', 458 | params_shape, 459 | initializer=tf.zeros_initializer) 460 | gamma = self._get_variable('gamma', 461 | params_shape, 462 | initializer=tf.ones_initializer) 463 | 464 | moving_mean = self._get_variable('moving_mean', 465 | params_shape, 466 | initializer=tf.zeros_initializer, 467 | trainable=False) 468 | moving_variance = self._get_variable('moving_variance', 469 | params_shape, 470 | initializer=tf.ones_initializer, 471 | trainable=False) 472 | 473 | # These ops will only be preformed when training. 474 | mean, variance = tf.nn.moments(x, axis) 475 | update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, self.BN_DECAY) 476 | update_moving_variance = moving_averages.assign_moving_average( 477 | moving_variance, variance, self.BN_DECAY) 478 | 479 | tf.add_to_collection(self.UPDATE_OPS_COLLECTION, update_moving_mean) 480 | tf.add_to_collection(self.UPDATE_OPS_COLLECTION, update_moving_variance) 481 | 482 | mean, variance = control_flow_ops.cond( 483 | c['is_training'], lambda: (mean, variance), 484 | lambda: (moving_mean, moving_variance)) 485 | 486 | x = tf.nn.batch_normalization(x, mean, variance, beta, gamma, self.BN_EPSILON) 487 | 488 | return x 489 | 490 | def activation(self, x): 491 | alphas = tf.get_variable('alpha', x.get_shape()[-1], 492 | initializer=tf.constant_initializer(0.25), 493 | dtype=tf.float32) 494 | pos = tf.nn.relu(x) 495 | neg = alphas * (x - abs(x)) * 0.5 496 | 497 | return pos + neg 498 | 499 | def fc(self, x, c): 500 | num_units_in = x.get_shape()[1] 501 | num_units_out = c['fc_units_out'] 502 | weights_initializer = tf.truncated_normal_initializer( 503 | stddev=self.FC_WEIGHT_STDDEV) 504 | weights = self._get_variable('weights', 505 | shape=[num_units_in, num_units_out], 506 | initializer=weights_initializer, 507 | weight_decay=self.FC_WEIGHT_STDDEV) 508 | biases = self._get_variable('biases', 509 | shape=[num_units_out], 510 | initializer=tf.zeros_initializer) 511 | x = tf.nn.xw_plus_b(x, weights, biases) 512 | 513 | return x 514 | 515 | def stack_fc(self, x, c): 516 | num_units_in = x.get_shape()[1] 517 | 518 | weights_initializer = tf.truncated_normal_initializer( 519 | stddev=self.FC_WEIGHT_STDDEV) 520 | 521 | weights = self._get_variable('weights', 522 | shape=[num_units_in, 256], 523 | initializer=weights_initializer, 524 | weight_decay=self.FC_WEIGHT_STDDEV) 525 | biases = self._get_variable('biases', 526 | shape=[256], 527 | initializer=tf.zeros_initializer) 528 | x = tf.nn.xw_plus_b(x, weights, biases) 529 | 530 | weights_2 = self._get_variable('weights_2', 531 | shape=[256, 256], 532 | initializer=weights_initializer, 533 | weight_decay=self.FC_WEIGHT_STDDEV) 534 | biases_2 = self._get_variable('biases_2', 535 | shape=[256], 536 | initializer=tf.zeros_initializer) 537 | x = tf.nn.xw_plus_b(x, weights_2, biases_2) 538 | 539 | num_units_out = c['fc_units_out'] 540 | 541 | weights_3 = self._get_variable('weights_3', 542 | shape=[256, num_units_out], 543 | initializer=weights_initializer, 544 | weight_decay=self.FC_WEIGHT_STDDEV) 545 | biases_3 = self._get_variable('biases_3', 546 | shape=[num_units_out], 547 | initializer=tf.zeros_initializer) 548 | x = tf.nn.xw_plus_b(x, weights_3, biases_3) 549 | 550 | return x 551 | 552 | def _get_variable(self, name, 553 | shape, 554 | initializer, 555 | weight_decay=0.0, 556 | dtype='float', 557 | trainable=True): 558 | if weight_decay > 0: 559 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 560 | else: 561 | regularizer = None 562 | collections = [tf.GraphKeys.VARIABLES, self.RESNET_VARIABLES] 563 | 564 | return tf.get_variable(name, 565 | shape=shape, 566 | initializer=initializer, 567 | dtype=dtype, 568 | regularizer=regularizer, 569 | collections=collections, 570 | trainable=trainable) 571 | 572 | def conv(self, x, c): 573 | ksize = c['ksize'] 574 | stride = c['stride'] 575 | filters_out = c['conv_filters_out'] 576 | 577 | filters_in = x.get_shape()[-1] 578 | shape = [ksize, ksize, filters_in, filters_out] 579 | initializer = tf.truncated_normal_initializer(stddev=self.CONV_WEIGHT_STDDEV) 580 | weights = self._get_variable('weights', 581 | shape=shape, 582 | dtype='float', 583 | initializer=initializer, 584 | weight_decay=self.CONV_WEIGHT_DECAY) 585 | 586 | return tf.nn.conv2d(x, weights, [1, stride, stride, 1], padding='SAME') 587 | 588 | def _max_pool(self, x, ksize=3, stride=2): 589 | return tf.nn.max_pool(x, 590 | ksize=[1, ksize, ksize, 1], 591 | strides=[1, stride, stride, 1], 592 | padding='SAME') 593 | 594 | ### FAB model ### 595 | def FAB_inference(self, 596 | input_images_boundary, 597 | input_images_blur, 598 | F,H, 599 | batch_size, 600 | net_channel=64, 601 | num_classes=136, 602 | num_blocks=[2, 2, 2, 2], 603 | use_bias=False, 604 | bottleneck=True, 605 | dropout_ratio=1.0): 606 | 607 | ####structure_predictor_model#### 608 | with tf.variable_scope('structure_predictor_model_'): 609 | with slim.arg_scope([slim.conv2d], 610 | activation_fn=tf.nn.relu, 611 | weights_initializer=tf.truncated_normal_initializer(0.0, 0.01), 612 | weights_regularizer=slim.l2_regularizer(0.0001)): 613 | 614 | batch_norm_params = { 615 | 'decay': 0.9997, 616 | 'epsilon': 0.001, 617 | 'is_training': self.structure_predictor_is_train, 618 | } 619 | with slim.arg_scope([slim.batch_norm], 620 | is_training=self.structure_predictor_is_train, 621 | updates_collections=None): 622 | with slim.arg_scope([slim.conv2d], normalizer_fn=slim.batch_norm, 623 | normalizer_params=batch_norm_params): 624 | net = slim.conv2d(input_images_boundary, 64, [5, 5], stride=1, scope='conv1') 625 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 626 | net = slim.conv2d(net, 128, [5, 5], stride=1, scope='conv2') 627 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 628 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv3') 629 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 630 | net = tf.image.resize_bilinear(net, [64,64]) 631 | net = slim.conv2d(net, 256, [3, 3], stride=1, scope='conv4') 632 | net = tf.image.resize_bilinear(net, [128,128]) 633 | net = slim.conv2d(net, 128, [3, 3], stride=1, scope='conv5') 634 | net = tf.image.resize_bilinear(net, [256,256]) 635 | net = slim.conv2d(net, 64, [5, 5], stride=1, scope='conv6') 636 | net = slim.conv2d(net, 3, [5, 5], stride=1, 637 | activation_fn=tf.tanh, normalizer_fn=None, scope='conv7') 638 | flow = net[:, :, :, 0:2] 639 | mask = tf.expand_dims(net[:, :, :, 2], 3) 640 | 641 | grid_x, grid_y = meshgrid(256, 256) 642 | grid_x = tf.tile(grid_x, [batch_size, 1, 1]) 643 | grid_y = tf.tile(grid_y, [batch_size, 1, 1]) 644 | 645 | coor_x_1 = grid_x + flow[:, :, :, 0]*2 646 | coor_y_1 = grid_y + flow[:, :, :, 1]*2 647 | coor_x_2 = grid_x + flow[:, :, :, 0] 648 | coor_y_2 = grid_y + flow[:, :, :, 1] 649 | 650 | output_1 = bilinear_interp(input_images_boundary[:, :, :, 0:1], 651 | coor_x_1, coor_y_1, 'extrapolate') 652 | output_2 = bilinear_interp(input_images_boundary[:, :, :, 1:2], 653 | coor_x_2, coor_y_2, 'extrapolate') 654 | 655 | mask = 0.5 * (1.0 + mask) 656 | mask = tf.tile(mask, [1, 1, 1, 3]) 657 | self.next_frame = tf.multiply(mask, output_1) + tf.multiply(1.0 - mask, output_2) 658 | self.structure_predictor_output = tf.concat([self.next_frame,input_images_blur],3) 659 | 660 | ####video_deblur_model#### 661 | with tf.variable_scope('video_deblur_model_'): 662 | H_curr = [] 663 | with tf.variable_scope("encoding"): 664 | 665 | with tf.variable_scope("conv1"): 666 | filter_size = 5 667 | net_X = self.conv2d(self.structure_predictor_output, self.weight_variable([filter_size, 668 | filter_size, 669 | self.get_shape(self.structure_predictor_output, 3), 670 | net_channel])) 671 | net_X = tf.nn.relu(net_X) 672 | 673 | with tf.variable_scope("conv2"): 674 | filter_size = 3 675 | net_X = self.conv2d(net_X, self.weight_variable([filter_size, 676 | filter_size, 677 | self.get_shape(net_X, 3), 678 | net_channel//2]), 679 | stride = 2) 680 | net_X = tf.nn.relu(net_X) 681 | 682 | net = tf.concat([net_X, F], 3) 683 | f0 = net 684 | filter_size = 3 685 | num_resnet_layers = 8 686 | 687 | for i in range (num_resnet_layers): 688 | with tf.variable_scope('resnet_block%d' % (i+1)): 689 | net = self.resnet_block(net, net_channel) 690 | 691 | if i == 3: 692 | (net, alpha) = self.dynamic_fusion(net, H[0]) 693 | h = tf.expand_dims(net, axis=0) 694 | H_curr = h 695 | 696 | with tf.variable_scope("feat_out"): 697 | F = self.conv2d(net, self.weight_variable([filter_size, 698 | filter_size, 699 | self.get_shape(net, 3), 700 | net_channel//2], 701 | name = 'conv_F')) 702 | F = tf.nn.relu(F) 703 | 704 | with tf.variable_scope("img_out"): 705 | filter_size = 4 706 | shape = [self.get_shape(self.structure_predictor_output, 0), 707 | self.get_shape(self.structure_predictor_output, 1), 708 | self.get_shape(self.structure_predictor_output, 2), 709 | net_channel] 710 | Y = self.conv2d_transpose(net, self.weight_variable([filter_size, 711 | filter_size, 712 | net_channel, 713 | net_channel], 714 | name = "deconv"), 715 | shape, 716 | stride = 2) 717 | Y = tf.nn.relu(Y) 718 | filter_size = 3 719 | self.video_deblur_output = self.conv2d(Y, self.weight_variable([filter_size, 720 | filter_size, 721 | self.get_shape(Y, 3), 722 | 3], 723 | name = 'conv')) 724 | 725 | ####resnet_model#### 726 | with tf.variable_scope('resnet_model_'): 727 | c = Config() 728 | c['bottleneck'] = bottleneck 729 | c['is_training'] = tf.convert_to_tensor(self.resnet_is_train, 730 | dtype='bool', 731 | name='is_training') 732 | c['ksize'] = 3 733 | c['stride'] = 1 734 | c['use_bias'] = use_bias 735 | c['fc_units_out'] = num_classes 736 | c['num_blocks'] = num_blocks 737 | c['stack_stride'] = 2 738 | 739 | with tf.variable_scope('scale1'): 740 | c['conv_filters_out'] = 16 741 | c['ksize'] = 7 742 | c['stride'] = 2 743 | x = self.conv(self.video_deblur_output, c) 744 | x = self.resnet_bn(x, c) 745 | x = self.activation(x) 746 | 747 | with tf.variable_scope('scale1_pool'): 748 | x = self._max_pool(x, ksize=3, stride=2) 749 | x = self.resnet_bn(x, c) 750 | x = self.activation(x) 751 | 752 | with tf.variable_scope('scale2'): 753 | x = self._max_pool(x, ksize=3, stride=2) 754 | c['num_blocks'] = num_blocks[0] 755 | c['stack_stride'] = 1 756 | c['block_filters_internal'] = 8 757 | x = self.stack(x, c) 758 | 759 | with tf.variable_scope('scale3'): 760 | c['num_blocks'] = num_blocks[1] 761 | c['block_filters_internal'] = 16 762 | assert c['stack_stride'] == 2 763 | x = self.stack(x, c) 764 | 765 | with tf.variable_scope('scale4'): 766 | c['num_blocks'] = num_blocks[2] 767 | c['block_filters_internal'] = 32 768 | x = self.stack(x, c) 769 | 770 | with tf.variable_scope('scale5'): 771 | c['num_blocks'] = num_blocks[3] 772 | c['block_filters_internal'] = 64 773 | x = self.stack(x, c) 774 | 775 | x = tf.reduce_mean(x, reduction_indices=[1, 2], name="avg_pool") 776 | 777 | if num_classes != None: 778 | with tf.variable_scope('fc1'): 779 | c['fc_units_out'] = 256 780 | x = self.fc(x, c) 781 | 782 | with tf.variable_scope('dropout1'): 783 | x = tf.nn.dropout(x, dropout_ratio) 784 | 785 | with tf.variable_scope('fc2'): 786 | c['fc_units_out'] = 256 787 | x = self.fc(x, c) 788 | 789 | with tf.variable_scope('dropout2'): 790 | x = tf.nn.dropout(x, dropout_ratio) 791 | 792 | with tf.variable_scope('fc3'): 793 | c['fc_units_out'] = 136 794 | self.logits = self.fc(x, c) 795 | 796 | return F, H_curr 797 | --------------------------------------------------------------------------------