├── Dockerfile
├── README.md
├── face_detection_webcam.py
├── img
├── ex1.png
└── ex2.png
├── model
└── frozen_inference_graph.pb
├── proto
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ └── string_int_label_map_pb2.cpython-35.pyc
├── label_map.pbtxt
└── string_int_label_map_pb2.py
├── requirements.txt
├── runDetection.sh
└── utils
├── __pycache__
├── __init__.cpython-35.pyc
├── label_map_util.cpython-35.pyc
└── visualization_utils_color.cpython-35.pyc
├── label_map_util.py
└── visualization_utils_color.py
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:16.04
2 |
3 | FROM python:3
4 |
5 | COPY ./requirements.txt /requirements.txt
6 |
7 | WORKDIR /
8 |
9 | RUN pip install -r requirements.txt
10 |
11 | COPY . /
12 |
13 | CMD [ "python", "./face_detection_webcam.py"]
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensorflow face detection
2 |
3 | Tensorflow face detection implementation based on Mobilenet SSD V2, trained on Wider face dataset using Tensorflow object detection API.
4 |
5 |
6 | ## Dependencies
7 |
8 | * Tensorflow >= 1.12
9 | * OpenCv
10 | * imutils
11 |
12 | ```
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ## Usage
17 | ```
18 | python face_detection_webcam.py
19 | ```
20 |
21 | ## Docker
22 |
23 |
24 | ```
25 | docker build -t face_detection .
26 | ```
27 |
28 | Run the project with pre-trained model :
29 |
30 | ```
31 | bash runDetection.sh
32 | ```
33 |
34 | ## Result
35 | Achieves 19 FPS with a 640x480 resolution on Intel Core i7-7600U CPU 2.80GHz × 4 U.
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | ## Train Model
44 |
45 | If you want to train your own model, i advise you to follow the tutorial about tensorflow object detection api, you'll just need to download an annotated dataset.
46 |
47 | ## Reference
48 | * Tensorflow object detection API
49 |
50 | * WIDERFace dataset
51 |
--------------------------------------------------------------------------------
/face_detection_webcam.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import time
5 | import numpy as np
6 | import tensorflow as tf
7 | import cv2
8 |
9 | from utils import label_map_util
10 | from utils import visualization_utils_color as vis_util
11 |
12 | from imutils.video import FPS
13 | from imutils.video import WebcamVideoStream
14 |
15 |
16 | # Path to frozen detection graph. This is the actual model that is used for the object detection.
17 | PATH_TO_CKPT = './model/frozen_inference_graph.pb'
18 |
19 | # List of the strings that is used to add correct label for each box.
20 | PATH_TO_LABELS = './proto/label_map.pbtxt'
21 |
22 | NUM_CLASSES = 1
23 |
24 | # Loading label map
25 | label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
26 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
27 | category_index = label_map_util.create_category_index(categories)
28 |
29 |
30 | def face_detection():
31 |
32 | # Load Tensorflow model
33 | detection_graph = tf.Graph()
34 | with detection_graph.as_default():
35 | od_graph_def = tf.GraphDef()
36 | with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
37 | serialized_graph = fid.read()
38 | od_graph_def.ParseFromString(serialized_graph)
39 | tf.import_graph_def(od_graph_def, name='')
40 |
41 | sess = tf.Session(graph=detection_graph)
42 |
43 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
44 |
45 | # Each box represents a part of the image where a particular object was detected.
46 | detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
47 |
48 | # Each score represent how level of confidence for each of the objects.
49 | # Score is shown on the result image, together with the class label.
50 | detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
51 |
52 | # Actual detection.
53 | detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
54 | num_detections = detection_graph.get_tensor_by_name('num_detections:0')
55 |
56 | # Start video stream
57 | cap = WebcamVideoStream(0).start()
58 | fps = FPS().start()
59 |
60 | while True:
61 |
62 | frame = cap.read()
63 |
64 | # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
65 | expanded_frame = np.expand_dims(frame, axis=0)
66 | (boxes, scores, classes, num_c) = sess.run(
67 | [detection_boxes, detection_scores, detection_classes, num_detections],
68 | feed_dict={image_tensor: expanded_frame})
69 |
70 | # Visualization of the detection
71 | vis_util.visualize_boxes_and_labels_on_image_array(
72 | frame,
73 | np.squeeze(boxes),
74 | np.squeeze(classes).astype(np.int32),
75 | np.squeeze(scores),
76 | category_index,
77 | use_normalized_coordinates=True,
78 | line_thickness=2,
79 | min_score_thresh=0.40)
80 |
81 | cv2.imshow('Detection', frame)
82 | fps.update()
83 |
84 | if cv2.waitKey(1) == ord('q'):
85 | fps.stop()
86 | break
87 |
88 | print("Fps: {:.2f}".format(fps.fps()))
89 | fps.update()
90 | cap.stop()
91 | cv2.destroyAllWindows()
92 |
93 |
94 | if __name__ == '__main__':
95 | face_detection()
96 |
--------------------------------------------------------------------------------
/img/ex1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/img/ex1.png
--------------------------------------------------------------------------------
/img/ex2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/img/ex2.png
--------------------------------------------------------------------------------
/model/frozen_inference_graph.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/model/frozen_inference_graph.pb
--------------------------------------------------------------------------------
/proto/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/proto/__init__.py
--------------------------------------------------------------------------------
/proto/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/proto/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/proto/__pycache__/string_int_label_map_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/proto/__pycache__/string_int_label_map_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/proto/label_map.pbtxt:
--------------------------------------------------------------------------------
1 | item {
2 | id: 1
3 | name: 'face'
4 | }
5 |
6 |
7 |
--------------------------------------------------------------------------------
/proto/string_int_label_map_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/string_int_label_map.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/string_int_label_map.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor(
29 | name='StringIntLabelMapItem',
30 | full_name='object_detection.protos.StringIntLabelMapItem',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0,
37 | number=1, type=9, cpp_type=9, label=1,
38 | has_default_value=False, default_value=_b("").decode('utf-8'),
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2,
51 | number=3, type=9, cpp_type=9, label=1,
52 | has_default_value=False, default_value=_b("").decode('utf-8'),
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | ],
57 | extensions=[
58 | ],
59 | nested_types=[],
60 | enum_types=[
61 | ],
62 | options=None,
63 | is_extendable=False,
64 | extension_ranges=[],
65 | oneofs=[
66 | ],
67 | serialized_start=79,
68 | serialized_end=150,
69 | )
70 |
71 |
72 | _STRINGINTLABELMAP = _descriptor.Descriptor(
73 | name='StringIntLabelMap',
74 | full_name='object_detection.protos.StringIntLabelMap',
75 | filename=None,
76 | file=DESCRIPTOR,
77 | containing_type=None,
78 | fields=[
79 | _descriptor.FieldDescriptor(
80 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0,
81 | number=1, type=11, cpp_type=10, label=3,
82 | has_default_value=False, default_value=[],
83 | message_type=None, enum_type=None, containing_type=None,
84 | is_extension=False, extension_scope=None,
85 | options=None),
86 | ],
87 | extensions=[
88 | ],
89 | nested_types=[],
90 | enum_types=[
91 | ],
92 | options=None,
93 | is_extendable=False,
94 | extension_ranges=[],
95 | oneofs=[
96 | ],
97 | serialized_start=152,
98 | serialized_end=233,
99 | )
100 |
101 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM
102 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM
103 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP
104 |
105 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict(
106 | DESCRIPTOR = _STRINGINTLABELMAPITEM,
107 | __module__ = 'object_detection.protos.string_int_label_map_pb2'
108 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem)
109 | ))
110 | _sym_db.RegisterMessage(StringIntLabelMapItem)
111 |
112 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict(
113 | DESCRIPTOR = _STRINGINTLABELMAP,
114 | __module__ = 'object_detection.protos.string_int_label_map_pb2'
115 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap)
116 | ))
117 | _sym_db.RegisterMessage(StringIntLabelMap)
118 |
119 |
120 | # @@protoc_insertion_point(module_scope)
121 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | Pillow
3 | imutils
4 | notebook
5 | jupyter
6 | tensorflow
7 | moviepy
8 | autovizwidget
9 | opencv-python
10 |
--------------------------------------------------------------------------------
/runDetection.sh:
--------------------------------------------------------------------------------
1 | xhost + local:docker
2 | docker run --privileged --device=/dev/video0 -v /tmp/.X11-unix:/tmp/.X11-unix -e DISPLAY=unix$DISPLAY face_detection
3 |
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/label_map_util.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/utils/__pycache__/label_map_util.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/visualization_utils_color.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Fszta/Tensorflow-face-detection/2cfc29788439a84f433cbe29067c953044d5ce91/utils/__pycache__/visualization_utils_color.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/label_map_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Label map utility functions."""
17 |
18 | import logging
19 |
20 | import tensorflow as tf
21 | from google.protobuf import text_format
22 | from proto import string_int_label_map_pb2
23 |
24 |
25 | def _validate_label_map(label_map):
26 | """Checks if a label map is valid.
27 |
28 | Args:
29 | label_map: StringIntLabelMap to validate.
30 |
31 | Raises:
32 | ValueError: if label map is invalid.
33 | """
34 | for item in label_map.item:
35 | if item.id < 1:
36 | raise ValueError('Label map ids should be >= 1.')
37 |
38 |
39 | def create_category_index(categories):
40 | """Creates dictionary of COCO compatible categories keyed by category id.
41 |
42 | Args:
43 | categories: a list of dicts, each of which has the following keys:
44 | 'id': (required) an integer id uniquely identifying this category.
45 | 'name': (required) string representing category name
46 | e.g., 'cat', 'dog', 'pizza'.
47 |
48 | Returns:
49 | category_index: a dict containing the same entries as categories, but keyed
50 | by the 'id' field of each category.
51 | """
52 | category_index = {}
53 | for cat in categories:
54 | category_index[cat['id']] = cat
55 | return category_index
56 |
57 |
58 | def convert_label_map_to_categories(label_map,
59 | max_num_classes,
60 | use_display_name=True):
61 | """Loads label map proto and returns categories list compatible with eval.
62 |
63 | This function loads a label map and returns a list of dicts, each of which
64 | has the following keys:
65 | 'id': (required) an integer id uniquely identifying this category.
66 | 'name': (required) string representing category name
67 | e.g., 'cat', 'dog', 'pizza'.
68 | We only allow class into the list if its id-label_id_offset is
69 | between 0 (inclusive) and max_num_classes (exclusive).
70 | If there are several items mapping to the same id in the label map,
71 | we will only keep the first one in the categories list.
72 |
73 | Args:
74 | label_map: a StringIntLabelMapProto or None. If None, a default categories
75 | list is created with max_num_classes categories.
76 | max_num_classes: maximum number of (consecutive) label indices to include.
77 | use_display_name: (boolean) choose whether to load 'display_name' field
78 | as category name. If False or if the display_name field does not exist,
79 | uses 'name' field as category names instead.
80 | Returns:
81 | categories: a list of dictionaries representing all possible categories.
82 | """
83 | categories = []
84 | list_of_ids_already_added = []
85 | if not label_map:
86 | label_id_offset = 1
87 | for class_id in range(max_num_classes):
88 | categories.append({
89 | 'id': class_id + label_id_offset,
90 | 'name': 'category_{}'.format(class_id + label_id_offset)
91 | })
92 | return categories
93 | for item in label_map.item:
94 | if not 0 < item.id <= max_num_classes:
95 | logging.info('Ignore item %d since it falls outside of requested '
96 | 'label range.', item.id)
97 | continue
98 | if use_display_name and item.HasField('display_name'):
99 | name = item.display_name
100 | else:
101 | name = item.name
102 | if item.id not in list_of_ids_already_added:
103 | list_of_ids_already_added.append(item.id)
104 | categories.append({'id': item.id, 'name': name})
105 | return categories
106 |
107 |
108 | def load_labelmap(path):
109 | """Loads label map proto.
110 |
111 | Args:
112 | path: path to StringIntLabelMap proto text file.
113 | Returns:
114 | a StringIntLabelMapProto
115 | """
116 | with tf.gfile.GFile(path, 'r') as fid:
117 | label_map_string = fid.read()
118 | label_map = string_int_label_map_pb2.StringIntLabelMap()
119 | try:
120 | text_format.Merge(label_map_string, label_map)
121 | except text_format.ParseError:
122 | label_map.ParseFromString(label_map_string)
123 | _validate_label_map(label_map)
124 | return label_map
125 |
126 |
127 | def get_label_map_dict(label_map_path):
128 | """Reads a label map and returns a dictionary of label names to id.
129 |
130 | Args:
131 | label_map_path: path to label_map.
132 |
133 | Returns:
134 | A dictionary mapping label names to id.
135 | """
136 | label_map = load_labelmap(label_map_path)
137 | label_map_dict = {}
138 | for item in label_map.item:
139 | label_map_dict[item.name] = item.id
140 | return label_map_dict
141 |
--------------------------------------------------------------------------------
/utils/visualization_utils_color.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """A set of functions that are used for visualization.
17 |
18 | These functions often receive an image, perform some visualization on the image.
19 | The functions do not return a value, instead they modify the image itself.
20 |
21 | """
22 | import collections
23 | import numpy as np
24 | import PIL.Image as Image
25 | import PIL.ImageColor as ImageColor
26 | import PIL.ImageDraw as ImageDraw
27 | import PIL.ImageFont as ImageFont
28 | import six
29 | import tensorflow as tf
30 |
31 |
32 | _TITLE_LEFT_MARGIN = 10
33 | _TITLE_TOP_MARGIN = 10
34 | STANDARD_COLORS = [
35 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
36 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
37 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
38 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
39 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
40 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
41 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
42 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
43 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
44 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
45 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
46 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
47 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
48 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
49 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
50 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
51 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
52 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
53 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
54 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
55 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
56 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
57 | 'WhiteSmoke', 'Yellow', 'YellowGreen'
58 | ]
59 |
60 |
61 | def save_image_array_as_png(image, output_path):
62 | """Saves an image (represented as a numpy array) to PNG.
63 |
64 | Args:
65 | image: a numpy array with shape [height, width, 3].
66 | output_path: path to which image should be written.
67 | """
68 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
69 | with tf.gfile.Open(output_path, 'w') as fid:
70 | image_pil.save(fid, 'PNG')
71 |
72 |
73 | def encode_image_array_as_png_str(image):
74 | """Encodes a numpy array into a PNG string.
75 |
76 | Args:
77 | image: a numpy array with shape [height, width, 3].
78 |
79 | Returns:
80 | PNG encoded image string.
81 | """
82 | image_pil = Image.fromarray(np.uint8(image))
83 | output = six.BytesIO()
84 | image_pil.save(output, format='PNG')
85 | png_string = output.getvalue()
86 | output.close()
87 | return png_string
88 |
89 |
90 | def draw_bounding_box_on_image_array(image,
91 | ymin,
92 | xmin,
93 | ymax,
94 | xmax,
95 | color='red',
96 | thickness=4,
97 | display_str_list=(),
98 | use_normalized_coordinates=True):
99 | """Adds a bounding box to an image (numpy array).
100 |
101 | Args:
102 | image: a numpy array with shape [height, width, 3].
103 | ymin: ymin of bounding box in normalized coordinates (same below).
104 | xmin: xmin of bounding box.
105 | ymax: ymax of bounding box.
106 | xmax: xmax of bounding box.
107 | color: color to draw bounding box. Default is red.
108 | thickness: line thickness. Default value is 4.
109 | display_str_list: list of strings to display in box
110 | (each to be shown on its own line).
111 | use_normalized_coordinates: If True (default), treat coordinates
112 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
113 | coordinates as absolute.
114 | """
115 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
116 | draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
117 | thickness, display_str_list,
118 | use_normalized_coordinates)
119 | np.copyto(image, np.array(image_pil))
120 |
121 |
122 | def draw_bounding_box_on_image(image,
123 | ymin,
124 | xmin,
125 | ymax,
126 | xmax,
127 | color='red',
128 | thickness=4,
129 | display_str_list=(),
130 | use_normalized_coordinates=True):
131 | """Adds a bounding box to an image.
132 |
133 | Each string in display_str_list is displayed on a separate line above the
134 | bounding box in black text on a rectangle filled with the input 'color'.
135 |
136 | Args:
137 | image: a PIL.Image object.
138 | ymin: ymin of bounding box.
139 | xmin: xmin of bounding box.
140 | ymax: ymax of bounding box.
141 | xmax: xmax of bounding box.
142 | color: color to draw bounding box. Default is red.
143 | thickness: line thickness. Default value is 4.
144 | display_str_list: list of strings to display in box
145 | (each to be shown on its own line).
146 | use_normalized_coordinates: If True (default), treat coordinates
147 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
148 | coordinates as absolute.
149 | """
150 | draw = ImageDraw.Draw(image)
151 | im_width, im_height = image.size
152 | if use_normalized_coordinates:
153 | (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
154 | ymin * im_height, ymax * im_height)
155 | else:
156 | (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
157 | draw.line([(left, top), (left, bottom), (right, bottom),
158 | (right, top), (left, top)], width=thickness, fill=color)
159 | try:
160 | font = ImageFont.truetype('arial.ttf', 24)
161 | except IOError:
162 | font = ImageFont.load_default()
163 |
164 | text_bottom = top
165 | # Reverse list and print from bottom to top.
166 | for display_str in display_str_list[::-1]:
167 | text_width, text_height = font.getsize(display_str)
168 | margin = np.ceil(0.05 * text_height)
169 | draw.rectangle(
170 | [(left, text_bottom - text_height - 2 * margin), (left + text_width,
171 | text_bottom)],
172 | fill=color)
173 | draw.text(
174 | (left + margin, text_bottom - text_height - margin),
175 | display_str,
176 | fill='black',
177 | font=font)
178 | text_bottom -= text_height - 2 * margin
179 |
180 |
181 | def draw_bounding_boxes_on_image_array(image,
182 | boxes,
183 | color='red',
184 | thickness=4,
185 | display_str_list_list=()):
186 | """Draws bounding boxes on image (numpy array).
187 |
188 | Args:
189 | image: a numpy array object.
190 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
191 | The coordinates are in normalized format between [0, 1].
192 | color: color to draw bounding box. Default is red.
193 | thickness: line thickness. Default value is 4.
194 | display_str_list_list: list of list of strings.
195 | a list of strings for each bounding box.
196 | The reason to pass a list of strings for a
197 | bounding box is that it might contain
198 | multiple labels.
199 |
200 | Raises:
201 | ValueError: if boxes is not a [N, 4] array
202 | """
203 | image_pil = Image.fromarray(image)
204 | draw_bounding_boxes_on_image(image_pil, boxes, color, thickness,
205 | display_str_list_list)
206 | np.copyto(image, np.array(image_pil))
207 |
208 |
209 | def draw_bounding_boxes_on_image(image,
210 | boxes,
211 | color='red',
212 | thickness=4,
213 | display_str_list_list=()):
214 | """Draws bounding boxes on image.
215 |
216 | Args:
217 | image: a PIL.Image object.
218 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
219 | The coordinates are in normalized format between [0, 1].
220 | color: color to draw bounding box. Default is red.
221 | thickness: line thickness. Default value is 4.
222 | display_str_list_list: list of list of strings.
223 | a list of strings for each bounding box.
224 | The reason to pass a list of strings for a
225 | bounding box is that it might contain
226 | multiple labels.
227 |
228 | Raises:
229 | ValueError: if boxes is not a [N, 4] array
230 | """
231 | boxes_shape = boxes.shape
232 | if not boxes_shape:
233 | return
234 | if len(boxes_shape) != 2 or boxes_shape[1] != 4:
235 | raise ValueError('Input must be of size [N, 4]')
236 | for i in range(boxes_shape[0]):
237 | display_str_list = ()
238 | if display_str_list_list:
239 | display_str_list = display_str_list_list[i]
240 | draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2],
241 | boxes[i, 3], color, thickness, display_str_list)
242 |
243 |
244 | def draw_keypoints_on_image_array(image,
245 | keypoints,
246 | color='red',
247 | radius=2,
248 | use_normalized_coordinates=True):
249 | """Draws keypoints on an image (numpy array).
250 |
251 | Args:
252 | image: a numpy array with shape [height, width, 3].
253 | keypoints: a numpy array with shape [num_keypoints, 2].
254 | color: color to draw the keypoints with. Default is red.
255 | radius: keypoint radius. Default value is 2.
256 | use_normalized_coordinates: if True (default), treat keypoint values as
257 | relative to the image. Otherwise treat them as absolute.
258 | """
259 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
260 | draw_keypoints_on_image(image_pil, keypoints, color, radius,
261 | use_normalized_coordinates)
262 | np.copyto(image, np.array(image_pil))
263 |
264 |
265 | def draw_keypoints_on_image(image,
266 | keypoints,
267 | color='red',
268 | radius=2,
269 | use_normalized_coordinates=True):
270 | """Draws keypoints on an image.
271 |
272 | Args:
273 | image: a PIL.Image object.
274 | keypoints: a numpy array with shape [num_keypoints, 2].
275 | color: color to draw the keypoints with. Default is red.
276 | radius: keypoint radius. Default value is 2.
277 | use_normalized_coordinates: if True (default), treat keypoint values as
278 | relative to the image. Otherwise treat them as absolute.
279 | """
280 | draw = ImageDraw.Draw(image)
281 | im_width, im_height = image.size
282 | keypoints_x = [k[1] for k in keypoints]
283 | keypoints_y = [k[0] for k in keypoints]
284 | if use_normalized_coordinates:
285 | keypoints_x = tuple([im_width * x for x in keypoints_x])
286 | keypoints_y = tuple([im_height * y for y in keypoints_y])
287 | for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y):
288 | draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
289 | (keypoint_x + radius, keypoint_y + radius)],
290 | outline=color, fill=color)
291 |
292 |
293 | def draw_mask_on_image_array(image, mask, color='red', alpha=0.7):
294 | """Draws mask on an image.
295 |
296 | Args:
297 | image: uint8 numpy array with shape (img_height, img_height, 3)
298 | mask: a float numpy array of shape (img_height, img_height) with
299 | values between 0 and 1
300 | color: color to draw the keypoints with. Default is red.
301 | alpha: transparency value between 0 and 1. (default: 0.7)
302 |
303 | Raises:
304 | ValueError: On incorrect data type for image or masks.
305 | """
306 | if image.dtype != np.uint8:
307 | raise ValueError('`image` not of type np.uint8')
308 | if mask.dtype != np.float32:
309 | raise ValueError('`mask` not of type np.float32')
310 | if np.any(np.logical_or(mask > 1.0, mask < 0.0)):
311 | raise ValueError('`mask` elements should be in [0, 1]')
312 | rgb = ImageColor.getrgb(color)
313 | pil_image = Image.fromarray(image)
314 |
315 | solid_color = np.expand_dims(
316 | np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
317 | pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
318 | pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L')
319 | pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
320 | np.copyto(image, np.array(pil_image.convert('RGB')))
321 |
322 |
323 | def visualize_boxes_and_labels_on_image_array(image,
324 | boxes,
325 | classes,
326 | scores,
327 | category_index,
328 | instance_masks=None,
329 | keypoints=None,
330 | use_normalized_coordinates=False,
331 | max_boxes_to_draw=20,
332 | min_score_thresh=.7,
333 | agnostic_mode=False,
334 | line_thickness=4):
335 | """Overlay labeled boxes on an image with formatted scores and label names.
336 |
337 | This function groups boxes that correspond to the same location
338 | and creates a display string for each detection and overlays these
339 | on the image. Note that this function modifies the image array in-place
340 | and does not return anything.
341 |
342 | Args:
343 | image: uint8 numpy array with shape (img_height, img_width, 3)
344 | boxes: a numpy array of shape [N, 4]
345 | classes: a numpy array of shape [N]
346 | scores: a numpy array of shape [N] or None. If scores=None, then
347 | this function assumes that the boxes to be plotted are groundtruth
348 | boxes and plot all boxes as black with no classes or scores.
349 | category_index: a dict containing category dictionaries (each holding
350 | category index `id` and category name `name`) keyed by category indices.
351 | instance_masks: a numpy array of shape [N, image_height, image_width], can
352 | be None
353 | keypoints: a numpy array of shape [N, num_keypoints, 2], can
354 | be None
355 | use_normalized_coordinates: whether boxes is to be interpreted as
356 | normalized coordinates or not.
357 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw
358 | all boxes.
359 | min_score_thresh: minimum score threshold for a box to be visualized
360 | agnostic_mode: boolean (default: False) controlling whether to evaluate in
361 | class-agnostic mode or not. This mode will display scores but ignore
362 | classes.
363 | line_thickness: integer (default: 4) controlling line width of the boxes.
364 | """
365 | # Create a display string (and color) for every box location, group any boxes
366 | # that correspond to the same location.
367 | box_to_display_str_map = collections.defaultdict(list)
368 | box_to_color_map = collections.defaultdict(str)
369 | box_to_instance_masks_map = {}
370 | box_to_keypoints_map = collections.defaultdict(list)
371 | if not max_boxes_to_draw:
372 | max_boxes_to_draw = boxes.shape[0]
373 | for i in range(min(max_boxes_to_draw, boxes.shape[0])):
374 | if scores is None or scores[i] > min_score_thresh:
375 | box = tuple(boxes[i].tolist())
376 | if instance_masks is not None:
377 | box_to_instance_masks_map[box] = instance_masks[i]
378 | if keypoints is not None:
379 | box_to_keypoints_map[box].extend(keypoints[i])
380 | if scores is None:
381 | box_to_color_map[box] = 'black'
382 | else:
383 | if not agnostic_mode:
384 | if classes[i] in category_index.keys():
385 | class_name = category_index[classes[i]]['name']
386 | else:
387 | class_name = 'N/A'
388 | display_str = '{}: {}%'.format(
389 | class_name,
390 | int(100*scores[i]))
391 | else:
392 | display_str = 'score: {}%'.format(int(100 * scores[i]))
393 | box_to_display_str_map[box].append(display_str)
394 | if agnostic_mode:
395 | box_to_color_map[box] = 'DarkOrange'
396 | else:
397 | box_to_color_map[box] = STANDARD_COLORS[
398 | classes[i] % len(STANDARD_COLORS)]
399 |
400 | # Draw all boxes onto image.
401 | for box, color in box_to_color_map.items():
402 | color = 'Violet'
403 | ymin, xmin, ymax, xmax = box
404 | if instance_masks is not None:
405 | draw_mask_on_image_array(
406 | image,
407 | box_to_instance_masks_map[box],
408 | color=color
409 | )
410 | draw_bounding_box_on_image_array(
411 | image,
412 | ymin,
413 | xmin,
414 | ymax,
415 | xmax,
416 | color=color,
417 | thickness=line_thickness,
418 | display_str_list=box_to_display_str_map[box],
419 | use_normalized_coordinates=use_normalized_coordinates)
420 | if keypoints is not None:
421 | draw_keypoints_on_image_array(
422 | image,
423 | box_to_keypoints_map[box],
424 | color=color,
425 | radius=line_thickness / 2,
426 | use_normalized_coordinates=use_normalized_coordinates)
427 |
--------------------------------------------------------------------------------