├── src ├── input_pipeline │ ├── __init__.py │ ├── other_augmentations.py │ ├── pipeline.py │ └── random_image_crop.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── nms.py │ └── box_utils.py ├── constants.py ├── network.py ├── training_target_creation.py ├── losses_and_ohem.py ├── anchor_generator.py └── detector.py ├── .gitignore ├── eval_results ├── roc.png └── plot_roc.ipynb ├── images ├── the_office.jpg ├── brockhampton.jpg └── training_loss.png ├── config.json ├── LICENSE ├── save.py ├── create_pb.py ├── train.py ├── face_detector.py ├── test_input_pipeline.ipynb ├── try_detector.ipynb ├── model.py ├── README.md ├── create_tfrecords.py ├── predict_for_FDDB.ipynb ├── visualize_densified_anchor_boxes.ipynb ├── prepare_data ├── explore_and_prepare_WIDER.ipynb └── explore_and_convert_FDDB.ipynb └── evaluation_utils.py /src/input_pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from src.input_pipeline.pipeline import Pipeline 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | data/ 4 | models/ 5 | export/ 6 | *.pb 7 | 8 | -------------------------------------------------------------------------------- /eval_results/roc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TropComplique/FaceBoxes-tensorflow/HEAD/eval_results/roc.png -------------------------------------------------------------------------------- /images/the_office.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TropComplique/FaceBoxes-tensorflow/HEAD/images/the_office.jpg -------------------------------------------------------------------------------- /images/brockhampton.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TropComplique/FaceBoxes-tensorflow/HEAD/images/brockhampton.jpg -------------------------------------------------------------------------------- /images/training_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TropComplique/FaceBoxes-tensorflow/HEAD/images/training_loss.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .detector import Detector 2 | from .anchor_generator import AnchorGenerator 3 | from .network import FeatureExtractor 4 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.box_utils import iou, area, intersection, encode, batch_decode 2 | from src.utils.nms import batch_non_max_suppression 3 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_params": { 3 | "model_dir": "models/run00", 4 | 5 | "weight_decay": 1e-3, 6 | "score_threshold": 0.05, "iou_threshold": 0.3, "max_boxes": 200, 7 | 8 | "localization_loss_weight": 1.0, "classification_loss_weight": 1.0, 9 | 10 | "loss_to_use": "classification", 11 | "loc_loss_weight": 0.0, "cls_loss_weight": 1.0, 12 | "num_hard_examples": 500, "nms_threshold": 0.99, 13 | "max_negatives_per_positive": 3.0, "min_negatives_per_image": 30, 14 | 15 | "lr_boundaries": [160000, 200000], 16 | "lr_values": [0.004, 0.0004, 0.00004] 17 | }, 18 | 19 | "input_pipeline_params": { 20 | "image_size": [1024, 1024], 21 | "batch_size": 16, 22 | "train_dataset": "data/train_shards/", 23 | "val_dataset": "data/val_shards/", 24 | "num_steps": 240000 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # a small value 4 | EPSILON = 1e-8 5 | 6 | # this is used when we are doing box encoding/decoding 7 | SCALE_FACTORS = [10.0, 10.0, 5.0, 5.0] 8 | # you can read about them here: 9 | # github.com/rykov8/ssd_keras/issues/53 10 | # github.com/weiliu89/caffe/issues/155 11 | 12 | # here are input pipeline settings. 13 | # you need to tweak these numbers for your system, 14 | # it can accelerate training 15 | SHUFFLE_BUFFER_SIZE = 15000 16 | NUM_THREADS = 8 17 | # read here about the buffer sizes: 18 | # stackoverflow.com/questions/46444018/meaning-of-buffer-size-in-dataset-map-dataset-prefetch-and-dataset-shuffle 19 | 20 | # images are resized before feeding them to the network 21 | RESIZE_METHOD = tf.image.ResizeMethod.BILINEAR 22 | 23 | # threshold for IoU when creating training targets 24 | MATCHING_THRESHOLD = 0.35 25 | 26 | # this is used in tf.map_fn when creating training targets or doing NMS 27 | PARALLEL_ITERATIONS = 8 28 | 29 | # this can be important 30 | BATCH_NORM_MOMENTUM = 0.9 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Dan Antoshchenko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /save.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import json 3 | from model import model_fn 4 | 5 | 6 | """The purpose of this script is to export a savedmodel.""" 7 | 8 | 9 | CONFIG = 'config.json' 10 | OUTPUT_FOLDER = 'export/run00' 11 | GPU_TO_USE = '0' 12 | 13 | WIDTH, HEIGHT = None, None 14 | # size of an input image, 15 | # set (None, None) if you want inference 16 | # for images of variable size 17 | 18 | 19 | tf.logging.set_verbosity('INFO') 20 | params = json.load(open(CONFIG)) 21 | model_params = params['model_params'] 22 | 23 | config = tf.ConfigProto() 24 | config.gpu_options.visible_device_list = GPU_TO_USE 25 | run_config = tf.estimator.RunConfig() 26 | run_config = run_config.replace( 27 | model_dir=model_params['model_dir'], 28 | session_config=config 29 | ) 30 | estimator = tf.estimator.Estimator(model_fn, params=model_params, config=run_config) 31 | 32 | 33 | def serving_input_receiver_fn(): 34 | images = tf.placeholder(dtype=tf.uint8, shape=[None, HEIGHT, WIDTH, 3], name='image_tensor') 35 | features = {'images': tf.to_float(images)*(1.0/255.0)} 36 | return tf.estimator.export.ServingInputReceiver(features, {'images': images}) 37 | 38 | 39 | estimator.export_savedmodel( 40 | OUTPUT_FOLDER, serving_input_receiver_fn 41 | ) 42 | -------------------------------------------------------------------------------- /create_pb.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | 4 | """Create a .pb frozen inference graph from a SavedModel.""" 5 | 6 | 7 | def make_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | '-s', '--saved_model_folder', type=str 11 | ) 12 | parser.add_argument( 13 | '-o', '--output_pb', type=str, default='model.pb' 14 | ) 15 | return parser.parse_args() 16 | 17 | 18 | def main(): 19 | 20 | graph = tf.Graph() 21 | config = tf.ConfigProto() 22 | config.gpu_options.visible_device_list = '0' 23 | with graph.as_default(): 24 | with tf.Session(graph=graph, config=config) as sess: 25 | tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], ARGS.saved_model_folder) 26 | 27 | # output ops 28 | keep_nodes = ['boxes', 'scores', 'num_boxes'] 29 | 30 | input_graph_def = tf.graph_util.convert_variables_to_constants( 31 | sess, graph.as_graph_def(), 32 | output_node_names=keep_nodes 33 | ) 34 | output_graph_def = tf.graph_util.remove_training_nodes( 35 | input_graph_def, 36 | protected_nodes=keep_nodes + [n.name for n in input_graph_def.node if 'nms' in n.name] 37 | ) 38 | # ops in 'nms' scope must be protected for some reason, 39 | # but why? 40 | 41 | with tf.gfile.GFile(ARGS.output_pb, 'wb') as f: 42 | f.write(output_graph_def.SerializeToString()) 43 | print('%d ops in the final graph.' % len(output_graph_def.node)) 44 | 45 | 46 | ARGS = make_args() 47 | tf.logging.set_verbosity('INFO') 48 | main() 49 | -------------------------------------------------------------------------------- /eval_results/plot_roc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import seaborn as sns\n", 12 | "sns.set_style('whitegrid')\n", 13 | "%matplotlib inline" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "roc = pd.read_csv('discrete-ROC-step-240000.txt', sep=' ', header=None)\n", 23 | "roc.columns = ['tpr', 'fp', 'threshold']" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def plot_roc():\n", 33 | " _, axis = plt.subplots(nrows=1, ncols=1, figsize=(7, 4), dpi=120)\n", 34 | " axis.plot(roc.fp, roc.tpr, c='r', linewidth=2.0);\n", 35 | " axis.set_title('Discrete Score ROC')\n", 36 | " axis.set_xlim([0, 2000.0])\n", 37 | " axis.set_ylim([0.6, 1.0])\n", 38 | " axis.set_xlabel('False Positives')\n", 39 | " axis.set_ylabel('True Positive Rate');\n", 40 | "\n", 41 | "plot_roc()" 42 | ] 43 | } 44 | ], 45 | "metadata": { 46 | "kernelspec": { 47 | "display_name": "Python 3", 48 | "language": "python", 49 | "name": "python3" 50 | }, 51 | "language_info": { 52 | "codemirror_mode": { 53 | "name": "ipython", 54 | "version": 3 55 | }, 56 | "file_extension": ".py", 57 | "mimetype": "text/x-python", 58 | "name": "python", 59 | "nbconvert_exporter": "python", 60 | "pygments_lexer": "ipython3", 61 | "version": "3.6.3" 62 | } 63 | }, 64 | "nbformat": 4, 65 | "nbformat_minor": 2 66 | } 67 | -------------------------------------------------------------------------------- /src/utils/nms.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.constants import PARALLEL_ITERATIONS 3 | 4 | 5 | def batch_non_max_suppression( 6 | boxes, scores, 7 | score_threshold, iou_threshold, 8 | max_boxes): 9 | """ 10 | Arguments: 11 | boxes: a float tensor with shape [batch_size, N, 4]. 12 | scores: a float tensor with shape [batch_size, N]. 13 | score_threshold: a float number. 14 | iou_threshold: a float number, threshold for IoU. 15 | max_boxes: an integer, maximum number of retained boxes. 16 | Returns: 17 | boxes: a float tensor with shape [batch_size, max_boxes, 4]. 18 | scores: a float tensor with shape [batch_size, max_boxes]. 19 | num_detections: an int tensor with shape [batch_size]. 20 | """ 21 | def fn(x): 22 | boxes, scores = x 23 | 24 | # low scoring boxes are removed 25 | ids = tf.where(tf.greater_equal(scores, score_threshold)) 26 | ids = tf.squeeze(ids, axis=1) 27 | boxes = tf.gather(boxes, ids) 28 | scores = tf.gather(scores, ids) 29 | 30 | selected_indices = tf.image.non_max_suppression( 31 | boxes, scores, max_boxes, iou_threshold 32 | ) 33 | boxes = tf.gather(boxes, selected_indices) 34 | scores = tf.gather(scores, selected_indices) 35 | num_boxes = tf.to_int32(tf.shape(boxes)[0]) 36 | 37 | zero_padding = max_boxes - num_boxes 38 | boxes = tf.pad(boxes, [[0, zero_padding], [0, 0]]) 39 | scores = tf.pad(scores, [[0, zero_padding]]) 40 | 41 | boxes.set_shape([max_boxes, 4]) 42 | scores.set_shape([max_boxes]) 43 | return boxes, scores, num_boxes 44 | 45 | boxes, scores, num_detections = tf.map_fn( 46 | fn, [boxes, scores], 47 | dtype=(tf.float32, tf.float32, tf.int32), 48 | parallel_iterations=PARALLEL_ITERATIONS, 49 | back_prop=False, swap_memory=False, infer_shape=True 50 | ) 51 | return boxes, scores, num_detections 52 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import json 3 | import os 4 | 5 | from model import model_fn 6 | from src.input_pipeline import Pipeline 7 | tf.logging.set_verbosity('INFO') 8 | 9 | 10 | CONFIG = 'config.json' 11 | GPU_TO_USE = '0' 12 | 13 | 14 | params = json.load(open(CONFIG)) 15 | model_params = params['model_params'] 16 | input_params = params['input_pipeline_params'] 17 | 18 | 19 | def get_input_fn(is_training=True): 20 | 21 | image_size = input_params['image_size'] if is_training else None 22 | # (for evaluation i use images of different sizes) 23 | dataset_path = input_params['train_dataset'] if is_training else input_params['val_dataset'] 24 | batch_size = input_params['batch_size'] if is_training else 1 25 | # for evaluation it's important to set batch_size to 1 26 | 27 | filenames = os.listdir(dataset_path) 28 | filenames = [n for n in filenames if n.endswith('.tfrecords')] 29 | filenames = [os.path.join(dataset_path, n) for n in sorted(filenames)] 30 | 31 | def input_fn(): 32 | with tf.device('/cpu:0'), tf.name_scope('input_pipeline'): 33 | pipeline = Pipeline( 34 | filenames, 35 | batch_size=batch_size, image_size=image_size, 36 | repeat=is_training, shuffle=is_training, 37 | augmentation=is_training 38 | ) 39 | features, labels = pipeline.get_batch() 40 | return features, labels 41 | 42 | return input_fn 43 | 44 | 45 | config = tf.ConfigProto() 46 | config.gpu_options.visible_device_list = GPU_TO_USE 47 | 48 | run_config = tf.estimator.RunConfig() 49 | run_config = run_config.replace( 50 | model_dir=model_params['model_dir'], 51 | session_config=config, 52 | save_summary_steps=200, 53 | save_checkpoints_secs=600, 54 | log_step_count_steps=100 55 | ) 56 | 57 | train_input_fn = get_input_fn(is_training=True) 58 | val_input_fn = get_input_fn(is_training=False) 59 | estimator = tf.estimator.Estimator(model_fn, params=model_params, config=run_config) 60 | 61 | train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=input_params['num_steps']) 62 | eval_spec = tf.estimator.EvalSpec(val_input_fn, steps=None, start_delay_secs=1800, throttle_secs=1800) 63 | tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) 64 | -------------------------------------------------------------------------------- /face_detector.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class FaceDetector: 6 | def __init__(self, model_path, gpu_memory_fraction=0.25, visible_device_list='0'): 7 | """ 8 | Arguments: 9 | model_path: a string, path to a pb file. 10 | gpu_memory_fraction: a float number. 11 | visible_device_list: a string. 12 | """ 13 | with tf.gfile.GFile(model_path, 'rb') as f: 14 | graph_def = tf.GraphDef() 15 | graph_def.ParseFromString(f.read()) 16 | 17 | graph = tf.Graph() 18 | with graph.as_default(): 19 | tf.import_graph_def(graph_def, name='import') 20 | 21 | self.input_image = graph.get_tensor_by_name('import/image_tensor:0') 22 | self.output_ops = [ 23 | graph.get_tensor_by_name('import/boxes:0'), 24 | graph.get_tensor_by_name('import/scores:0'), 25 | graph.get_tensor_by_name('import/num_boxes:0'), 26 | ] 27 | 28 | gpu_options = tf.GPUOptions( 29 | per_process_gpu_memory_fraction=gpu_memory_fraction, 30 | visible_device_list=visible_device_list 31 | ) 32 | config_proto = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False) 33 | self.sess = tf.Session(graph=graph, config=config_proto) 34 | 35 | def __call__(self, image, score_threshold=0.5): 36 | """Detect faces. 37 | 38 | Arguments: 39 | image: a numpy uint8 array with shape [height, width, 3], 40 | that represents a RGB image. 41 | score_threshold: a float number. 42 | Returns: 43 | boxes: a float numpy array of shape [num_faces, 4]. 44 | scores: a float numpy array of shape [num_faces]. 45 | 46 | Note that box coordinates are in the order: ymin, xmin, ymax, xmax! 47 | """ 48 | h, w, _ = image.shape 49 | image = np.expand_dims(image, 0) 50 | 51 | boxes, scores, num_boxes = self.sess.run( 52 | self.output_ops, feed_dict={self.input_image: image} 53 | ) 54 | num_boxes = num_boxes[0] 55 | boxes = boxes[0][:num_boxes] 56 | scores = scores[0][:num_boxes] 57 | 58 | to_keep = scores > score_threshold 59 | boxes = boxes[to_keep] 60 | scores = scores[to_keep] 61 | 62 | scaler = np.array([h, w, h, w], dtype='float32') 63 | boxes = boxes * scaler 64 | 65 | return boxes, scores 66 | -------------------------------------------------------------------------------- /test_input_pipeline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import tensorflow as tf\n", 20 | "import numpy as np\n", 21 | "from PIL import Image, ImageDraw\n", 22 | "\n", 23 | "from src.input_pipeline import Pipeline" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Get images and boxes" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": { 37 | "scrolled": false 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "tf.reset_default_graph()\n", 42 | "\n", 43 | "pipeline = Pipeline(\n", 44 | " ['data/train_shards/shard-0000.tfrecords'],\n", 45 | " batch_size=24, image_size=(1024, 1024),\n", 46 | " repeat=False, shuffle=False, \n", 47 | " augmentation=True\n", 48 | ")\n", 49 | "features, labels = pipeline.get_batch()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "with tf.Session() as sess:\n", 59 | " I, B, N = sess.run([\n", 60 | " features['images'], labels['boxes'], labels['num_boxes']\n", 61 | " ])" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "# Show an augmented image with boxes" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def draw_boxes(image, boxes):\n", 78 | " image_copy = image.copy()\n", 79 | " draw = ImageDraw.Draw(image_copy, 'RGBA')\n", 80 | " width, height = image.size\n", 81 | "\n", 82 | " for box in boxes:\n", 83 | " ymin, xmin, ymax, xmax = box\n", 84 | " xmin, xmax = width*xmin, width*xmax\n", 85 | " ymin, ymax = height*ymin, height*ymax\n", 86 | "\n", 87 | " fill = (255, 255, 255, 45)\n", 88 | " outline = 'black'\n", 89 | "\n", 90 | " draw.rectangle(\n", 91 | " [(xmin, ymin), (xmax, ymax)],\n", 92 | " fill=fill, outline=outline\n", 93 | " )\n", 94 | "\n", 95 | " return image_copy" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "# choose an image\n", 105 | "i = 0\n", 106 | "image = Image.fromarray((np.transpose(I[i], [1, 2, 0])*255.0).astype('uint8'))\n", 107 | "num_boxes = N[i]\n", 108 | "boxes = B[i][:num_boxes]\n", 109 | "\n", 110 | "draw_boxes(image, boxes)" 111 | ] 112 | } 113 | ], 114 | "metadata": { 115 | "kernelspec": { 116 | "display_name": "Python 3", 117 | "language": "python", 118 | "name": "python3" 119 | }, 120 | "language_info": { 121 | "codemirror_mode": { 122 | "name": "ipython", 123 | "version": 3 124 | }, 125 | "file_extension": ".py", 126 | "mimetype": "text/x-python", 127 | "name": "python", 128 | "nbconvert_exporter": "python", 129 | "pygments_lexer": "ipython3", 130 | "version": "3.6.3" 131 | } 132 | }, 133 | "nbformat": 4, 134 | "nbformat_minor": 2 135 | } 136 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | from src.constants import BATCH_NORM_MOMENTUM 4 | 5 | 6 | class FeatureExtractor: 7 | def __init__(self, is_training): 8 | self.is_training = is_training 9 | 10 | def __call__(self, images): 11 | """ 12 | Arguments: 13 | images: a float tensor with shape [batch_size, height, width, 3], 14 | a batch of RGB images with pixel values in the range [0, 1]. 15 | Returns: 16 | a list of float tensors where the ith tensor 17 | has shape [batch, height_i, width_i, channels_i]. 18 | """ 19 | 20 | def batch_norm(x): 21 | x = tf.layers.batch_normalization( 22 | x, axis=3, center=True, scale=True, 23 | momentum=BATCH_NORM_MOMENTUM, epsilon=0.001, 24 | training=self.is_training, fused=True, 25 | name='batch_norm' 26 | ) 27 | return x 28 | 29 | with tf.name_scope('standardize_input'): 30 | x = preprocess(images) 31 | 32 | # rapidly digested convolutional layers 33 | params = { 34 | 'padding': 'SAME', 35 | 'activation_fn': lambda x: tf.nn.crelu(x, axis=3), 36 | 'normalizer_fn': batch_norm, 'data_format': 'NHWC' 37 | } 38 | with slim.arg_scope([slim.conv2d], **params): 39 | with slim.arg_scope([slim.max_pool2d], stride=2, padding='SAME', data_format='NHWC'): 40 | x = slim.conv2d(x, 24, (7, 7), stride=4, scope='conv1') 41 | x = slim.max_pool2d(x, (3, 3), scope='pool1') 42 | x = slim.conv2d(x, 64, (5, 5), stride=2, scope='conv2') 43 | x = slim.max_pool2d(x, (3, 3), scope='pool2') 44 | 45 | # multiple scale convolutional layers 46 | params = { 47 | 'padding': 'SAME', 'activation_fn': tf.nn.relu, 48 | 'normalizer_fn': batch_norm, 'data_format': 'NHWC' 49 | } 50 | with slim.arg_scope([slim.conv2d], **params): 51 | features = [] # extracted feature maps 52 | x = inception_module(x, scope='inception1') 53 | x = inception_module(x, scope='inception2') 54 | x = inception_module(x, scope='inception3') 55 | features.append(x) # scale 0 56 | x = slim.conv2d(x, 128, (1, 1), scope='conv3_1') 57 | x = slim.conv2d(x, 256, (3, 3), stride=2, scope='conv3_2') 58 | features.append(x) # scale 1 59 | x = slim.conv2d(x, 128, (1, 1), scope='conv4_1') 60 | x = slim.conv2d(x, 256, (3, 3), stride=2, scope='conv4_2') 61 | features.append(x) # scale 2 62 | 63 | return features 64 | 65 | 66 | def preprocess(images): 67 | """Transform images before feeding them to the network.""" 68 | return (2.0*images) - 1.0 69 | 70 | 71 | def inception_module(x, scope): 72 | # path 1 73 | x1 = slim.conv2d(x, 32, (1, 1), scope=scope + '/conv_1x1_path1') 74 | # path 2 75 | y = slim.max_pool2d(x, (3, 3), stride=1, padding='SAME', scope=scope + '/pool_3x3_path2') 76 | x2 = slim.conv2d(y, 32, (1, 1), scope=scope + '/conv_1x1_path2') 77 | # path 3 78 | y = slim.conv2d(x, 24, (1, 1), scope=scope + '/conv_1x1_path3') 79 | x3 = slim.conv2d(y, 32, (3, 3), scope=scope + '/conv_3x3_path3') 80 | # path 4 81 | y = slim.conv2d(x, 24, (1, 1), scope=scope + '/conv_1x1_path4') 82 | y = slim.conv2d(y, 32, (3, 3), scope=scope + '/conv_3x3_path4') 83 | x4 = slim.conv2d(y, 32, (3, 3), scope=scope + '/conv_3x3_second_path4') 84 | return tf.concat([x1, x2, x3, x4], axis=3, name=scope + '/concat') 85 | -------------------------------------------------------------------------------- /try_detector.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import os\n", 20 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'\n", 21 | "\n", 22 | "import numpy as np\n", 23 | "from PIL import Image, ImageDraw\n", 24 | "import os\n", 25 | "import cv2\n", 26 | "import time\n", 27 | "\n", 28 | "from face_detector import FaceDetector" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "MODEL_PATH = 'model.pb'\n", 38 | "face_detector = FaceDetector(MODEL_PATH, gpu_memory_fraction=0.25, visible_device_list='0')" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "# Get an image" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "path = '/home/gpu2/hdd/dan/WIDER/WIDER_train/images/48--Parachutist_Paratrooper/48_Parachutist_Paratrooper_Parachutist_Paratrooper_48_972.jpg'\n", 55 | "\n", 56 | "image_array = cv2.imread(path)\n", 57 | "image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)\n", 58 | "image = Image.fromarray(image_array)\n", 59 | "image" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "# Show detections" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "def draw_boxes_on_image(image, boxes, scores):\n", 76 | "\n", 77 | " image_copy = image.copy()\n", 78 | " draw = ImageDraw.Draw(image_copy, 'RGBA')\n", 79 | " width, height = image.size\n", 80 | "\n", 81 | " for b, s in zip(boxes, scores):\n", 82 | " ymin, xmin, ymax, xmax = b\n", 83 | " fill = (255, 0, 0, 45)\n", 84 | " outline = 'red'\n", 85 | " draw.rectangle(\n", 86 | " [(xmin, ymin), (xmax, ymax)],\n", 87 | " fill=fill, outline=outline\n", 88 | " )\n", 89 | " draw.text((xmin, ymin), text='{:.3f}'.format(s))\n", 90 | " return image_copy" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "scrolled": false 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "boxes, scores = face_detector(image_array, score_threshold=0.3)\n", 102 | "draw_boxes_on_image(Image.fromarray(image_array), boxes, scores)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "# Measure speed" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "times = []\n", 119 | "for _ in range(110):\n", 120 | " start = time.perf_counter()\n", 121 | " boxes, scores = face_detector(image_array, score_threshold=0.25)\n", 122 | " times.append(time.perf_counter() - start)\n", 123 | " \n", 124 | "times = np.array(times)\n", 125 | "times = times[10:]\n", 126 | "print(times.mean(), times.std())" 127 | ] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.6.3" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 1 151 | } 152 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import tensorflow as tf 4 | 5 | from src import Detector, AnchorGenerator, FeatureExtractor 6 | from evaluation_utils import Evaluator 7 | 8 | 9 | def model_fn(features, labels, mode, params, config): 10 | """This is a function for creating a computational tensorflow graph. 11 | The function is in format required by tf.estimator. 12 | """ 13 | 14 | # the base network 15 | is_training = mode == tf.estimator.ModeKeys.TRAIN 16 | feature_extractor = FeatureExtractor(is_training) 17 | 18 | # anchor maker 19 | anchor_generator = AnchorGenerator() 20 | 21 | # add box/label predictors to the feature extractor 22 | detector = Detector(features['images'], feature_extractor, anchor_generator) 23 | 24 | # add NMS to the graph 25 | if not is_training: 26 | predictions = detector.get_predictions( 27 | score_threshold=params['score_threshold'], 28 | iou_threshold=params['iou_threshold'], 29 | max_boxes=params['max_boxes'] 30 | ) 31 | 32 | if mode == tf.estimator.ModeKeys.PREDICT: 33 | # this is required for exporting a savedmodel 34 | export_outputs = tf.estimator.export.PredictOutput({ 35 | name: tf.identity(tensor, name) 36 | for name, tensor in predictions.items() 37 | }) 38 | return tf.estimator.EstimatorSpec( 39 | mode, predictions=predictions, 40 | export_outputs={'outputs': export_outputs} 41 | ) 42 | 43 | # add L2 regularization 44 | with tf.name_scope('weight_decay'): 45 | add_weight_decay(params['weight_decay']) 46 | regularization_loss = tf.losses.get_regularization_loss() 47 | 48 | # create localization and classification losses 49 | losses = detector.loss(labels, params) 50 | tf.losses.add_loss(params['localization_loss_weight'] * losses['localization_loss']) 51 | tf.losses.add_loss(params['classification_loss_weight'] * losses['classification_loss']) 52 | tf.summary.scalar('regularization_loss', regularization_loss) 53 | tf.summary.scalar('localization_loss', losses['localization_loss']) 54 | tf.summary.scalar('classification_loss', losses['classification_loss']) 55 | total_loss = tf.losses.get_total_loss(add_regularization_losses=True) 56 | 57 | if mode == tf.estimator.ModeKeys.EVAL: 58 | 59 | filenames = features['filenames'] 60 | batch_size = filenames.shape.as_list()[0] 61 | assert batch_size == 1 62 | 63 | with tf.name_scope('evaluator'): 64 | evaluator = Evaluator() 65 | eval_metric_ops = evaluator.get_metric_ops(filenames, labels, predictions) 66 | 67 | return tf.estimator.EstimatorSpec( 68 | mode, loss=total_loss, 69 | eval_metric_ops=eval_metric_ops 70 | ) 71 | 72 | assert mode == tf.estimator.ModeKeys.TRAIN 73 | with tf.variable_scope('learning_rate'): 74 | global_step = tf.train.get_global_step() 75 | learning_rate = tf.train.piecewise_constant(global_step, params['lr_boundaries'], params['lr_values']) 76 | tf.summary.scalar('learning_rate', learning_rate) 77 | 78 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 79 | with tf.control_dependencies(update_ops), tf.variable_scope('optimizer'): 80 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9, use_nesterov=True) 81 | grads_and_vars = optimizer.compute_gradients(total_loss) 82 | train_op = optimizer.apply_gradients(grads_and_vars, global_step) 83 | 84 | for g, v in grads_and_vars: 85 | tf.summary.histogram(v.name[:-2] + '_hist', v) 86 | tf.summary.histogram(v.name[:-2] + '_grad_hist', g) 87 | 88 | return tf.estimator.EstimatorSpec(mode, loss=total_loss, train_op=train_op) 89 | 90 | 91 | def add_weight_decay(weight_decay): 92 | """Add L2 regularization to all (or some) trainable kernel weights.""" 93 | weight_decay = tf.constant( 94 | weight_decay, tf.float32, 95 | [], 'weight_decay' 96 | ) 97 | trainable_vars = tf.trainable_variables() 98 | kernels = [v for v in trainable_vars if 'weights' in v.name] 99 | for K in kernels: 100 | x = tf.multiply(weight_decay, tf.nn.l2_loss(K)) 101 | tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, x) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FaceBoxes-tensorflow 2 | 3 | This is an implementation of [FaceBoxes: A CPU Real-time Face Detector with High Accuracy](https://arxiv.org/abs/1708.05234). 4 | I provide full training code, data preparation scripts, and a pretrained model. 5 | The detector has speed **~7 ms/image** (image size is 1024x1024, video card is NVIDIA GeForce GTX 1080). 6 | 7 | ## How to use the pretrained model 8 | 9 | To use the pretrained face detector you will need to download `face_detector.py` and 10 | a frozen inference graph (`.pb` file, it is [here](https://drive.google.com/drive/folders/1DYdxvMXm6n6BsOy4dOTbN9h43F0CoUoK?usp=sharing)). You can see an example of usage in `try_detector.ipynb`. 11 | 12 | Examples of face detections: 13 | ![example1](images/brockhampton.jpg) 14 | ![example1](images/the_office.jpg) 15 | 16 | ## Requirements 17 | 18 | * tensorflow 1.10 (inference was tested using tensorflow 1.12) 19 | * opencv-python, Pillow, tqdm 20 | 21 | ## Notes 22 | 23 | 1. *Warning:* This detector doesn't work well on small faces. 24 | But you can improve its performance if you upscale images before feeding them to the network. 25 | For example, resize an image keeping its aspect ratio so its smaller dimension is 768. 26 | 2. You can see how anchor densification works in `visualize_densified_anchor_boxes.ipynb`. 27 | 3. You can see how my data augmentation works in `test_input_pipeline.ipynb`. 28 | 4. The speed on a CPU is **~25 ms/image** (image size is 1024x768, model is i7-7700K CPU @ 4.20GHz). 29 | 30 | ## How to train 31 | 32 | For training I use `train`+`val` parts of the WIDER dataset. 33 | It is 16106 images in total (12880 + 3226). 34 | For evaluation during the training I use the FDDB dataset (2845 images) and `AP@IOU=0.5` metrics (it is not like in the original FDDB evaluation, but like in PASCAL VOC Challenge). 35 | 36 | 1. Run `prepare_data/explore_and_prepare_WIDER.ipynb` to prepare the WIDER dataset 37 | (also, you will need to combine the two created dataset parts using `cp train_part2/* train/ -a`). 38 | 2. Run `prepare_data/explore_and_prepare_FDDB.ipynb` to prepare the FDDB dataset. 39 | 3. Create tfrecords: 40 | ``` 41 | python create_tfrecords.py \ 42 | --image_dir=/home/gpu2/hdd/dan/WIDER/train/images/ \ 43 | --annotations_dir=/home/gpu2/hdd/dan/WIDER/train/annotations/ \ 44 | --output=data/train_shards/ \ 45 | --num_shards=150 46 | 47 | python create_tfrecords.py \ 48 | --image_dir=/home/gpu2/hdd/dan/FDDB/val/images/ \ 49 | --annotations_dir=/home/gpu2/hdd/dan/FDDB/val/annotations/ \ 50 | --output=data/val_shards/ \ 51 | --num_shards=20 52 | ``` 53 | 4. Run `python train.py` to train a face detector. Evaluation on FDDB will happen periodically. 54 | 5. Run `tensorboard --logdir=models/run00` to observe training and evaluation. 55 | 6. Run `python save.py` and `create_pb.py` to convert the trained model into a `.pb` file. 56 | 7. Use `class` in `face_detector.py` and `.pb` file to do inference. 57 | 8. Also, you can get my final training checkpoint [here](https://drive.google.com/drive/folders/1DYdxvMXm6n6BsOy4dOTbN9h43F0CoUoK?usp=sharing). 58 | 9. The training speed was `~2.6 batches/second` on one NVIDIA GeForce GTX 1080. So total training time is ~26 hours 59 | (and I believe that you can make it much shorter if you optimize the input pipeline). 60 | 61 | Training loss curve looks like this: 62 | ![loss](images/training_loss.png) 63 | 64 | ## How to evaluate on FDDB 65 | 66 | 1. Download the evaluation code from [here](http://vis-www.cs.umass.edu/fddb/results.html). 67 | 2. `tar -zxvf evaluation.tgz; cd evaluation`. 68 | Then compile it using `make` (it can be very tricky to make it work). 69 | 3. Run `predict_for_FDDB.ipynb` to make predictions on the evaluation dataset. 70 | You will get `ellipseList.txt`, `faceList.txt`, `detections.txt`, and `images/`. 71 | 4. Run `./evaluate -a result/ellipseList.txt -d result/detections.txt -i result/images/ -l result/faceList.txt -z .jpg -f 0`. 72 | 5. You will get something like `eval_results/discrete-ROC.txt`. 73 | 6. Run `eval_results/plot_roc.ipynb` to plot the curve. 74 | 75 | Also see this [repository](https://github.com/pkdogcom/fddb-evaluate) and the official [FAQ](http://vis-www.cs.umass.edu/fddb/faq.html) if you have questions about the evaluation. 76 | 77 | ## Results on FDDB 78 | True positive rate at 1000 false positives is `0.902`. 79 | Note that it is lower than in the original paper. 80 | Maybe it's because some hyperparameters are wrong. 81 | Or it's because I didn't do any upscaling of images when evaluating 82 | (I used the original image size). 83 | 84 | ![roc](eval_results/roc.png) 85 | 86 | You can see the whole curve in `eval_results/discrete-ROC.txt` (it's the output of the official evaluation script). 87 | -------------------------------------------------------------------------------- /src/input_pipeline/other_augmentations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | """ 4 | There are various data augmentations for training object detectors. 5 | 6 | `image` is assumed to be a float tensor with shape [height, width, 3], 7 | it is a RGB image with pixel values in range [0, 1]. 8 | """ 9 | 10 | 11 | def random_color_manipulations(image, probability=0.5, grayscale_probability=0.1): 12 | 13 | def manipulate(image): 14 | # intensity and order of this operations are kinda random, 15 | # so you will need to tune this for you problem 16 | image = tf.image.random_brightness(image, 0.1) 17 | image = tf.image.random_contrast(image, 0.8, 1.2) 18 | image = tf.image.random_hue(image, 0.1) 19 | image = tf.image.random_saturation(image, 0.8, 1.2) 20 | image = tf.clip_by_value(image, 0.0, 1.0) 21 | return image 22 | 23 | def to_grayscale(image): 24 | image = tf.image.rgb_to_grayscale(image) 25 | image = tf.image.grayscale_to_rgb(image) 26 | return image 27 | 28 | with tf.name_scope('random_color_manipulations'): 29 | do_it = tf.less(tf.random_uniform([]), probability) 30 | image = tf.cond(do_it, lambda: manipulate(image), lambda: image) 31 | 32 | with tf.name_scope('to_grayscale'): 33 | make_gray = tf.less(tf.random_uniform([]), grayscale_probability) 34 | image = tf.cond(make_gray, lambda: to_grayscale(image), lambda: image) 35 | 36 | return image 37 | 38 | 39 | def random_flip_left_right(image, boxes): 40 | 41 | def flip(image, boxes): 42 | flipped_image = tf.image.flip_left_right(image) 43 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 44 | flipped_xmin = tf.subtract(1.0, xmax) 45 | flipped_xmax = tf.subtract(1.0, xmin) 46 | flipped_boxes = tf.stack([ymin, flipped_xmin, ymax, flipped_xmax], 1) 47 | return flipped_image, flipped_boxes 48 | 49 | with tf.name_scope('random_flip_left_right'): 50 | do_it = tf.less(tf.random_uniform([]), 0.5) 51 | image, boxes = tf.cond(do_it, lambda: flip(image, boxes), lambda: (image, boxes)) 52 | return image, boxes 53 | 54 | 55 | def random_pixel_value_scale(image, minval=0.9, maxval=1.1, probability=0.5): 56 | """This function scales each pixel independently of the other ones. 57 | 58 | Arguments: 59 | image: a float tensor with shape [height, width, 3], 60 | an image with pixel values varying between [0, 1]. 61 | minval: a float number, lower ratio of scaling pixel values. 62 | maxval: a float number, upper ratio of scaling pixel values. 63 | probability: a float number. 64 | Returns: 65 | a float tensor with shape [height, width, 3]. 66 | """ 67 | def random_value_scale(image): 68 | color_coefficient = tf.random_uniform( 69 | tf.shape(image), minval=minval, 70 | maxval=maxval, dtype=tf.float32 71 | ) 72 | image = tf.multiply(image, color_coefficient) 73 | image = tf.clip_by_value(image, 0.0, 1.0) 74 | return image 75 | 76 | with tf.name_scope('random_pixel_value_scale'): 77 | do_it = tf.less(tf.random_uniform([]), probability) 78 | image = tf.cond(do_it, lambda: random_value_scale(image), lambda: image) 79 | return image 80 | 81 | 82 | def random_jitter_boxes(boxes, ratio=0.05): 83 | """Randomly jitter bounding boxes. 84 | 85 | Arguments: 86 | boxes: a float tensor with shape [N, 4]. 87 | ratio: a float number. 88 | The ratio of the box width and height that the corners can jitter. 89 | For example if the width is 100 pixels and ratio is 0.05, 90 | the corners can jitter up to 5 pixels in the x direction. 91 | Returns: 92 | a float tensor with shape [N, 4]. 93 | """ 94 | def random_jitter_box(box, ratio): 95 | """Randomly jitter a box. 96 | Arguments: 97 | box: a float tensor with shape [4]. 98 | ratio: a float number. 99 | Returns: 100 | a float tensor with shape [4]. 101 | """ 102 | ymin, xmin, ymax, xmax = [box[i] for i in range(4)] 103 | box_height, box_width = ymax - ymin, xmax - xmin 104 | hw_coefs = tf.stack([box_height, box_width, box_height, box_width]) 105 | 106 | rand_numbers = tf.random_uniform( 107 | [4], minval=-ratio, maxval=ratio, dtype=tf.float32 108 | ) 109 | hw_rand_coefs = tf.multiply(hw_coefs, rand_numbers) 110 | 111 | jittered_box = tf.add(box, hw_rand_coefs) 112 | return jittered_box 113 | 114 | with tf.name_scope('random_jitter_boxes'): 115 | distorted_boxes = tf.map_fn( 116 | lambda x: random_jitter_box(x, ratio), 117 | boxes, dtype=tf.float32, back_prop=False 118 | ) 119 | distorted_boxes = tf.clip_by_value(distorted_boxes, 0.0, 1.0) 120 | return distorted_boxes 121 | -------------------------------------------------------------------------------- /src/training_target_creation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.utils import encode, iou 3 | 4 | 5 | def get_training_targets(anchors, groundtruth_boxes, threshold=0.5): 6 | """ 7 | Arguments: 8 | anchors: a float tensor with shape [num_anchors, 4]. 9 | groundtruth_boxes: a float tensor with shape [N, 4]. 10 | threshold: a float number. 11 | Returns: 12 | reg_targets: a float tensor with shape [num_anchors, 4]. 13 | matches: an int tensor with shape [num_anchors], possible values 14 | that it can contain are [-1, 0, 1, 2, ..., (N - 1)]. 15 | """ 16 | with tf.name_scope('matching'): 17 | N = tf.shape(groundtruth_boxes)[0] 18 | num_anchors = tf.shape(anchors)[0] 19 | no_match_tensor = tf.fill([num_anchors], -1) 20 | matches = tf.cond( 21 | tf.greater(N, 0), 22 | lambda: _match(anchors, groundtruth_boxes, threshold), 23 | lambda: no_match_tensor 24 | ) 25 | matches = tf.to_int32(matches) 26 | 27 | with tf.name_scope('regression_target_creation'): 28 | reg_targets = _create_targets( 29 | anchors, groundtruth_boxes, matches 30 | ) 31 | 32 | return reg_targets, matches 33 | 34 | 35 | def _match(anchors, groundtruth_boxes, threshold=0.5): 36 | """Matching algorithm: 37 | 1) for each groundtruth box choose the anchor with largest iou, 38 | 2) remove this set of anchors from the set of all anchors, 39 | 3) for each remaining anchor choose the groundtruth box with largest iou, 40 | but only if this iou is larger than `threshold`. 41 | 42 | Note: after step 1, it could happen that for some two groundtruth boxes 43 | chosen anchors are the same. Let's hope this never happens. 44 | Also see the comments below. 45 | 46 | Arguments: 47 | anchors: a float tensor with shape [num_anchors, 4]. 48 | groundtruth_boxes: a float tensor with shape [N, 4]. 49 | threshold: a float number. 50 | Returns: 51 | an int tensor with shape [num_anchors]. 52 | """ 53 | num_anchors = tf.shape(anchors)[0] 54 | 55 | # for each anchor box choose the groundtruth box with largest iou 56 | similarity_matrix = iou(groundtruth_boxes, anchors) # shape [N, num_anchors] 57 | matches = tf.argmax(similarity_matrix, axis=0, output_type=tf.int32) # shape [num_anchors] 58 | matched_vals = tf.reduce_max(similarity_matrix, axis=0) # shape [num_anchors] 59 | below_threshold = tf.to_int32(tf.greater(threshold, matched_vals)) 60 | matches = tf.add(tf.multiply(matches, 1 - below_threshold), -1 * below_threshold) 61 | # after this, it could happen that some groundtruth 62 | # boxes are not matched with any anchor box 63 | 64 | # now we must ensure that each row (groundtruth box) is matched to 65 | # at least one column (which is not guaranteed 66 | # otherwise if `threshold` is high) 67 | 68 | # for each groundtruth box choose the anchor box with largest iou 69 | # (force match for each groundtruth box) 70 | forced_matches_ids = tf.argmax(similarity_matrix, axis=1, output_type=tf.int32) # shape [N] 71 | # if all indices in forced_matches_ids are different then all rows will be matched 72 | 73 | forced_matches_indicators = tf.one_hot(forced_matches_ids, depth=num_anchors, dtype=tf.int32) # shape [N, num_anchors] 74 | forced_match_row_ids = tf.argmax(forced_matches_indicators, axis=0, output_type=tf.int32) # shape [num_anchors] 75 | forced_match_mask = tf.greater(tf.reduce_max(forced_matches_indicators, axis=0), 0) # shape [num_anchors] 76 | matches = tf.where(forced_match_mask, forced_match_row_ids, matches) 77 | # even after this it could happen that some rows aren't matched, 78 | # but i believe that this event has low probability 79 | 80 | return matches 81 | 82 | 83 | def _create_targets(anchors, groundtruth_boxes, matches): 84 | """Returns regression targets for each anchor. 85 | 86 | Arguments: 87 | anchors: a float tensor with shape [num_anchors, 4]. 88 | groundtruth_boxes: a float tensor with shape [N, 4]. 89 | matches: a int tensor with shape [num_anchors]. 90 | Returns: 91 | reg_targets: a float tensor with shape [num_anchors, 4]. 92 | """ 93 | matched_anchor_indices = tf.where(tf.greater_equal(matches, 0)) # shape [num_matches, 1] 94 | matched_anchor_indices = tf.squeeze(matched_anchor_indices, axis=1) 95 | matched_gt_indices = tf.gather(matches, matched_anchor_indices) # shape [num_matches] 96 | 97 | matched_anchors = tf.gather(anchors, matched_anchor_indices) # shape [num_matches, 4] 98 | matched_gt_boxes = tf.gather(groundtruth_boxes, matched_gt_indices) # shape [num_matches, 4] 99 | matched_reg_targets = encode(matched_gt_boxes, matched_anchors) # shape [num_matches, 4] 100 | 101 | unmatched_anchor_indices = tf.where(tf.equal(matches, -1)) 102 | unmatched_anchor_indices = tf.squeeze(unmatched_anchor_indices, axis=1) 103 | # it has shape [num_anchors - num_matches] 104 | 105 | unmatched_reg_targets = tf.zeros([tf.size(unmatched_anchor_indices), 4]) 106 | # it has shape [num_anchors - num_matches, 4] 107 | 108 | matched_anchor_indices = tf.to_int32(matched_anchor_indices) 109 | unmatched_anchor_indices = tf.to_int32(unmatched_anchor_indices) 110 | 111 | reg_targets = tf.dynamic_stitch( 112 | [matched_anchor_indices, unmatched_anchor_indices], 113 | [matched_reg_targets, unmatched_reg_targets] 114 | ) 115 | return reg_targets 116 | -------------------------------------------------------------------------------- /create_tfrecords.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import PIL.Image 4 | import tensorflow as tf 5 | import json 6 | import shutil 7 | import random 8 | import math 9 | import argparse 10 | from tqdm import tqdm 11 | import sys 12 | 13 | 14 | """ 15 | The purpose of this script is to create a set of .tfrecords files 16 | from a folder of images and a folder of annotations. 17 | Annotations are in the json format. 18 | Images must have .jpg or .jpeg filename extension. 19 | 20 | Example of a json annotation (with filename "132416.json"): 21 | { 22 | "object": [ 23 | {"bndbox": {"ymin": 20, "ymax": 276, "xmax": 1219, "xmin": 1131}, "name": "face"}, 24 | {"bndbox": {"ymin": 1, "ymax": 248, "xmax": 1149, "xmin": 1014}, "name": "face"} 25 | ], 26 | "filename": "132416.jpg", 27 | "size": {"depth": 3, "width": 1920, "height": 1080} 28 | } 29 | 30 | Example of use: 31 | python create_tfrecords.py \ 32 | --image_dir=/home/gpu2/hdd/dan/WIDER/val/images/ \ 33 | --annotations_dir=/home/gpu2/hdd/dan/WIDER/val/annotations/ \ 34 | --output=data/train_shards/ \ 35 | --num_shards=100 36 | """ 37 | 38 | 39 | def make_args(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('-i', '--image_dir', type=str) 42 | parser.add_argument('-a', '--annotations_dir', type=str) 43 | parser.add_argument('-o', '--output', type=str) 44 | parser.add_argument('-s', '--num_shards', type=int, default=1) 45 | return parser.parse_args() 46 | 47 | 48 | def dict_to_tf_example(annotation, image_dir): 49 | """Convert dict to tf.Example proto. 50 | 51 | Notice that this function normalizes the bounding 52 | box coordinates provided by the raw data. 53 | 54 | Arguments: 55 | data: a dict. 56 | image_dir: a string, path to the image directory. 57 | Returns: 58 | an instance of tf.Example. 59 | """ 60 | image_name = annotation['filename'] 61 | assert image_name.endswith('.jpg') or image_name.endswith('.jpeg') 62 | 63 | image_path = os.path.join(image_dir, image_name) 64 | with tf.gfile.GFile(image_path, 'rb') as f: 65 | encoded_jpg = f.read() 66 | 67 | # check image format 68 | encoded_jpg_io = io.BytesIO(encoded_jpg) 69 | image = PIL.Image.open(encoded_jpg_io) 70 | if image.format != 'JPEG': 71 | raise ValueError('Image format not JPEG!') 72 | 73 | width = int(annotation['size']['width']) 74 | height = int(annotation['size']['height']) 75 | assert width > 0 and height > 0 76 | assert image.size[0] == width and image.size[1] == height 77 | ymin, xmin, ymax, xmax = [], [], [], [] 78 | 79 | just_name = image_name[:-4] if image_name.endswith('.jpg') else image_name[:-5] 80 | annotation_name = just_name + '.json' 81 | if len(annotation['object']) == 0: 82 | print(annotation_name, 'is without any objects!') 83 | 84 | for obj in annotation['object']: 85 | a = float(obj['bndbox']['ymin'])/height 86 | b = float(obj['bndbox']['xmin'])/width 87 | c = float(obj['bndbox']['ymax'])/height 88 | d = float(obj['bndbox']['xmax'])/width 89 | assert (a < c) and (b < d) 90 | ymin.append(a) 91 | xmin.append(b) 92 | ymax.append(c) 93 | xmax.append(d) 94 | assert obj['name'] == 'face' 95 | 96 | example = tf.train.Example(features=tf.train.Features(feature={ 97 | 'filename': _bytes_feature(image_name.encode()), 98 | 'image': _bytes_feature(encoded_jpg), 99 | 'xmin': _float_list_feature(xmin), 100 | 'xmax': _float_list_feature(xmax), 101 | 'ymin': _float_list_feature(ymin), 102 | 'ymax': _float_list_feature(ymax), 103 | })) 104 | return example 105 | 106 | 107 | def _bytes_feature(value): 108 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 109 | 110 | 111 | def _float_list_feature(value): 112 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 113 | 114 | 115 | def main(): 116 | ARGS = make_args() 117 | 118 | image_dir = ARGS.image_dir 119 | annotations_dir = ARGS.annotations_dir 120 | print('Reading images from:', image_dir) 121 | print('Reading annotations from:', annotations_dir, '\n') 122 | 123 | examples_list = os.listdir(annotations_dir) 124 | num_examples = len(examples_list) 125 | print('Number of images:', num_examples) 126 | 127 | num_shards = ARGS.num_shards 128 | shard_size = math.ceil(num_examples/num_shards) 129 | print('Number of images per shard:', shard_size) 130 | 131 | output_dir = ARGS.output 132 | shutil.rmtree(output_dir, ignore_errors=True) 133 | os.mkdir(output_dir) 134 | 135 | shard_id = 0 136 | num_examples_written = 0 137 | for example in tqdm(examples_list): 138 | 139 | if num_examples_written == 0: 140 | shard_path = os.path.join(output_dir, 'shard-%04d.tfrecords' % shard_id) 141 | writer = tf.python_io.TFRecordWriter(shard_path) 142 | 143 | path = os.path.join(annotations_dir, example) 144 | annotation = json.load(open(path)) 145 | tf_example = dict_to_tf_example(annotation, image_dir) 146 | writer.write(tf_example.SerializeToString()) 147 | num_examples_written += 1 148 | 149 | if num_examples_written == shard_size: 150 | shard_id += 1 151 | num_examples_written = 0 152 | writer.close() 153 | 154 | if num_examples_written != shard_size and num_examples % num_shards != 0: 155 | writer.close() 156 | 157 | print('Result is here:', ARGS.output) 158 | 159 | 160 | main() 161 | -------------------------------------------------------------------------------- /predict_for_FDDB.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import os\n", 21 | "import cv2\n", 22 | "from tqdm import tqdm\n", 23 | "import random\n", 24 | "import shutil\n", 25 | "\n", 26 | "from face_detector import FaceDetector" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "IMAGES_DIR = '/home/gpu2/hdd/dan/FDDB/originalPics/'\n", 36 | "ANNOTATIONS_PATH = '/home/gpu2/hdd/dan/FDDB/FDDB-folds/'\n", 37 | "RESULT_DIR = 'result/'\n", 38 | "MODEL_PATH = 'model.pb'" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "# Collect annotated images" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "annotations = [s for s in os.listdir(ANNOTATIONS_PATH) if s.endswith('ellipseList.txt')]\n", 55 | "image_lists = [s for s in os.listdir(ANNOTATIONS_PATH) if not s.endswith('ellipseList.txt')]\n", 56 | "annotations = sorted(annotations)\n", 57 | "image_lists = sorted(image_lists)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "images_to_use = []\n", 67 | "for n in image_lists:\n", 68 | " with open(os.path.join(ANNOTATIONS_PATH, n)) as f:\n", 69 | " images_to_use.extend(f.readlines())" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "images_to_use = [s.strip() for s in images_to_use]\n", 79 | "with open(os.path.join(RESULT_DIR, 'faceList.txt'), 'w') as f:\n", 80 | " for p in images_to_use:\n", 81 | " f.write(p + '\\n')" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "ellipses = []\n", 91 | "for n in annotations:\n", 92 | " with open(os.path.join(ANNOTATIONS_PATH, n)) as f:\n", 93 | " ellipses.extend(f.readlines())" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "i = 0\n", 103 | "with open(os.path.join(RESULT_DIR, 'ellipseList.txt'), 'w') as f:\n", 104 | " for p in ellipses:\n", 105 | " \n", 106 | " # check image order\n", 107 | " if 'big/img' in p:\n", 108 | " assert images_to_use[i] in p\n", 109 | " i += 1\n", 110 | "\n", 111 | " f.write(p)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "# Predict using trained detector" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "face_detector = FaceDetector(MODEL_PATH, gpu_memory_fraction=0.25, visible_device_list='0')" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "predictions = []\n", 137 | "for n in tqdm(images_to_use):\n", 138 | " image_array = cv2.imread(os.path.join(IMAGES_DIR, n) + '.jpg')\n", 139 | " image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)\n", 140 | " # threshold is important to set low\n", 141 | " boxes, scores = face_detector(image_array, score_threshold=0.05)\n", 142 | " predictions.append((n, boxes, scores))" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "with open(os.path.join(RESULT_DIR, 'detections.txt'), 'w') as f:\n", 152 | " for n, boxes, scores in predictions:\n", 153 | " f.write(n + '\\n')\n", 154 | " f.write(str(len(boxes)) + '\\n')\n", 155 | " for b, s in zip(boxes, scores):\n", 156 | " ymin, xmin, ymax, xmax = b\n", 157 | " h, w = int(ymax - ymin), int(xmax - xmin)\n", 158 | " f.write('{0} {1} {2} {3} {4:.4f}\\n'.format(int(xmin), int(ymin), w, h, s))" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "# Copy images" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "for n in tqdm(images_to_use):\n", 175 | " p = os.path.join(RESULT_DIR, 'images', n + '.jpg')\n", 176 | " os.makedirs(os.path.dirname(p), exist_ok=True)\n", 177 | " shutil.copy(os.path.join(IMAGES_DIR, n) + '.jpg', p)" 178 | ] 179 | } 180 | ], 181 | "metadata": { 182 | "kernelspec": { 183 | "display_name": "Python 3", 184 | "language": "python", 185 | "name": "python3" 186 | }, 187 | "language_info": { 188 | "codemirror_mode": { 189 | "name": "ipython", 190 | "version": 3 191 | }, 192 | "file_extension": ".py", 193 | "mimetype": "text/x-python", 194 | "name": "python", 195 | "nbconvert_exporter": "python", 196 | "pygments_lexer": "ipython3", 197 | "version": "3.6.3" 198 | } 199 | }, 200 | "nbformat": 4, 201 | "nbformat_minor": 1 202 | } 203 | -------------------------------------------------------------------------------- /visualize_densified_anchor_boxes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import tensorflow as tf\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "%matplotlib inline\n", 23 | "\n", 24 | "from src.anchor_generator import tile_anchors" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "WIDTH, HEIGHT = 512, 1024\n", 34 | "GRID_WIDTH, GRID_HEIGHT = 16, 32 # stride 32 or scale 0 in the face detector\n", 35 | "# GRID_WIDTH, GRID_HEIGHT = 8, 16 # stride 64 or scale 1 in the face detector\n", 36 | "# GRID_WIDTH, GRID_HEIGHT = 4, 8 # stride 128 or scale 2 in the face detector" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Generate anchors" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "tf.reset_default_graph()\n", 53 | "\n", 54 | "n = 4\n", 55 | "anchors = tile_anchors(\n", 56 | " (WIDTH, HEIGHT), GRID_HEIGHT, GRID_WIDTH,\n", 57 | " scale=32, aspect_ratio=1.0, \n", 58 | " anchor_stride=(1.0/GRID_HEIGHT, 1.0/GRID_WIDTH), \n", 59 | " anchor_offset=(0.5/GRID_HEIGHT, 0.5/GRID_WIDTH), \n", 60 | " n=n\n", 61 | ")" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "with tf.Session() as sess:\n", 71 | " anchor_boxes = sess.run(anchors)\n", 72 | "\n", 73 | "scaler = np.array([HEIGHT, WIDTH, HEIGHT, WIDTH], dtype='float32')\n", 74 | "anchor_boxes = anchor_boxes*scaler # shape [GRID_HEIGHT, GRID_WIDTH, n*n, 4]" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "# Show some non clipped anchors" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def get_grid_centers():\n", 91 | " anchor_stride = (1.0/GRID_HEIGHT, 1.0/GRID_WIDTH)\n", 92 | " anchor_offset = (0.5/GRID_HEIGHT, 0.5/GRID_WIDTH)\n", 93 | "\n", 94 | " y_center = np.arange(GRID_HEIGHT, dtype='float32') * anchor_stride[0] + anchor_offset[0]\n", 95 | " x_center = np.arange(GRID_WIDTH, dtype='float32') * anchor_stride[1] + anchor_offset[1]\n", 96 | " x_center, y_center = np.meshgrid(x_center, y_center)\n", 97 | " # they have shape [grid_height, grid_width]\n", 98 | "\n", 99 | " centers = np.stack([y_center, x_center], axis=2)\n", 100 | " scaler = np.array([HEIGHT, WIDTH], dtype='float32')\n", 101 | " centers = centers*scaler\n", 102 | " return centers\n", 103 | "\n", 104 | "\n", 105 | "def plot(anchor_boxes, cell_to_show):\n", 106 | " fig, ax = plt.subplots(1, dpi=120, figsize=(int(8*WIDTH/HEIGHT), 8))\n", 107 | "\n", 108 | " grid_centers = get_grid_centers()\n", 109 | " for point in grid_centers.reshape(-1, 2):\n", 110 | " cy, cx = point\n", 111 | " ax.plot([cx], [cy], marker='.', markersize=1, color='r')\n", 112 | " \n", 113 | " iy, ix = cell_to_show\n", 114 | " cy, cx = grid_centers[iy, ix, :]\n", 115 | " ax.plot([cx], [cy], marker='.', markersize=5, color='r')\n", 116 | " \n", 117 | " cy, cx, h, w = [anchor_boxes[:, :, :, i] for i in range(4)]\n", 118 | " centers = np.stack([cy, cx], axis=3)\n", 119 | " anchor_sizes = np.stack([h, w], axis=3)\n", 120 | "\n", 121 | " centers = centers[iy, ix, :, :]\n", 122 | " anchor_sizes = anchor_sizes[iy, ix, :, :]\n", 123 | " \n", 124 | " to_show = [1, 4, 15]\n", 125 | " for i, center, box in zip(range(len(centers)), centers, anchor_sizes):\n", 126 | " \n", 127 | " h, w = box\n", 128 | " cy, cx = center\n", 129 | " xmin, ymin = cx - 0.5*w, cy - 0.5*h\n", 130 | "\n", 131 | " linestyle = '-' if i in to_show else '--' \n", 132 | " random_color = np.random.rand(3,)\n", 133 | " color = random_color if i in to_show else 'k'\n", 134 | " alpha = 1.0 if i in to_show else 0.5\n", 135 | " linewidth = 2.0 if i in to_show else 0.7\n", 136 | "\n", 137 | " rect = plt.Rectangle(\n", 138 | " (xmin, ymin), w, h,\n", 139 | " linewidth=linewidth, edgecolor=color, \n", 140 | " facecolor='none', linestyle=linestyle,\n", 141 | " alpha=alpha\n", 142 | " )\n", 143 | " if i in to_show:\n", 144 | " ax.plot([cx], [cy], marker='s', markersize=7, color=random_color)\n", 145 | " ax.add_patch(rect)\n", 146 | " \n", 147 | " # why not ax.axis('equal')?\n", 148 | " ax.set_ylim([0, HEIGHT])\n", 149 | " ax.set_xlim([0, WIDTH])\n", 150 | " ax.invert_yaxis()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "plot(anchor_boxes, cell_to_show=(2, 3))" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "Python 3", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.6.3" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /src/utils/box_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.constants import EPSILON, SCALE_FACTORS 3 | 4 | 5 | """ 6 | Tools for dealing with bounding boxes. 7 | All boxes are of the format [ymin, xmin, ymax, xmax] 8 | if not stated otherwise. 9 | And box coordinates are normalized to [0, 1] range. 10 | """ 11 | 12 | 13 | def iou(boxes1, boxes2): 14 | """Computes pairwise intersection-over-union between two box collections. 15 | 16 | Arguments: 17 | boxes1: a float tensor with shape [N, 4]. 18 | boxes2: a float tensor with shape [M, 4]. 19 | Returns: 20 | a float tensor with shape [N, M] representing pairwise iou scores. 21 | """ 22 | with tf.name_scope('iou'): 23 | intersections = intersection(boxes1, boxes2) 24 | areas1 = area(boxes1) 25 | areas2 = area(boxes2) 26 | unions = tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections 27 | return tf.clip_by_value(tf.divide(intersections, unions), 0.0, 1.0) 28 | 29 | 30 | def intersection(boxes1, boxes2): 31 | """Compute pairwise intersection areas between boxes. 32 | 33 | Arguments: 34 | boxes1: a float tensor with shape [N, 4]. 35 | boxes2: a float tensor with shape [M, 4]. 36 | Returns: 37 | a float tensor with shape [N, M] representing pairwise intersections. 38 | """ 39 | with tf.name_scope('intersection'): 40 | 41 | ymin1, xmin1, ymax1, xmax1 = tf.split(boxes1, num_or_size_splits=4, axis=1) 42 | ymin2, xmin2, ymax2, xmax2 = tf.split(boxes2, num_or_size_splits=4, axis=1) 43 | # they all have shapes like [None, 1] 44 | 45 | all_pairs_min_ymax = tf.minimum(ymax1, tf.transpose(ymax2)) 46 | all_pairs_max_ymin = tf.maximum(ymin1, tf.transpose(ymin2)) 47 | intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin) 48 | all_pairs_min_xmax = tf.minimum(xmax1, tf.transpose(xmax2)) 49 | all_pairs_max_xmin = tf.maximum(xmin1, tf.transpose(xmin2)) 50 | intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin) 51 | # they all have shape [N, M] 52 | 53 | return intersect_heights * intersect_widths 54 | 55 | 56 | def area(boxes): 57 | """Computes area of boxes. 58 | 59 | Arguments: 60 | boxes: a float tensor with shape [N, 4]. 61 | Returns: 62 | a float tensor with shape [N] representing box areas. 63 | """ 64 | with tf.name_scope('area'): 65 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 66 | return (ymax - ymin) * (xmax - xmin) 67 | 68 | 69 | def to_minmax_coordinates(boxes): 70 | """Convert bounding boxes of the format 71 | [cy, cx, h, w] to the format [ymin, xmin, ymax, xmax]. 72 | 73 | Arguments: 74 | boxes: a list of float tensors with shape [N] 75 | that represent cy, cx, h, w. 76 | Returns: 77 | a list of float tensors with shape [N] 78 | that represent ymin, xmin, ymax, xmax. 79 | """ 80 | with tf.name_scope('to_minmax_coordinates'): 81 | cy, cx, h, w = boxes 82 | ymin, xmin = cy - 0.5*h, cx - 0.5*w 83 | ymax, xmax = cy + 0.5*h, cx + 0.5*w 84 | return [ymin, xmin, ymax, xmax] 85 | 86 | 87 | def to_center_coordinates(boxes): 88 | """Convert bounding boxes of the format 89 | [ymin, xmin, ymax, xmax] to the format [cy, cx, h, w]. 90 | 91 | Arguments: 92 | boxes: a list of float tensors with shape [N] 93 | that represent ymin, xmin, ymax, xmax. 94 | Returns: 95 | a list of float tensors with shape [N] 96 | that represent cy, cx, h, w. 97 | """ 98 | with tf.name_scope('to_center_coordinates'): 99 | ymin, xmin, ymax, xmax = boxes 100 | h = ymax - ymin 101 | w = xmax - xmin 102 | cy = ymin + 0.5*h 103 | cx = xmin + 0.5*w 104 | return [cy, cx, h, w] 105 | 106 | 107 | def encode(boxes, anchors): 108 | """Encode boxes with respect to anchors. 109 | 110 | Arguments: 111 | boxes: a float tensor with shape [N, 4]. 112 | anchors: a float tensor with shape [N, 4]. 113 | Returns: 114 | a float tensor with shape [N, 4], 115 | anchor-encoded boxes of the format [ty, tx, th, tw]. 116 | """ 117 | with tf.name_scope('encode_groundtruth'): 118 | 119 | ycenter_a, xcenter_a, ha, wa = to_center_coordinates(tf.unstack(anchors, axis=1)) 120 | ycenter, xcenter, h, w = to_center_coordinates(tf.unstack(boxes, axis=1)) 121 | 122 | # to avoid NaN in division and log below 123 | ha += EPSILON 124 | wa += EPSILON 125 | h += EPSILON 126 | w += EPSILON 127 | 128 | tx = (xcenter - xcenter_a)/wa 129 | ty = (ycenter - ycenter_a)/ha 130 | tw = tf.log(w / wa) 131 | th = tf.log(h / ha) 132 | 133 | ty *= SCALE_FACTORS[0] 134 | tx *= SCALE_FACTORS[1] 135 | th *= SCALE_FACTORS[2] 136 | tw *= SCALE_FACTORS[3] 137 | 138 | return tf.stack([ty, tx, th, tw], axis=1) 139 | 140 | 141 | def decode(codes, anchors): 142 | """Decode relative codes to boxes. 143 | 144 | Arguments: 145 | codes: a float tensor with shape [N, 4], 146 | anchor-encoded boxes of the format [ty, tx, th, tw]. 147 | anchors: a float tensor with shape [N, 4]. 148 | Returns: 149 | a float tensor with shape [N, 4], 150 | bounding boxes of the format [ymin, xmin, ymax, xmax]. 151 | """ 152 | with tf.name_scope('decode_predictions'): 153 | 154 | ycenter_a, xcenter_a, ha, wa = to_center_coordinates(tf.unstack(anchors, axis=1)) 155 | ty, tx, th, tw = tf.unstack(codes, axis=1) 156 | 157 | ty /= SCALE_FACTORS[0] 158 | tx /= SCALE_FACTORS[1] 159 | th /= SCALE_FACTORS[2] 160 | tw /= SCALE_FACTORS[3] 161 | w = tf.exp(tw) * wa 162 | h = tf.exp(th) * ha 163 | ycenter = ty * ha + ycenter_a 164 | xcenter = tx * wa + xcenter_a 165 | 166 | return tf.stack(to_minmax_coordinates([ycenter, xcenter, h, w]), axis=1) 167 | 168 | 169 | def batch_decode(box_encodings, anchors): 170 | """Decodes a batch of box encodings with respect to the anchors. 171 | 172 | Arguments: 173 | box_encodings: a float tensor with shape [batch_size, num_anchors, 4]. 174 | anchors: a float tensor with shape [num_anchors, 4]. 175 | Returns: 176 | a float tensor with shape [batch_size, num_anchors, 4]. 177 | It contains the decoded boxes. 178 | """ 179 | batch_size = tf.shape(box_encodings)[0] 180 | num_anchors = tf.shape(box_encodings)[1] 181 | 182 | tiled_anchor_boxes = tf.tile( 183 | tf.expand_dims(anchors, 0), 184 | [batch_size, 1, 1] 185 | ) # shape [batch_size, num_anchors, 4] 186 | decoded_boxes = decode( 187 | tf.reshape(box_encodings, [-1, 4]), 188 | tf.reshape(tiled_anchor_boxes, [-1, 4]) 189 | ) # shape [batch_size * num_anchors, 4] 190 | 191 | decoded_boxes = tf.reshape( 192 | decoded_boxes, 193 | [batch_size, num_anchors, 4] 194 | ) 195 | decoded_boxes = tf.clip_by_value(decoded_boxes, 0.0, 1.0) 196 | return decoded_boxes 197 | -------------------------------------------------------------------------------- /src/input_pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from src.constants import SHUFFLE_BUFFER_SIZE, NUM_THREADS, RESIZE_METHOD 4 | from src.input_pipeline.random_image_crop import random_image_crop 5 | from src.input_pipeline.other_augmentations import random_color_manipulations,\ 6 | random_flip_left_right, random_pixel_value_scale, random_jitter_boxes 7 | 8 | 9 | class Pipeline: 10 | """Input pipeline for training or evaluating object detectors.""" 11 | 12 | def __init__(self, filenames, batch_size, image_size, 13 | repeat=False, shuffle=False, augmentation=False): 14 | """ 15 | Note: when evaluating set batch_size to 1. 16 | 17 | Arguments: 18 | filenames: a list of strings, paths to tfrecords files. 19 | batch_size: an integer. 20 | image_size: a list with two integers [width, height] or None, 21 | images of this size will be in a batch. 22 | If value is None then images will not be resized. 23 | In this case batch size must be 1. 24 | repeat: a boolean, whether repeat indefinitely. 25 | shuffle: whether to shuffle the dataset. 26 | augmentation: whether to do data augmentation. 27 | """ 28 | if image_size is not None: 29 | self.image_width, self.image_height = image_size 30 | self.resize = True 31 | else: 32 | assert batch_size == 1 33 | self.image_width, self.image_height = None, None 34 | self.resize = False 35 | 36 | self.augmentation = augmentation 37 | self.batch_size = batch_size 38 | 39 | def get_num_samples(filename): 40 | return sum(1 for _ in tf.python_io.tf_record_iterator(filename)) 41 | 42 | num_examples = 0 43 | for filename in filenames: 44 | num_examples_in_file = get_num_samples(filename) 45 | assert num_examples_in_file > 0 46 | num_examples += num_examples_in_file 47 | self.num_examples = num_examples 48 | assert self.num_examples > 0 49 | 50 | dataset = tf.data.Dataset.from_tensor_slices(filenames) 51 | num_shards = len(filenames) 52 | 53 | if shuffle: 54 | dataset = dataset.shuffle(buffer_size=num_shards) 55 | 56 | dataset = dataset.flat_map(tf.data.TFRecordDataset) 57 | dataset = dataset.prefetch(buffer_size=batch_size) 58 | 59 | if shuffle: 60 | dataset = dataset.shuffle(buffer_size=SHUFFLE_BUFFER_SIZE) 61 | dataset = dataset.repeat(None if repeat else 1) 62 | dataset = dataset.map(self._parse_and_preprocess, num_parallel_calls=NUM_THREADS) 63 | 64 | # we need batches of fixed size 65 | padded_shapes = ([self.image_height, self.image_width, 3], [None, 4], [], []) 66 | dataset = dataset.apply( 67 | tf.contrib.data.padded_batch_and_drop_remainder(batch_size, padded_shapes) 68 | ) 69 | dataset = dataset.prefetch(buffer_size=1) 70 | 71 | self.iterator = dataset.make_one_shot_iterator() 72 | 73 | def get_batch(self): 74 | """ 75 | Returns: 76 | features: a dict with the following keys 77 | 'images': a float tensor with shape [batch_size, image_height, image_width, 3]. 78 | 'filenames': a string tensor with shape [batch_size]. 79 | labels: a dict with the following keys 80 | 'boxes': a float tensor with shape [batch_size, max_num_boxes, 4]. 81 | 'num_boxes': an int tensor with shape [batch_size]. 82 | where max_num_boxes = max(num_boxes). 83 | """ 84 | images, boxes, num_boxes, filenames = self.iterator.get_next() 85 | features = {'images': images, 'filenames': filenames} 86 | labels = {'boxes': boxes, 'num_boxes': num_boxes} 87 | return features, labels 88 | 89 | def _parse_and_preprocess(self, example_proto): 90 | """What this function does: 91 | 1. Parses one record from a tfrecords file and decodes it. 92 | 2. (optionally) Augments it. 93 | 94 | Returns: 95 | image: a float tensor with shape [image_height, image_width, 3], 96 | an RGB image with pixel values in the range [0, 1]. 97 | boxes: a float tensor with shape [num_boxes, 4]. 98 | num_boxes: an int tensor with shape []. 99 | filename: a string tensor with shape []. 100 | """ 101 | features = { 102 | 'filename': tf.FixedLenFeature([], tf.string), 103 | 'image': tf.FixedLenFeature([], tf.string), 104 | 'ymin': tf.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 105 | 'xmin': tf.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 106 | 'ymax': tf.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 107 | 'xmax': tf.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 108 | } 109 | parsed_features = tf.parse_single_example(example_proto, features) 110 | 111 | # get image 112 | image = tf.image.decode_jpeg(parsed_features['image'], channels=3) 113 | image = tf.image.convert_image_dtype(image, tf.float32) 114 | # now pixel values are scaled to [0, 1] range 115 | 116 | # get groundtruth boxes, they must be in from-zero-to-one format 117 | boxes = tf.stack([ 118 | parsed_features['ymin'], parsed_features['xmin'], 119 | parsed_features['ymax'], parsed_features['xmax'] 120 | ], axis=1) 121 | boxes = tf.to_float(boxes) 122 | # it is important to clip here! 123 | boxes = tf.clip_by_value(boxes, clip_value_min=0.0, clip_value_max=1.0) 124 | 125 | if self.augmentation: 126 | image, boxes = self._augmentation_fn(image, boxes) 127 | else: 128 | image = tf.image.resize_images( 129 | image, [self.image_height, self.image_width], 130 | method=RESIZE_METHOD 131 | ) if self.resize else image 132 | 133 | num_boxes = tf.to_int32(tf.shape(boxes)[0]) 134 | filename = parsed_features['filename'] 135 | return image, boxes, num_boxes, filename 136 | 137 | def _augmentation_fn(self, image, boxes): 138 | # there are a lot of hyperparameters here, 139 | # you will need to tune them all, haha 140 | 141 | image, boxes = random_image_crop( 142 | image, boxes, probability=0.9, 143 | min_object_covered=0.9, 144 | aspect_ratio_range=(0.93, 1.07), 145 | area_range=(0.4, 0.9), 146 | overlap_thresh=0.4 147 | ) 148 | image = tf.image.resize_images( 149 | image, [self.image_height, self.image_width], 150 | method=RESIZE_METHOD 151 | ) if self.resize else image 152 | # if you do color augmentations before resizing, it will be very slow! 153 | 154 | image = random_color_manipulations(image, probability=0.45, grayscale_probability=0.05) 155 | image = random_pixel_value_scale(image, minval=0.85, maxval=1.15, probability=0.2) 156 | boxes = random_jitter_boxes(boxes, ratio=0.01) 157 | image, boxes = random_flip_left_right(image, boxes) 158 | return image, boxes 159 | -------------------------------------------------------------------------------- /src/losses_and_ohem.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.utils import batch_decode 3 | 4 | 5 | """ 6 | Note that we have only one label (it is 'face'), 7 | so num_classes = 1. 8 | """ 9 | 10 | 11 | def localization_loss(predictions, targets, weights): 12 | """A usual L1 smooth loss. 13 | 14 | Arguments: 15 | predictions: a float tensor with shape [batch_size, num_anchors, 4], 16 | representing the (encoded) predicted locations of objects. 17 | targets: a float tensor with shape [batch_size, num_anchors, 4], 18 | representing the regression targets. 19 | weights: a float tensor with shape [batch_size, num_anchors]. 20 | Returns: 21 | a float tensor with shape [batch_size, num_anchors]. 22 | """ 23 | abs_diff = tf.abs(predictions - targets) 24 | abs_diff_lt_1 = tf.less(abs_diff, 1.0) 25 | return weights * tf.reduce_sum( 26 | tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5), axis=2 27 | ) 28 | 29 | 30 | def classification_loss(predictions, targets): 31 | """ 32 | Arguments: 33 | predictions: a float tensor with shape [batch_size, num_anchors, num_classes + 1], 34 | representing the predicted logits for each class. 35 | targets: an int tensor with shape [batch_size, num_anchors]. 36 | Returns: 37 | a float tensor with shape [batch_size, num_anchors]. 38 | """ 39 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 40 | labels=targets, logits=predictions 41 | ) 42 | return cross_entropy 43 | 44 | 45 | def apply_hard_mining( 46 | location_losses, cls_losses, 47 | class_predictions_with_background, 48 | matches, decoded_boxes, 49 | loss_to_use='classification', 50 | loc_loss_weight=1.0, cls_loss_weight=1.0, 51 | num_hard_examples=3000, nms_threshold=0.99, 52 | max_negatives_per_positive=3, min_negatives_per_image=0): 53 | """Applies hard mining to anchorwise losses. 54 | 55 | Arguments: 56 | location_losses: a float tensor with shape [batch_size, num_anchors]. 57 | cls_losses: a float tensor with shape [batch_size, num_anchors]. 58 | class_predictions_with_background: a float tensor with shape [batch_size, num_anchors, num_classes + 1]. 59 | matches: an int tensor with shape [batch_size, num_anchors]. 60 | decoded_boxes: a float tensor with shape [batch_size, num_anchors, 4]. 61 | loss_to_use: a string, only possible values are ['classification', 'both']. 62 | loc_loss_weight: a float number. 63 | cls_loss_weight: a float number. 64 | num_hard_examples: an integer. 65 | nms_threshold: a float number. 66 | max_negatives_per_positive: a float number. 67 | min_negatives_per_image: an integer. 68 | Returns: 69 | two float tensors with shape []. 70 | """ 71 | 72 | # when training it is important that 73 | # batch size is known 74 | batch_size, num_anchors = matches.shape.as_list() 75 | assert batch_size is not None 76 | decoded_boxes.set_shape([batch_size, num_anchors, 4]) 77 | location_losses.set_shape([batch_size, num_anchors]) 78 | cls_losses.set_shape([batch_size, num_anchors]) 79 | # all `set_shape` above are dirty tricks, 80 | # without them shape information is lost for some reason 81 | 82 | # all these tensors must have static first dimension (batch size) 83 | decoded_boxes_list = tf.unstack(decoded_boxes, axis=0) 84 | location_losses_list = tf.unstack(location_losses, axis=0) 85 | cls_losses_list = tf.unstack(cls_losses, axis=0) 86 | matches_list = tf.unstack(matches, axis=0) 87 | # they all lists with length = batch_size 88 | 89 | batch_size = len(decoded_boxes_list) 90 | num_positives_list, num_negatives_list = [], [] 91 | mined_location_losses, mined_cls_losses = [], [] 92 | 93 | # do OHEM for each image in the batch 94 | for i, box_locations in enumerate(decoded_boxes_list): 95 | image_losses = cls_losses_list[i] * cls_loss_weight 96 | if loss_to_use == 'both': 97 | image_losses += (location_losses_list[i] * loc_loss_weight) 98 | # it has shape [num_anchors] 99 | 100 | selected_indices = tf.image.non_max_suppression( 101 | box_locations, image_losses, num_hard_examples, nms_threshold 102 | ) 103 | selected_indices, num_positives, num_negatives = _subsample_selection_to_desired_neg_pos_ratio( 104 | selected_indices, matches_list[i], 105 | max_negatives_per_positive, min_negatives_per_image 106 | ) 107 | num_positives_list.append(num_positives) 108 | num_negatives_list.append(num_negatives) 109 | mined_location_losses.append( 110 | tf.reduce_sum(tf.gather(location_losses_list[i], selected_indices), axis=0) 111 | ) 112 | mined_cls_losses.append( 113 | tf.reduce_sum(tf.gather(cls_losses_list[i], selected_indices), axis=0) 114 | ) 115 | 116 | mean_num_positives = tf.reduce_mean(tf.stack(num_positives_list, axis=0), axis=0) 117 | mean_num_negatives = tf.reduce_mean(tf.stack(num_negatives_list, axis=0), axis=0) 118 | tf.summary.scalar('mean_num_positives', mean_num_positives) 119 | tf.summary.scalar('mean_num_negatives', mean_num_negatives) 120 | 121 | location_loss = tf.reduce_sum(tf.stack(mined_location_losses, axis=0), axis=0) 122 | cls_loss = tf.reduce_sum(tf.stack(mined_cls_losses, axis=0), axis=0) 123 | return location_loss, cls_loss 124 | 125 | 126 | def _subsample_selection_to_desired_neg_pos_ratio( 127 | indices, match, max_negatives_per_positive, min_negatives_per_image): 128 | """Subsample a collection of selected indices to a desired neg:pos ratio. 129 | 130 | Arguments: 131 | indices: an int or long tensor with shape [M], 132 | it represents a collection of selected anchor indices. 133 | match: an int tensor with shape [num_anchors]. 134 | max_negatives_per_positive: a float number, maximum number 135 | of negatives for each positive anchor. 136 | min_negatives_per_image: an integer, minimum number of negative anchors for a given 137 | image. Allows sampling negatives in image without any positive anchors. 138 | Returns: 139 | selected_indices: an int or long tensor with shape [M'] and with M' <= M. 140 | It represents a collection of selected anchor indices. 141 | num_positives: an int tensor with shape []. It represents the 142 | number of positive examples in selected set of indices. 143 | num_negatives: an int tensor with shape []. It represents the 144 | number of negative examples in selected set of indices. 145 | """ 146 | positives_indicator = tf.gather(tf.greater_equal(match, 0), indices) 147 | negatives_indicator = tf.logical_not(positives_indicator) 148 | # they have shape [num_hard_examples] 149 | 150 | # all positives in `indices` will be kept 151 | num_positives = tf.reduce_sum(tf.to_int32(positives_indicator), axis=0) 152 | max_negatives = tf.maximum( 153 | min_negatives_per_image, 154 | tf.to_int32(max_negatives_per_positive * tf.to_float(num_positives)) 155 | ) 156 | 157 | top_k_negatives_indicator = tf.less_equal( 158 | tf.cumsum(tf.to_int32(negatives_indicator), axis=0), 159 | max_negatives 160 | ) 161 | subsampled_selection_indices = tf.where( 162 | tf.logical_or(positives_indicator, top_k_negatives_indicator) 163 | ) # shape [num_hard_examples, 1] 164 | subsampled_selection_indices = tf.squeeze(subsampled_selection_indices, axis=1) 165 | selected_indices = tf.gather(indices, subsampled_selection_indices) 166 | 167 | num_negatives = tf.size(subsampled_selection_indices) - num_positives 168 | return selected_indices, num_positives, num_negatives 169 | -------------------------------------------------------------------------------- /src/anchor_generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | from src.utils.box_utils import to_minmax_coordinates 4 | 5 | 6 | ANCHOR_SPECIFICATIONS = [ 7 | [(32, 1.0, 4), (64, 1.0, 2), (128, 1.0, 1)], # scale 0 8 | [(256, 1.0, 1)], # scale 1 9 | [(512, 1.0, 1)], # scale 2 10 | ] 11 | # every tuple represents (box scale, aspect ratio, densification parameter) 12 | 13 | """ 14 | Notes: 15 | 1. Assume that size(image) = (image_width, image_height), 16 | then by definition 17 | image_aspect_ratio := image_width/image_height 18 | 19 | 2. All anchor boxes are in normalized coordinates (in [0, 1] range). 20 | So, for each box: 21 | (width * image_width)/(height * image_height) = aspect_ratio 22 | 23 | 3. Scale of an anchor box is defined like this: 24 | scale := sqrt(height * image_height * width * image_width) 25 | 26 | 4. Total number of anchor boxes depends on image size. 27 | 28 | 5. `scale` and `aspect_ratio` are independent of image size. 29 | `width` and `height` depend on image size. 30 | 31 | 6. If we change image size then normalized coordinates of 32 | the anchor boxes will change. 33 | """ 34 | 35 | 36 | class AnchorGenerator: 37 | def __init__(self): 38 | self.box_specs_list = ANCHOR_SPECIFICATIONS 39 | 40 | def __call__(self, image_features, image_size): 41 | """ 42 | Arguments: 43 | image_features: a list of float tensors where the ith tensor 44 | has shape [batch, height_i, width_i, channels_i]. 45 | image_size: a tuple of integers (int tensors with shape []) (width, height). 46 | Returns: 47 | a float tensor with shape [num_anchor, 4], 48 | boxes with normalized coordinates (and clipped to the unit square). 49 | """ 50 | feature_map_shape_list = [] 51 | for feature_map in image_features: 52 | 53 | height_i, width_i = feature_map.shape.as_list()[1:3] 54 | if height_i is None or width_i is None: 55 | height_i, width_i = tf.shape(feature_map)[1], tf.shape(feature_map)[2] 56 | 57 | feature_map_shape_list.append((height_i, width_i)) 58 | image_width, image_height = image_size 59 | 60 | # number of anchors per cell in a grid 61 | self.num_anchors_per_location = [ 62 | sum(n*n for _, _, n in layer_box_specs) 63 | for layer_box_specs in self.box_specs_list 64 | ] 65 | 66 | with tf.name_scope('anchor_generator'): 67 | anchor_grid_list, num_anchors_per_feature_map = [], [] 68 | for grid_size, box_spec in zip(feature_map_shape_list, self.box_specs_list): 69 | 70 | h, w = grid_size 71 | stride = (1.0/tf.to_float(h), 1.0/tf.to_float(w)) 72 | offset = (0.5/tf.to_float(h), 0.5/tf.to_float(w)) 73 | 74 | local_anchors = [] 75 | for scale, aspect_ratio, n in box_spec: 76 | local_anchors.append(tile_anchors( 77 | image_size=(image_width, image_height), 78 | grid_height=h, grid_width=w, scale=scale, 79 | aspect_ratio=aspect_ratio, anchor_stride=stride, 80 | anchor_offset=offset, n=n 81 | )) 82 | 83 | # reshaping in the right order is important 84 | local_anchors = tf.concat(local_anchors, axis=2) 85 | local_anchors = tf.reshape(local_anchors, [-1, 4]) 86 | anchor_grid_list.append(local_anchors) 87 | 88 | num_anchors_per_feature_map.append(h * w * sum(n*n for _, _, n in box_spec)) 89 | 90 | # constant tensors, anchors for each feature map 91 | self.anchor_grid_list = anchor_grid_list 92 | self.num_anchors_per_feature_map = num_anchors_per_feature_map 93 | 94 | with tf.name_scope('concatenate'): 95 | anchors = tf.concat(anchor_grid_list, axis=0) 96 | ymin, xmin, ymax, xmax = to_minmax_coordinates(tf.unstack(anchors, axis=1)) 97 | anchors = tf.stack([ymin, xmin, ymax, xmax], axis=1) 98 | anchors = tf.clip_by_value(anchors, 0.0, 1.0) 99 | return anchors 100 | 101 | 102 | def tile_anchors( 103 | image_size, grid_height, grid_width, 104 | scale, aspect_ratio, anchor_stride, anchor_offset, n): 105 | """ 106 | Arguments: 107 | image_size: a tuple of integers (width, height). 108 | grid_height: an integer, size of the grid in the y direction. 109 | grid_width: an integer, size of the grid in the x direction. 110 | scale: a float number. 111 | aspect_ratio: a float number. 112 | anchor_stride: a tuple of float numbers, difference in centers between 113 | anchors for adjacent grid positions. 114 | anchor_offset: a tuple of float numbers, 115 | center of the anchor on upper left element of the grid ((0, 0)-th anchor). 116 | n: an integer, densification parameter. 117 | Returns: 118 | a float tensor with shape [grid_height, grid_width, n*n, 4]. 119 | """ 120 | ratio_sqrt = tf.sqrt(aspect_ratio) 121 | unnormalized_height = scale / ratio_sqrt 122 | unnormalized_width = scale * ratio_sqrt 123 | 124 | # to [0, 1] range 125 | image_width, image_height = image_size 126 | height = unnormalized_height/tf.to_float(image_height) 127 | width = unnormalized_width/tf.to_float(image_width) 128 | # (sometimes it could be outside the range, but we clip it) 129 | 130 | boxes = generate_anchors_at_upper_left_corner(height, width, anchor_offset, n) 131 | # shape [n*n, 4] 132 | 133 | y_translation = tf.to_float(tf.range(grid_height)) * anchor_stride[0] 134 | x_translation = tf.to_float(tf.range(grid_width)) * anchor_stride[1] 135 | x_translation, y_translation = tf.meshgrid(x_translation, y_translation) 136 | # they have shape [grid_height, grid_width] 137 | 138 | center_translations = tf.stack([y_translation, x_translation], axis=2) 139 | translations = tf.pad(center_translations, [[0, 0], [0, 0], [0, 2]]) 140 | translations = tf.expand_dims(translations, 2) 141 | translations = tf.tile(translations, [1, 1, n*n, 1]) 142 | # shape [grid_height, grid_width, n*n, 4] 143 | 144 | boxes = tf.reshape(boxes, [1, 1, n*n, 4]) 145 | boxes = boxes + translations # shape [grid_height, grid_width, n*n, 4] 146 | return boxes 147 | 148 | 149 | def generate_anchors_at_upper_left_corner(height, width, anchor_offset, n): 150 | """Generate densified anchor boxes at (0, 0) grid position.""" 151 | 152 | # a usual center, if n = 1 it will be returned 153 | cy, cx = anchor_offset[0], anchor_offset[1] 154 | 155 | # a usual left upper corner 156 | ymin, xmin = cy - 0.5*height, cx - 0.5*width 157 | 158 | # now i shift the usual center a little (densification) 159 | sy, sx = height/n, width/n 160 | 161 | center_ids = tf.to_float(tf.range(n)) 162 | # shape [n] 163 | 164 | # shifted centers 165 | new_centers_y = ymin + 0.5*sy + sy*center_ids 166 | new_centers_x = xmin + 0.5*sx + sx*center_ids 167 | # they have shape [n] 168 | 169 | # now i must get all pairs of y, x coordinates 170 | new_centers_y = tf.expand_dims(new_centers_y, 0) # shape [1, n] 171 | new_centers_x = tf.expand_dims(new_centers_x, 1) # shape [n, 1] 172 | 173 | new_centers_y = tf.tile(new_centers_y, [n, 1]) 174 | new_centers_x = tf.tile(new_centers_x, [1, n]) 175 | # they have shape [n, n] 176 | 177 | centers = tf.stack([new_centers_y, new_centers_x], axis=2) 178 | # shape [n, n, 2] 179 | 180 | sizes = tf.stack([height, width], axis=0) # shape [2] 181 | sizes = tf.expand_dims(sizes, 0) 182 | sizes = tf.expand_dims(sizes, 0) # shape [1, 1, 2] 183 | sizes = tf.tile(sizes, [n, n, 1]) 184 | 185 | boxes = tf.stack([centers, sizes], axis=2) 186 | boxes = tf.reshape(boxes, [-1, 4]) 187 | return boxes 188 | -------------------------------------------------------------------------------- /src/input_pipeline/random_image_crop.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from src.utils import area, intersection 3 | 4 | 5 | def random_image_crop( 6 | image, boxes, probability=0.5, 7 | min_object_covered=0.9, 8 | aspect_ratio_range=(0.75, 1.33), 9 | area_range=(0.5, 1.0), 10 | overlap_thresh=0.3): 11 | 12 | def crop(image, boxes): 13 | image, boxes, _ = _random_crop_image( 14 | image, boxes, min_object_covered, 15 | aspect_ratio_range, 16 | area_range, overlap_thresh 17 | ) 18 | return image, boxes 19 | 20 | do_it = tf.less(tf.random_uniform([]), probability) 21 | image, boxes = tf.cond( 22 | do_it, 23 | lambda: crop(image, boxes), 24 | lambda: (image, boxes) 25 | ) 26 | return image, boxes 27 | 28 | 29 | def _random_crop_image( 30 | image, boxes, min_object_covered=0.9, 31 | aspect_ratio_range=(0.75, 1.33), area_range=(0.5, 1.0), 32 | overlap_thresh=0.3): 33 | """Performs random crop. Given the input image and its bounding boxes, 34 | this op randomly crops a subimage. Given a user-provided set of input constraints, 35 | the crop window is resampled until it satisfies these constraints. 36 | If within 100 trials it is unable to find a valid crop, the original 37 | image is returned. Both input boxes and returned boxes are in normalized 38 | form (e.g., lie in the unit square [0, 1]). 39 | 40 | Arguments: 41 | image: a float tensor with shape [height, width, 3], 42 | with pixel values varying between [0, 1]. 43 | boxes: a float tensor containing bounding boxes. It has shape 44 | [num_boxes, 4]. Boxes are in normalized form, meaning 45 | their coordinates vary between [0, 1]. 46 | Each row is in the form of [ymin, xmin, ymax, xmax]. 47 | min_object_covered: the cropped image must cover at least this fraction of 48 | at least one of the input bounding boxes. 49 | aspect_ratio_range: allowed range for aspect ratio of cropped image. 50 | area_range: allowed range for area ratio between cropped image and the 51 | original image. 52 | overlap_thresh: minimum overlap thresh with new cropped 53 | image to keep the box. 54 | Returns: 55 | image: cropped image. 56 | boxes: remaining boxes. 57 | keep_ids: indices of remaining boxes in input boxes tensor. 58 | They are used to get a slice from the 'labels' tensor (if you have one). 59 | len(keep_ids) = len(boxes). 60 | """ 61 | with tf.name_scope('random_crop_image'): 62 | 63 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 64 | tf.shape(image), 65 | bounding_boxes=tf.expand_dims(boxes, 0), 66 | min_object_covered=min_object_covered, 67 | aspect_ratio_range=aspect_ratio_range, 68 | area_range=area_range, 69 | max_attempts=100, 70 | use_image_if_no_bounding_boxes=True 71 | ) 72 | begin, size, window = sample_distorted_bounding_box 73 | image = tf.slice(image, begin, size) 74 | window = tf.squeeze(window, axis=[0, 1]) 75 | 76 | # remove boxes that are completely outside cropped image 77 | boxes, inside_window_ids = _prune_completely_outside_window( 78 | boxes, window 79 | ) 80 | 81 | # remove boxes that are two much outside image 82 | boxes, keep_ids = _prune_non_overlapping_boxes( 83 | boxes, tf.expand_dims(window, 0), overlap_thresh 84 | ) 85 | 86 | # change coordinates of the remaining boxes 87 | boxes = _change_coordinate_frame(boxes, window) 88 | 89 | keep_ids = tf.gather(inside_window_ids, keep_ids) 90 | return image, boxes, keep_ids 91 | 92 | 93 | def _prune_completely_outside_window(boxes, window): 94 | """Prunes bounding boxes that fall completely outside of the given window. 95 | This function does not clip partially overflowing boxes. 96 | 97 | Arguments: 98 | boxes: a float tensor with shape [M_in, 4]. 99 | window: a float tensor with shape [4] representing [ymin, xmin, ymax, xmax] 100 | of the window. 101 | Returns: 102 | boxes: a float tensor with shape [M_out, 4] where 0 <= M_out <= M_in. 103 | valid_indices: a long tensor with shape [M_out] indexing the valid bounding boxes 104 | in the input 'boxes' tensor. 105 | """ 106 | with tf.name_scope('prune_completely_outside_window'): 107 | 108 | y_min, x_min, y_max, x_max = tf.split(boxes, num_or_size_splits=4, axis=1) 109 | # they have shape [None, 1] 110 | win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window) 111 | # they have shape [] 112 | 113 | coordinate_violations = tf.concat([ 114 | tf.greater_equal(y_min, win_y_max), tf.greater_equal(x_min, win_x_max), 115 | tf.less_equal(y_max, win_y_min), tf.less_equal(x_max, win_x_min) 116 | ], axis=1) 117 | valid_indices = tf.squeeze( 118 | tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), 119 | axis=1 120 | ) 121 | boxes = tf.gather(boxes, valid_indices) 122 | return boxes, valid_indices 123 | 124 | 125 | def _prune_non_overlapping_boxes(boxes1, boxes2, min_overlap=0.0): 126 | """Prunes the boxes in boxes1 that overlap less than thresh with boxes2. 127 | For each box in boxes1, we want its IOA to be more than min_overlap with 128 | at least one of the boxes in boxes2. If it does not, we remove it. 129 | 130 | Arguments: 131 | boxes1: a float tensor with shape [N, 4]. 132 | boxes2: a float tensor with shape [M, 4]. 133 | min_overlap: minimum required overlap between boxes, 134 | to count them as overlapping. 135 | Returns: 136 | boxes: a float tensor with shape [N', 4]. 137 | keep_inds: a long tensor with shape [N'] indexing kept bounding boxes in the 138 | first input tensor ('boxes1'). 139 | """ 140 | with tf.name_scope('prune_non_overlapping_boxes'): 141 | ioa = _ioa(boxes2, boxes1) # [M, N] tensor 142 | ioa = tf.reduce_max(ioa, axis=0) # [N] tensor 143 | keep_bool = tf.greater_equal(ioa, tf.constant(min_overlap)) 144 | keep_inds = tf.squeeze(tf.where(keep_bool), axis=1) 145 | boxes = tf.gather(boxes1, keep_inds) 146 | return boxes, keep_inds 147 | 148 | 149 | def _change_coordinate_frame(boxes, window): 150 | """Change coordinate frame of the boxes to be relative to window's frame. 151 | 152 | Arguments: 153 | boxes: a float tensor with shape [N, 4]. 154 | window: a float tensor with shape [4]. 155 | Returns: 156 | a float tensor with shape [N, 4]. 157 | """ 158 | with tf.name_scope('change_coordinate_frame'): 159 | 160 | ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=1) 161 | ymin -= window[0] 162 | xmin -= window[1] 163 | ymax -= window[0] 164 | xmax -= window[1] 165 | 166 | win_height = window[2] - window[0] 167 | win_width = window[3] - window[1] 168 | boxes = tf.stack([ 169 | ymin/win_height, xmin/win_width, 170 | ymax/win_height, xmax/win_width 171 | ], axis=1) 172 | boxes = tf.clip_by_value(boxes, clip_value_min=0.0, clip_value_max=1.0) 173 | return boxes 174 | 175 | 176 | def _ioa(boxes1, boxes2): 177 | """Computes pairwise intersection-over-area between box collections. 178 | intersection-over-area (IOA) between two boxes box1 and box2 is defined as 179 | their intersection area over box2's area. Note that ioa is not symmetric, 180 | that is, ioa(box1, box2) != ioa(box2, box1). 181 | 182 | Arguments: 183 | boxes1: a float tensor with shape [N, 4]. 184 | boxes2: a float tensor with shape [M, 4]. 185 | Returns: 186 | a float tensor with shape [N, M] representing pairwise ioa scores. 187 | """ 188 | with tf.name_scope('ioa'): 189 | intersections = intersection(boxes1, boxes2) # shape [N, M] 190 | areas = tf.expand_dims(area(boxes2), 0) # shape [1, M] 191 | return tf.divide(intersections, areas) 192 | -------------------------------------------------------------------------------- /prepare_data/explore_and_prepare_WIDER.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import json\n", 11 | "from PIL import Image, ImageDraw\n", 12 | "import os\n", 13 | "import cv2\n", 14 | "import pandas as pd\n", 15 | "from tqdm import tqdm\n", 16 | "import shutil\n", 17 | "import random" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "The purpose of this script is to explore images/annotations of the WIDER dataset. \n", 25 | "Also it converts annotations into json format." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "# first run this script for this images:\n", 35 | "IMAGES_DIR = '/home/gpu2/hdd/dan/WIDER/WIDER_train/images/'\n", 36 | "BOXES_PATH = '/home/gpu2/hdd/dan/WIDER/wider_face_split/wider_face_train_bbx_gt.txt'\n", 37 | "RESULT_DIR = '/home/gpu2/hdd/dan/WIDER/train/'\n", 38 | "\n", 39 | "# then run for this images:\n", 40 | "# IMAGES_DIR = '/home/gpu2/hdd/dan/WIDER/WIDER_val/images/'\n", 41 | "# BOXES_PATH = '/home/gpu2/hdd/dan/WIDER/wider_face_split/wider_face_val_bbx_gt.txt'\n", 42 | "# RESULT_DIR = '/home/gpu2/hdd/dan/WIDER/train_part2/'" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "# Read data" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# collect paths to all images\n", 59 | "\n", 60 | "all_paths = []\n", 61 | "for path, subdirs, files in tqdm(os.walk(IMAGES_DIR)):\n", 62 | " for name in files:\n", 63 | " all_paths.append(os.path.join(path, name))\n", 64 | " \n", 65 | "metadata = pd.DataFrame(all_paths, columns=['full_path'])\n", 66 | "\n", 67 | "# strip root folder\n", 68 | "metadata['path'] = metadata.full_path.apply(lambda x: os.path.relpath(x, IMAGES_DIR))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "# see all unique endings\n", 78 | "metadata.path.apply(lambda x: x.split('.')[-1]).unique()" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# number of images\n", 88 | "len(metadata)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# read annotations\n", 98 | "with open(BOXES_PATH, 'r') as f:\n", 99 | " content = f.readlines()\n", 100 | " content = [s.strip() for s in content]" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# split annotations by image\n", 110 | "boxes = {}\n", 111 | "num_lines = len(content)\n", 112 | "i = 0\n", 113 | "name = None\n", 114 | "\n", 115 | "while i < num_lines:\n", 116 | " s = content[i]\n", 117 | " if s.endswith('.jpg'):\n", 118 | " if name is not None:\n", 119 | " assert len(boxes[name]) == num_boxes\n", 120 | " name = s\n", 121 | " boxes[name] = []\n", 122 | " i += 1\n", 123 | " num_boxes = int(content[i])\n", 124 | " i += 1\n", 125 | " else:\n", 126 | " xmin, ymin, w, h = s.split(' ')[:4]\n", 127 | " xmin, ymin, w, h = int(xmin), int(ymin), int(w), int(h)\n", 128 | " if h <= 0 or w <= 0:\n", 129 | " print(name) \n", 130 | " # some boxes are weird!\n", 131 | " # so i don't use them\n", 132 | " num_boxes -= 1\n", 133 | " else:\n", 134 | " boxes[name].append((xmin, ymin, w, h))\n", 135 | " i += 1" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "# check that all images have bounding boxes\n", 145 | "assert metadata.path.apply(lambda x: x in boxes).all()" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "# Show some bounding boxes" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "def draw_boxes_on_image(path, boxes):\n", 162 | "\n", 163 | " image = Image.open(path)\n", 164 | " draw = ImageDraw.Draw(image, 'RGBA')\n", 165 | " width, height = image.size\n", 166 | "\n", 167 | " for b in boxes:\n", 168 | " xmin, ymin, w, h = b\n", 169 | " xmax, ymax = xmin + w, ymin + h\n", 170 | "\n", 171 | " fill = (255, 255, 255, 45)\n", 172 | " outline = 'red'\n", 173 | " draw.rectangle(\n", 174 | " [(xmin, ymin), (xmax, ymax)],\n", 175 | " fill=fill, outline=outline\n", 176 | " )\n", 177 | " return image" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "scrolled": false 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "i = random.randint(0, len(metadata) - 1) # choose a random image\n", 189 | "some_boxes = boxes[metadata.path[i]]\n", 190 | "draw_boxes_on_image(metadata.full_path[i], some_boxes)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "markdown", 195 | "metadata": {}, 196 | "source": [ 197 | "# Convert" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "metadata": {}, 204 | "outputs": [], 205 | "source": [ 206 | "def get_annotation(path, width, height):\n", 207 | " name = path.split('/')[-1]\n", 208 | " annotation = {\n", 209 | " \"filename\": name,\n", 210 | " \"size\": {\"depth\": 3, \"width\": width, \"height\": height}\n", 211 | " }\n", 212 | " objects = []\n", 213 | " for b in boxes[path]:\n", 214 | " xmin, ymin, w, h = b\n", 215 | " xmax, ymax = xmin + w, ymin + h\n", 216 | " objects.append({\n", 217 | " \"bndbox\": {\"ymin\": ymin, \"ymax\": ymax, \"xmax\": xmax, \"xmin\": xmin}, \n", 218 | " \"name\": \"face\"\n", 219 | " })\n", 220 | " annotation[\"object\"] = objects\n", 221 | " return annotation" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "# create a folder for the converted dataset\n", 231 | "shutil.rmtree(RESULT_DIR, ignore_errors=True)\n", 232 | "os.mkdir(RESULT_DIR)\n", 233 | "os.mkdir(os.path.join(RESULT_DIR, 'images'))\n", 234 | "os.mkdir(os.path.join(RESULT_DIR, 'annotations'))" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "for T in tqdm(metadata.itertuples()):\n", 244 | " \n", 245 | " # get width and height of an image\n", 246 | " image = cv2.imread(T.full_path)\n", 247 | " h, w, c = image.shape\n", 248 | " assert c == 3\n", 249 | " \n", 250 | " # name of the image\n", 251 | " name = T.path.split('/')[-1]\n", 252 | " assert name.endswith('.jpg')\n", 253 | "\n", 254 | " # copy the image\n", 255 | " shutil.copy(T.full_path, os.path.join(RESULT_DIR, 'images', name))\n", 256 | " \n", 257 | " # save annotation for it\n", 258 | " d = get_annotation(T.path, w, h)\n", 259 | " json_name = name[:-4] + '.json'\n", 260 | " json.dump(d, open(os.path.join(RESULT_DIR, 'annotations', json_name), 'w')) " 261 | ] 262 | } 263 | ], 264 | "metadata": { 265 | "kernelspec": { 266 | "display_name": "Python 3", 267 | "language": "python", 268 | "name": "python3" 269 | }, 270 | "language_info": { 271 | "codemirror_mode": { 272 | "name": "ipython", 273 | "version": 3 274 | }, 275 | "file_extension": ".py", 276 | "mimetype": "text/x-python", 277 | "name": "python", 278 | "nbconvert_exporter": "python", 279 | "pygments_lexer": "ipython3", 280 | "version": "3.6.3" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 1 285 | } 286 | -------------------------------------------------------------------------------- /evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | """ 6 | For evaluation during the training i use average precision @ iou=0.5 7 | like in PASCAL VOC Challenge (detection task): 8 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/devkit_doc.pdf 9 | """ 10 | 11 | 12 | class Box: 13 | def __init__(self, image, box, score=None): 14 | """ 15 | Arguments: 16 | image: a string, identifier of a image. 17 | box: a numpy float array with shape [4]. 18 | score: a float number or None. 19 | """ 20 | self.image = image 21 | self.confidence = score 22 | self.is_matched = False 23 | 24 | # top left corner 25 | self.ymin = box[0] 26 | self.xmin = box[1] 27 | 28 | # bottom right corner 29 | self.ymax = box[2] 30 | self.xmax = box[3] 31 | 32 | 33 | class Evaluator: 34 | def __init__(self): 35 | self._initialize() 36 | 37 | def evaluate(self, iou_threshold=0.5): 38 | self.metrics = evaluate_detector( 39 | self.groundtruth_by_image, 40 | self.detections, iou_threshold 41 | ) 42 | 43 | def clear(self): 44 | self._initialize() 45 | 46 | def get_metric_ops(self, image_name, groundtruth, predictions): 47 | """ 48 | Arguments: 49 | image_name: a string tensor with shape [1]. 50 | groundtruth: a dict with the following keys 51 | 'boxes': a float tensor with shape [1, max_num_boxes, 4]. 52 | 'num_boxes': an int tensor with shape [1]. 53 | predictions: a dict with the following keys 54 | 'boxes': a float tensor with shape [1, max_num_boxes, 4]. 55 | 'scores': a float tensor with shape [1, max_num_boxes]. 56 | 'num_boxes': an int tensor with shape [1]. 57 | """ 58 | 59 | def update_op_func(image_name, gt_boxes, gt_num_boxes, boxes, scores, num_boxes): 60 | self.add_groundtruth(image_name, gt_boxes, gt_num_boxes) 61 | self.add_detections(image_name, boxes, scores, num_boxes) 62 | 63 | tensors = [ 64 | image_name[0], groundtruth['boxes'][0], groundtruth['num_boxes'][0], 65 | predictions['boxes'][0], predictions['scores'][0], predictions['num_boxes'][0] 66 | ] 67 | update_op = tf.py_func(update_op_func, tensors, [], stateful=True) 68 | 69 | def evaluate_func(): 70 | self.evaluate() 71 | self.clear() 72 | evaluate_op = tf.py_func(evaluate_func, [], []) 73 | 74 | def get_value_func(measure): 75 | def value_func(): 76 | return np.float32(self.metrics[measure]) 77 | return value_func 78 | 79 | with tf.control_dependencies([evaluate_op]): 80 | 81 | metric_names = ['AP', 'precision', 'recall', 'mean_iou', 'threshold', 'FP', 'FN'] 82 | eval_metric_ops = { 83 | 'metrics/' + measure: 84 | (tf.py_func(get_value_func(measure), [], tf.float32), update_op) 85 | for measure in metric_names 86 | } 87 | 88 | return eval_metric_ops 89 | 90 | def _initialize(self): 91 | self.detections = [] 92 | self.groundtruth_by_image = {} 93 | 94 | def add_detections(self, image_name, boxes, scores, num_boxes): 95 | """ 96 | Arguments: 97 | images: a numpy string array with shape []. 98 | boxes: a numpy float array with shape [N, 4]. 99 | scores: a numpy float array with shape [N]. 100 | num_boxes: a numpy int array with shape []. 101 | """ 102 | boxes, scores = boxes[:num_boxes], scores[:num_boxes] 103 | for box, score in zip(boxes, scores): 104 | self.detections.append(Box(image_name, box, score)) 105 | 106 | def add_groundtruth(self, image_name, boxes, num_boxes): 107 | for box in boxes[:num_boxes]: 108 | if image_name in self.groundtruth_by_image: 109 | self.groundtruth_by_image[image_name] += [Box(image_name, box)] 110 | else: 111 | self.groundtruth_by_image[image_name] = [Box(image_name, box)] 112 | 113 | 114 | def evaluate_detector(groundtruth_by_img, all_detections, iou_threshold=0.5): 115 | """ 116 | Arguments: 117 | groundtruth_by_img: a dict of lists with boxes, 118 | image -> list of groundtruth boxes on the image. 119 | all_detections: a list of boxes. 120 | iou_threshold: a float number. 121 | Returns: 122 | a dict with seven values. 123 | """ 124 | 125 | # each ground truth box is either TP or FN 126 | n_groundtruth_boxes = 0 127 | 128 | for boxes in groundtruth_by_img.values(): 129 | n_groundtruth_boxes += len(boxes) 130 | n_groundtruth_boxes = max(n_groundtruth_boxes, 1) 131 | 132 | # sort by confidence in decreasing order 133 | all_detections.sort(key=lambda box: box.confidence, reverse=True) 134 | 135 | n_correct_detections = 0 136 | n_detections = 0 137 | mean_iou = 0.0 138 | precision = [0.0]*len(all_detections) 139 | recall = [0.0]*len(all_detections) 140 | confidences = [box.confidence for box in all_detections] 141 | 142 | for k, detection in enumerate(all_detections): 143 | 144 | # each detection is either TP or FP 145 | n_detections += 1 146 | 147 | if detection.image in groundtruth_by_img: 148 | groundtruth_boxes = groundtruth_by_img[detection.image] 149 | else: 150 | groundtruth_boxes = [] 151 | 152 | best_groundtruth_i, max_iou = match(detection, groundtruth_boxes) 153 | mean_iou += max_iou 154 | 155 | if best_groundtruth_i >= 0 and max_iou >= iou_threshold: 156 | box = groundtruth_boxes[best_groundtruth_i] 157 | if not box.is_matched: 158 | box.is_matched = True 159 | n_correct_detections += 1 # increase number of TP 160 | 161 | precision[k] = float(n_correct_detections)/float(n_detections) # TP/(TP + FP) 162 | recall[k] = float(n_correct_detections)/float(n_groundtruth_boxes) # TP/(TP + FN) 163 | 164 | ap = compute_ap(precision, recall) 165 | best_threshold, best_precision, best_recall = compute_best_threshold( 166 | precision, recall, confidences 167 | ) 168 | mean_iou /= max(n_detections, 1) 169 | return { 170 | 'AP': ap, 'precision': best_precision, 171 | 'recall': best_recall, 'threshold': best_threshold, 172 | 'mean_iou': mean_iou, 'FP': n_detections - n_correct_detections, 173 | 'FN': n_groundtruth_boxes - n_correct_detections 174 | } 175 | 176 | 177 | def compute_best_threshold(precision, recall, confidences): 178 | """ 179 | Arguments: 180 | precision, recall, confidences: lists of floats of the same length. 181 | 182 | Returns: 183 | 1. a float number, best confidence threshold. 184 | 2. a float number, precision at the threshold. 185 | 3. a float number, recall at the threshold. 186 | """ 187 | if len(confidences) == 0: 188 | return 0.0, 0.0, 0.0 189 | 190 | precision = np.asarray(precision) 191 | recall = np.asarray(recall) 192 | confidences = np.asarray(confidences) 193 | 194 | diff = np.abs(precision - recall) 195 | prod = precision*recall 196 | best_i = np.argmax(prod*(1.0 - diff)) 197 | best_threshold = confidences[best_i] 198 | 199 | return best_threshold, precision[best_i], recall[best_i] 200 | 201 | 202 | def compute_iou(box1, box2): 203 | w = min(box1.xmax, box2.xmax) - max(box1.xmin, box2.xmin) 204 | if w > 0: 205 | h = min(box1.ymax, box2.ymax) - max(box1.ymin, box2.ymin) 206 | if h > 0: 207 | intersection = w*h 208 | w1 = box1.xmax - box1.xmin 209 | h1 = box1.ymax - box1.ymin 210 | w2 = box2.xmax - box2.xmin 211 | h2 = box2.ymax - box2.ymin 212 | union = (w1*h1 + w2*h2) - intersection 213 | return float(intersection)/float(union) 214 | return 0.0 215 | 216 | 217 | def match(detection, groundtruth_boxes): 218 | """ 219 | Arguments: 220 | detection: a box. 221 | groundtruth_boxes: a list of boxes. 222 | Returns: 223 | best_i: an integer, index of the best groundtruth box. 224 | max_iou: a float number. 225 | """ 226 | best_i = -1 227 | max_iou = 0.0 228 | for i, box in enumerate(groundtruth_boxes): 229 | iou = compute_iou(detection, box) 230 | if iou > max_iou: 231 | best_i = i 232 | max_iou = iou 233 | return best_i, max_iou 234 | 235 | 236 | def compute_ap(precision, recall): 237 | previous_recall_value = 0.0 238 | ap = 0.0 239 | # recall is in increasing order 240 | for p, r in zip(precision, recall): 241 | delta = r - previous_recall_value 242 | ap += p*delta 243 | previous_recall_value = r 244 | return ap 245 | -------------------------------------------------------------------------------- /prepare_data/explore_and_convert_FDDB.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import json\n", 11 | "from PIL import Image, ImageDraw\n", 12 | "import os\n", 13 | "import cv2\n", 14 | "import pandas as pd\n", 15 | "from tqdm import tqdm\n", 16 | "import shutil\n", 17 | "import random" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "The purpose of this script is to explore images/annotations of the FDDB dataset. \n", 25 | "Also it converts face ellipses into face bounding boxes. \n", 26 | "Also it converts annotations into json format." 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "IMAGES_DIR = '/home/gpu2/hdd/dan/FDDB/originalPics/'\n", 36 | "BOXES_DIR = '/home/gpu2/hdd/dan/FDDB/FDDB-folds/'\n", 37 | "RESULT_DIR = '/home/gpu2/hdd/dan/FDDB/val/'" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "# Read data" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "# collect paths to all images\n", 54 | "\n", 55 | "all_paths = []\n", 56 | "for path, subdirs, files in tqdm(os.walk(IMAGES_DIR)):\n", 57 | " for name in files:\n", 58 | " all_paths.append(os.path.join(path, name))\n", 59 | " \n", 60 | "metadata = pd.DataFrame(all_paths, columns=['full_path'])\n", 61 | "\n", 62 | "# strip root folder\n", 63 | "metadata['path'] = metadata.full_path.apply(lambda x: os.path.relpath(x, IMAGES_DIR))" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# all unique endings\n", 73 | "metadata.path.apply(lambda x: x.split('.')[-1]).unique()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "# number of images\n", 83 | "len(metadata)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "annotation_files = os.listdir(BOXES_DIR)\n", 93 | "annotation_files = [f for f in annotation_files if f.endswith('ellipseList.txt')]\n", 94 | "annotation_files = [os.path.join(BOXES_DIR, f) for f in annotation_files]" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "def ellipse_to_box(major_axis_radius, minor_axis_radius, angle, center_x, center_y):\n", 104 | " half_h = major_axis_radius * np.sin(-angle)\n", 105 | " half_w = minor_axis_radius * np.sin(-angle)\n", 106 | " xmin, xmax = center_x - half_w, center_x + half_w\n", 107 | " ymin, ymax = center_y - half_h, center_y + half_h\n", 108 | " return xmin, ymin, xmax, ymax\n", 109 | "\n", 110 | "\n", 111 | "def get_boxes(path):\n", 112 | " \n", 113 | " with open(path, 'r') as f:\n", 114 | " content = f.readlines()\n", 115 | " content = [s.strip() for s in content]\n", 116 | "\n", 117 | " boxes = {}\n", 118 | " num_lines = len(content)\n", 119 | " i = 0\n", 120 | " name = None\n", 121 | "\n", 122 | " while i < num_lines:\n", 123 | " s = content[i]\n", 124 | " if 'big/img' in s:\n", 125 | " if name is not None:\n", 126 | " assert len(boxes[name]) == num_boxes\n", 127 | " name = s + '.jpg'\n", 128 | " boxes[name] = []\n", 129 | " i += 1\n", 130 | " num_boxes = int(content[i])\n", 131 | " i += 1\n", 132 | " else:\n", 133 | " numbers = [float(f) for f in s.split(' ')[:5]]\n", 134 | " major_axis_radius, minor_axis_radius, angle, center_x, center_y = numbers\n", 135 | "\n", 136 | " xmin, ymin, xmax, ymax = ellipse_to_box(\n", 137 | " major_axis_radius, minor_axis_radius, \n", 138 | " angle, center_x, center_y\n", 139 | " )\n", 140 | " if xmin == xmax or ymin == ymax:\n", 141 | " num_boxes -= 1\n", 142 | " else:\n", 143 | " boxes[name].append((\n", 144 | " min(xmin, xmax), min(ymin, ymax), \n", 145 | " max(xmin, xmax), max(ymin, ymax)\n", 146 | " ))\n", 147 | " i += 1\n", 148 | " return boxes" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "boxes = {}\n", 158 | "for p in annotation_files:\n", 159 | " boxes.update(get_boxes(p))" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "# check number of images with annotations\n", 169 | "# and number of boxes\n", 170 | "# (these values are taken from the official website) \n", 171 | "assert len(boxes) == 2845\n", 172 | "assert sum(len(b) for b in boxes.values()) == 5171 - 1 # one box is empty" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "metadata = metadata.loc[metadata.path.apply(lambda x: x in boxes)]\n", 182 | "metadata = metadata.reset_index(drop=True)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "metadata": {}, 188 | "source": [ 189 | "# Show bounding boxes" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "def draw_boxes_on_image(path, boxes):\n", 199 | "\n", 200 | " image = Image.open(path)\n", 201 | " draw = ImageDraw.Draw(image, 'RGBA')\n", 202 | " width, height = image.size\n", 203 | "\n", 204 | " for b in boxes:\n", 205 | " xmin, ymin, xmax, ymax = b\n", 206 | "\n", 207 | " fill = (255, 255, 255, 45)\n", 208 | " outline = 'red'\n", 209 | " draw.rectangle(\n", 210 | " [(xmin, ymin), (xmax, ymax)],\n", 211 | " fill=fill, outline=outline\n", 212 | " )\n", 213 | " return image" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "metadata": { 220 | "scrolled": false 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "i = random.randint(0, len(metadata) - 1) # choose a random image\n", 225 | "some_boxes = boxes[metadata.path[i]]\n", 226 | "draw_boxes_on_image(metadata.full_path[i], some_boxes)" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "# Convert" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "def get_annotation(path, name, width, height):\n", 243 | " annotation = {\n", 244 | " \"filename\": name,\n", 245 | " \"size\": {\"depth\": 3, \"width\": width, \"height\": height}\n", 246 | " }\n", 247 | " objects = []\n", 248 | " for b in boxes[path]:\n", 249 | " xmin, ymin, xmax, ymax = b\n", 250 | " objects.append({\"bndbox\": {\"ymin\": ymin, \"ymax\": ymax, \"xmax\": xmax, \"xmin\": xmin}, \"name\": \"face\"})\n", 251 | " annotation[\"object\"] = objects\n", 252 | " return annotation" 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "shutil.rmtree(RESULT_DIR, ignore_errors=True)\n", 262 | "os.mkdir(RESULT_DIR)\n", 263 | "os.mkdir(os.path.join(RESULT_DIR, 'images'))\n", 264 | "os.mkdir(os.path.join(RESULT_DIR, 'annotations'))" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": null, 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "for T in tqdm(metadata.itertuples()):\n", 274 | " \n", 275 | " # get width and height of an image\n", 276 | " image = cv2.imread(T.full_path)\n", 277 | " h, w, c = image.shape\n", 278 | " assert c == 3\n", 279 | " \n", 280 | " # name of the image\n", 281 | " name = '-'.join(T.path.split('/')[:3]) + '_' + T.path.split('/')[-1]\n", 282 | " assert name.endswith('.jpg')\n", 283 | "\n", 284 | " # copy the image\n", 285 | " shutil.copy(T.full_path, os.path.join(RESULT_DIR, 'images', name))\n", 286 | " \n", 287 | " # save annotation for it\n", 288 | " d = get_annotation(T.path, name, w, h)\n", 289 | " json_name = name[:-4] + '.json'\n", 290 | " json.dump(d, open(os.path.join(RESULT_DIR, 'annotations', json_name), 'w')) " 291 | ] 292 | } 293 | ], 294 | "metadata": { 295 | "kernelspec": { 296 | "display_name": "Python 3", 297 | "language": "python", 298 | "name": "python3" 299 | }, 300 | "language_info": { 301 | "codemirror_mode": { 302 | "name": "ipython", 303 | "version": 3 304 | }, 305 | "file_extension": ".py", 306 | "mimetype": "text/x-python", 307 | "name": "python", 308 | "nbconvert_exporter": "python", 309 | "pygments_lexer": "ipython3", 310 | "version": "3.6.3" 311 | } 312 | }, 313 | "nbformat": 4, 314 | "nbformat_minor": 1 315 | } 316 | -------------------------------------------------------------------------------- /src/detector.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | import math 4 | 5 | from src.constants import MATCHING_THRESHOLD, PARALLEL_ITERATIONS, BATCH_NORM_MOMENTUM, RESIZE_METHOD 6 | from src.utils import batch_non_max_suppression, batch_decode 7 | from src.training_target_creation import get_training_targets 8 | from src.losses_and_ohem import localization_loss, classification_loss, apply_hard_mining 9 | 10 | 11 | class Detector: 12 | def __init__(self, images, feature_extractor, anchor_generator): 13 | """ 14 | Arguments: 15 | images: a float tensor with shape [batch_size, height, width, 3], 16 | a batch of RGB images with pixel values in the range [0, 1]. 17 | feature_extractor: an instance of FeatureExtractor. 18 | anchor_generator: an instance of AnchorGenerator. 19 | """ 20 | 21 | # sometimes images will be of different sizes, 22 | # so i need to use the dynamic shape 23 | h, w = images.shape.as_list()[1:3] 24 | 25 | # image padding here is very tricky and important part of the detector, 26 | # if we don't do it then some bounding box 27 | # predictions will be badly shifted! 28 | 29 | x = 128 # mysterious parameter 30 | # (actually, it is the stride of the last layer) 31 | 32 | self.box_scaler = tf.ones([4], dtype=tf.float32) 33 | if h is None or w is None or h % x != 0 or w % x != 0: 34 | h, w = tf.shape(images)[1], tf.shape(images)[2] 35 | with tf.name_scope('image_padding'): 36 | 37 | # image size must be divisible by 128 38 | new_h = x * tf.to_int32(tf.ceil(h/x)) 39 | new_w = x * tf.to_int32(tf.ceil(w/x)) 40 | # also we will need to rescale bounding box coordinates 41 | self.box_scaler = tf.to_float(tf.stack([ 42 | h/new_h, w/new_w, h/new_h, w/new_w 43 | ])) 44 | # pad the images with zeros on the right and on the bottom 45 | images = tf.image.pad_to_bounding_box( 46 | images, offset_height=0, offset_width=0, 47 | target_height=new_h, target_width=new_w 48 | ) 49 | h, w = new_h, new_w 50 | 51 | feature_maps = feature_extractor(images) 52 | self.is_training = feature_extractor.is_training 53 | 54 | self.anchors = anchor_generator(feature_maps, image_size=(w, h)) 55 | self.num_anchors_per_location = anchor_generator.num_anchors_per_location 56 | self.num_anchors_per_feature_map = anchor_generator.num_anchors_per_feature_map 57 | self._add_box_predictions(feature_maps) 58 | 59 | def get_predictions(self, score_threshold=0.1, iou_threshold=0.6, max_boxes=20): 60 | """Postprocess outputs of the network. 61 | 62 | Returns: 63 | boxes: a float tensor with shape [batch_size, N, 4]. 64 | scores: a float tensor with shape [batch_size, N]. 65 | num_boxes: an int tensor with shape [batch_size], it 66 | represents the number of detections on an image. 67 | 68 | where N = max_boxes. 69 | """ 70 | with tf.name_scope('postprocessing'): 71 | boxes = batch_decode(self.box_encodings, self.anchors) 72 | # if the images were padded we need to rescale predicted boxes: 73 | boxes = boxes / self.box_scaler 74 | boxes = tf.clip_by_value(boxes, 0.0, 1.0) 75 | # it has shape [batch_size, num_anchors, 4] 76 | 77 | scores = tf.nn.softmax(self.class_predictions_with_background, axis=2)[:, :, 1] 78 | # it has shape [batch_size, num_anchors] 79 | 80 | with tf.name_scope('nms'): 81 | boxes, scores, num_detections = batch_non_max_suppression( 82 | boxes, scores, score_threshold, iou_threshold, max_boxes 83 | ) 84 | return {'boxes': boxes, 'scores': scores, 'num_boxes': num_detections} 85 | 86 | def loss(self, groundtruth, params): 87 | """Compute scalar loss tensors with respect to provided groundtruth. 88 | 89 | Arguments: 90 | groundtruth: a dict with the following keys 91 | 'boxes': a float tensor with shape [batch_size, max_num_boxes, 4]. 92 | 'num_boxes': an int tensor with shape [batch_size]. 93 | where max_num_boxes = max(num_boxes). 94 | params: a dict with parameters for OHEM. 95 | Returns: 96 | two float tensors with shape []. 97 | """ 98 | reg_targets, matches = self._create_targets(groundtruth) 99 | 100 | with tf.name_scope('losses'): 101 | 102 | # whether anchor is matched 103 | is_matched = tf.greater_equal(matches, 0) 104 | weights = tf.to_float(is_matched) 105 | # shape [batch_size, num_anchors] 106 | 107 | # we have binary classification for each anchor 108 | cls_targets = tf.to_int32(is_matched) 109 | 110 | with tf.name_scope('classification_loss'): 111 | cls_losses = classification_loss( 112 | self.class_predictions_with_background, 113 | cls_targets 114 | ) 115 | with tf.name_scope('localization_loss'): 116 | location_losses = localization_loss( 117 | self.box_encodings, 118 | reg_targets, weights 119 | ) 120 | # they have shape [batch_size, num_anchors] 121 | 122 | with tf.name_scope('normalization'): 123 | matches_per_image = tf.reduce_sum(weights, axis=1) # shape [batch_size] 124 | num_matches = tf.reduce_sum(matches_per_image) # shape [] 125 | normalizer = tf.maximum(num_matches, 1.0) 126 | 127 | scores = tf.nn.softmax(self.class_predictions_with_background, axis=2) 128 | # it has shape [batch_size, num_anchors, 2] 129 | 130 | decoded_boxes = batch_decode(self.box_encodings, self.anchors) 131 | decoded_boxes = decoded_boxes / self.box_scaler 132 | # it has shape [batch_size, num_anchors, 4] 133 | 134 | # add summaries for predictions 135 | is_background = tf.equal(matches, -1) 136 | self._add_scalewise_histograms(tf.to_float(is_background) * scores[:, :, 0], 'background_probability') 137 | self._add_scalewise_histograms(weights * scores[:, :, 1], 'face_probability') 138 | ymin, xmin, ymax, xmax = tf.unstack(decoded_boxes, axis=2) 139 | h, w = ymax - ymin, xmax - xmin 140 | self._add_scalewise_histograms(weights * h, 'box_heights') 141 | self._add_scalewise_histograms(weights * w, 'box_widths') 142 | 143 | # add summaries for losses and matches 144 | self._add_scalewise_matches_summaries(weights) 145 | self._add_scalewise_summaries(cls_losses, name='classification_losses') 146 | self._add_scalewise_summaries(location_losses, name='localization_losses') 147 | tf.summary.scalar('total_mean_matches_per_image', tf.reduce_mean(matches_per_image)) 148 | 149 | with tf.name_scope('ohem'): 150 | location_loss, cls_loss = apply_hard_mining( 151 | location_losses, cls_losses, 152 | self.class_predictions_with_background, 153 | matches, decoded_boxes, 154 | loss_to_use=params['loss_to_use'], 155 | loc_loss_weight=params['loc_loss_weight'], 156 | cls_loss_weight=params['cls_loss_weight'], 157 | num_hard_examples=params['num_hard_examples'], 158 | nms_threshold=params['nms_threshold'], 159 | max_negatives_per_positive=params['max_negatives_per_positive'], 160 | min_negatives_per_image=params['min_negatives_per_image'] 161 | ) 162 | return {'localization_loss': location_loss/normalizer, 'classification_loss': cls_loss/normalizer} 163 | 164 | def _add_scalewise_summaries(self, tensor, name, percent=0.2): 165 | """Adds histograms of the biggest 20 percent of 166 | tensor's values for each scale (feature map). 167 | 168 | Arguments: 169 | tensor: a float tensor with shape [batch_size, num_anchors]. 170 | name: a string. 171 | percent: a float number, default value is 20%. 172 | """ 173 | index = 0 174 | for i, n in enumerate(self.num_anchors_per_feature_map): 175 | k = tf.ceil(tf.to_float(n) * percent) 176 | k = tf.to_int32(k) 177 | biggest_values, _ = tf.nn.top_k(tensor[:, index:(index + n)], k, sorted=False) 178 | # it has shape [batch_size, k] 179 | tf.summary.histogram( 180 | name + '_on_scale_' + str(i), 181 | tf.reduce_mean(biggest_values, axis=0) 182 | ) 183 | index += n 184 | 185 | def _add_scalewise_histograms(self, tensor, name): 186 | """Adds histograms of the tensor's nonzero values for each scale (feature map). 187 | 188 | Arguments: 189 | tensor: a float tensor with shape [batch_size, num_anchors]. 190 | name: a string. 191 | """ 192 | index = 0 193 | for i, n in enumerate(self.num_anchors_per_feature_map): 194 | values = tf.reshape(tensor[:, index:(index + n)], [-1]) 195 | nonzero = tf.greater(values, 0.0) 196 | values = tf.boolean_mask(values, nonzero) 197 | tf.summary.histogram(name + '_on_scale_' + str(i), values) 198 | index += n 199 | 200 | def _add_scalewise_matches_summaries(self, weights): 201 | """Adds summaries for the number of matches on each scale.""" 202 | index = 0 203 | for i, n in enumerate(self.num_anchors_per_feature_map): 204 | matches_per_image = tf.reduce_sum(weights[:, index:(index + n)], axis=1) 205 | tf.summary.scalar( 206 | 'mean_matches_per_image_on_scale_' + str(i), 207 | tf.reduce_mean(matches_per_image, axis=0) 208 | ) 209 | index += n 210 | 211 | def _create_targets(self, groundtruth): 212 | """ 213 | Arguments: 214 | groundtruth: a dict with the following keys 215 | 'boxes': a float tensor with shape [batch_size, N, 4]. 216 | 'num_boxes': an int tensor with shape [batch_size]. 217 | Returns: 218 | reg_targets: a float tensor with shape [batch_size, num_anchors, 4]. 219 | matches: an int tensor with shape [batch_size, num_anchors]. 220 | """ 221 | def fn(x): 222 | boxes, num_boxes = x 223 | boxes = boxes[:num_boxes] 224 | # if the images are padded we need to rescale groundtruth boxes: 225 | boxes = boxes * self.box_scaler 226 | reg_targets, matches = get_training_targets( 227 | self.anchors, boxes, threshold=MATCHING_THRESHOLD 228 | ) 229 | return reg_targets, matches 230 | 231 | with tf.name_scope('target_creation'): 232 | reg_targets, matches = tf.map_fn( 233 | fn, [groundtruth['boxes'], groundtruth['num_boxes']], 234 | dtype=(tf.float32, tf.int32), 235 | parallel_iterations=PARALLEL_ITERATIONS, 236 | back_prop=False, swap_memory=False, infer_shape=True 237 | ) 238 | return reg_targets, matches 239 | 240 | def _add_box_predictions(self, feature_maps): 241 | """Adds box predictors to each feature map, reshapes, and returns concatenated results. 242 | 243 | Arguments: 244 | feature_maps: a list of float tensors where the ith tensor has shape 245 | [batch, height_i, width_i, channels_i]. 246 | 247 | It creates two tensors: 248 | box_encodings: a float tensor with shape [batch_size, num_anchors, 4]. 249 | class_predictions_with_background: a float tensor with shape 250 | [batch_size, num_anchors, 2]. 251 | """ 252 | num_anchors_per_location = self.num_anchors_per_location 253 | num_feature_maps = len(feature_maps) 254 | box_encodings, class_predictions_with_background = [], [] 255 | 256 | with tf.variable_scope('prediction_layers'): 257 | for i in range(num_feature_maps): 258 | 259 | x = feature_maps[i] 260 | num_predictions_per_location = num_anchors_per_location[i] 261 | 262 | y = slim.conv2d( 263 | x, num_predictions_per_location * 4, 264 | [3, 3], activation_fn=None, scope='box_encoding_predictor_%d' % i, 265 | data_format='NHWC', padding='SAME' 266 | ) 267 | # it has shape [batch_size, height_i, width_i, num_predictions_per_location * 4] 268 | box_encodings.append(y) 269 | 270 | import numpy as np 271 | biases = np.zeros([num_predictions_per_location, 2], dtype='float32') 272 | biases[:, 0] = np.log(0.99) # background class 273 | biases[:, 1] = np.log(0.01) # object class 274 | biases = biases.reshape(num_predictions_per_location * 2) 275 | 276 | y = slim.conv2d( 277 | x, num_predictions_per_location * 2, 278 | [3, 3], activation_fn=None, scope='class_predictor_%d' % i, 279 | data_format='NHWC', padding='SAME', 280 | biases_initializer=tf.constant_initializer(biases) 281 | ) 282 | # it has shape [batch_size, height_i, width_i, num_predictions_per_location * 2] 283 | class_predictions_with_background.append(y) 284 | 285 | # it is important that reshaping here is the same as when anchors were generated 286 | with tf.name_scope('reshaping'): 287 | for i in range(num_feature_maps): 288 | 289 | x = feature_maps[i] 290 | num_predictions_per_location = num_anchors_per_location[i] 291 | batch_size = tf.shape(x)[0] 292 | height_i = tf.shape(x)[1] 293 | width_i = tf.shape(x)[2] 294 | num_anchors_on_feature_map = height_i * width_i * num_predictions_per_location 295 | 296 | y = box_encodings[i] 297 | y = tf.reshape(y, tf.stack([batch_size, height_i, width_i, num_predictions_per_location, 4])) 298 | box_encodings[i] = tf.reshape(y, [batch_size, num_anchors_on_feature_map, 4]) 299 | 300 | y = class_predictions_with_background[i] 301 | y = tf.reshape(y, [batch_size, height_i, width_i, num_predictions_per_location, 2]) 302 | class_predictions_with_background[i] = tf.reshape(y, tf.stack([batch_size, num_anchors_on_feature_map, 2])) 303 | 304 | self.box_encodings = tf.concat(box_encodings, axis=1) 305 | self.class_predictions_with_background = tf.concat(class_predictions_with_background, axis=1) 306 | --------------------------------------------------------------------------------