├── .gitignore ├── LICENSE ├── README.md ├── convert_weights_pb.py ├── sync_detection_yolo.py ├── utils.py ├── yolo_v4_tiny.json └── yolo_v4_tiny.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-yolov4-tiny 2 | 3 | Adapted from https://github.com/mystic123/tensorflow-yolo-v3 4 | 5 | Refer to how this version(https://github.com/TNTWEN/OpenVINO-YOLOV4) handles split 6 | 7 | Tested on Python 3.6, tensorflow 1.14.0, Ubuntu 18.04, l_openvino_toolkit_p_2020.3.194/2020.4.287 8 | 9 | Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz fps=30 10 | ## Todo list: 11 | - [x] Weights converter to pb 12 | - [x] Syns detect yolo 13 | 14 | ## How to work: 15 | 1. Download COCO class names file: 16 | `wget https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names` 17 | 2. Download v4-Tiny weights: 18 | `wget https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v4_pre/yolov4-tiny.weights` 19 | 3. Run `python convert_weights_pb.py` 20 | 4. Pb converter to IR 21 | `cp ./yolo_v4_tiny.json /opt/intel/openvino/deployment_tools/model_optimizer/extensions/front/tf` 22 | `cd /opt/intel/openvino/deployment_tools/model_optimizer` 23 | `python mo.py --input_model yolov4-tiny.pb --transformations_config ./extensions/front/tf/yolo_v4_tiny.json --batch 1` 24 | 5. Openvino-Object Detection YOLO\* Python Demo 25 | `python sync_detection_yolo.py` 26 | 27 | ####Optional Flags 28 | 1. convert_weights_pb.py: 29 | 1. `--class_names` 30 | 1. Path to the class names file 31 | 2. `--weights_file` 32 | 1. Path to the desired weights file 33 | 3. `--data_format` 34 | 1. `NCHW` (gpu only) or `NHWC` 35 | 4. `--tiny` 36 | 1. Use yolov4-tiny 37 | 6. `--output_graph` 38 | 1. Location to write the output .pb graph to 39 | 2. sync_detection_yolo.py: 40 | 1. `-m` 41 | 1. Path to an .xml file with a trained model. 42 | 2. `-labels` 43 | 1. Path to the coco.names 44 | -------------------------------------------------------------------------------- /convert_weights_pb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import yolo_v3 6 | import yolo_v3_tiny 7 | import yolo_v4_tiny 8 | from PIL import Image, ImageDraw 9 | from utils import load_weights, load_coco_names, detections_boxes, freeze_graph 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | tf.app.flags.DEFINE_string( 14 | 'class_names', 'coco.names', 'File with class names') 15 | tf.app.flags.DEFINE_string( 16 | 'weights_file', 'yolov4-tiny.weights', 'Binary file with detector weights') 17 | tf.app.flags.DEFINE_string( 18 | 'data_format', 'NHWC', 'Data format: NCHW (gpu only) / NHWC') 19 | tf.app.flags.DEFINE_string( 20 | 'output_graph', 'yolov4-tiny.pb', 'Frozen tensorflow protobuf model output path') 21 | tf.app.flags.DEFINE_bool( 22 | 'tiny', True, 'Use tiny version of YOLOv3') 23 | tf.app.flags.DEFINE_integer( 24 | 'size', 416, 'Image size') 25 | 26 | 27 | 28 | def main(argv=None): 29 | if FLAGS.tiny: 30 | # model = yolo_v3_tiny.yolo_v3_tiny 31 | model = yolo_v4_tiny.yolo_v4_tiny 32 | else: 33 | model = yolo_v3.yolo_v3 34 | 35 | classes = load_coco_names(FLAGS.class_names) 36 | 37 | # placeholder for detector inputs 38 | inputs = tf.placeholder(tf.float32, [None, FLAGS.size, FLAGS.size, 3], "inputs") 39 | 40 | with tf.variable_scope('detector'): 41 | detections = model(inputs, len(classes), data_format=FLAGS.data_format) 42 | load_ops = load_weights(tf.global_variables(scope='detector'), FLAGS.weights_file) 43 | 44 | # Sets the output nodes in the current session 45 | boxes = detections_boxes(detections) 46 | 47 | with tf.Session() as sess: 48 | sess.run(load_ops) 49 | freeze_graph(sess, FLAGS.output_graph) 50 | 51 | if __name__ == '__main__': 52 | tf.app.run() 53 | -------------------------------------------------------------------------------- /sync_detection_yolo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | openvino sync 4 | """ 5 | from __future__ import print_function, division 6 | 7 | import logging 8 | import os 9 | import sys 10 | from argparse import ArgumentParser, SUPPRESS 11 | from math import exp as exp 12 | from time import time 13 | from time import perf_counter 14 | 15 | import cv2 16 | from openvino.inference_engine import IENetwork, IECore 17 | 18 | logging.basicConfig(format="[ %(levelname)s ] %(message)s", level=logging.INFO, stream=sys.stdout) 19 | log = logging.getLogger() 20 | 21 | 22 | def build_argparser(): 23 | parser = ArgumentParser(add_help=False) 24 | args = parser.add_argument_group('Options') 25 | args.add_argument("-m", "--model", default='yolov4-tiny.xml', 26 | help="Required. Path to an .xml file with a trained model.", 27 | type=str) 28 | args.add_argument("--labels", help="Optional. Labels mapping file", default='coco.names', type=str) 29 | 30 | args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.') 31 | args.add_argument("-d", "--device", 32 | help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL or MYRIAD is" 33 | " acceptable. The sample will look for a suitable plugin for device specified. " 34 | "Default value is CPU", default="CPU", type=str) 35 | args.add_argument("-t", "--prob_threshold", help="Optional. Probability threshold for detections filtering", 36 | default=0.3, type=float) 37 | args.add_argument("-iout", "--iou_threshold", help="Optional. Intersection over union threshold for overlapping " 38 | "detections filtering", default=0.3, type=float) 39 | args.add_argument("-ni", "--number_iter", help="Optional. Number of inference iterations", default=1, type=int) 40 | args.add_argument("-pc", "--perf_counts", help="Optional. Report performance counters", default=False, 41 | action="store_true") 42 | args.add_argument("-r", "--raw_output_message", help="Optional. Output inference results raw values showing", 43 | default=False, action="store_true") 44 | return parser 45 | 46 | 47 | class YoloParams: 48 | # ------------------------------------------- Extracting layer parameters ------------------------------------------ 49 | # Magic numbers are copied from yolo samples 50 | def __init__(self, param, side): 51 | self.num = 3 if 'num' not in param else int(param['num']) 52 | self.coords = 4 if 'coords' not in param else int(param['coords']) 53 | self.classes = 80 if 'classes' not in param else int(param['classes']) 54 | # self.anchors = [10.0, 13.0, 16.0, 30.0, 33.0, 23.0, 30.0, 61.0, 62.0, 45.0, 59.0, 119.0, 116.0, 90.0, 156.0, 55 | # 198.0, 56 | # 373.0, 326.0] if 'anchors' not in param else [float(a) for a in param['anchors'].split(',')] 57 | self.anchors = [10.0,14.0, 23.0,27.0, 37.0,58.0, 81.0,82.0, 135.0,169.0, 58 | 344.0,319.0] if 'anchors' not in param else [float(a) for a in param['anchors'].split(',')] 59 | 60 | if 'mask' in param: 61 | mask = [int(idx) for idx in param['mask'].split(',')] 62 | self.num = len(mask) 63 | 64 | maskedAnchors = [] 65 | for idx in mask: 66 | maskedAnchors += [self.anchors[idx * 2], self.anchors[idx * 2 + 1]] 67 | self.anchors = maskedAnchors 68 | 69 | self.side = side 70 | self.isYoloV3 = 'mask' in param # Weak way to determine but the only one. 71 | 72 | 73 | def entry_index(side, coord, classes, location, entry): 74 | side_power_2 = side ** 2 75 | n = location // side_power_2 76 | loc = location % side_power_2 77 | return int(side_power_2 * (n * (coord + classes + 1) + entry) + loc) 78 | 79 | 80 | def scale_bbox(x, y, h, w, class_id, confidence, h_scale, w_scale): 81 | xmin = int((x - w / 2) * w_scale) 82 | ymin = int((y - h / 2) * h_scale) 83 | xmax = int(xmin + w * w_scale) 84 | ymax = int(ymin + h * h_scale) 85 | return dict(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, class_id=class_id, confidence=confidence) 86 | 87 | 88 | def parse_yolo_region(blob, resized_image_shape, original_im_shape, params, threshold): 89 | # ------------------------------------------ Validating output parameters ------------------------------------------ 90 | _, _, out_blob_h, out_blob_w = blob.shape 91 | assert out_blob_w == out_blob_h, "Invalid size of output blob. It sould be in NCHW layout and height should " \ 92 | "be equal to width. Current height = {}, current width = {}" \ 93 | "".format(out_blob_h, out_blob_w) 94 | 95 | # ------------------------------------------ Extracting layer parameters ------------------------------------------- 96 | orig_im_h, orig_im_w = original_im_shape 97 | resized_image_h, resized_image_w = resized_image_shape 98 | objects = list() 99 | predictions = blob.flatten() 100 | side_square = params.side * params.side 101 | 102 | # ------------------------------------------- Parsing YOLO Region output ------------------------------------------- 103 | for i in range(side_square): 104 | row = i // params.side 105 | col = i % params.side 106 | for n in range(params.num): 107 | obj_index = entry_index(params.side, params.coords, params.classes, n * side_square + i, params.coords) 108 | scale = predictions[obj_index] 109 | if scale < threshold: 110 | continue 111 | box_index = entry_index(params.side, params.coords, params.classes, n * side_square + i, 0) 112 | # Network produces location predictions in absolute coordinates of feature maps. 113 | # Scale it to relative coordinates. 114 | x = (col + predictions[box_index + 0 * side_square]) / params.side 115 | y = (row + predictions[box_index + 1 * side_square]) / params.side 116 | # Value for exp is very big number in some cases so following construction is using here 117 | try: 118 | w_exp = exp(predictions[box_index + 2 * side_square]) 119 | h_exp = exp(predictions[box_index + 3 * side_square]) 120 | except OverflowError: 121 | continue 122 | # Depends on topology we need to normalize sizes by feature maps (up to YOLOv3) or by input shape (YOLOv3) 123 | w = w_exp * params.anchors[2 * n] / (resized_image_w if params.isYoloV3 else params.side) 124 | h = h_exp * params.anchors[2 * n + 1] / (resized_image_h if params.isYoloV3 else params.side) 125 | for j in range(params.classes): 126 | class_index = entry_index(params.side, params.coords, params.classes, n * side_square + i, 127 | params.coords + 1 + j) 128 | confidence = scale * predictions[class_index] 129 | if confidence < threshold: 130 | continue 131 | objects.append(scale_bbox(x=x, y=y, h=h, w=w, class_id=j, confidence=confidence, 132 | h_scale=orig_im_h, w_scale=orig_im_w)) 133 | return objects 134 | 135 | 136 | def intersection_over_union(box_1, box_2): 137 | width_of_overlap_area = min(box_1['xmax'], box_2['xmax']) - max(box_1['xmin'], box_2['xmin']) 138 | height_of_overlap_area = min(box_1['ymax'], box_2['ymax']) - max(box_1['ymin'], box_2['ymin']) 139 | if width_of_overlap_area < 0 or height_of_overlap_area < 0: 140 | area_of_overlap = 0 141 | else: 142 | area_of_overlap = width_of_overlap_area * height_of_overlap_area 143 | box_1_area = (box_1['ymax'] - box_1['ymin']) * (box_1['xmax'] - box_1['xmin']) 144 | box_2_area = (box_2['ymax'] - box_2['ymin']) * (box_2['xmax'] - box_2['xmin']) 145 | area_of_union = box_1_area + box_2_area - area_of_overlap 146 | if area_of_union == 0: 147 | return 0 148 | return area_of_overlap / area_of_union 149 | 150 | 151 | class ObjectDetection(object): 152 | def __init__(self): 153 | self.args = build_argparser().parse_args() 154 | 155 | model_xml = self.args.model 156 | model_bin = os.path.splitext(model_xml)[0] + ".bin" 157 | 158 | # ------------- 1. Plugin initialization for specified device and load extensions library if specified ------------- 159 | log.info("Creating Inference Engine...") 160 | ie = IECore() 161 | 162 | # -------------------- 2. Reading the IR generated by the Model Optimizer (.xml and .bin files) -------------------- 163 | log.info("Loading network files:\n\t{}\n\t{}".format(model_xml, model_bin)) 164 | self.net = ie.read_network(model=model_xml, weights=model_bin) 165 | 166 | # ---------------------------------- 3. Load CPU extension for support specific layer ------------------------------ 167 | if "CPU" in self.args.device: 168 | supported_layers = ie.query_network(self.net, "CPU") 169 | not_supported_layers = [l for l in self.net.layers.keys() if l not in supported_layers] 170 | if len(not_supported_layers) != 0: 171 | log.error("Following layers are not supported by the plugin for specified device {}:\n {}". 172 | format(self.args.device, ', '.join(not_supported_layers))) 173 | sys.exit(1) 174 | 175 | assert len(self.net.input_info.keys()) == 1, "Sample supports only YOLO V3 based single input topologies" 176 | 177 | # ---------------------------------------------- 4. Preparing inputs ----------------------------------------------- 178 | log.info("Preparing inputs") 179 | self.input_blob = next(iter(self.net.input_info)) 180 | 181 | # Defaulf batch_size is 1 182 | self.net.batch_size = 1 183 | 184 | if self.args.labels: 185 | with open(self.args.labels, 'r') as f: 186 | self.labels_map = [x.strip() for x in f] 187 | else: 188 | self.labels_map = None 189 | 190 | # ----------------------------------------- 5. Loading model to the plugin ----------------------------------------- 191 | log.info("Loading model to the plugin") 192 | self.exec_net = ie.load_network(network=self.net, num_requests=2, device_name=self.args.device) 193 | 194 | def inference(self, frame): 195 | ''' 196 | 197 | :param frame: 198 | :return: 199 | ''' 200 | cur_request_id = 0 201 | parsing_time = 0 202 | # ----------------------------------------------- 6. Doing inference ----------------------------------------------- 203 | is_async_mode = False 204 | while cap.isOpened(): 205 | # Here is the first asynchronous point: in the Async mode, we capture frame to populate the NEXT infer request 206 | # in the regular mode, we capture frame to the CURRENT infer request 207 | if not ret: 208 | break 209 | 210 | # Read and pre-process input images 211 | n, c, h, w = self.net.input_info[self.input_blob].input_data.shape 212 | 213 | request_id = cur_request_id 214 | in_frame = cv2.resize(frame, (w, h)) 215 | 216 | # resize input_frame to network size 217 | in_frame = in_frame.transpose((2, 0, 1)) # Change data layout from HWC to CHW 218 | in_frame = in_frame.reshape((n, c, h, w)) 219 | # Start inference 220 | infer_time = time() 221 | self.exec_net.start_async(request_id=request_id, inputs={self.input_blob: in_frame}) 222 | # exec_net.infer(inputs={self.input_blob: in_frame}) 223 | det_time = time() - infer_time 224 | 225 | # Collecting object detection results 226 | objects = list() 227 | if self.exec_net.requests[cur_request_id].wait(-1) == 0: 228 | output = self.exec_net.requests[cur_request_id].outputs 229 | # for layer_name, out_blob in output.items(): 230 | # print("-----------The layer name of collecting object detection results:----------") 231 | # print(layer_name) 232 | start_time = time() 233 | for layer_name, out_blob in output.items(): 234 | # if layer_name == 'detector/yolo-v4-tiny/strided_slice/Split.0' or layer_name == 'detector/yolo-v4-tiny/strided_slice_1/Split.0' \ 235 | # or layer_name == 'detector/yolo-v4-tiny/strided_slice_2/Split.0': 236 | # pass 237 | # else: 238 | out_blob = out_blob.reshape(self.net.layers[self.net.layers[layer_name].parents[0]].out_data[0].shape) 239 | layer_params = YoloParams(self.net.layers[layer_name].params, out_blob.shape[2]) 240 | objects += parse_yolo_region(out_blob, in_frame.shape[2:], 241 | frame.shape[:-1], layer_params, 242 | self.args.prob_threshold) 243 | parsing_time = time() - start_time 244 | 245 | # Filtering overlapping boxes with respect to the --iou_threshold CLI parameter 246 | objects = sorted(objects, key=lambda obj : obj['confidence'], reverse=True) 247 | for i in range(len(objects)): 248 | if objects[i]['confidence'] == 0: 249 | continue 250 | for j in range(i + 1, len(objects)): 251 | if intersection_over_union(objects[i], objects[j]) > self.args.iou_threshold: 252 | objects[j]['confidence'] = 0 253 | 254 | # Drawing objects with respect to the --prob_threshold CLI parameter 255 | objects = [obj for obj in objects if obj['confidence'] >= self.args.prob_threshold] 256 | 257 | if len(objects) and self.args.raw_output_message: 258 | log.info("\nDetected boxes for batch {}:".format(1)) 259 | log.info(" Class ID | Confidence | XMIN | YMIN | XMAX | YMAX | COLOR ") 260 | 261 | origin_im_size = frame.shape[:-1] 262 | for obj in objects: 263 | # Validation bbox of detected object 264 | if obj['xmax'] > origin_im_size[1] or obj['ymax'] > origin_im_size[0] or obj['xmin'] < 0 or obj['ymin'] < 0: 265 | continue 266 | color = (int(min(obj['class_id'] * 12.5, 255)), 267 | min(obj['class_id'] * 7, 255), min(obj['class_id'] * 5, 255)) 268 | det_label = self.labels_map[obj['class_id']] if self.labels_map and len(self.labels_map) >= obj['class_id'] else \ 269 | str(obj['class_id']) 270 | 271 | if self.args.raw_output_message: 272 | log.info( 273 | "{:^9} | {:10f} | {:4} | {:4} | {:4} | {:4} | {} ".format(det_label, obj['confidence'], obj['xmin'], 274 | obj['ymin'], obj['xmax'], obj['ymax'], 275 | color)) 276 | 277 | cv2.rectangle(frame, (obj['xmin'], obj['ymin']), (obj['xmax'], obj['ymax']), color, 2) 278 | cv2.putText(frame, 279 | "#" + det_label + ' ' + str(round(obj['confidence'] * 100, 1)) + ' %', 280 | (obj['xmin'], obj['ymin'] - 7), cv2.FONT_HERSHEY_COMPLEX, 0.6, color, 1) 281 | 282 | # Draw performance stats over frame 283 | inf_time_message = "" if is_async_mode else \ 284 | "Inference time: {:.3f} ms".format(det_time * 1e3) 285 | async_mode_message = "sync mode is on. Processing request {}".format(cur_request_id) if is_async_mode else \ 286 | '' 287 | parsing_message = "parsing time is {:.3f}".format(parsing_time * 1e3) 288 | 289 | cv2.putText(frame, inf_time_message, (15, 15), cv2.FONT_HERSHEY_COMPLEX, 0.5, (200, 10, 10), 1) 290 | cv2.putText(frame, async_mode_message, (10, int(origin_im_size[0] - 20)), cv2.FONT_HERSHEY_COMPLEX, 0.5, 291 | (10, 10, 200), 1) 292 | cv2.putText(frame, parsing_message, (15, 30), cv2.FONT_HERSHEY_COMPLEX, 0.5, (10, 10, 200), 1) 293 | return frame 294 | 295 | 296 | if __name__ == '__main__': 297 | yolo = ObjectDetection() 298 | cap = cv2.VideoCapture(0) 299 | cv2.namedWindow('frame', cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) 300 | last_start_time = perf_counter() 301 | count_frame = 0 302 | fps_time = 0 303 | while cap.isOpened(): 304 | ret, frame = cap.read() 305 | if not ret: 306 | break 307 | start_time = perf_counter() 308 | out_frame = yolo.inference(frame) 309 | all_time = perf_counter() - start_time 310 | print('The processing time of one frame is', all_time) 311 | cv2.imwrite("result.jpg", frame) 312 | count_frame = count_frame + 1 313 | print("FPS is", count_frame / (perf_counter() - last_start_time)) 314 | cv2.imshow("frame", frame) 315 | key = cv2.waitKey(3) 316 | if key == 27: 317 | break 318 | 319 | cv2.destroyAllWindows() 320 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from PIL import ImageDraw, Image 6 | 7 | 8 | def get_boxes_and_inputs_pb(frozen_graph): 9 | 10 | with frozen_graph.as_default(): 11 | boxes = tf.get_default_graph().get_tensor_by_name("output_boxes:0") 12 | inputs = tf.get_default_graph().get_tensor_by_name("inputs:0") 13 | 14 | return boxes, inputs 15 | 16 | 17 | def get_boxes_and_inputs(model, num_classes, size, data_format): 18 | 19 | inputs = tf.placeholder(tf.float32, [1, size, size, 3]) 20 | 21 | with tf.variable_scope('detector'): 22 | detections = model(inputs, num_classes, 23 | data_format=data_format) 24 | 25 | boxes = detections_boxes(detections) 26 | 27 | return boxes, inputs 28 | 29 | 30 | def load_graph(frozen_graph_filename): 31 | 32 | with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 33 | graph_def = tf.GraphDef() 34 | graph_def.ParseFromString(f.read()) 35 | 36 | with tf.Graph().as_default() as graph: 37 | tf.import_graph_def(graph_def, name="") 38 | 39 | return graph 40 | 41 | 42 | def freeze_graph(sess, output_graph): 43 | 44 | output_node_names = [ 45 | "output_boxes", 46 | "inputs", 47 | ] 48 | output_node_names = ",".join(output_node_names) 49 | 50 | output_graph_def = tf.graph_util.convert_variables_to_constants( 51 | sess, 52 | tf.get_default_graph().as_graph_def(), 53 | output_node_names.split(",") 54 | ) 55 | 56 | with tf.gfile.GFile(output_graph, "wb") as f: 57 | f.write(output_graph_def.SerializeToString()) 58 | print("{} ops written to {}.".format(len(output_graph_def.node), output_graph)) 59 | 60 | 61 | def load_weights(var_list, weights_file): 62 | """ 63 | Loads and converts pre-trained weights. 64 | :param var_list: list of network variables. 65 | :param weights_file: name of the binary file. 66 | :return: list of assign ops 67 | """ 68 | with open(weights_file, "rb") as fp: 69 | _ = np.fromfile(fp, dtype=np.int32, count=5) 70 | 71 | weights = np.fromfile(fp, dtype=np.float32) 72 | 73 | ptr = 0 74 | i = 0 75 | assign_ops = [] 76 | while i < len(var_list) - 1: 77 | var1 = var_list[i] 78 | var2 = var_list[i + 1] 79 | # do something only if we process conv layer 80 | if 'Conv' in var1.name.split('/')[-2]: 81 | # check type of next layer 82 | if 'BatchNorm' in var2.name.split('/')[-2]: 83 | # load batch norm params 84 | gamma, beta, mean, var = var_list[i + 1:i + 5] 85 | batch_norm_vars = [beta, gamma, mean, var] 86 | for var in batch_norm_vars: 87 | shape = var.shape.as_list() 88 | num_params = np.prod(shape) 89 | var_weights = weights[ptr:ptr + num_params].reshape(shape) 90 | ptr += num_params 91 | assign_ops.append( 92 | tf.assign(var, var_weights, validate_shape=True)) 93 | 94 | # we move the pointer by 4, because we loaded 4 variables 95 | i += 4 96 | elif 'Conv' in var2.name.split('/')[-2]: 97 | # load biases 98 | bias = var2 99 | bias_shape = bias.shape.as_list() 100 | bias_params = np.prod(bias_shape) 101 | bias_weights = weights[ptr:ptr + 102 | bias_params].reshape(bias_shape) 103 | ptr += bias_params 104 | assign_ops.append( 105 | tf.assign(bias, bias_weights, validate_shape=True)) 106 | 107 | # we loaded 1 variable 108 | i += 1 109 | # we can load weights of conv layer 110 | shape = var1.shape.as_list() 111 | num_params = np.prod(shape) 112 | print(shape) 113 | var_weights = weights[ptr:ptr + num_params].reshape( 114 | (shape[3], shape[2], shape[0], shape[1])) 115 | # remember to transpose to column-major 116 | var_weights = np.transpose(var_weights, (2, 3, 1, 0)) 117 | ptr += num_params 118 | assign_ops.append( 119 | tf.assign(var1, var_weights, validate_shape=True)) 120 | i += 1 121 | 122 | return assign_ops 123 | 124 | 125 | def detections_boxes(detections): 126 | """ 127 | Converts center x, center y, width and height values to coordinates of top left and bottom right points. 128 | 129 | :param detections: outputs of YOLO v3 detector of shape (?, 10647, (num_classes + 5)) 130 | :return: converted detections of same shape as input 131 | """ 132 | center_x, center_y, width, height, attrs = tf.split( 133 | detections, [1, 1, 1, 1, -1], axis=-1) 134 | w2 = width / 2 135 | h2 = height / 2 136 | x0 = center_x - w2 137 | y0 = center_y - h2 138 | x1 = center_x + w2 139 | y1 = center_y + h2 140 | 141 | boxes = tf.concat([x0, y0, x1, y1], axis=-1) 142 | detections = tf.concat([boxes, attrs], axis=-1, name="output_boxes") 143 | return detections 144 | 145 | 146 | def _iou(box1, box2): 147 | """ 148 | Computes Intersection over Union value for 2 bounding boxes 149 | 150 | :param box1: array of 4 values (top left and bottom right coords): [x0, y0, x1, x2] 151 | :param box2: same as box1 152 | :return: IoU 153 | """ 154 | b1_x0, b1_y0, b1_x1, b1_y1 = box1 155 | b2_x0, b2_y0, b2_x1, b2_y1 = box2 156 | 157 | int_x0 = max(b1_x0, b2_x0) 158 | int_y0 = max(b1_y0, b2_y0) 159 | int_x1 = min(b1_x1, b2_x1) 160 | int_y1 = min(b1_y1, b2_y1) 161 | 162 | int_area = max(int_x1 - int_x0, 0) * max(int_y1 - int_y0, 0) 163 | 164 | b1_area = (b1_x1 - b1_x0) * (b1_y1 - b1_y0) 165 | b2_area = (b2_x1 - b2_x0) * (b2_y1 - b2_y0) 166 | 167 | # we add small epsilon of 1e-05 to avoid division by 0 168 | iou = int_area / (b1_area + b2_area - int_area + 1e-05) 169 | return iou 170 | 171 | 172 | def non_max_suppression(predictions_with_boxes, confidence_threshold, iou_threshold=0.4): 173 | """ 174 | Applies Non-max suppression to prediction boxes. 175 | 176 | :param predictions_with_boxes: 3D numpy array, first 4 values in 3rd dimension are bbox attrs, 5th is confidence 177 | :param confidence_threshold: the threshold for deciding if prediction is valid 178 | :param iou_threshold: the threshold for deciding if two boxes overlap 179 | :return: dict: class -> [(box, score)] 180 | """ 181 | conf_mask = np.expand_dims( 182 | (predictions_with_boxes[:, :, 4] > confidence_threshold), -1) 183 | predictions = predictions_with_boxes * conf_mask 184 | 185 | result = {} 186 | for i, image_pred in enumerate(predictions): 187 | shape = image_pred.shape 188 | non_zero_idxs = np.nonzero(image_pred) 189 | image_pred = image_pred[non_zero_idxs] 190 | image_pred = image_pred.reshape(-1, shape[-1]) 191 | 192 | bbox_attrs = image_pred[:, :5] 193 | classes = image_pred[:, 5:] 194 | classes = np.argmax(classes, axis=-1) 195 | 196 | unique_classes = list(set(classes.reshape(-1))) 197 | 198 | for cls in unique_classes: 199 | cls_mask = classes == cls 200 | cls_boxes = bbox_attrs[np.nonzero(cls_mask)] 201 | cls_boxes = cls_boxes[cls_boxes[:, -1].argsort()[::-1]] 202 | cls_scores = cls_boxes[:, -1] 203 | cls_boxes = cls_boxes[:, :-1] 204 | 205 | while len(cls_boxes) > 0: 206 | box = cls_boxes[0] 207 | score = cls_scores[0] 208 | if cls not in result: 209 | result[cls] = [] 210 | result[cls].append((box, score)) 211 | cls_boxes = cls_boxes[1:] 212 | cls_scores = cls_scores[1:] 213 | ious = np.array([_iou(box, x) for x in cls_boxes]) 214 | iou_mask = ious < iou_threshold 215 | cls_boxes = cls_boxes[np.nonzero(iou_mask)] 216 | cls_scores = cls_scores[np.nonzero(iou_mask)] 217 | 218 | return result 219 | 220 | 221 | def load_coco_names(file_name): 222 | names = {} 223 | with open(file_name) as f: 224 | for id, name in enumerate(f): 225 | names[id] = name 226 | return names 227 | 228 | 229 | def draw_boxes(boxes, img, cls_names, detection_size, is_letter_box_image): 230 | draw = ImageDraw.Draw(img) 231 | 232 | for cls, bboxs in boxes.items(): 233 | color = tuple(np.random.randint(0, 256, 3)) 234 | for box, score in bboxs: 235 | box = convert_to_original_size(box, np.array(detection_size), 236 | np.array(img.size), 237 | is_letter_box_image) 238 | draw.rectangle(box, outline=color) 239 | draw.text(box[:2], '{} {:.2f}%'.format( 240 | cls_names[cls], score * 100), fill=color) 241 | 242 | 243 | def convert_to_original_size(box, size, original_size, is_letter_box_image): 244 | if is_letter_box_image: 245 | box = box.reshape(2, 2) 246 | box[0, :] = letter_box_pos_to_original_pos(box[0, :], size, original_size) 247 | box[1, :] = letter_box_pos_to_original_pos(box[1, :], size, original_size) 248 | else: 249 | ratio = original_size / size 250 | box = box.reshape(2, 2) * ratio 251 | return list(box.reshape(-1)) 252 | 253 | 254 | def letter_box_image(image: Image.Image, output_height: int, output_width: int, fill_value)-> np.ndarray: 255 | """ 256 | Fit image with final image with output_width and output_height. 257 | :param image: PILLOW Image object. 258 | :param output_height: width of the final image. 259 | :param output_width: height of the final image. 260 | :param fill_value: fill value for empty area. Can be uint8 or np.ndarray 261 | :return: numpy image fit within letterbox. dtype=uint8, shape=(output_height, output_width) 262 | """ 263 | 264 | height_ratio = float(output_height)/image.size[1] 265 | width_ratio = float(output_width)/image.size[0] 266 | fit_ratio = min(width_ratio, height_ratio) 267 | fit_height = int(image.size[1] * fit_ratio) 268 | fit_width = int(image.size[0] * fit_ratio) 269 | fit_image = np.asarray(image.resize((fit_width, fit_height), resample=Image.BILINEAR)) 270 | 271 | if isinstance(fill_value, int): 272 | fill_value = np.full(fit_image.shape[2], fill_value, fit_image.dtype) 273 | 274 | to_return = np.tile(fill_value, (output_height, output_width, 1)) 275 | pad_top = int(0.5 * (output_height - fit_height)) 276 | pad_left = int(0.5 * (output_width - fit_width)) 277 | to_return[pad_top:pad_top+fit_height, pad_left:pad_left+fit_width] = fit_image 278 | return to_return 279 | 280 | 281 | def letter_box_pos_to_original_pos(letter_pos, current_size, ori_image_size)-> np.ndarray: 282 | """ 283 | Parameters should have same shape and dimension space. (Width, Height) or (Height, Width) 284 | :param letter_pos: The current position within letterbox image including fill value area. 285 | :param current_size: The size of whole image including fill value area. 286 | :param ori_image_size: The size of image before being letter boxed. 287 | :return: 288 | """ 289 | letter_pos = np.asarray(letter_pos, dtype=np.float) 290 | current_size = np.asarray(current_size, dtype=np.float) 291 | ori_image_size = np.asarray(ori_image_size, dtype=np.float) 292 | final_ratio = min(current_size[0]/ori_image_size[0], current_size[1]/ori_image_size[1]) 293 | pad = 0.5 * (current_size - final_ratio * ori_image_size) 294 | pad = pad.astype(np.int32) 295 | to_return_pos = (letter_pos - pad) / final_ratio 296 | return to_return_pos 297 | -------------------------------------------------------------------------------- /yolo_v4_tiny.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "TFYOLOV3", 4 | "match_kind": "general", 5 | "custom_attributes": { 6 | "classes": 80, 7 | "anchors": [10, 14, 23, 27, 37, 58, 81, 82, 135, 169, 344, 319], 8 | "coords": 4, 9 | "num": 6, 10 | "masks": [[3, 4, 5], [1, 2, 3]], 11 | "entry_points": ["detector/yolo-v4-tiny/Reshape", "detector/yolo-v4-tiny/Reshape_4"] 12 | } 13 | } 14 | ] 15 | -------------------------------------------------------------------------------- /yolo_v4_tiny.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from builtins import * 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | slim = tf.contrib.slim 8 | 9 | _BATCH_NORM_DECAY = 0.9 10 | _BATCH_NORM_EPSILON = 1e-05 11 | _LEAKY_RELU = 0.1 12 | 13 | _ANCHORS = [(10, 14), (23, 27), (37, 58), 14 | (81, 82), (135, 169), (344, 319)] 15 | 16 | 17 | def route_group(input_layer, in_channels, data_format): 18 | # convs = tf.split(input_layer, num_or_size_splits=groups, axis=-1) 19 | # return convs[group_id] 20 | split = input_layer[:, in_channels//2:, :, :] if data_format == "NCHW" else input_layer[:, :, :, in_channels // 2:] 21 | return split 22 | 23 | 24 | def _get_size(shape, data_format): 25 | if len(shape) == 4: 26 | shape = shape[1:] 27 | return shape[1:3] if data_format == 'NCHW' else shape[0:2] 28 | 29 | 30 | def _conv2d_fixed_padding(inputs, filters, kernel_size, strides=1): 31 | if strides > 1: 32 | inputs = _fixed_padding(inputs, kernel_size) 33 | inputs = slim.conv2d(inputs, filters, kernel_size, stride=strides, 34 | padding=('SAME' if strides == 1 else 'VALID')) 35 | return inputs 36 | 37 | 38 | def _detection_layer(inputs, num_classes, anchors, img_size, data_format): 39 | num_anchors = len(anchors) 40 | predictions = slim.conv2d(inputs, num_anchors * (5 + num_classes), 1, 41 | stride=1, normalizer_fn=None, 42 | activation_fn=None, 43 | biases_initializer=tf.zeros_initializer()) 44 | 45 | shape = predictions.get_shape().as_list() 46 | grid_size = _get_size(shape, data_format) 47 | dim = grid_size[0] * grid_size[1] 48 | bbox_attrs = 5 + num_classes 49 | 50 | if data_format == 'NCHW': 51 | predictions = tf.reshape( 52 | predictions, [-1, num_anchors * bbox_attrs, dim]) 53 | predictions = tf.transpose(predictions, [0, 2, 1]) 54 | 55 | predictions = tf.reshape(predictions, [-1, num_anchors * dim, bbox_attrs]) 56 | 57 | stride = (img_size[0] // grid_size[0], img_size[1] // grid_size[1]) 58 | 59 | anchors = [(a[0] / stride[0], a[1] / stride[1]) for a in anchors] 60 | 61 | box_centers, box_sizes, confidence, classes = tf.split( 62 | predictions, [2, 2, 1, num_classes], axis=-1) 63 | 64 | box_centers = tf.nn.sigmoid(box_centers) 65 | confidence = tf.nn.sigmoid(confidence) 66 | 67 | grid_x = tf.range(grid_size[0], dtype=tf.float32) 68 | grid_y = tf.range(grid_size[1], dtype=tf.float32) 69 | a, b = tf.meshgrid(grid_x, grid_y) 70 | 71 | x_offset = tf.reshape(a, (-1, 1)) 72 | y_offset = tf.reshape(b, (-1, 1)) 73 | 74 | x_y_offset = tf.concat([x_offset, y_offset], axis=-1) 75 | x_y_offset = tf.reshape(tf.tile(x_y_offset, [1, num_anchors]), [1, -1, 2]) 76 | 77 | box_centers = box_centers + x_y_offset 78 | box_centers = box_centers * stride 79 | 80 | anchors = tf.tile(anchors, [dim, 1]) 81 | box_sizes = tf.exp(box_sizes) * anchors 82 | box_sizes = box_sizes * stride 83 | 84 | detections = tf.concat([box_centers, box_sizes, confidence], axis=-1) 85 | 86 | classes = tf.nn.sigmoid(classes) 87 | predictions = tf.concat([detections, classes], axis=-1) 88 | return predictions 89 | 90 | 91 | def _upsample(inputs, out_shape, data_format='NCHW'): 92 | # tf.image.resize_nearest_neighbor accepts input in format NHWC 93 | if data_format == 'NCHW': 94 | inputs = tf.transpose(inputs, [0, 2, 3, 1]) 95 | 96 | if data_format == 'NCHW': 97 | new_height = out_shape[3] 98 | new_width = out_shape[2] 99 | else: 100 | new_height = out_shape[2] 101 | new_width = out_shape[1] 102 | 103 | inputs = tf.image.resize_nearest_neighbor(inputs, (new_height, new_width)) 104 | 105 | # back to NCHW if needed 106 | if data_format == 'NCHW': 107 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 108 | 109 | inputs = tf.identity(inputs, name='upsampled') 110 | return inputs 111 | 112 | 113 | @tf.contrib.framework.add_arg_scope 114 | def _fixed_padding(inputs, kernel_size, *args, mode='CONSTANT', **kwargs): 115 | """ 116 | Pads the input along the spatial dimensions independently of input size. 117 | 118 | Args: 119 | inputs: A tensor of size [batch, channels, height_in, width_in] or 120 | [batch, height_in, width_in, channels] depending on data_format. 121 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation. 122 | Should be a positive integer. 123 | data_format: The input format ('NHWC' or 'NCHW'). 124 | mode: The mode for tf.pad. 125 | 126 | Returns: 127 | A tensor with the same format as the input with the data either intact 128 | (if kernel_size == 1) or padded (if kernel_size > 1). 129 | """ 130 | pad_total = kernel_size - 1 131 | pad_beg = pad_total // 2 132 | pad_end = pad_total - pad_beg 133 | 134 | if kwargs['data_format'] == 'NCHW': 135 | padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], 136 | [pad_beg, pad_end], 137 | [pad_beg, pad_end]], 138 | mode=mode) 139 | else: 140 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 141 | [pad_beg, pad_end], [0, 0]], mode=mode) 142 | return padded_inputs 143 | 144 | 145 | def yolo_v4_tiny(inputs, num_classes, is_training=False, data_format='NCHW', reuse=False): 146 | """ 147 | Creates YOLO v4 tiny model. 148 | 149 | :param inputs: a 4-D tensor of size [batch_size, height, width, channels]. 150 | Dimension batch_size may be undefined. The channel order is RGB. 151 | :param num_classes: number of predicted classes. 152 | :param is_training: whether is training or not. 153 | :param data_format: data format NCHW or NHWC. 154 | :param reuse: whether or not the network and its variables should be reused. 155 | :return: 156 | """ 157 | # it will be needed later on 158 | img_size = inputs.get_shape().as_list()[1:3] 159 | 160 | # transpose the inputs to NCHW 161 | if data_format == 'NCHW': 162 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 163 | 164 | # normalize values to range [0..1] 165 | inputs = inputs / 255 166 | 167 | # set batch norm params 168 | batch_norm_params = { 169 | 'decay': _BATCH_NORM_DECAY, 170 | 'epsilon': _BATCH_NORM_EPSILON, 171 | 'scale': True, 172 | 'is_training': is_training, 173 | 'fused': None, # Use fused batch norm if possible. 174 | } 175 | 176 | # Set activation_fn and parameters for conv2d, batch_norm. 177 | with slim.arg_scope([slim.conv2d, slim.batch_norm, _fixed_padding, slim.max_pool2d], data_format=data_format): 178 | with slim.arg_scope([slim.conv2d, slim.batch_norm, _fixed_padding], reuse=reuse): 179 | with slim.arg_scope([slim.conv2d], 180 | normalizer_fn=slim.batch_norm, 181 | normalizer_params=batch_norm_params, 182 | biases_initializer=None, 183 | activation_fn=lambda x: tf.nn.leaky_relu(x, alpha=_LEAKY_RELU)): 184 | 185 | with tf.variable_scope('yolo-v4-tiny'): 186 | for i in range(2): 187 | inputs = _conv2d_fixed_padding( 188 | inputs, 16 * pow(2, i+1), 3, 2) 189 | #cspdarknet 190 | inputs = _conv2d_fixed_padding( 191 | inputs, 64, 3) 192 | inputs_2 = inputs 193 | inputs = route_group(inputs, 64, data_format) 194 | inputs = _conv2d_fixed_padding( 195 | inputs, 32, 3) 196 | inputs_4 = inputs 197 | inputs = _conv2d_fixed_padding(inputs, 32, 3) 198 | inputs = tf.concat([inputs, inputs_4], 199 | axis=1 if data_format == 'NCHW' else 3) 200 | inputs = _conv2d_fixed_padding(inputs, 64, 1) 201 | inputs = tf.concat([inputs_2, inputs], 202 | axis=1 if data_format == 'NCHW' else 3) 203 | inputs = slim.max_pool2d( 204 | inputs, [2, 2], scope='pool2') 205 | 206 | 207 | inputs = _conv2d_fixed_padding( 208 | inputs, 128, 3) 209 | inputs_10 = inputs 210 | inputs = route_group(inputs, 128, data_format) 211 | inputs = _conv2d_fixed_padding( 212 | inputs, 64, 3) 213 | inputs_12 = inputs 214 | inputs = _conv2d_fixed_padding(inputs, 64, 3) 215 | inputs = tf.concat([inputs, inputs_12], 216 | axis=1 if data_format == 'NCHW' else 3) 217 | inputs = _conv2d_fixed_padding(inputs, 128, 1) 218 | inputs = tf.concat([inputs_10, inputs], 219 | axis=1 if data_format == 'NCHW' else 3) 220 | inputs = slim.max_pool2d( 221 | inputs, [2, 2], scope='pool2') 222 | 223 | 224 | inputs = _conv2d_fixed_padding( 225 | inputs, 256, 3) 226 | inputs_18 = inputs 227 | inputs = route_group(inputs, 256, data_format) 228 | inputs = _conv2d_fixed_padding( 229 | inputs, 128, 3) 230 | inputs_20 = inputs 231 | inputs = _conv2d_fixed_padding(inputs, 128, 3) 232 | inputs = tf.concat([inputs, inputs_20], 233 | axis=1 if data_format == 'NCHW' else 3) 234 | inputs = _conv2d_fixed_padding(inputs, 256, 1) 235 | inputs_23 = inputs 236 | inputs = tf.concat([inputs_18, inputs], 237 | axis=1 if data_format == 'NCHW' else 3) 238 | inputs = slim.max_pool2d( 239 | inputs, [2, 2], scope='pool2') 240 | 241 | inputs = _conv2d_fixed_padding(inputs, 512, 3) 242 | inputs = _conv2d_fixed_padding(inputs, 256, 1) 243 | inputs_27 = inputs 244 | inputs = _conv2d_fixed_padding(inputs, 512, 3) 245 | detect_1 = _detection_layer( 246 | inputs, num_classes, _ANCHORS[3:6], img_size, data_format) 247 | detect_1 = tf.identity(detect_1, name='detect_1') 248 | inputs = inputs_27 249 | inputs = _conv2d_fixed_padding(inputs, 128, 1) 250 | upsample_size = inputs.get_shape().as_list() 251 | upsample_size[1] = upsample_size[1] * 2 252 | upsample_size[2] = upsample_size[2] * 2 253 | inputs = _upsample(inputs, upsample_size, data_format) 254 | inputs = tf.concat([inputs_23, inputs], 255 | axis=1 if data_format == 'NCHW' else 3) 256 | inputs = _conv2d_fixed_padding(inputs, 256, 3) 257 | 258 | detect_2 = _detection_layer( 259 | inputs, num_classes, _ANCHORS[1:4], img_size, data_format) 260 | detect_2 = tf.identity(detect_2, name='detect_2') 261 | 262 | detections = tf.concat([detect_1, detect_2], axis=1) 263 | detections = tf.identity(detections, name='detections') 264 | return detections 265 | --------------------------------------------------------------------------------