├── LICENSE ├── NOTICE ├── README.md ├── backend ├── .dockerignore ├── .gitignore ├── Dockerfile ├── config.py ├── docker-compose.yml ├── main.py ├── notification │ ├── __init__.py │ ├── console.py │ ├── factory.py │ ├── provider.py │ └── spark.py └── requirements.txt ├── data ├── alert.jpg └── ppe.jpg └── inference ├── config.py ├── demo_video.mp4 ├── model └── frozen_inference_graph.pb ├── requirements.txt ├── shape_utils.py ├── standard_fields.py ├── static_shape.py ├── video_demo.py └── visualization_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | ppe-detection 2 | 3 | Copyright (c) 2019 Cisco Systems, Inc. and/or its affiliates 4 | 5 | This project includes software developed at Cisco Systems, Inc. and/or its affiliates. 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Personal Protection Equipment Detection based on Deep Learning 2 | 3 | Real time Personal Protection Equipment(PPE) detection running on NVIDIA Jetson TX2 and Ubuntu 16.04 4 | 5 | - Person, HardHat and Vest detection 6 | - Input from Video file or USB Camera 7 | - A backend service which can push message to "console" or "Cisco® Webex Teams space" when an abnormal event is detected. 8 | 9 | ![PPE Image](data/ppe.jpg) 10 | 11 | # Requirements 12 | - NVIDIA Jetson TX2 or Ubuntu 16.04 13 | - NVIDIA GPU on Ubuntu 16.04 is optional 14 | - Python3 15 | 16 | # How to run 17 | 18 | ## Video Inference Service 19 | 20 | ```sh 21 | $ cd inference 22 | $ pip3 install -r requirements.txt 23 | $ python3 video_demo.py --model_dir=xxx --video_file_name=xxx --show_video_window=xxx --camera_id=xxx 24 | ``` 25 | * model_dir: the path to model directory 26 | * video_file_name: input video file name or usb camera device name, you can get camera device name on ubuntu or NVIDIA Jeston by running 27 | ```sh 28 | $ ls /dev/video* 29 | ``` 30 | * show_video_window: the flag to show video window, the options are {0, 1} 31 | * camera_id: It is just convenient for humans to distinguish between different cameras, and you can assign any value, such as camera001 32 | 33 | ## Backend Service 34 | run the following command 35 | ``` 36 | $ cd backend 37 | $ pip3 install -r requirements.txt 38 | $ python3 main.py 39 | ``` 40 | 41 | run application as docker 42 | ``` 43 | docker-compose up 44 | or 45 | docker-compose up --build 46 | ``` 47 | 48 | send notification 49 | 50 | By default, it will use the console notification, this just print the notification to stdout. 51 | If you want to use Cisco® Webex Teams, use change the config referring to `config.py`. 52 | Or you can write your own if you write your provider inheriting the `notification.Provider` 53 | 54 | setup Cisco® Webex Teams 55 | 56 | * create a robot referring to https://developer.cisco.com/webex-teams/, you will get the token 57 | * create a webex-teams room and add the robot to that team 58 | * go to https://developer.webex.com/docs/api/v1/rooms/list-rooms to get the new created room id 59 | * put the above info to the `config.py` 60 | 61 | Alert Message Format 62 | 63 | ![PPE Image](data/alert.jpg) 64 | 65 | * total_person: number of people detected 66 | * without_hardhat: number of people without hard hat 67 | * without_vest: number of people without Vest 68 | * without_both: number of people without hard hat and vest 69 | 70 | # Training Program 71 | Based on TensorFlow Object Detection API, using pretrained ssd_mobilenet_v1 on COCO dataset to initialize weights. 72 | 73 | # Training Data 74 | coming soon! 75 | 76 | # Reference work 77 | * TensorFlow Object Detection: https://github.com/tensorflow/models/tree/master/research/object_detection 78 | -------------------------------------------------------------------------------- /backend/.dockerignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | npm-debug.log 3 | Dockerfile* 4 | docker-compose* 5 | .dockerignore 6 | .git 7 | .gitignore 8 | .env 9 | */bin 10 | */obj 11 | README.md 12 | LICENSE 13 | .vscode 14 | __pycache__ 15 | -------------------------------------------------------------------------------- /backend/.gitignore: -------------------------------------------------------------------------------- 1 | vendor 2 | debug 3 | node_modules 4 | __pycache__ 5 | 6 | *.swp 7 | *.tar 8 | release 9 | .vscode 10 | .images 11 | images 12 | tmp 13 | tmp* 14 | -------------------------------------------------------------------------------- /backend/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:alpine 2 | 3 | EXPOSE 8080 4 | 5 | WORKDIR /app 6 | ADD . /app 7 | 8 | # Using pip: 9 | RUN python3 -m pip install -r requirements.txt 10 | CMD ["python3", "main.py"] 11 | -------------------------------------------------------------------------------- /backend/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | config = { 4 | "port": int(os.getenv("PORT", 8080)), 5 | "provider": os.getenv("PROVIDER", "spark"), # spark, console 6 | "spark": { 7 | "url": os.getenv("SPARK_URL", "https://api.ciscospark.com/v1"), 8 | "token": os.getenv("SPARK_TOKEN", "your token"), 9 | "room_id": os.getenv("SPARK_ROOM_ID", "your room id"), 10 | }, 11 | "console": { 12 | "room_id": "test", 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /backend/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.1' 2 | 3 | services: 4 | ppe-backend: 5 | image: ppe-backend 6 | build: . 7 | environment: 8 | #- PROVIDER=spark 9 | - PROVIDER=console 10 | - PYTHONUNBUFFERED=1 11 | ports: 12 | - 8080:8080 13 | -------------------------------------------------------------------------------- /backend/main.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from flask import Flask, jsonify, request 4 | 5 | from config import config 6 | from notification.factory import NotificationFactory 7 | 8 | 9 | notification_svc = NotificationFactory.instance(config) 10 | 11 | 12 | MSG_TEMPLATE = """ 13 | ### ppe demo 14 | **Alert** at **{point}** at {time} 15 | > total_person={total_person} without_hardhat={without_hardhat} without_vest={without_vest} without_both={without_both} 16 | """ 17 | 18 | 19 | def _construct_msg(ts, point, total_person, without_hardhat, without_vest, without_both): 20 | t = datetime.utcfromtimestamp(ts / 1000).strftime("%Y-%m-%d %H:%M:%S UTC") 21 | return MSG_TEMPLATE.format( 22 | time=t, point=point, total_person=total_person, without_hardhat=without_hardhat, without_vest=without_vest, without_both=without_both) 23 | 24 | 25 | class HttpError(Exception): 26 | status_code = 400 27 | 28 | def __init__(self, message, status_code=None, payload=None): 29 | Exception.__init__(self) 30 | self.message = message 31 | if status_code is not None: 32 | self.status_code = status_code 33 | self.payload = payload 34 | 35 | def to_dict(self): 36 | rv = dict(self.payload or ()) 37 | rv['message'] = self.message 38 | return rv 39 | 40 | 41 | app = Flask(__name__) 42 | 43 | 44 | @app.errorhandler(HttpError) 45 | def handle_http_error(error): 46 | response = jsonify(error.to_dict()) 47 | response.status_code = error.status_code 48 | print("[Error]:", error.message) 49 | return response 50 | 51 | 52 | @app.route("/") 53 | def home(): 54 | return "ok" 55 | 56 | 57 | # { 58 | # "id": "331404be-7c57-11e9-a345-dca90488d3b9", 59 | # "cameraId": "camera1", 60 | # "timestamp": 1558506692, 61 | # "persons": [ 62 | # { 63 | # "hardhat": true, 64 | # "vest": true 65 | # }, 66 | # { 67 | # "hardhat": true, 68 | # "vest": true 69 | # } 70 | # ], 71 | # "image": { 72 | # "height": 200, 73 | # "width": 300, 74 | # "format": "jpeg", 75 | # "raw": "base64 encoded data", 76 | # "url": "http://ppe-backend:7200/images/uuid1" 77 | # }, 78 | # "createdAt": 1558506697000, 79 | # "updatedAt": 1558506697000 80 | # } 81 | @app.route("/v1/detections", methods=["POST"]) 82 | def create_detections_v1(): 83 | js = request.json 84 | js["image"]["raw"] = "omited" 85 | cameraId = js.get("cameraId") 86 | if cameraId is None: 87 | print("json field missing") 88 | raise HttpError("cameraId missing", status_code=400) 89 | print("[Info] recieved:", js["cameraId"], js["timestamp"]) 90 | 91 | without_hardhat = len(list(filter(lambda p: not p["hardhat"], js["persons"]))) 92 | without_vest = len(list(filter(lambda p: not p["vest"], js["persons"]))) 93 | without_both = len(list(filter(lambda p: not p["vest"] and not p["hardhat"], js["persons"]))) 94 | if without_hardhat > 0 or without_vest > 0 or without_both > 0: 95 | print("[Warn]", "someone violate the rule") 96 | msg = _construct_msg(js["timestamp"], js["cameraId"], len(js["persons"]), without_hardhat - without_both, without_vest - without_both, without_both) 97 | notification_svc.send(msg) 98 | else: 99 | print("[Info]", "no one violate the rule") 100 | 101 | return jsonify(request.json), 201 102 | 103 | 104 | if __name__ == "__main__": 105 | app.run(host="0.0.0.0", port=config["port"]) 106 | -------------------------------------------------------------------------------- /backend/notification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CiscoDevNet/ppe-detection/0312466ed731f5f20e0771180a2d49065aea8b15/backend/notification/__init__.py -------------------------------------------------------------------------------- /backend/notification/console.py: -------------------------------------------------------------------------------- 1 | from notification.provider import Provider 2 | 3 | 4 | class Console(Provider): 5 | def __init__(self, config): 6 | self.room_id = config["room_id"] 7 | 8 | def send(self, msg): 9 | print("sending msg to", self.room_id) 10 | print(msg) 11 | -------------------------------------------------------------------------------- /backend/notification/factory.py: -------------------------------------------------------------------------------- 1 | from notification.console import Console 2 | from notification.provider import Noop 3 | from notification.spark import Spark 4 | 5 | 6 | class NotificationFactory: 7 | @staticmethod 8 | def instance(config): 9 | if config["provider"] == "spark": 10 | return Spark(config["spark"]) 11 | elif config["provider"] == "console": 12 | return Console(config["console"]) 13 | return Noop() 14 | -------------------------------------------------------------------------------- /backend/notification/provider.py: -------------------------------------------------------------------------------- 1 | class Provider: 2 | def send(self, msg): 3 | pass 4 | 5 | 6 | class Noop(Provider): 7 | pass 8 | -------------------------------------------------------------------------------- /backend/notification/spark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | from notification.provider import Provider 4 | 5 | 6 | class Spark(Provider): 7 | def __init__(self, config): 8 | self.token = config["token"] 9 | self.room_id = config["room_id"] 10 | self.url = config["url"] 11 | self.msg_url = "{}/messages".format(self.url) 12 | self.headers = { 13 | "Accept": "application/json", 14 | "Content-Type": "application/json", 15 | "Authorization": "Bearer {}".format(self.token) 16 | } 17 | 18 | def send(self, msg): 19 | body = { 20 | "roomId": self.room_id, 21 | "markdown": msg, 22 | } 23 | 24 | r = requests.post(self.msg_url, json.dumps(body), headers=self.headers) 25 | if r.status_code != 200: 26 | print("[Error]", "send spark message failed with code={}, resp={}".format(r.status_code, r.content)) 27 | return False 28 | 29 | print("[Debug]", msg) 30 | print("[Info]", "send spark message successfully") 31 | return True 32 | -------------------------------------------------------------------------------- /backend/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==1.0.2 2 | requests==2.18.4 3 | -------------------------------------------------------------------------------- /data/alert.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CiscoDevNet/ppe-detection/0312466ed731f5f20e0771180a2d49065aea8b15/data/alert.jpg -------------------------------------------------------------------------------- /data/ppe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CiscoDevNet/ppe-detection/0312466ed731f5f20e0771180a2d49065aea8b15/data/ppe.jpg -------------------------------------------------------------------------------- /inference/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration file 3 | 4 | Author: zmingen 5 | """ 6 | import os 7 | 8 | # Send detection result to this URL 9 | detection_api = os.getenv("PPE_DETECTION_URL", "http://localhost:8080/v1/detections") 10 | 11 | # Message sending interval, time unit is millisecond 12 | message_send_interval = int(os.getenv("PPE_MESSAGE_SEND_INTERVAL", 1000)) 13 | 14 | # Object confidence threshold, if the object confidence smaller than threshold, it will be filtered out 15 | object_confidence_threshold = float(os.getenv("PPE_OBJECT_CONFIDENCE_THRESHOLD", .5)) 16 | 17 | # Capture Image Size, Allowed Resolution:(640, 480), (1280, 720), (1920, 1080) 18 | supported_video_resolution = [(640, 480), (1280, 720), (1920, 1080)] 19 | capture_image_width = int(os.getenv("PPE_CAPTURE_IMAGE_WIDTH", 1280)) 20 | capture_image_height = int(os.getenv("PPE_CAPTURE_IMAGE_HEIGHT", 720)) 21 | 22 | # Display Window Size 23 | display_full_screen = os.getenv("PPE_DISPLAY_FULL_SCREEN", "True").lower() == "true" 24 | display_window_width = int(os.getenv("PPE_DISPLAY_WINDOW_WIDTH", 1280)) 25 | display_window_height = int(os.getenv("PPE_DISPLAY_WINDOW_HEIGHT", 720)) 26 | 27 | # Image Size of Storage 28 | storage_image_width = int(os.getenv("PPE_STORAGE_IMAGE_WIDTH", 640)) 29 | storage_image_height = int(os.getenv("PPE_STORAGE_IMAGE_HEIGHT", 360)) 30 | 31 | # Input Type, ["camera", "file"] 32 | input_type = os.getenv("PPE_INPUT_TYPE", "file") 33 | 34 | 35 | -------------------------------------------------------------------------------- /inference/demo_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CiscoDevNet/ppe-detection/0312466ed731f5f20e0771180a2d49065aea8b15/inference/demo_video.mp4 -------------------------------------------------------------------------------- /inference/model/frozen_inference_graph.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CiscoDevNet/ppe-detection/0312466ed731f5f20e0771180a2d49065aea8b15/inference/model/frozen_inference_graph.pb -------------------------------------------------------------------------------- /inference/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.5 2 | tensorflow-gpu>1.2.0 3 | opencv-python 4 | pillow 5 | requests 6 | matplotlib 7 | -------------------------------------------------------------------------------- /inference/shape_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils used to manipulate tensor shapes.""" 17 | 18 | import tensorflow as tf 19 | 20 | import static_shape 21 | 22 | 23 | def _is_tensor(t): 24 | """Returns a boolean indicating whether the input is a tensor. 25 | 26 | Args: 27 | t: the input to be tested. 28 | 29 | Returns: 30 | a boolean that indicates whether t is a tensor. 31 | """ 32 | return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable)) 33 | 34 | 35 | def _set_dim_0(t, d0): 36 | """Sets the 0-th dimension of the input tensor. 37 | 38 | Args: 39 | t: the input tensor, assuming the rank is at least 1. 40 | d0: an integer indicating the 0-th dimension of the input tensor. 41 | 42 | Returns: 43 | the tensor t with the 0-th dimension set. 44 | """ 45 | t_shape = t.get_shape().as_list() 46 | t_shape[0] = d0 47 | t.set_shape(t_shape) 48 | return t 49 | 50 | 51 | def pad_tensor(t, length): 52 | """Pads the input tensor with 0s along the first dimension up to the length. 53 | 54 | Args: 55 | t: the input tensor, assuming the rank is at least 1. 56 | length: a tensor of shape [1] or an integer, indicating the first dimension 57 | of the input tensor t after padding, assuming length <= t.shape[0]. 58 | 59 | Returns: 60 | padded_t: the padded tensor, whose first dimension is length. If the length 61 | is an integer, the first dimension of padded_t is set to length 62 | statically. 63 | """ 64 | t_rank = tf.rank(t) 65 | t_shape = tf.shape(t) 66 | t_d0 = t_shape[0] 67 | pad_d0 = tf.expand_dims(length - t_d0, 0) 68 | pad_shape = tf.cond( 69 | tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), 70 | lambda: tf.expand_dims(length - t_d0, 0)) 71 | padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0) 72 | if not _is_tensor(length): 73 | padded_t = _set_dim_0(padded_t, length) 74 | return padded_t 75 | 76 | 77 | def clip_tensor(t, length): 78 | """Clips the input tensor along the first dimension up to the length. 79 | 80 | Args: 81 | t: the input tensor, assuming the rank is at least 1. 82 | length: a tensor of shape [1] or an integer, indicating the first dimension 83 | of the input tensor t after clipping, assuming length <= t.shape[0]. 84 | 85 | Returns: 86 | clipped_t: the clipped tensor, whose first dimension is length. If the 87 | length is an integer, the first dimension of clipped_t is set to length 88 | statically. 89 | """ 90 | clipped_t = tf.gather(t, tf.range(length)) 91 | if not _is_tensor(length): 92 | clipped_t = _set_dim_0(clipped_t, length) 93 | return clipped_t 94 | 95 | 96 | def pad_or_clip_tensor(t, length): 97 | """Pad or clip the input tensor along the first dimension. 98 | 99 | Args: 100 | t: the input tensor, assuming the rank is at least 1. 101 | length: a tensor of shape [1] or an integer, indicating the first dimension 102 | of the input tensor t after processing. 103 | 104 | Returns: 105 | processed_t: the processed tensor, whose first dimension is length. If the 106 | length is an integer, the first dimension of the processed tensor is set 107 | to length statically. 108 | """ 109 | return pad_or_clip_nd(t, [length] + t.shape.as_list()[1:]) 110 | 111 | 112 | def pad_or_clip_nd(tensor, output_shape): 113 | """Pad or Clip given tensor to the output shape. 114 | 115 | Args: 116 | tensor: Input tensor to pad or clip. 117 | output_shape: A list of integers / scalar tensors (or None for dynamic dim) 118 | representing the size to pad or clip each dimension of the input tensor. 119 | 120 | Returns: 121 | Input tensor padded and clipped to the output shape. 122 | """ 123 | tensor_shape = tf.shape(tensor) 124 | clip_size = [ 125 | tf.where(tensor_shape[i] - shape > 0, shape, -1) 126 | if shape is not None else -1 for i, shape in enumerate(output_shape) 127 | ] 128 | clipped_tensor = tf.slice( 129 | tensor, 130 | begin=tf.zeros(len(clip_size), dtype=tf.int32), 131 | size=clip_size) 132 | 133 | # Pad tensor if the shape of clipped tensor is smaller than the expected 134 | # shape. 135 | clipped_tensor_shape = tf.shape(clipped_tensor) 136 | trailing_paddings = [ 137 | shape - clipped_tensor_shape[i] if shape is not None else 0 138 | for i, shape in enumerate(output_shape) 139 | ] 140 | paddings = tf.stack( 141 | [ 142 | tf.zeros(len(trailing_paddings), dtype=tf.int32), 143 | trailing_paddings 144 | ], 145 | axis=1) 146 | padded_tensor = tf.pad(clipped_tensor, paddings=paddings) 147 | output_static_shape = [ 148 | dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape 149 | ] 150 | padded_tensor.set_shape(output_static_shape) 151 | return padded_tensor 152 | 153 | 154 | def combined_static_and_dynamic_shape(tensor): 155 | """Returns a list containing static and dynamic values for the dimensions. 156 | 157 | Returns a list of static and dynamic values for shape dimensions. This is 158 | useful to preserve static shapes when available in reshape operation. 159 | 160 | Args: 161 | tensor: A tensor of any type. 162 | 163 | Returns: 164 | A list of size tensor.shape.ndims containing integers or a scalar tensor. 165 | """ 166 | static_tensor_shape = tensor.shape.as_list() 167 | dynamic_tensor_shape = tf.shape(tensor) 168 | combined_shape = [] 169 | for index, dim in enumerate(static_tensor_shape): 170 | if dim is not None: 171 | combined_shape.append(dim) 172 | else: 173 | combined_shape.append(dynamic_tensor_shape[index]) 174 | return combined_shape 175 | 176 | 177 | def static_or_dynamic_map_fn(fn, elems, dtype=None, 178 | parallel_iterations=32, back_prop=True): 179 | """Runs map_fn as a (static) for loop when possible. 180 | 181 | This function rewrites the map_fn as an explicit unstack input -> for loop 182 | over function calls -> stack result combination. This allows our graphs to 183 | be acyclic when the batch size is static. 184 | For comparison, see https://www.tensorflow.org/api_docs/python/tf/map_fn. 185 | 186 | Note that `static_or_dynamic_map_fn` currently is not *fully* interchangeable 187 | with the default tf.map_fn function as it does not accept nested inputs (only 188 | Tensors or lists of Tensors). Likewise, the output of `fn` can only be a 189 | Tensor or list of Tensors. 190 | 191 | TODO(jonathanhuang): make this function fully interchangeable with tf.map_fn. 192 | 193 | Args: 194 | fn: The callable to be performed. It accepts one argument, which will have 195 | the same structure as elems. Its output must have the 196 | same structure as elems. 197 | elems: A tensor or list of tensors, each of which will 198 | be unpacked along their first dimension. The sequence of the 199 | resulting slices will be applied to fn. 200 | dtype: (optional) The output type(s) of fn. If fn returns a structure of 201 | Tensors differing from the structure of elems, then dtype is not optional 202 | and must have the same structure as the output of fn. 203 | parallel_iterations: (optional) number of batch items to process in 204 | parallel. This flag is only used if the native tf.map_fn is used 205 | and defaults to 32 instead of 10 (unlike the standard tf.map_fn default). 206 | back_prop: (optional) True enables support for back propagation. 207 | This flag is only used if the native tf.map_fn is used. 208 | 209 | Returns: 210 | A tensor or sequence of tensors. Each tensor packs the 211 | results of applying fn to tensors unpacked from elems along the first 212 | dimension, from first to last. 213 | Raises: 214 | ValueError: if `elems` a Tensor or a list of Tensors. 215 | ValueError: if `fn` does not return a Tensor or list of Tensors 216 | """ 217 | if isinstance(elems, list): 218 | for elem in elems: 219 | if not isinstance(elem, tf.Tensor): 220 | raise ValueError('`elems` must be a Tensor or list of Tensors.') 221 | 222 | elem_shapes = [elem.shape.as_list() for elem in elems] 223 | # Fall back on tf.map_fn if shapes of each entry of `elems` are None or fail 224 | # to all be the same size along the batch dimension. 225 | for elem_shape in elem_shapes: 226 | if (not elem_shape or not elem_shape[0] 227 | or elem_shape[0] != elem_shapes[0][0]): 228 | return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) 229 | arg_tuples = zip(*[tf.unstack(elem) for elem in elems]) 230 | outputs = [fn(arg_tuple) for arg_tuple in arg_tuples] 231 | else: 232 | if not isinstance(elems, tf.Tensor): 233 | raise ValueError('`elems` must be a Tensor or list of Tensors.') 234 | elems_shape = elems.shape.as_list() 235 | if not elems_shape or not elems_shape[0]: 236 | return tf.map_fn(fn, elems, dtype, parallel_iterations, back_prop) 237 | outputs = [fn(arg) for arg in tf.unstack(elems)] 238 | # Stack `outputs`, which is a list of Tensors or list of lists of Tensors 239 | if all([isinstance(output, tf.Tensor) for output in outputs]): 240 | return tf.stack(outputs) 241 | else: 242 | if all([isinstance(output, list) for output in outputs]): 243 | if all([all( 244 | [isinstance(entry, tf.Tensor) for entry in output_list]) 245 | for output_list in outputs]): 246 | return [tf.stack(output_tuple) for output_tuple in zip(*outputs)] 247 | raise ValueError('`fn` should return a Tensor or a list of Tensors.') 248 | 249 | 250 | def check_min_image_dim(min_dim, image_tensor): 251 | """Checks that the image width/height are greater than some number. 252 | 253 | This function is used to check that the width and height of an image are above 254 | a certain value. If the image shape is static, this function will perform the 255 | check at graph construction time. Otherwise, if the image shape varies, an 256 | Assertion control dependency will be added to the graph. 257 | 258 | Args: 259 | min_dim: The minimum number of pixels along the width and height of the 260 | image. 261 | image_tensor: The image tensor to check size for. 262 | 263 | Returns: 264 | If `image_tensor` has dynamic size, return `image_tensor` with a Assert 265 | control dependency. Otherwise returns image_tensor. 266 | 267 | Raises: 268 | ValueError: if `image_tensor`'s' width or height is smaller than `min_dim`. 269 | """ 270 | image_shape = image_tensor.get_shape() 271 | image_height = static_shape.get_height(image_shape) 272 | image_width = static_shape.get_width(image_shape) 273 | if image_height is None or image_width is None: 274 | shape_assert = tf.Assert( 275 | tf.logical_and(tf.greater_equal(tf.shape(image_tensor)[1], min_dim), 276 | tf.greater_equal(tf.shape(image_tensor)[2], min_dim)), 277 | ['image size must be >= {} in both height and width.'.format(min_dim)]) 278 | with tf.control_dependencies([shape_assert]): 279 | return tf.identity(image_tensor) 280 | 281 | if image_height < min_dim or image_width < min_dim: 282 | raise ValueError( 283 | 'image size must be >= %d in both height and width; image dim = %d,%d' % 284 | (min_dim, image_height, image_width)) 285 | 286 | return image_tensor 287 | 288 | 289 | def assert_shape_equal(shape_a, shape_b): 290 | """Asserts that shape_a and shape_b are equal. 291 | 292 | If the shapes are static, raises a ValueError when the shapes 293 | mismatch. 294 | 295 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes 296 | mismatch. 297 | 298 | Args: 299 | shape_a: a list containing shape of the first tensor. 300 | shape_b: a list containing shape of the second tensor. 301 | 302 | Returns: 303 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op 304 | when the shapes are dynamic. 305 | 306 | Raises: 307 | ValueError: When shapes are both static and unequal. 308 | """ 309 | if (all(isinstance(dim, int) for dim in shape_a) and 310 | all(isinstance(dim, int) for dim in shape_b)): 311 | if shape_a != shape_b: 312 | raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) 313 | else: return tf.no_op() 314 | else: 315 | return tf.assert_equal(shape_a, shape_b) 316 | 317 | 318 | def assert_shape_equal_along_first_dimension(shape_a, shape_b): 319 | """Asserts that shape_a and shape_b are the same along the 0th-dimension. 320 | 321 | If the shapes are static, raises a ValueError when the shapes 322 | mismatch. 323 | 324 | If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes 325 | mismatch. 326 | 327 | Args: 328 | shape_a: a list containing shape of the first tensor. 329 | shape_b: a list containing shape of the second tensor. 330 | 331 | Returns: 332 | Either a tf.no_op() when shapes are all static and a tf.assert_equal() op 333 | when the shapes are dynamic. 334 | 335 | Raises: 336 | ValueError: When shapes are both static and unequal. 337 | """ 338 | if isinstance(shape_a[0], int) and isinstance(shape_b[0], int): 339 | if shape_a[0] != shape_b[0]: 340 | raise ValueError('Unequal first dimension {}, {}'.format( 341 | shape_a[0], shape_b[0])) 342 | else: return tf.no_op() 343 | else: 344 | return tf.assert_equal(shape_a[0], shape_b[0]) 345 | 346 | 347 | def assert_box_normalized(boxes, maximum_normalized_coordinate=1.1): 348 | """Asserts the input box tensor is normalized. 349 | 350 | Args: 351 | boxes: a tensor of shape [N, 4] where N is the number of boxes. 352 | maximum_normalized_coordinate: Maximum coordinate value to be considered 353 | as normalized, default to 1.1. 354 | 355 | Returns: 356 | a tf.Assert op which fails when the input box tensor is not normalized. 357 | 358 | Raises: 359 | ValueError: When the input box tensor is not normalized. 360 | """ 361 | box_minimum = tf.reduce_min(boxes) 362 | box_maximum = tf.reduce_max(boxes) 363 | return tf.Assert( 364 | tf.logical_and( 365 | tf.less_equal(box_maximum, maximum_normalized_coordinate), 366 | tf.greater_equal(box_minimum, 0)), 367 | [boxes]) 368 | -------------------------------------------------------------------------------- /inference/standard_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains classes specifying naming conventions used for object detection. 17 | 18 | 19 | Specifies: 20 | InputDataFields: standard fields used by reader/preprocessor/batcher. 21 | DetectionResultFields: standard fields returned by object detector. 22 | BoxListFields: standard field used by BoxList 23 | TfExampleFields: standard fields for tf-example data format (go/tf-example). 24 | """ 25 | 26 | 27 | class InputDataFields(object): 28 | """Names for the input tensors. 29 | 30 | Holds the standard data field names to use for identifying input tensors. This 31 | should be used by the decoder to identify keys for the returned tensor_dict 32 | containing input tensors. And it should be used by the model to identify the 33 | tensors it needs. 34 | 35 | Attributes: 36 | image: image. 37 | image_additional_channels: additional channels. 38 | original_image: image in the original input size. 39 | original_image_spatial_shape: image in the original input size. 40 | key: unique key corresponding to image. 41 | source_id: source of the original image. 42 | filename: original filename of the dataset (without common path). 43 | groundtruth_image_classes: image-level class labels. 44 | groundtruth_image_confidences: image-level class confidences. 45 | groundtruth_boxes: coordinates of the ground truth boxes in the image. 46 | groundtruth_classes: box-level class labels. 47 | groundtruth_confidences: box-level class confidences. The shape should be 48 | the same as the shape of groundtruth_classes. 49 | groundtruth_label_types: box-level label types (e.g. explicit negative). 50 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] 51 | is the groundtruth a single object or a crowd. 52 | groundtruth_area: area of a groundtruth segment. 53 | groundtruth_difficult: is a `difficult` object 54 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the 55 | same class, forming a connected group, where instances are heavily 56 | occluding each other. 57 | proposal_boxes: coordinates of object proposal boxes. 58 | proposal_objectness: objectness score of each proposal. 59 | groundtruth_instance_masks: ground truth instance masks. 60 | groundtruth_instance_boundaries: ground truth instance boundaries. 61 | groundtruth_instance_classes: instance mask-level class labels. 62 | groundtruth_keypoints: ground truth keypoints. 63 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities. 64 | groundtruth_label_weights: groundtruth label weights. 65 | groundtruth_weights: groundtruth weight factor for bounding boxes. 66 | num_groundtruth_boxes: number of groundtruth boxes. 67 | is_annotated: whether an image has been labeled or not. 68 | true_image_shapes: true shapes of images in the resized images, as resized 69 | images can be padded with zeros. 70 | multiclass_scores: the label score per class for each box. 71 | """ 72 | image = 'image' 73 | image_additional_channels = 'image_additional_channels' 74 | original_image = 'original_image' 75 | original_image_spatial_shape = 'original_image_spatial_shape' 76 | key = 'key' 77 | source_id = 'source_id' 78 | filename = 'filename' 79 | groundtruth_image_classes = 'groundtruth_image_classes' 80 | groundtruth_image_confidences = 'groundtruth_image_confidences' 81 | groundtruth_boxes = 'groundtruth_boxes' 82 | groundtruth_classes = 'groundtruth_classes' 83 | groundtruth_confidences = 'groundtruth_confidences' 84 | groundtruth_label_types = 'groundtruth_label_types' 85 | groundtruth_is_crowd = 'groundtruth_is_crowd' 86 | groundtruth_area = 'groundtruth_area' 87 | groundtruth_difficult = 'groundtruth_difficult' 88 | groundtruth_group_of = 'groundtruth_group_of' 89 | proposal_boxes = 'proposal_boxes' 90 | proposal_objectness = 'proposal_objectness' 91 | groundtruth_instance_masks = 'groundtruth_instance_masks' 92 | groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' 93 | groundtruth_instance_classes = 'groundtruth_instance_classes' 94 | groundtruth_keypoints = 'groundtruth_keypoints' 95 | groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' 96 | groundtruth_label_weights = 'groundtruth_label_weights' 97 | groundtruth_weights = 'groundtruth_weights' 98 | num_groundtruth_boxes = 'num_groundtruth_boxes' 99 | is_annotated = 'is_annotated' 100 | true_image_shape = 'true_image_shape' 101 | multiclass_scores = 'multiclass_scores' 102 | 103 | 104 | class DetectionResultFields(object): 105 | """Naming conventions for storing the output of the detector. 106 | 107 | Attributes: 108 | source_id: source of the original image. 109 | key: unique key corresponding to image. 110 | detection_boxes: coordinates of the detection boxes in the image. 111 | detection_scores: detection scores for the detection boxes in the image. 112 | detection_classes: detection-level class labels. 113 | detection_masks: contains a segmentation mask for each detection box. 114 | detection_boundaries: contains an object boundary for each detection box. 115 | detection_keypoints: contains detection keypoints for each detection box. 116 | num_detections: number of detections in the batch. 117 | raw_detection_boxes: contains decoded detection boxes without Non-Max 118 | suppression. 119 | raw_detection_scores: contains class score logits for raw detection boxes. 120 | """ 121 | 122 | source_id = 'source_id' 123 | key = 'key' 124 | detection_boxes = 'detection_boxes' 125 | detection_scores = 'detection_scores' 126 | detection_classes = 'detection_classes' 127 | detection_masks = 'detection_masks' 128 | detection_boundaries = 'detection_boundaries' 129 | detection_keypoints = 'detection_keypoints' 130 | num_detections = 'num_detections' 131 | raw_detection_boxes = 'raw_detection_boxes' 132 | raw_detection_scores = 'raw_detection_scores' 133 | 134 | 135 | class BoxListFields(object): 136 | """Naming conventions for BoxLists. 137 | 138 | Attributes: 139 | boxes: bounding box coordinates. 140 | classes: classes per bounding box. 141 | scores: scores per bounding box. 142 | weights: sample weights per bounding box. 143 | objectness: objectness score per bounding box. 144 | masks: masks per bounding box. 145 | boundaries: boundaries per bounding box. 146 | keypoints: keypoints per bounding box. 147 | keypoint_heatmaps: keypoint heatmaps per bounding box. 148 | is_crowd: is_crowd annotation per bounding box. 149 | """ 150 | boxes = 'boxes' 151 | classes = 'classes' 152 | scores = 'scores' 153 | weights = 'weights' 154 | confidences = 'confidences' 155 | objectness = 'objectness' 156 | masks = 'masks' 157 | boundaries = 'boundaries' 158 | keypoints = 'keypoints' 159 | keypoint_heatmaps = 'keypoint_heatmaps' 160 | is_crowd = 'is_crowd' 161 | 162 | 163 | class TfExampleFields(object): 164 | """TF-example proto feature names for object detection. 165 | 166 | Holds the standard feature names to load from an Example proto for object 167 | detection. 168 | 169 | Attributes: 170 | image_encoded: JPEG encoded string 171 | image_format: image format, e.g. "JPEG" 172 | filename: filename 173 | channels: number of channels of image 174 | colorspace: colorspace, e.g. "RGB" 175 | height: height of image in pixels, e.g. 462 176 | width: width of image in pixels, e.g. 581 177 | source_id: original source of the image 178 | image_class_text: image-level label in text format 179 | image_class_label: image-level label in numerical format 180 | object_class_text: labels in text format, e.g. ["person", "cat"] 181 | object_class_label: labels in numbers, e.g. [16, 8] 182 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 183 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 184 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 185 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 186 | object_view: viewpoint of object, e.g. ["frontal", "left"] 187 | object_truncated: is object truncated, e.g. [true, false] 188 | object_occluded: is object occluded, e.g. [true, false] 189 | object_difficult: is object difficult, e.g. [true, false] 190 | object_group_of: is object a single object or a group of objects 191 | object_depiction: is object a depiction 192 | object_is_crowd: [DEPRECATED, use object_group_of instead] 193 | is the object a single object or a crowd 194 | object_segment_area: the area of the segment. 195 | object_weight: a weight factor for the object's bounding box. 196 | instance_masks: instance segmentation masks. 197 | instance_boundaries: instance boundaries. 198 | instance_classes: Classes for each instance segmentation mask. 199 | detection_class_label: class label in numbers. 200 | detection_bbox_ymin: ymin coordinates of a detection box. 201 | detection_bbox_xmin: xmin coordinates of a detection box. 202 | detection_bbox_ymax: ymax coordinates of a detection box. 203 | detection_bbox_xmax: xmax coordinates of a detection box. 204 | detection_score: detection score for the class label and box. 205 | """ 206 | image_encoded = 'image/encoded' 207 | image_format = 'image/format' # format is reserved keyword 208 | filename = 'image/filename' 209 | channels = 'image/channels' 210 | colorspace = 'image/colorspace' 211 | height = 'image/height' 212 | width = 'image/width' 213 | source_id = 'image/source_id' 214 | image_class_text = 'image/class/text' 215 | image_class_label = 'image/class/label' 216 | object_class_text = 'image/object/class/text' 217 | object_class_label = 'image/object/class/label' 218 | object_bbox_ymin = 'image/object/bbox/ymin' 219 | object_bbox_xmin = 'image/object/bbox/xmin' 220 | object_bbox_ymax = 'image/object/bbox/ymax' 221 | object_bbox_xmax = 'image/object/bbox/xmax' 222 | object_view = 'image/object/view' 223 | object_truncated = 'image/object/truncated' 224 | object_occluded = 'image/object/occluded' 225 | object_difficult = 'image/object/difficult' 226 | object_group_of = 'image/object/group_of' 227 | object_depiction = 'image/object/depiction' 228 | object_is_crowd = 'image/object/is_crowd' 229 | object_segment_area = 'image/object/segment/area' 230 | object_weight = 'image/object/weight' 231 | instance_masks = 'image/segmentation/object' 232 | instance_boundaries = 'image/boundaries/object' 233 | instance_classes = 'image/segmentation/object/class' 234 | detection_class_label = 'image/detection/label' 235 | detection_bbox_ymin = 'image/detection/bbox/ymin' 236 | detection_bbox_xmin = 'image/detection/bbox/xmin' 237 | detection_bbox_ymax = 'image/detection/bbox/ymax' 238 | detection_bbox_xmax = 'image/detection/bbox/xmax' 239 | detection_score = 'image/detection/score' 240 | -------------------------------------------------------------------------------- /inference/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions to access TensorShape values. 17 | 18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 19 | """ 20 | 21 | 22 | def get_batch_size(tensor_shape): 23 | """Returns batch size from the tensor shape. 24 | 25 | Args: 26 | tensor_shape: A rank 4 TensorShape. 27 | 28 | Returns: 29 | An integer representing the batch size of the tensor. 30 | """ 31 | tensor_shape.assert_has_rank(rank=4) 32 | return tensor_shape[0].value 33 | 34 | 35 | def get_height(tensor_shape): 36 | """Returns height from the tensor shape. 37 | 38 | Args: 39 | tensor_shape: A rank 4 TensorShape. 40 | 41 | Returns: 42 | An integer representing the height of the tensor. 43 | """ 44 | tensor_shape.assert_has_rank(rank=4) 45 | return tensor_shape[1].value 46 | 47 | 48 | def get_width(tensor_shape): 49 | """Returns width from the tensor shape. 50 | 51 | Args: 52 | tensor_shape: A rank 4 TensorShape. 53 | 54 | Returns: 55 | An integer representing the width of the tensor. 56 | """ 57 | tensor_shape.assert_has_rank(rank=4) 58 | return tensor_shape[2].value 59 | 60 | 61 | def get_depth(tensor_shape): 62 | """Returns depth from the tensor shape. 63 | 64 | Args: 65 | tensor_shape: A rank 4 TensorShape. 66 | 67 | Returns: 68 | An integer representing the depth of the tensor. 69 | """ 70 | tensor_shape.assert_has_rank(rank=4) 71 | return tensor_shape[3].value 72 | -------------------------------------------------------------------------------- /inference/video_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This Module is ppe 4 | Example: 5 | $python video_demo.py 6 | Author: Ming'en Zheng 7 | """ 8 | import os 9 | import time 10 | from multiprocessing import Process, Queue, Value 11 | import queue 12 | import numpy as np 13 | import tensorflow as tf 14 | import cv2 15 | import argparse 16 | import requests 17 | from distutils.version import StrictVersion 18 | import visualization_utils as vis_utils 19 | import config 20 | import base64 21 | 22 | 23 | if StrictVersion(tf.__version__) < StrictVersion('1.12.0'): 24 | raise ImportError('Please upgrade your TensorFlow installation to v1.12.*') 25 | 26 | 27 | def load_model(inference_model_path): 28 | detection_graph = tf.Graph() 29 | with detection_graph.as_default(): 30 | od_graph_def = tf.GraphDef() 31 | with tf.gfile.GFile(inference_model_path, 'rb') as fid: 32 | serialized_graph = fid.read() 33 | od_graph_def.ParseFromString(serialized_graph) 34 | tf.import_graph_def(od_graph_def, name='') 35 | return detection_graph 36 | 37 | 38 | def load_image_into_numpy_array(image): 39 | (im_width, im_height) = image.size 40 | return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8) 41 | 42 | 43 | def run_inference_for_single_image(image, sess, tensor_dict): 44 | image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') 45 | 46 | output_dict = sess.run(tensor_dict, feed_dict={image_tensor: image}) 47 | 48 | output_dict['num_detections'] = int(output_dict['num_detections'][0]) 49 | output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.int64) 50 | output_dict['detection_boxes'] = output_dict['detection_boxes'][0] 51 | output_dict['detection_scores'] = output_dict['detection_scores'][0] 52 | 53 | return output_dict 54 | 55 | 56 | def is_wearing_hardhat(person_box, hardhat_box, intersection_ratio): 57 | xA = max(person_box[0], hardhat_box[0]) 58 | yA = max(person_box[1], hardhat_box[1]) 59 | xB = min(person_box[2], hardhat_box[2]) 60 | yB = min(person_box[3], hardhat_box[3]) 61 | 62 | interArea = max(0, xB - xA ) * max(0, yB - yA ) 63 | 64 | hardhat_size = (hardhat_box[2] - hardhat_box[0]) * (hardhat_box[3] - hardhat_box[1]) 65 | 66 | if interArea / hardhat_size > intersection_ratio: 67 | return True 68 | else: 69 | return False 70 | 71 | 72 | def is_wearing_vest(person_box, vest_box, vest_intersection_ratio): 73 | xA = max(person_box[0], vest_box[0]) 74 | yA = max(person_box[1], vest_box[1]) 75 | xB = min(person_box[2], vest_box[2]) 76 | yB = min(person_box[3], vest_box[3]) 77 | 78 | interArea = max(0, xB - xA) * max(0, yB - yA) 79 | 80 | vest_size = (vest_box[2] - vest_box[0]) * (vest_box[3] - vest_box[1]) 81 | 82 | if interArea / vest_size > vest_intersection_ratio: 83 | return True 84 | else: 85 | return False 86 | 87 | 88 | def is_wearing_hardhat_vest(hardhat_boxes, vest_boxes, person_box): 89 | hardhat_flag = False 90 | vest_flag = False 91 | hardhat_intersection_ratio = 0.6 92 | vest_intersection_ratio = 0.6 93 | 94 | for hardhat_box in hardhat_boxes: 95 | hardhat_flag = is_wearing_hardhat(person_box, hardhat_box, hardhat_intersection_ratio) 96 | if hardhat_flag: 97 | break 98 | 99 | for vest_box in vest_boxes: 100 | vest_flag = is_wearing_vest(person_box, vest_box, vest_intersection_ratio) 101 | if vest_flag: 102 | break 103 | 104 | return hardhat_flag, vest_flag 105 | 106 | 107 | def post_message_process(run_flag, message_queue): 108 | 109 | while run_flag.value: 110 | try: 111 | camera_id, output_dict, image, min_score_thresh = message_queue.get(block=True, timeout=5) 112 | post_message(camera_id, output_dict, image, min_score_thresh) 113 | except queue.Empty: 114 | continue 115 | 116 | 117 | def post_message(camera_id, output_dict, image, min_score_thresh): 118 | message = dict() 119 | message["timestamp"] = int(time.time() * 1000) 120 | message["cameraId"] = camera_id 121 | 122 | image_info = {} 123 | image_info["height"] = image.shape[0] 124 | image_info["width"] = image.shape[1] 125 | image_info["format"] = "jpeg" 126 | 127 | success, encoded_image = cv2.imencode('.jpg', image) 128 | content = encoded_image.tobytes() 129 | image_info["raw"] = base64.b64encode(content).decode('utf-8') 130 | 131 | message["image"] = image_info 132 | 133 | detection_scores = np.where(output_dict["detection_scores"] > min_score_thresh, True, False) 134 | 135 | detection_boxes = output_dict["detection_boxes"][detection_scores] 136 | detection_classes = output_dict["detection_classes"][detection_scores] 137 | 138 | hardhat_boxes = detection_boxes[np.where(detection_classes == 1)] 139 | vest_boxes = detection_boxes[np.where(detection_classes == 2)] 140 | person_boxes = detection_boxes[np.where(detection_classes == 3)] 141 | 142 | persons = [] 143 | for person_box in person_boxes: 144 | person = dict() 145 | person["hardhat"], person["vest"] = is_wearing_hardhat_vest(hardhat_boxes, vest_boxes, person_box) 146 | persons.append(person) 147 | 148 | message["persons"] = persons 149 | 150 | if len(persons) == 0: 151 | return False 152 | 153 | print(message["persons"]) 154 | try: 155 | headers = {'Content-type': 'application/json'} 156 | if len(persons): 157 | result = requests.post(config.detection_api, json=message, headers=headers) 158 | print(result) 159 | return True 160 | except requests.exceptions.ConnectionError: 161 | print("Connect to backend failed") 162 | return False 163 | 164 | 165 | def image_processing(graph, category_index, image_file_name, show_video_window): 166 | 167 | img = cv2.imread(image_file_name) 168 | image_expanded = np.expand_dims(img, axis=0) 169 | 170 | with graph.as_default(): 171 | ops = tf.get_default_graph().get_operations() 172 | all_tensor_names = {output.name for op in ops for output in op.outputs} 173 | tensor_dict = {} 174 | for key in [ 175 | 'num_detections', 'detection_boxes', 'detection_scores', 176 | 'detection_classes', 'detection_masks' 177 | ]: 178 | tensor_name = key + ':0' 179 | if tensor_name in all_tensor_names: 180 | tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( 181 | tensor_name) 182 | with tf.Session() as sess: 183 | output_dict = run_inference_for_single_image(image_expanded, sess, tensor_dict) 184 | 185 | vis_utils.visualize_boxes_and_labels_on_image_array( 186 | img, 187 | output_dict['detection_boxes'], 188 | output_dict['detection_classes'], 189 | output_dict['detection_scores'], 190 | category_index, 191 | instance_masks=output_dict.get('detection_masks'), 192 | use_normalized_coordinates=True, 193 | line_thickness=4) 194 | 195 | if show_video_window: 196 | cv2.imshow('ppe', img) 197 | cv2.waitKey(5000) 198 | 199 | 200 | def video_processing(graph, category_index, video_file_name, show_video_window, camera_id, run_flag, message_queue): 201 | cap = cv2.VideoCapture(video_file_name) 202 | 203 | if show_video_window: 204 | cv2.namedWindow('ppe', cv2.WINDOW_NORMAL) 205 | if config.display_full_screen: 206 | cv2.setWindowProperty('ppe', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN) 207 | else: 208 | cv2.setWindowProperty('ppe', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_NORMAL) 209 | 210 | if (config.capture_image_width, config.capture_image_height) in config.supported_video_resolution: 211 | print("video_processing:", "supported video resoulution") 212 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, config.capture_image_width) 213 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, config.capture_image_height) 214 | 215 | with graph.as_default(): 216 | print("video_processing:", "default tensorflow graph") 217 | ops = tf.get_default_graph().get_operations() 218 | all_tensor_names = {output.name for op in ops for output in op.outputs} 219 | tensor_dict = {} 220 | for key in [ 221 | 'num_detections', 'detection_boxes', 'detection_scores', 222 | 'detection_classes', 'detection_masks' 223 | ]: 224 | tensor_name = key + ':0' 225 | if tensor_name in all_tensor_names: 226 | tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( 227 | tensor_name) 228 | with tf.Session() as sess: 229 | print("video_processing:", "tensorflow session") 230 | send_message_time = time.time() 231 | frame_counter = 0 232 | while True: 233 | ret, frame = cap.read() 234 | 235 | if config.input_type.lower() == "file": 236 | frame_counter += 1 237 | if frame_counter == int(cap.get(cv2.CAP_PROP_FRAME_COUNT)): 238 | frame_counter = 0 239 | cap.set(cv2.CAP_PROP_POS_FRAMES, 0) 240 | continue 241 | 242 | if frame is None: 243 | print("video_processing:", "null frame") 244 | break 245 | 246 | image_expanded = np.expand_dims(frame, axis=0) 247 | output_dict = run_inference_for_single_image(image_expanded, sess, tensor_dict) 248 | 249 | vis_utils.visualize_boxes_and_labels_on_image_array( 250 | frame, 251 | output_dict['detection_boxes'], 252 | output_dict['detection_classes'], 253 | output_dict['detection_scores'], 254 | category_index, 255 | instance_masks=output_dict.get('detection_masks'), 256 | use_normalized_coordinates=True, 257 | line_thickness=4) 258 | 259 | if time.time() - send_message_time > config.message_send_interval / 1000.0: 260 | resized_frame = cv2.resize(frame, dsize=(config.storage_image_width, config.storage_image_height)) 261 | try: 262 | message_queue.put_nowait((camera_id, output_dict, resized_frame, config.object_confidence_threshold)) 263 | except queue.Full: 264 | print("message queue is full") 265 | else: 266 | send_message_time = time.time() 267 | 268 | if show_video_window: 269 | resized_frame = cv2.resize(frame, dsize=(config.display_window_width, config.display_window_height)) 270 | cv2.imshow('ppe', resized_frame) 271 | if cv2.waitKey(1) & 0xFF == ord('q'): 272 | run_flag.value = 0 273 | break 274 | 275 | print("video_processing:", "releasing video capture") 276 | cap.release() 277 | cv2.destroyAllWindows() 278 | 279 | 280 | def main(): 281 | parser = argparse.ArgumentParser(description="Hardhat and Vest Detection", add_help=True) 282 | parser.add_argument("--model_dir", type=str, required=True, help="path to model directory") 283 | parser.add_argument("--video_file_name", type=str, required=True, help="path to video file, or camera device, i.e /dev/video1") 284 | parser.add_argument("--show_video_window", type=int, required=True, help="the flag for showing the video window, 0 is not dispaly, 1 display") 285 | parser.add_argument("--camera_id", type=str, required=True, help="camera identifier") 286 | args = parser.parse_args() 287 | 288 | frozen_model_path = os.path.join(args.model_dir, "frozen_inference_graph.pb") 289 | if not os.path.exists(frozen_model_path): 290 | print("frozen_inference_graph.db file is not exist in model directory") 291 | exit(-1) 292 | print("loading model") 293 | graph = load_model(frozen_model_path) 294 | category_index = {1: {'id': 1 , 'name': 'hardhat'}, 295 | 2: {'id': 2, 'name': 'vest'}, 296 | 3: {'id': 3, 'name': 'person'}} 297 | 298 | print("start message queue") 299 | run_flag = Value('i', 1) 300 | message_queue = Queue(1) 301 | p = Process(target=post_message_process, args=(run_flag, message_queue)) 302 | p.start() 303 | print("video processing") 304 | video_processing(graph, category_index, args.video_file_name, args.show_video_window, args.camera_id, run_flag, message_queue) 305 | p.join() 306 | 307 | #image_processing(graph, category_index, './examples/002.jpg', True) 308 | 309 | if __name__ == '__main__': 310 | main() 311 | 312 | -------------------------------------------------------------------------------- /inference/visualization_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A set of functions that are used for visualization. 17 | 18 | These functions often receive an image, perform some visualization on the image. 19 | The functions do not return a value, instead they modify the image itself. 20 | 21 | """ 22 | import abc 23 | import collections 24 | import functools 25 | # Set headless-friendly backend. 26 | import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements 27 | import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top 28 | import numpy as np 29 | import PIL.Image as Image 30 | import PIL.ImageColor as ImageColor 31 | import PIL.ImageDraw as ImageDraw 32 | import PIL.ImageFont as ImageFont 33 | import six 34 | import tensorflow as tf 35 | 36 | import standard_fields as fields 37 | import shape_utils 38 | 39 | _TITLE_LEFT_MARGIN = 10 40 | _TITLE_TOP_MARGIN = 10 41 | STANDARD_COLORS = [ 42 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', 43 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', 44 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', 45 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', 46 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', 47 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', 48 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', 49 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', 50 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', 51 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', 52 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', 53 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', 54 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', 55 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', 56 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', 57 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', 58 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', 59 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', 60 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', 61 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', 62 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', 63 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', 64 | 'WhiteSmoke', 'Yellow', 'YellowGreen' 65 | ] 66 | 67 | 68 | def save_image_array_as_png(image, output_path): 69 | """Saves an image (represented as a numpy array) to PNG. 70 | 71 | Args: 72 | image: a numpy array with shape [height, width, 3]. 73 | output_path: path to which image should be written. 74 | """ 75 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 76 | with tf.gfile.Open(output_path, 'w') as fid: 77 | image_pil.save(fid, 'PNG') 78 | 79 | 80 | def encode_image_array_as_png_str(image): 81 | """Encodes a numpy array into a PNG string. 82 | 83 | Args: 84 | image: a numpy array with shape [height, width, 3]. 85 | 86 | Returns: 87 | PNG encoded image string. 88 | """ 89 | image_pil = Image.fromarray(np.uint8(image)) 90 | output = six.BytesIO() 91 | image_pil.save(output, format='PNG') 92 | png_string = output.getvalue() 93 | output.close() 94 | return png_string 95 | 96 | 97 | def draw_bounding_box_on_image_array(image, 98 | ymin, 99 | xmin, 100 | ymax, 101 | xmax, 102 | color='red', 103 | thickness=4, 104 | display_str_list=(), 105 | use_normalized_coordinates=True): 106 | """Adds a bounding box to an image (numpy array). 107 | 108 | Bounding box coordinates can be specified in either absolute (pixel) or 109 | normalized coordinates by setting the use_normalized_coordinates argument. 110 | 111 | Args: 112 | image: a numpy array with shape [height, width, 3]. 113 | ymin: ymin of bounding box. 114 | xmin: xmin of bounding box. 115 | ymax: ymax of bounding box. 116 | xmax: xmax of bounding box. 117 | color: color to draw bounding box. Default is red. 118 | thickness: line thickness. Default value is 4. 119 | display_str_list: list of strings to display in box 120 | (each to be shown on its own line). 121 | use_normalized_coordinates: If True (default), treat coordinates 122 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 123 | coordinates as absolute. 124 | """ 125 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 126 | draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, 127 | thickness, display_str_list, 128 | use_normalized_coordinates) 129 | np.copyto(image, np.array(image_pil)) 130 | 131 | 132 | def draw_bounding_box_on_image(image, 133 | ymin, 134 | xmin, 135 | ymax, 136 | xmax, 137 | color='red', 138 | thickness=4, 139 | display_str_list=(), 140 | use_normalized_coordinates=True): 141 | """Adds a bounding box to an image. 142 | 143 | Bounding box coordinates can be specified in either absolute (pixel) or 144 | normalized coordinates by setting the use_normalized_coordinates argument. 145 | 146 | Each string in display_str_list is displayed on a separate line above the 147 | bounding box in black text on a rectangle filled with the input 'color'. 148 | If the top of the bounding box extends to the edge of the image, the strings 149 | are displayed below the bounding box. 150 | 151 | Args: 152 | image: a PIL.Image object. 153 | ymin: ymin of bounding box. 154 | xmin: xmin of bounding box. 155 | ymax: ymax of bounding box. 156 | xmax: xmax of bounding box. 157 | color: color to draw bounding box. Default is red. 158 | thickness: line thickness. Default value is 4. 159 | display_str_list: list of strings to display in box 160 | (each to be shown on its own line). 161 | use_normalized_coordinates: If True (default), treat coordinates 162 | ymin, xmin, ymax, xmax as relative to the image. Otherwise treat 163 | coordinates as absolute. 164 | """ 165 | draw = ImageDraw.Draw(image) 166 | im_width, im_height = image.size 167 | if use_normalized_coordinates: 168 | (left, right, top, bottom) = (xmin * im_width, xmax * im_width, 169 | ymin * im_height, ymax * im_height) 170 | else: 171 | (left, right, top, bottom) = (xmin, xmax, ymin, ymax) 172 | draw.line([(left, top), (left, bottom), (right, bottom), 173 | (right, top), (left, top)], width=thickness, fill=color) 174 | try: 175 | font = ImageFont.truetype('arial.ttf', 24) 176 | except IOError: 177 | font = ImageFont.load_default() 178 | 179 | # If the total height of the display strings added to the top of the bounding 180 | # box exceeds the top of the image, stack the strings below the bounding box 181 | # instead of above. 182 | display_str_heights = [font.getsize(ds)[1] for ds in display_str_list] 183 | # Each display_str has a top and bottom margin of 0.05x. 184 | total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights) 185 | 186 | if top > total_display_str_height: 187 | text_bottom = top 188 | else: 189 | text_bottom = bottom + total_display_str_height 190 | # Reverse list and print from bottom to top. 191 | for display_str in display_str_list[::-1]: 192 | text_width, text_height = font.getsize(display_str) 193 | margin = np.ceil(0.05 * text_height) 194 | draw.rectangle( 195 | [(left, text_bottom - text_height - 2 * margin), (left + text_width, 196 | text_bottom)], 197 | fill=color) 198 | draw.text( 199 | (left + margin, text_bottom - text_height - margin), 200 | display_str, 201 | fill='black', 202 | font=font) 203 | text_bottom -= text_height - 2 * margin 204 | 205 | 206 | def draw_bounding_boxes_on_image_array(image, 207 | boxes, 208 | color='red', 209 | thickness=4, 210 | display_str_list_list=()): 211 | """Draws bounding boxes on image (numpy array). 212 | 213 | Args: 214 | image: a numpy array object. 215 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 216 | The coordinates are in normalized format between [0, 1]. 217 | color: color to draw bounding box. Default is red. 218 | thickness: line thickness. Default value is 4. 219 | display_str_list_list: list of list of strings. 220 | a list of strings for each bounding box. 221 | The reason to pass a list of strings for a 222 | bounding box is that it might contain 223 | multiple labels. 224 | 225 | Raises: 226 | ValueError: if boxes is not a [N, 4] array 227 | """ 228 | image_pil = Image.fromarray(image) 229 | draw_bounding_boxes_on_image(image_pil, boxes, color, thickness, 230 | display_str_list_list) 231 | np.copyto(image, np.array(image_pil)) 232 | 233 | 234 | def draw_bounding_boxes_on_image(image, 235 | boxes, 236 | color='red', 237 | thickness=4, 238 | display_str_list_list=()): 239 | """Draws bounding boxes on image. 240 | 241 | Args: 242 | image: a PIL.Image object. 243 | boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). 244 | The coordinates are in normalized format between [0, 1]. 245 | color: color to draw bounding box. Default is red. 246 | thickness: line thickness. Default value is 4. 247 | display_str_list_list: list of list of strings. 248 | a list of strings for each bounding box. 249 | The reason to pass a list of strings for a 250 | bounding box is that it might contain 251 | multiple labels. 252 | 253 | Raises: 254 | ValueError: if boxes is not a [N, 4] array 255 | """ 256 | boxes_shape = boxes.shape 257 | if not boxes_shape: 258 | return 259 | if len(boxes_shape) != 2 or boxes_shape[1] != 4: 260 | raise ValueError('Input must be of size [N, 4]') 261 | for i in range(boxes_shape[0]): 262 | display_str_list = () 263 | if display_str_list_list: 264 | display_str_list = display_str_list_list[i] 265 | draw_bounding_box_on_image(image, boxes[i, 0], boxes[i, 1], boxes[i, 2], 266 | boxes[i, 3], color, thickness, display_str_list) 267 | 268 | 269 | def _visualize_boxes(image, boxes, classes, scores, category_index, **kwargs): 270 | return visualize_boxes_and_labels_on_image_array( 271 | image, boxes, classes, scores, category_index=category_index, **kwargs) 272 | 273 | 274 | def _visualize_boxes_and_masks(image, boxes, classes, scores, masks, 275 | category_index, **kwargs): 276 | return visualize_boxes_and_labels_on_image_array( 277 | image, 278 | boxes, 279 | classes, 280 | scores, 281 | category_index=category_index, 282 | instance_masks=masks, 283 | **kwargs) 284 | 285 | 286 | def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, 287 | category_index, **kwargs): 288 | return visualize_boxes_and_labels_on_image_array( 289 | image, 290 | boxes, 291 | classes, 292 | scores, 293 | category_index=category_index, 294 | keypoints=keypoints, 295 | **kwargs) 296 | 297 | 298 | def _visualize_boxes_and_masks_and_keypoints( 299 | image, boxes, classes, scores, masks, keypoints, category_index, **kwargs): 300 | return visualize_boxes_and_labels_on_image_array( 301 | image, 302 | boxes, 303 | classes, 304 | scores, 305 | category_index=category_index, 306 | instance_masks=masks, 307 | keypoints=keypoints, 308 | **kwargs) 309 | 310 | 311 | def _resize_original_image(image, image_shape): 312 | image = tf.expand_dims(image, 0) 313 | image = tf.image.resize_images( 314 | image, 315 | image_shape, 316 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, 317 | align_corners=True) 318 | return tf.cast(tf.squeeze(image, 0), tf.uint8) 319 | 320 | 321 | def draw_bounding_boxes_on_image_tensors(images, 322 | boxes, 323 | classes, 324 | scores, 325 | category_index, 326 | original_image_spatial_shape=None, 327 | true_image_shape=None, 328 | instance_masks=None, 329 | keypoints=None, 330 | max_boxes_to_draw=20, 331 | min_score_thresh=0.2, 332 | use_normalized_coordinates=True): 333 | """Draws bounding boxes, masks, and keypoints on batch of image tensors. 334 | 335 | Args: 336 | images: A 4D uint8 image tensor of shape [N, H, W, C]. If C > 3, additional 337 | channels will be ignored. If C = 1, then we convert the images to RGB 338 | images. 339 | boxes: [N, max_detections, 4] float32 tensor of detection boxes. 340 | classes: [N, max_detections] int tensor of detection classes. Note that 341 | classes are 1-indexed. 342 | scores: [N, max_detections] float32 tensor of detection scores. 343 | category_index: a dict that maps integer ids to category dicts. e.g. 344 | {1: {1: 'dog'}, 2: {2: 'cat'}, ...} 345 | original_image_spatial_shape: [N, 2] tensor containing the spatial size of 346 | the original image. 347 | true_image_shape: [N, 3] tensor containing the spatial size of unpadded 348 | original_image. 349 | instance_masks: A 4D uint8 tensor of shape [N, max_detection, H, W] with 350 | instance masks. 351 | keypoints: A 4D float32 tensor of shape [N, max_detection, num_keypoints, 2] 352 | with keypoints. 353 | max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. 354 | min_score_thresh: Minimum score threshold for visualization. Default 0.2. 355 | use_normalized_coordinates: Whether to assume boxes and kepoints are in 356 | normalized coordinates (as opposed to absolute coordiantes). 357 | Default is True. 358 | 359 | Returns: 360 | 4D image tensor of type uint8, with boxes drawn on top. 361 | """ 362 | # Additional channels are being ignored. 363 | if images.shape[3] > 3: 364 | images = images[:, :, :, 0:3] 365 | elif images.shape[3] == 1: 366 | images = tf.image.grayscale_to_rgb(images) 367 | visualization_keyword_args = { 368 | 'use_normalized_coordinates': use_normalized_coordinates, 369 | 'max_boxes_to_draw': max_boxes_to_draw, 370 | 'min_score_thresh': min_score_thresh, 371 | 'agnostic_mode': False, 372 | 'line_thickness': 4 373 | } 374 | if true_image_shape is None: 375 | true_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 3]) 376 | else: 377 | true_shapes = true_image_shape 378 | if original_image_spatial_shape is None: 379 | original_shapes = tf.constant(-1, shape=[images.shape.as_list()[0], 2]) 380 | else: 381 | original_shapes = original_image_spatial_shape 382 | 383 | if instance_masks is not None and keypoints is None: 384 | visualize_boxes_fn = functools.partial( 385 | _visualize_boxes_and_masks, 386 | category_index=category_index, 387 | **visualization_keyword_args) 388 | elems = [ 389 | true_shapes, original_shapes, images, boxes, classes, scores, 390 | instance_masks 391 | ] 392 | elif instance_masks is None and keypoints is not None: 393 | visualize_boxes_fn = functools.partial( 394 | _visualize_boxes_and_keypoints, 395 | category_index=category_index, 396 | **visualization_keyword_args) 397 | elems = [ 398 | true_shapes, original_shapes, images, boxes, classes, scores, keypoints 399 | ] 400 | elif instance_masks is not None and keypoints is not None: 401 | visualize_boxes_fn = functools.partial( 402 | _visualize_boxes_and_masks_and_keypoints, 403 | category_index=category_index, 404 | **visualization_keyword_args) 405 | elems = [ 406 | true_shapes, original_shapes, images, boxes, classes, scores, 407 | instance_masks, keypoints 408 | ] 409 | else: 410 | visualize_boxes_fn = functools.partial( 411 | _visualize_boxes, 412 | category_index=category_index, 413 | **visualization_keyword_args) 414 | elems = [ 415 | true_shapes, original_shapes, images, boxes, classes, scores 416 | ] 417 | 418 | def draw_boxes(image_and_detections): 419 | """Draws boxes on image.""" 420 | true_shape = image_and_detections[0] 421 | original_shape = image_and_detections[1] 422 | if true_image_shape is not None: 423 | image = shape_utils.pad_or_clip_nd(image_and_detections[2], 424 | [true_shape[0], true_shape[1], 3]) 425 | if original_image_spatial_shape is not None: 426 | image_and_detections[2] = _resize_original_image(image, original_shape) 427 | 428 | image_with_boxes = tf.py_func(visualize_boxes_fn, image_and_detections[2:], 429 | tf.uint8) 430 | return image_with_boxes 431 | 432 | images = tf.map_fn(draw_boxes, elems, dtype=tf.uint8, back_prop=False) 433 | return images 434 | 435 | 436 | def draw_side_by_side_evaluation_image(eval_dict, 437 | category_index, 438 | max_boxes_to_draw=20, 439 | min_score_thresh=0.2, 440 | use_normalized_coordinates=True): 441 | """Creates a side-by-side image with detections and groundtruth. 442 | 443 | Bounding boxes (and instance masks, if available) are visualized on both 444 | subimages. 445 | 446 | Args: 447 | eval_dict: The evaluation dictionary returned by 448 | eval_util.result_dict_for_batched_example() or 449 | eval_util.result_dict_for_single_example(). 450 | category_index: A category index (dictionary) produced from a labelmap. 451 | max_boxes_to_draw: The maximum number of boxes to draw for detections. 452 | min_score_thresh: The minimum score threshold for showing detections. 453 | use_normalized_coordinates: Whether to assume boxes and kepoints are in 454 | normalized coordinates (as opposed to absolute coordiantes). 455 | Default is True. 456 | 457 | Returns: 458 | A list of [1, H, 2 * W, C] uint8 tensor. The subimage on the left 459 | corresponds to detections, while the subimage on the right corresponds to 460 | groundtruth. 461 | """ 462 | detection_fields = fields.DetectionResultFields() 463 | input_data_fields = fields.InputDataFields() 464 | 465 | images_with_detections_list = [] 466 | 467 | # Add the batch dimension if the eval_dict is for single example. 468 | if len(eval_dict[detection_fields.detection_classes].shape) == 1: 469 | for key in eval_dict: 470 | if key != input_data_fields.original_image: 471 | eval_dict[key] = tf.expand_dims(eval_dict[key], 0) 472 | 473 | for indx in range(eval_dict[input_data_fields.original_image].shape[0]): 474 | instance_masks = None 475 | if detection_fields.detection_masks in eval_dict: 476 | instance_masks = tf.cast( 477 | tf.expand_dims( 478 | eval_dict[detection_fields.detection_masks][indx], axis=0), 479 | tf.uint8) 480 | keypoints = None 481 | if detection_fields.detection_keypoints in eval_dict: 482 | keypoints = tf.expand_dims( 483 | eval_dict[detection_fields.detection_keypoints][indx], axis=0) 484 | groundtruth_instance_masks = None 485 | if input_data_fields.groundtruth_instance_masks in eval_dict: 486 | groundtruth_instance_masks = tf.cast( 487 | tf.expand_dims( 488 | eval_dict[input_data_fields.groundtruth_instance_masks][indx], 489 | axis=0), tf.uint8) 490 | 491 | images_with_detections = draw_bounding_boxes_on_image_tensors( 492 | tf.expand_dims( 493 | eval_dict[input_data_fields.original_image][indx], axis=0), 494 | tf.expand_dims( 495 | eval_dict[detection_fields.detection_boxes][indx], axis=0), 496 | tf.expand_dims( 497 | eval_dict[detection_fields.detection_classes][indx], axis=0), 498 | tf.expand_dims( 499 | eval_dict[detection_fields.detection_scores][indx], axis=0), 500 | category_index, 501 | original_image_spatial_shape=tf.expand_dims( 502 | eval_dict[input_data_fields.original_image_spatial_shape][indx], 503 | axis=0), 504 | true_image_shape=tf.expand_dims( 505 | eval_dict[input_data_fields.true_image_shape][indx], axis=0), 506 | instance_masks=instance_masks, 507 | keypoints=keypoints, 508 | max_boxes_to_draw=max_boxes_to_draw, 509 | min_score_thresh=min_score_thresh, 510 | use_normalized_coordinates=use_normalized_coordinates) 511 | images_with_groundtruth = draw_bounding_boxes_on_image_tensors( 512 | tf.expand_dims( 513 | eval_dict[input_data_fields.original_image][indx], axis=0), 514 | tf.expand_dims( 515 | eval_dict[input_data_fields.groundtruth_boxes][indx], axis=0), 516 | tf.expand_dims( 517 | eval_dict[input_data_fields.groundtruth_classes][indx], axis=0), 518 | tf.expand_dims( 519 | tf.ones_like( 520 | eval_dict[input_data_fields.groundtruth_classes][indx], 521 | dtype=tf.float32), 522 | axis=0), 523 | category_index, 524 | original_image_spatial_shape=tf.expand_dims( 525 | eval_dict[input_data_fields.original_image_spatial_shape][indx], 526 | axis=0), 527 | true_image_shape=tf.expand_dims( 528 | eval_dict[input_data_fields.true_image_shape][indx], axis=0), 529 | instance_masks=groundtruth_instance_masks, 530 | keypoints=None, 531 | max_boxes_to_draw=None, 532 | min_score_thresh=0.0, 533 | use_normalized_coordinates=use_normalized_coordinates) 534 | images_with_detections_list.append( 535 | tf.concat([images_with_detections, images_with_groundtruth], axis=2)) 536 | return images_with_detections_list 537 | 538 | 539 | def draw_keypoints_on_image_array(image, 540 | keypoints, 541 | color='red', 542 | radius=2, 543 | use_normalized_coordinates=True): 544 | """Draws keypoints on an image (numpy array). 545 | 546 | Args: 547 | image: a numpy array with shape [height, width, 3]. 548 | keypoints: a numpy array with shape [num_keypoints, 2]. 549 | color: color to draw the keypoints with. Default is red. 550 | radius: keypoint radius. Default value is 2. 551 | use_normalized_coordinates: if True (default), treat keypoint values as 552 | relative to the image. Otherwise treat them as absolute. 553 | """ 554 | image_pil = Image.fromarray(np.uint8(image)).convert('RGB') 555 | draw_keypoints_on_image(image_pil, keypoints, color, radius, 556 | use_normalized_coordinates) 557 | np.copyto(image, np.array(image_pil)) 558 | 559 | 560 | def draw_keypoints_on_image(image, 561 | keypoints, 562 | color='red', 563 | radius=2, 564 | use_normalized_coordinates=True): 565 | """Draws keypoints on an image. 566 | 567 | Args: 568 | image: a PIL.Image object. 569 | keypoints: a numpy array with shape [num_keypoints, 2]. 570 | color: color to draw the keypoints with. Default is red. 571 | radius: keypoint radius. Default value is 2. 572 | use_normalized_coordinates: if True (default), treat keypoint values as 573 | relative to the image. Otherwise treat them as absolute. 574 | """ 575 | draw = ImageDraw.Draw(image) 576 | im_width, im_height = image.size 577 | keypoints_x = [k[1] for k in keypoints] 578 | keypoints_y = [k[0] for k in keypoints] 579 | if use_normalized_coordinates: 580 | keypoints_x = tuple([im_width * x for x in keypoints_x]) 581 | keypoints_y = tuple([im_height * y for y in keypoints_y]) 582 | for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): 583 | draw.ellipse([(keypoint_x - radius, keypoint_y - radius), 584 | (keypoint_x + radius, keypoint_y + radius)], 585 | outline=color, fill=color) 586 | 587 | 588 | def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): 589 | """Draws mask on an image. 590 | 591 | Args: 592 | image: uint8 numpy array with shape (img_height, img_height, 3) 593 | mask: a uint8 numpy array of shape (img_height, img_height) with 594 | values between either 0 or 1. 595 | color: color to draw the keypoints with. Default is red. 596 | alpha: transparency value between 0 and 1. (default: 0.4) 597 | 598 | Raises: 599 | ValueError: On incorrect data type for image or masks. 600 | """ 601 | if image.dtype != np.uint8: 602 | raise ValueError('`image` not of type np.uint8') 603 | if mask.dtype != np.uint8: 604 | raise ValueError('`mask` not of type np.uint8') 605 | if np.any(np.logical_and(mask != 1, mask != 0)): 606 | raise ValueError('`mask` elements should be in [0, 1]') 607 | if image.shape[:2] != mask.shape: 608 | raise ValueError('The image has spatial dimensions %s but the mask has ' 609 | 'dimensions %s' % (image.shape[:2], mask.shape)) 610 | rgb = ImageColor.getrgb(color) 611 | pil_image = Image.fromarray(image) 612 | 613 | solid_color = np.expand_dims( 614 | np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) 615 | pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') 616 | pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') 617 | pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) 618 | np.copyto(image, np.array(pil_image.convert('RGB'))) 619 | 620 | 621 | def visualize_boxes_and_labels_on_image_array( 622 | image, 623 | boxes, 624 | classes, 625 | scores, 626 | category_index, 627 | instance_masks=None, 628 | instance_boundaries=None, 629 | keypoints=None, 630 | use_normalized_coordinates=False, 631 | max_boxes_to_draw=20, 632 | min_score_thresh=.5, 633 | agnostic_mode=False, 634 | line_thickness=4, 635 | groundtruth_box_visualization_color='black', 636 | skip_scores=False, 637 | skip_labels=False): 638 | """Overlay labeled boxes on an image with formatted scores and label names. 639 | 640 | This function groups boxes that correspond to the same location 641 | and creates a display string for each detection and overlays these 642 | on the image. Note that this function modifies the image in place, and returns 643 | that same image. 644 | 645 | Args: 646 | image: uint8 numpy array with shape (img_height, img_width, 3) 647 | boxes: a numpy array of shape [N, 4] 648 | classes: a numpy array of shape [N]. Note that class indices are 1-based, 649 | and match the keys in the label map. 650 | scores: a numpy array of shape [N] or None. If scores=None, then 651 | this function assumes that the boxes to be plotted are groundtruth 652 | boxes and plot all boxes as black with no classes or scores. 653 | category_index: a dict containing category dictionaries (each holding 654 | category index `id` and category name `name`) keyed by category indices. 655 | instance_masks: a numpy array of shape [N, image_height, image_width] with 656 | values ranging between 0 and 1, can be None. 657 | instance_boundaries: a numpy array of shape [N, image_height, image_width] 658 | with values ranging between 0 and 1, can be None. 659 | keypoints: a numpy array of shape [N, num_keypoints, 2], can 660 | be None 661 | use_normalized_coordinates: whether boxes is to be interpreted as 662 | normalized coordinates or not. 663 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw 664 | all boxes. 665 | min_score_thresh: minimum score threshold for a box to be visualized 666 | agnostic_mode: boolean (default: False) controlling whether to evaluate in 667 | class-agnostic mode or not. This mode will display scores but ignore 668 | classes. 669 | line_thickness: integer (default: 4) controlling line width of the boxes. 670 | groundtruth_box_visualization_color: box color for visualizing groundtruth 671 | boxes 672 | skip_scores: whether to skip score when drawing a single detection 673 | skip_labels: whether to skip label when drawing a single detection 674 | 675 | Returns: 676 | uint8 numpy array with shape (img_height, img_width, 3) with overlaid boxes. 677 | """ 678 | # Create a display string (and color) for every box location, group any boxes 679 | # that correspond to the same location. 680 | box_to_display_str_map = collections.defaultdict(list) 681 | box_to_color_map = collections.defaultdict(str) 682 | box_to_instance_masks_map = {} 683 | box_to_instance_boundaries_map = {} 684 | box_to_keypoints_map = collections.defaultdict(list) 685 | if not max_boxes_to_draw: 686 | max_boxes_to_draw = boxes.shape[0] 687 | for i in range(min(max_boxes_to_draw, boxes.shape[0])): 688 | if scores is None or scores[i] > min_score_thresh: 689 | box = tuple(boxes[i].tolist()) 690 | if instance_masks is not None: 691 | box_to_instance_masks_map[box] = instance_masks[i] 692 | if instance_boundaries is not None: 693 | box_to_instance_boundaries_map[box] = instance_boundaries[i] 694 | if keypoints is not None: 695 | box_to_keypoints_map[box].extend(keypoints[i]) 696 | if scores is None: 697 | box_to_color_map[box] = groundtruth_box_visualization_color 698 | else: 699 | display_str = '' 700 | if not skip_labels: 701 | if not agnostic_mode: 702 | if classes[i] in category_index.keys(): 703 | class_name = category_index[classes[i]]['name'] 704 | else: 705 | class_name = 'N/A' 706 | display_str = str(class_name) 707 | if not skip_scores: 708 | if not display_str: 709 | display_str = '{}%'.format(int(100*scores[i])) 710 | else: 711 | display_str = '{}: {}%'.format(display_str, int(100*scores[i])) 712 | box_to_display_str_map[box].append(display_str) 713 | if agnostic_mode: 714 | box_to_color_map[box] = 'DarkOrange' 715 | else: 716 | box_to_color_map[box] = STANDARD_COLORS[ 717 | classes[i] % len(STANDARD_COLORS)] 718 | 719 | # Draw all boxes onto image. 720 | for box, color in box_to_color_map.items(): 721 | ymin, xmin, ymax, xmax = box 722 | if instance_masks is not None: 723 | draw_mask_on_image_array( 724 | image, 725 | box_to_instance_masks_map[box], 726 | color=color 727 | ) 728 | if instance_boundaries is not None: 729 | draw_mask_on_image_array( 730 | image, 731 | box_to_instance_boundaries_map[box], 732 | color='red', 733 | alpha=1.0 734 | ) 735 | draw_bounding_box_on_image_array( 736 | image, 737 | ymin, 738 | xmin, 739 | ymax, 740 | xmax, 741 | color=color, 742 | thickness=line_thickness, 743 | display_str_list=box_to_display_str_map[box], 744 | use_normalized_coordinates=use_normalized_coordinates) 745 | if keypoints is not None: 746 | draw_keypoints_on_image_array( 747 | image, 748 | box_to_keypoints_map[box], 749 | color=color, 750 | radius=line_thickness / 2, 751 | use_normalized_coordinates=use_normalized_coordinates) 752 | 753 | return image 754 | 755 | 756 | def add_cdf_image_summary(values, name): 757 | """Adds a tf.summary.image for a CDF plot of the values. 758 | 759 | Normalizes `values` such that they sum to 1, plots the cumulative distribution 760 | function and creates a tf image summary. 761 | 762 | Args: 763 | values: a 1-D float32 tensor containing the values. 764 | name: name for the image summary. 765 | """ 766 | def cdf_plot(values): 767 | """Numpy function to plot CDF.""" 768 | normalized_values = values / np.sum(values) 769 | sorted_values = np.sort(normalized_values) 770 | cumulative_values = np.cumsum(sorted_values) 771 | fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) 772 | / cumulative_values.size) 773 | fig = plt.figure(frameon=False) 774 | ax = fig.add_subplot('111') 775 | ax.plot(fraction_of_examples, cumulative_values) 776 | ax.set_ylabel('cumulative normalized values') 777 | ax.set_xlabel('fraction of examples') 778 | fig.canvas.draw() 779 | width, height = fig.get_size_inches() * fig.get_dpi() 780 | image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( 781 | 1, int(height), int(width), 3) 782 | return image 783 | cdf_plot = tf.py_func(cdf_plot, [values], tf.uint8) 784 | tf.summary.image(name, cdf_plot) 785 | 786 | 787 | def add_hist_image_summary(values, bins, name): 788 | """Adds a tf.summary.image for a histogram plot of the values. 789 | 790 | Plots the histogram of values and creates a tf image summary. 791 | 792 | Args: 793 | values: a 1-D float32 tensor containing the values. 794 | bins: bin edges which will be directly passed to np.histogram. 795 | name: name for the image summary. 796 | """ 797 | 798 | def hist_plot(values, bins): 799 | """Numpy function to plot hist.""" 800 | fig = plt.figure(frameon=False) 801 | ax = fig.add_subplot('111') 802 | y, x = np.histogram(values, bins=bins) 803 | ax.plot(x[:-1], y) 804 | ax.set_ylabel('count') 805 | ax.set_xlabel('value') 806 | fig.canvas.draw() 807 | width, height = fig.get_size_inches() * fig.get_dpi() 808 | image = np.fromstring( 809 | fig.canvas.tostring_rgb(), dtype='uint8').reshape( 810 | 1, int(height), int(width), 3) 811 | return image 812 | hist_plot = tf.py_func(hist_plot, [values, bins], tf.uint8) 813 | tf.summary.image(name, hist_plot) 814 | 815 | 816 | class EvalMetricOpsVisualization(object): 817 | """Abstract base class responsible for visualizations during evaluation. 818 | 819 | Currently, summary images are not run during evaluation. One way to produce 820 | evaluation images in Tensorboard is to provide tf.summary.image strings as 821 | `value_ops` in tf.estimator.EstimatorSpec's `eval_metric_ops`. This class is 822 | responsible for accruing images (with overlaid detections and groundtruth) 823 | and returning a dictionary that can be passed to `eval_metric_ops`. 824 | """ 825 | __metaclass__ = abc.ABCMeta 826 | 827 | def __init__(self, 828 | category_index, 829 | max_examples_to_draw=5, 830 | max_boxes_to_draw=20, 831 | min_score_thresh=0.2, 832 | use_normalized_coordinates=True, 833 | summary_name_prefix='evaluation_image'): 834 | """Creates an EvalMetricOpsVisualization. 835 | 836 | Args: 837 | category_index: A category index (dictionary) produced from a labelmap. 838 | max_examples_to_draw: The maximum number of example summaries to produce. 839 | max_boxes_to_draw: The maximum number of boxes to draw for detections. 840 | min_score_thresh: The minimum score threshold for showing detections. 841 | use_normalized_coordinates: Whether to assume boxes and kepoints are in 842 | normalized coordinates (as opposed to absolute coordiantes). 843 | Default is True. 844 | summary_name_prefix: A string prefix for each image summary. 845 | """ 846 | 847 | self._category_index = category_index 848 | self._max_examples_to_draw = max_examples_to_draw 849 | self._max_boxes_to_draw = max_boxes_to_draw 850 | self._min_score_thresh = min_score_thresh 851 | self._use_normalized_coordinates = use_normalized_coordinates 852 | self._summary_name_prefix = summary_name_prefix 853 | self._images = [] 854 | 855 | def clear(self): 856 | self._images = [] 857 | 858 | def add_images(self, images): 859 | """Store a list of images, each with shape [1, H, W, C].""" 860 | if len(self._images) >= self._max_examples_to_draw: 861 | return 862 | 863 | # Store images and clip list if necessary. 864 | self._images.extend(images) 865 | if len(self._images) > self._max_examples_to_draw: 866 | self._images[self._max_examples_to_draw:] = [] 867 | 868 | def get_estimator_eval_metric_ops(self, eval_dict): 869 | """Returns metric ops for use in tf.estimator.EstimatorSpec. 870 | 871 | Args: 872 | eval_dict: A dictionary that holds an image, groundtruth, and detections 873 | for a batched example. Note that, we use only the first example for 874 | visualization. See eval_util.result_dict_for_batched_example() for a 875 | convenient method for constructing such a dictionary. The dictionary 876 | contains 877 | fields.InputDataFields.original_image: [batch_size, H, W, 3] image. 878 | fields.InputDataFields.original_image_spatial_shape: [batch_size, 2] 879 | tensor containing the size of the original image. 880 | fields.InputDataFields.true_image_shape: [batch_size, 3] 881 | tensor containing the spatial size of the upadded original image. 882 | fields.InputDataFields.groundtruth_boxes - [batch_size, num_boxes, 4] 883 | float32 tensor with groundtruth boxes in range [0.0, 1.0]. 884 | fields.InputDataFields.groundtruth_classes - [batch_size, num_boxes] 885 | int64 tensor with 1-indexed groundtruth classes. 886 | fields.InputDataFields.groundtruth_instance_masks - (optional) 887 | [batch_size, num_boxes, H, W] int64 tensor with instance masks. 888 | fields.DetectionResultFields.detection_boxes - [batch_size, 889 | max_num_boxes, 4] float32 tensor with detection boxes in range [0.0, 890 | 1.0]. 891 | fields.DetectionResultFields.detection_classes - [batch_size, 892 | max_num_boxes] int64 tensor with 1-indexed detection classes. 893 | fields.DetectionResultFields.detection_scores - [batch_size, 894 | max_num_boxes] float32 tensor with detection scores. 895 | fields.DetectionResultFields.detection_masks - (optional) [batch_size, 896 | max_num_boxes, H, W] float32 tensor of binarized masks. 897 | fields.DetectionResultFields.detection_keypoints - (optional) 898 | [batch_size, max_num_boxes, num_keypoints, 2] float32 tensor with 899 | keypoints. 900 | 901 | Returns: 902 | A dictionary of image summary names to tuple of (value_op, update_op). The 903 | `update_op` is the same for all items in the dictionary, and is 904 | responsible for saving a single side-by-side image with detections and 905 | groundtruth. Each `value_op` holds the tf.summary.image string for a given 906 | image. 907 | """ 908 | if self._max_examples_to_draw == 0: 909 | return {} 910 | images = self.images_from_evaluation_dict(eval_dict) 911 | 912 | def get_images(): 913 | """Returns a list of images, padded to self._max_images_to_draw.""" 914 | images = self._images 915 | while len(images) < self._max_examples_to_draw: 916 | images.append(np.array(0, dtype=np.uint8)) 917 | self.clear() 918 | return images 919 | 920 | def image_summary_or_default_string(summary_name, image): 921 | """Returns image summaries for non-padded elements.""" 922 | return tf.cond( 923 | tf.equal(tf.size(tf.shape(image)), 4), 924 | lambda: tf.summary.image(summary_name, image), 925 | lambda: tf.constant('')) 926 | 927 | update_op = tf.py_func(self.add_images, [[images[0]]], []) 928 | image_tensors = tf.py_func( 929 | get_images, [], [tf.uint8] * self._max_examples_to_draw) 930 | eval_metric_ops = {} 931 | for i, image in enumerate(image_tensors): 932 | summary_name = self._summary_name_prefix + '/' + str(i) 933 | value_op = image_summary_or_default_string(summary_name, image) 934 | eval_metric_ops[summary_name] = (value_op, update_op) 935 | return eval_metric_ops 936 | 937 | @abc.abstractmethod 938 | def images_from_evaluation_dict(self, eval_dict): 939 | """Converts evaluation dictionary into a list of image tensors. 940 | 941 | To be overridden by implementations. 942 | 943 | Args: 944 | eval_dict: A dictionary with all the necessary information for producing 945 | visualizations. 946 | 947 | Returns: 948 | A list of [1, H, W, C] uint8 tensors. 949 | """ 950 | raise NotImplementedError 951 | 952 | 953 | class VisualizeSingleFrameDetections(EvalMetricOpsVisualization): 954 | """Class responsible for single-frame object detection visualizations.""" 955 | 956 | def __init__(self, 957 | category_index, 958 | max_examples_to_draw=5, 959 | max_boxes_to_draw=20, 960 | min_score_thresh=0.2, 961 | use_normalized_coordinates=True, 962 | summary_name_prefix='Detections_Left_Groundtruth_Right'): 963 | super(VisualizeSingleFrameDetections, self).__init__( 964 | category_index=category_index, 965 | max_examples_to_draw=max_examples_to_draw, 966 | max_boxes_to_draw=max_boxes_to_draw, 967 | min_score_thresh=min_score_thresh, 968 | use_normalized_coordinates=use_normalized_coordinates, 969 | summary_name_prefix=summary_name_prefix) 970 | 971 | def images_from_evaluation_dict(self, eval_dict): 972 | return draw_side_by_side_evaluation_image( 973 | eval_dict, self._category_index, self._max_boxes_to_draw, 974 | self._min_score_thresh, self._use_normalized_coordinates) 975 | --------------------------------------------------------------------------------