├── Dockerfile ├── README.md ├── face_detection_webcam.py ├── img ├── ex1.png └── ex2.png ├── model └── frozen_inference_graph.pb ├── proto ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ └── string_int_label_map_pb2.cpython-35.pyc ├── label_map.pbtxt └── string_int_label_map_pb2.py ├── requirements.txt ├── runDetection.sh └── utils ├── __pycache__ ├── __init__.cpython-35.pyc ├── label_map_util.cpython-35.pyc └── visualization_utils_color.cpython-35.pyc ├── label_map_util.py └── visualization_utils_color.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | FROM python:3 4 | 5 | COPY ./requirements.txt /requirements.txt 6 | 7 | WORKDIR / 8 | 9 | RUN pip install -r requirements.txt 10 | 11 | COPY . / 12 | 13 | CMD [ "python", "./face_detection_webcam.py"] 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow face detection 2 | 3 | Tensorflow face detection implementation based on Mobilenet SSD V2, trained on Wider face dataset using Tensorflow object detection API. 4 | 5 | 6 | ## Dependencies 7 | 8 | * Tensorflow >= 1.12 9 | * OpenCv 10 | * imutils 11 | 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Usage 17 | ``` 18 | python face_detection_webcam.py 19 | ``` 20 | 21 | ## Docker 22 | 23 | 24 | ``` 25 | docker build -t face_detection . 26 | ``` 27 | 28 | Run the project with pre-trained model : 29 | 30 | ``` 31 | bash runDetection.sh 32 | ``` 33 | 34 | ## Result 35 | Achieves 19 FPS with a 640x480 resolution on Intel Core i7-7600U CPU 2.80GHz × 4 U. 36 | 37 |

38 | ex1 39 | ex2 40 |

41 | 42 | 43 | ## Train Model 44 | 45 | If you want to train your own model, i advise you to follow the tutorial about tensorflow object detection api, you'll just need to download an annotated dataset. 46 | 47 | ## Reference 48 | * Tensorflow object detection API 49 | 50 | * WIDERFace dataset 51 | -------------------------------------------------------------------------------- /face_detection_webcam.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import numpy as np 6 | import tensorflow as tf 7 | import cv2 8 | 9 | from utils import label_map_util 10 | from utils import visualization_utils_color as vis_util 11 | 12 | from imutils.video import FPS 13 | from imutils.video import WebcamVideoStream 14 | 15 | 16 | # Path to frozen detection graph. This is the actual model that is used for the object detection. 17 | PATH_TO_CKPT = './model/frozen_inference_graph.pb' 18 | 19 | # List of the strings that is used to add correct label for each box. 20 | PATH_TO_LABELS = './proto/label_map.pbtxt' 21 | 22 | NUM_CLASSES = 1 23 | 24 | # Loading label map 25 | label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 26 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True) 27 | category_index = label_map_util.create_category_index(categories) 28 | 29 | 30 | def face_detection(): 31 | 32 | # Load Tensorflow model 33 | detection_graph = tf.Graph() 34 | with detection_graph.as_default(): 35 | od_graph_def = tf.GraphDef() 36 | with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 37 | serialized_graph = fid.read() 38 | od_graph_def.ParseFromString(serialized_graph) 39 | tf.import_graph_def(od_graph_def, name='') 40 | 41 | sess = tf.Session(graph=detection_graph) 42 | 43 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 44 | 45 | # Each box represents a part of the image where a particular object was detected. 46 | detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 47 | 48 | # Each score represent how level of confidence for each of the objects. 49 | # Score is shown on the result image, together with the class label. 50 | detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') 51 | 52 | # Actual detection. 53 | detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') 54 | num_detections = detection_graph.get_tensor_by_name('num_detections:0') 55 | 56 | # Start video stream 57 | cap = WebcamVideoStream(0).start() 58 | fps = FPS().start() 59 | 60 | while True: 61 | 62 | frame = cap.read() 63 | 64 | # Expand dimensions since the model expects images to have shape: [1, None, None, 3] 65 | expanded_frame = np.expand_dims(frame, axis=0) 66 | (boxes, scores, classes, num_c) = sess.run( 67 | [detection_boxes, detection_scores, detection_classes, num_detections], 68 | feed_dict={image_tensor: expanded_frame}) 69 | 70 | # Visualization of the detection 71 | vis_util.visualize_boxes_and_labels_on_image_array( 72 | frame, 73 | np.squeeze(boxes), 74 | np.squeeze(classes).astype(np.int32), 75 | np.squeeze(scores), 76 | category_index, 77 | use_normalized_coordinates=True, 78 | line_thickness=2, 79 | min_score_thresh=0.40) 80 | 81 | cv2.imshow('Detection', frame) 82 | fps.update() 83 | 84 | if cv2.waitKey(1) == ord('q'): 85 | fps.stop() 86 | break 87 | 88 | print("Fps: {:.2f}".format(fps.fps())) 89 | fps.update() 90 | cap.stop() 91 | cv2.destroyAllWindows() 92 | 93 | 94 | if __name__ == '__main__': 95 | face_detection() 96 | -------------------------------------------------------------------------------- /img/ex1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/img/ex1.png -------------------------------------------------------------------------------- /img/ex2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/img/ex2.png -------------------------------------------------------------------------------- /model/frozen_inference_graph.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/model/frozen_inference_graph.pb -------------------------------------------------------------------------------- /proto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/proto/__init__.py -------------------------------------------------------------------------------- /proto/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/proto/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /proto/__pycache__/string_int_label_map_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/proto/__pycache__/string_int_label_map_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /proto/label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'face' 4 | } 5 | 6 | 7 | -------------------------------------------------------------------------------- /proto/string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/string_int_label_map.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/string_int_label_map.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 29 | name='StringIntLabelMapItem', 30 | full_name='object_detection.protos.StringIntLabelMapItem', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | extension_ranges=[], 65 | oneofs=[ 66 | ], 67 | serialized_start=79, 68 | serialized_end=150, 69 | ) 70 | 71 | 72 | _STRINGINTLABELMAP = _descriptor.Descriptor( 73 | name='StringIntLabelMap', 74 | full_name='object_detection.protos.StringIntLabelMap', 75 | filename=None, 76 | file=DESCRIPTOR, 77 | containing_type=None, 78 | fields=[ 79 | _descriptor.FieldDescriptor( 80 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 81 | number=1, type=11, cpp_type=10, label=3, 82 | has_default_value=False, default_value=[], 83 | message_type=None, enum_type=None, containing_type=None, 84 | is_extension=False, extension_scope=None, 85 | options=None), 86 | ], 87 | extensions=[ 88 | ], 89 | nested_types=[], 90 | enum_types=[ 91 | ], 92 | options=None, 93 | is_extendable=False, 94 | extension_ranges=[], 95 | oneofs=[ 96 | ], 97 | serialized_start=152, 98 | serialized_end=233, 99 | ) 100 | 101 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 102 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 103 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 104 | 105 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 106 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 107 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 108 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 109 | )) 110 | _sym_db.RegisterMessage(StringIntLabelMapItem) 111 | 112 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 113 | DESCRIPTOR = _STRINGINTLABELMAP, 114 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 115 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 116 | )) 117 | _sym_db.RegisterMessage(StringIntLabelMap) 118 | 119 | 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | Pillow 3 | imutils 4 | notebook 5 | jupyter 6 | tensorflow 7 | moviepy 8 | autovizwidget 9 | opencv-python 10 | -------------------------------------------------------------------------------- /runDetection.sh: -------------------------------------------------------------------------------- 1 | xhost + local:docker 2 | docker run --privileged --device=/dev/video0 -v /tmp/.X11-unix:/tmp/.X11-unix -e DISPLAY=unix$DISPLAY face_detection 3 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/label_map_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/utils/__pycache__/label_map_util.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization_utils_color.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/utils/__pycache__/visualization_utils_color.cpython-35.pyc -------------------------------------------------------------------------------- /utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Label map utility functions.""" 17 | 18 | import logging 19 | 20 | import tensorflow as tf 21 | from google.protobuf import text_format 22 | from proto import string_int_label_map_pb2 23 | 24 | 25 | def _validate_label_map(label_map): 26 | """Checks if a label map is valid. 27 | 28 | Args: 29 | label_map: StringIntLabelMap to validate. 30 | 31 | Raises: 32 | ValueError: if label map is invalid. 33 | """ 34 | for item in label_map.item: 35 | if item.id < 1: 36 | raise ValueError('Label map ids should be >= 1.') 37 | 38 | 39 | def create_category_index(categories): 40 | """Creates dictionary of COCO compatible categories keyed by category id. 41 | 42 | Args: 43 | categories: a list of dicts, each of which has the following keys: 44 | 'id': (required) an integer id uniquely identifying this category. 45 | 'name': (required) string representing category name 46 | e.g., 'cat', 'dog', 'pizza'. 47 | 48 | Returns: 49 | category_index: a dict containing the same entries as categories, but keyed 50 | by the 'id' field of each category. 51 | """ 52 | category_index = {} 53 | for cat in categories: 54 | category_index[cat['id']] = cat 55 | return category_index 56 | 57 | 58 | def convert_label_map_to_categories(label_map, 59 | max_num_classes, 60 | use_display_name=True): 61 | """Loads label map proto and returns categories list compatible with eval. 62 | 63 | This function loads a label map and returns a list of dicts, each of which 64 | has the following keys: 65 | 'id': (required) an integer id uniquely identifying this category. 66 | 'name': (required) string representing category name 67 | e.g., 'cat', 'dog', 'pizza'. 68 | We only allow class into the list if its id-label_id_offset is 69 | between 0 (inclusive) and max_num_classes (exclusive). 70 | If there are several items mapping to the same id in the label map, 71 | we will only keep the first one in the categories list. 72 | 73 | Args: 74 | label_map: a StringIntLabelMapProto or None. If None, a default categories 75 | list is created with max_num_classes categories. 76 | max_num_classes: maximum number of (consecutive) label indices to include. 77 | use_display_name: (boolean) choose whether to load 'display_name' field 78 | as category name. If False or if the display_name field does not exist, 79 | uses 'name' field as category names instead. 80 | Returns: 81 | categories: a list of dictionaries representing all possible categories. 82 | """ 83 | categories = [] 84 | list_of_ids_already_added = [] 85 | if not label_map: 86 | label_id_offset = 1 87 | for class_id in range(max_num_classes): 88 | categories.append({ 89 | 'id': class_id + label_id_offset, 90 | 'name': 'category_{}'.format(class_id + label_id_offset) 91 | }) 92 | return categories 93 | for item in label_map.item: 94 | if not 0 < item.id <= max_num_classes: 95 | logging.info('Ignore item %d since it falls outside of requested ' 96 | 'label range.', item.id) 97 | continue 98 | if use_display_name and item.HasField('display_name'): 99 | name = item.display_name 100 | else: 101 | name = item.name 102 | if item.id not in list_of_ids_already_added: 103 | list_of_ids_already_added.append(item.id) 104 | categories.append({'id': item.id, 'name': name}) 105 | return categories 106 | 107 | 108 | def load_labelmap(path): 109 | """Loads label map proto. 110 | 111 | Args: 112 | path: path to StringIntLabelMap proto text file. 113 | Returns: 114 | a StringIntLabelMapProto 115 | """ 116 | with tf.gfile.GFile(path, 'r') as fid: 117 | label_map_string = fid.read() 118 | label_map = string_int_label_map_pb2.StringIntLabelMap() 119 | try: 120 | text_format.Merge(label_map_string, label_map) 121 | except text_format.ParseError: 122 | label_map.ParseFromString(label_map_string) 123 | _validate_label_map(label_map) 124 | return label_map 125 | 126 | 127 | def get_label_map_dict(label_map_path): 128 | """Reads a label map and returns a dictionary of label names to id. 129 | 130 | Args: 131 | label_map_path: path to label_map. 132 | 133 | Returns: 134 | A dictionary mapping label names to id. 135 | """ 136 | label_map = load_labelmap(label_map_path) 137 | label_map_dict = {} 138 | for item in label_map.item: 139 | label_map_dict[item.name] = item.id 140 | return label_map_dict 141 | -------------------------------------------------------------------------------- /utils/visualization_utils_color.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A set of functions that are used for visualization. 17 | 18 | These functions often receive an image, perform some visualization on the image. 19 | The functions do not return a value, instead they modify the image itself. 20 | 21 | """ 22 | import collections 23 | import numpy as np 24 | import PIL.Image as Image 25 | import PIL.ImageColor as ImageColor 26 | import PIL.ImageDraw as ImageDraw 27 | import PIL.ImageFont as ImageFont 28 | import six 29 | import tensorflow as tf 30 | 31 | 32 | _TITLE_LEFT_MARGIN = 10 33 | _TITLE_TOP_MARGIN = 10 34 | STANDARD_COLORS = [ 35 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', 36 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', 37 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', 38 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', 39 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', 40 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', 41 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', 42 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', 43 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', 44 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', 45 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', 46 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', 47 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', 48 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', 49 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', 50 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', 51 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', 52 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', 53 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', 54 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', 55 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', 56 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', 57 | 'WhiteSmoke', 'Yellow', 'YellowGreen' 58 | ] 59 | 60 | 61 | def save_image_array_as_png(image, output_path): 62 | """Saves an image (represented as a numpy array) to PNG. 63 | 64 | Args: 65 | image: a numpy array with shape [height, width, 3]. 66 | output_path: path to which image should be written. 67 | """ 68 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 69 | with tf.gfile.Open(output_path, 'w') as fid: 70 | image_pil.save(fid, 'PNG') 71 | 72 | 73 | def encode_image_array_as_png_str(image): 74 | """Encodes a numpy array into a PNG string. 75 | 76 | Args: 77 | image: a numpy array with shape [height, width, 3]. 78 | 79 | Returns: 80 | PNG encoded image string. 81 | """ 82 | image_pil = Image.fromarray(np.uint8(image)) 83 | output = six.BytesIO() 84 | image_pil.save(output, format='PNG') 85 | png_string = output.getvalue() 86 | output.close() 87 | return png_string 88 | 89 | 90 | def draw_bounding_box_on_image_array(image, 91 | ymin, 92 | xmin, 93 | ymax, 94 | xmax, 95 | color='red', 96 | thickness=4, 97 | display_str_list=(), 98 | use_normalized_coordinates=True): 99 | """Adds a bounding box to an image (numpy array). 100 | 101 | Args: 102 | image: a numpy array with shape [height, width, 3]. 103 | ymin: ymin of bounding box in normalized coordinates (same below). 104 | xmin: xmin of bounding box. 105 | ymax: ymax of bounding box. 106 | xmax: xmax of bounding box. 107 | color: color to draw bounding box. Default is red. 108 | thickness: line thickness. Default value is 4. 109 | display_str_list: list of strings to display in box 110 | (each to be shown on its own line). 111 | use_normalized_coordinates: If True (default), treat coordinates 112 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 113 | coordinates as absolute. 114 | """ 115 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 116 | draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, 117 | thickness, display_str_list, 118 | use_normalized_coordinates) 119 | np.copyto(image, np.array(image_pil)) 120 | 121 | 122 | def draw_bounding_box_on_image(image, 123 | ymin, 124 | xmin, 125 | ymax, 126 | xmax, 127 | color='red', 128 | thickness=4, 129 | display_str_list=(), 130 | use_normalized_coordinates=True): 131 | """Adds a bounding box to an image. 132 | 133 | Each string in display_str_list is displayed on a separate line above the 134 | bounding box in black text on a rectangle filled with the input 'color'. 135 | 136 | Args: 137 | image: a PIL.Image object. 138 | ymin: ymin of bounding box. 139 | xmin: xmin of bounding box. 140 | ymax: ymax of bounding box. 141 | xmax: xmax of bounding box. 142 | color: color to draw bounding box. Default is red. 143 | thickness: line thickness. Default value is 4. 144 | display_str_list: list of strings to display in box 145 | (each to be shown on its own line). 146 | use_normalized_coordinates: If True (default), treat coordinates 147 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 148 | coordinates as absolute. 149 | """ 150 | draw = ImageDraw.Draw(image) 151 | im_width, im_height = image.size 152 | if use_normalized_coordinates: 153 | (left, right, top, bottom) = (xmin * im_width, xmax * im_width, 154 | ymin * im_height, ymax * im_height) 155 | else: 156 | (left, right, top, bottom) = (xmin, xmax, ymin, ymax) 157 | draw.line([(left, top), (left, bottom), (right, bottom), 158 | (right, top), (left, top)], width=thickness, fill=color) 159 | try: 160 | font = ImageFont.truetype('arial.ttf', 24) 161 | except IOError: 162 | font = ImageFont.load_default() 163 | 164 | text_bottom = top 165 | # Reverse list and print from bottom to top. 166 | for display_str in display_str_list[::-1]: 167 | text_width, text_height = font.getsize(display_str) 168 | margin = np.ceil(0.05 * text_height) 169 | draw.rectangle( 170 | [(left, text_bottom - text_height - 2 * margin), (left + text_width, 171 | text_bottom)], 172 | fill=color) 173 | draw.text( 174 | (left + margin, text_bottom - text_height - margin), 175 | display_str, 176 | fill='black', 177 | font=font) 178 | text_bottom -= text_height - 2 * margin 179 | 180 | 181 | def draw_bounding_boxes_on_image_array(image, 182 | boxes, 183 | color='red', 184 | thickness=4, 185 | display_str_list_list=()): 186 | """Draws bounding boxes on image (numpy array). 187 | 188 | Args: 189 | image: a numpy array object. 190 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 191 | The coordinates are in normalized format between [0, 1]. 192 | color: color to draw bounding box. Default is red. 193 | thickness: line thickness. Default value is 4. 194 | display_str_list_list: list of list of strings. 195 | a list of strings for each bounding box. 196 | The reason to pass a list of strings for a 197 | bounding box is that it might contain 198 | multiple labels. 199 | 200 | Raises: 201 | ValueError: if boxes is not a [N, 4] array 202 | """ 203 | image_pil = Image.fromarray(image) 204 | draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, 205 | display_str_list_list) 206 | np.copyto(image, np.array(image_pil)) 207 | 208 | 209 | def draw_bounding_boxes_on_image(image, 210 | boxes, 211 | color='red', 212 | thickness=4, 213 | display_str_list_list=()): 214 | """Draws bounding boxes on image. 215 | 216 | Args: 217 | image: a PIL.Image object. 218 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 219 | The coordinates are in normalized format between [0, 1]. 220 | color: color to draw bounding box. Default is red. 221 | thickness: line thickness. Default value is 4. 222 | display_str_list_list: list of list of strings. 223 | a list of strings for each bounding box. 224 | The reason to pass a list of strings for a 225 | bounding box is that it might contain 226 | multiple labels. 227 | 228 | Raises: 229 | ValueError: if boxes is not a [N, 4] array 230 | """ 231 | boxes_shape = boxes.shape 232 | if not boxes_shape: 233 | return 234 | if len(boxes_shape) != 2 or boxes_shape[1] != 4: 235 | raise ValueError('Input must be of size [N, 4]') 236 | for i in range(boxes_shape[0]): 237 | display_str_list = () 238 | if display_str_list_list: 239 | display_str_list = display_str_list_list[i] 240 | draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], 241 | boxes[i, 3], color, thickness, display_str_list) 242 | 243 | 244 | def draw_keypoints_on_image_array(image, 245 | keypoints, 246 | color='red', 247 | radius=2, 248 | use_normalized_coordinates=True): 249 | """Draws keypoints on an image (numpy array). 250 | 251 | Args: 252 | image: a numpy array with shape [height, width, 3]. 253 | keypoints: a numpy array with shape [num_keypoints, 2]. 254 | color: color to draw the keypoints with. Default is red. 255 | radius: keypoint radius. Default value is 2. 256 | use_normalized_coordinates: if True (default), treat keypoint values as 257 | relative to the image. Otherwise treat them as absolute. 258 | """ 259 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 260 | draw_keypoints_on_image(image_pil, keypoints, color, radius, 261 | use_normalized_coordinates) 262 | np.copyto(image, np.array(image_pil)) 263 | 264 | 265 | def draw_keypoints_on_image(image, 266 | keypoints, 267 | color='red', 268 | radius=2, 269 | use_normalized_coordinates=True): 270 | """Draws keypoints on an image. 271 | 272 | Args: 273 | image: a PIL.Image object. 274 | keypoints: a numpy array with shape [num_keypoints, 2]. 275 | color: color to draw the keypoints with. Default is red. 276 | radius: keypoint radius. Default value is 2. 277 | use_normalized_coordinates: if True (default), treat keypoint values as 278 | relative to the image. Otherwise treat them as absolute. 279 | """ 280 | draw = ImageDraw.Draw(image) 281 | im_width, im_height = image.size 282 | keypoints_x = [k[1] for k in keypoints] 283 | keypoints_y = [k[0] for k in keypoints] 284 | if use_normalized_coordinates: 285 | keypoints_x = tuple([im_width * x for x in keypoints_x]) 286 | keypoints_y = tuple([im_height * y for y in keypoints_y]) 287 | for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): 288 | draw.ellipse([(keypoint_x - radius, keypoint_y - radius), 289 | (keypoint_x + radius, keypoint_y + radius)], 290 | outline=color, fill=color) 291 | 292 | 293 | def draw_mask_on_image_array(image, mask, color='red', alpha=0.7): 294 | """Draws mask on an image. 295 | 296 | Args: 297 | image: uint8 numpy array with shape (img_height, img_height, 3) 298 | mask: a float numpy array of shape (img_height, img_height) with 299 | values between 0 and 1 300 | color: color to draw the keypoints with. Default is red. 301 | alpha: transparency value between 0 and 1. (default: 0.7) 302 | 303 | Raises: 304 | ValueError: On incorrect data type for image or masks. 305 | """ 306 | if image.dtype != np.uint8: 307 | raise ValueError('`image` not of type np.uint8') 308 | if mask.dtype != np.float32: 309 | raise ValueError('`mask` not of type np.float32') 310 | if np.any(np.logical_or(mask > 1.0, mask < 0.0)): 311 | raise ValueError('`mask` elements should be in [0, 1]') 312 | rgb = ImageColor.getrgb(color) 313 | pil_image = Image.fromarray(image) 314 | 315 | solid_color = np.expand_dims( 316 | np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) 317 | pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') 318 | pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') 319 | pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) 320 | np.copyto(image, np.array(pil_image.convert('RGB'))) 321 | 322 | 323 | def visualize_boxes_and_labels_on_image_array(image, 324 | boxes, 325 | classes, 326 | scores, 327 | category_index, 328 | instance_masks=None, 329 | keypoints=None, 330 | use_normalized_coordinates=False, 331 | max_boxes_to_draw=20, 332 | min_score_thresh=.7, 333 | agnostic_mode=False, 334 | line_thickness=4): 335 | """Overlay labeled boxes on an image with formatted scores and label names. 336 | 337 | This function groups boxes that correspond to the same location 338 | and creates a display string for each detection and overlays these 339 | on the image. Note that this function modifies the image array in-place 340 | and does not return anything. 341 | 342 | Args: 343 | image: uint8 numpy array with shape (img_height, img_width, 3) 344 | boxes: a numpy array of shape [N, 4] 345 | classes: a numpy array of shape [N] 346 | scores: a numpy array of shape [N] or None. If scores=None, then 347 | this function assumes that the boxes to be plotted are groundtruth 348 | boxes and plot all boxes as black with no classes or scores. 349 | category_index: a dict containing category dictionaries (each holding 350 | category index `id` and category name `name`) keyed by category indices. 351 | instance_masks: a numpy array of shape [N, image_height, image_width], can 352 | be None 353 | keypoints: a numpy array of shape [N, num_keypoints, 2], can 354 | be None 355 | use_normalized_coordinates: whether boxes is to be interpreted as 356 | normalized coordinates or not. 357 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw 358 | all boxes. 359 | min_score_thresh: minimum score threshold for a box to be visualized 360 | agnostic_mode: boolean (default: False) controlling whether to evaluate in 361 | class-agnostic mode or not. This mode will display scores but ignore 362 | classes. 363 | line_thickness: integer (default: 4) controlling line width of the boxes. 364 | """ 365 | # Create a display string (and color) for every box location, group any boxes 366 | # that correspond to the same location. 367 | box_to_display_str_map = collections.defaultdict(list) 368 | box_to_color_map = collections.defaultdict(str) 369 | box_to_instance_masks_map = {} 370 | box_to_keypoints_map = collections.defaultdict(list) 371 | if not max_boxes_to_draw: 372 | max_boxes_to_draw = boxes.shape[0] 373 | for i in range(min(max_boxes_to_draw, boxes.shape[0])): 374 | if scores is None or scores[i] > min_score_thresh: 375 | box = tuple(boxes[i].tolist()) 376 | if instance_masks is not None: 377 | box_to_instance_masks_map[box] = instance_masks[i] 378 | if keypoints is not None: 379 | box_to_keypoints_map[box].extend(keypoints[i]) 380 | if scores is None: 381 | box_to_color_map[box] = 'black' 382 | else: 383 | if not agnostic_mode: 384 | if classes[i] in category_index.keys(): 385 | class_name = category_index[classes[i]]['name'] 386 | else: 387 | class_name = 'N/A' 388 | display_str = '{}: {}%'.format( 389 | class_name, 390 | int(100*scores[i])) 391 | else: 392 | display_str = 'score: {}%'.format(int(100 * scores[i])) 393 | box_to_display_str_map[box].append(display_str) 394 | if agnostic_mode: 395 | box_to_color_map[box] = 'DarkOrange' 396 | else: 397 | box_to_color_map[box] = STANDARD_COLORS[ 398 | classes[i] % len(STANDARD_COLORS)] 399 | 400 | # Draw all boxes onto image. 401 | for box, color in box_to_color_map.items(): 402 | color = 'Violet' 403 | ymin, xmin, ymax, xmax = box 404 | if instance_masks is not None: 405 | draw_mask_on_image_array( 406 | image, 407 | box_to_instance_masks_map[box], 408 | color=color 409 | ) 410 | draw_bounding_box_on_image_array( 411 | image, 412 | ymin, 413 | xmin, 414 | ymax, 415 | xmax, 416 | color=color, 417 | thickness=line_thickness, 418 | display_str_list=box_to_display_str_map[box], 419 | use_normalized_coordinates=use_normalized_coordinates) 420 | if keypoints is not None: 421 | draw_keypoints_on_image_array( 422 | image, 423 | box_to_keypoints_map[box], 424 | color=color, 425 | radius=line_thickness / 2, 426 | use_normalized_coordinates=use_normalized_coordinates) 427 | --------------------------------------------------------------------------------