├── .gitignore ├── __init__.py ├── analytics ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── tracking.cpython-36.pyc └── tracking.py ├── detect_object.py ├── detection └── data │ └── mscoco_label_map.pbtxt ├── objection_detection_app.py ├── readme.md ├── requirements.txt ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── label_map_util.cpython-36.pyc │ ├── string_int_label_map_pb2.cpython-36.pyc │ └── webcam.cpython-36.pyc ├── label_map_util.py ├── string_int_label_map_pb2.py └── webcam.py └── video_writer.py /.gitignore: -------------------------------------------------------------------------------- 1 | tf_models/ 2 | *.csv 3 | *.mp4 4 | __pycache__/ -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/__init__.py -------------------------------------------------------------------------------- /analytics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/analytics/__init__.py -------------------------------------------------------------------------------- /analytics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/analytics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /analytics/__pycache__/tracking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/analytics/__pycache__/tracking.cpython-36.pyc -------------------------------------------------------------------------------- /analytics/tracking.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import csv 3 | import cv2 4 | import os 5 | 6 | # tracks and tallys what's present at any particular time on screen 7 | # puts stats in a csv file in another directory 8 | # [['person: 98%'], ['book: 55%'], ['book: 50%']] 9 | class ObjectTracker(object): 10 | def __init__(self, path, file_name): 11 | self.class_counts = {} 12 | self.occupancy = False 13 | self.fp = open(os.path.join(path, file_name), 'w') 14 | self.writer = csv.DictWriter(self.fp, fieldnames=['frame', 'detections']) 15 | self.writer.writeheader() 16 | self.prev = None 17 | 18 | 19 | def update_class_counts(self, class_names): 20 | frame_counts = defaultdict(int) 21 | for item in class_names: 22 | count_item = item[0].split(':')[0] 23 | frame_counts[count_item] += 1 24 | 25 | # sort this dictionary? 26 | self.class_counts = frame_counts 27 | 28 | def update_person_status(self, class_names): 29 | for item in class_names: 30 | if item[0].split(':')[0] == 'person': 31 | self.occupancy = True 32 | return 33 | self.occupancy = False 34 | 35 | 36 | def write_to_report(self, frame_number): 37 | self.writer.writerow({'frame': frame_number, 'detections': self.class_counts}) 38 | 39 | 40 | def __call__(self, context): 41 | self.update_class_counts(context['class_names']) 42 | self.update_person_status(context['class_names']) 43 | frame = context['frame'] 44 | font = cv2.FONT_HERSHEY_SIMPLEX 45 | for point, name, color in zip(context['rec_points'], context['class_names'], context['class_colors']): 46 | 47 | cv2.rectangle(frame, (int(point['xmin'] * context['width']), int(point['ymin'] * context['height'])), 48 | (int(point['xmax'] * context['width']), int(point['ymax'] * context['height'])), color, 3) 49 | cv2.rectangle(frame, (int(point['xmin'] * context['width']), int(point['ymin'] * context['height'])), 50 | (int(point['xmin'] * context['width']) + len(name[0]) * 6, 51 | int(point['ymin'] * context['height']) - 10), color, -1, cv2.LINE_AA) 52 | cv2.putText(frame, name[0], (int(point['xmin'] * context['width']), int(point['ymin'] * context['height'])), font, 53 | 0.3, (0, 0, 0), 1) 54 | 55 | cv2.rectangle(frame, (0, 0), (frame.shape[1], 50), (0, 0, 0), cv2.FILLED) 56 | cv2.putText(frame, ("Room occupied: {occupied}".format(occupied=self.occupancy)), (30, 30), 57 | font, 0.6, (255, 255, 255), 1) 58 | 59 | if len(list(self.class_counts.keys())) > 0: 60 | key_1 = str(list(self.class_counts.keys())[0]) 61 | cv2.putText(frame, (key_1 + ':' + str(self.class_counts[key_1])), (int(frame.shape[1] * 0.85), 30), font, 0.6, (255, 255, 255), 1) 62 | 63 | self.write_to_report(context['frame_number']) 64 | 65 | return frame 66 | -------------------------------------------------------------------------------- /detect_object.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from utils import label_map_util 4 | from utils.webcam import draw_boxes_and_labels 5 | 6 | 7 | CWD_PATH = os.getcwd() 8 | PATH_TO_LABELS = os.path.join(CWD_PATH, 'detection', 'data', 'mscoco_label_map.pbtxt') 9 | 10 | NUM_CLASSES = 90 11 | # label map 12 | label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 13 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, 14 | use_display_name=True) 15 | category_index = label_map_util.create_category_index(categories) 16 | 17 | # pass in image_np, returns 18 | def detect_objects(image_np, sess, detection_graph): 19 | 20 | image_np_expanded = np.expand_dims(image_np, axis=0) 21 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 22 | 23 | boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 24 | 25 | scores = detection_graph.get_tensor_by_name('detection_scores:0') 26 | classes = detection_graph.get_tensor_by_name('detection_classes:0') 27 | num_detections = detection_graph.get_tensor_by_name('num_detections:0') 28 | 29 | # Do the detection/model prediction here 30 | (boxes, scores, classes, num_detections) = sess.run( 31 | [boxes, scores, classes, num_detections], 32 | feed_dict={image_tensor: image_np_expanded}) 33 | 34 | rect_points, class_names, class_colors = draw_boxes_and_labels( 35 | boxes=np.squeeze(boxes), 36 | classes=np.squeeze(classes).astype(np.int32), 37 | scores=np.squeeze(scores), 38 | category_index=category_index, 39 | min_score_thresh=.5 40 | ) 41 | 42 | return dict(rect_points=rect_points, class_names=class_names, class_colors=class_colors) -------------------------------------------------------------------------------- /detection/data/mscoco_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "/m/01g317" 3 | id: 1 4 | display_name: "person" 5 | } 6 | item { 7 | name: "/m/0199g" 8 | id: 2 9 | display_name: "bicycle" 10 | } 11 | item { 12 | name: "/m/0k4j" 13 | id: 3 14 | display_name: "car" 15 | } 16 | item { 17 | name: "/m/04_sv" 18 | id: 4 19 | display_name: "motorcycle" 20 | } 21 | item { 22 | name: "/m/05czz6l" 23 | id: 5 24 | display_name: "airplane" 25 | } 26 | item { 27 | name: "/m/01bjv" 28 | id: 6 29 | display_name: "bus" 30 | } 31 | item { 32 | name: "/m/07jdr" 33 | id: 7 34 | display_name: "train" 35 | } 36 | item { 37 | name: "/m/07r04" 38 | id: 8 39 | display_name: "truck" 40 | } 41 | item { 42 | name: "/m/019jd" 43 | id: 9 44 | display_name: "boat" 45 | } 46 | item { 47 | name: "/m/015qff" 48 | id: 10 49 | display_name: "traffic light" 50 | } 51 | item { 52 | name: "/m/01pns0" 53 | id: 11 54 | display_name: "fire hydrant" 55 | } 56 | item { 57 | name: "/m/02pv19" 58 | id: 13 59 | display_name: "stop sign" 60 | } 61 | item { 62 | name: "/m/015qbp" 63 | id: 14 64 | display_name: "parking meter" 65 | } 66 | item { 67 | name: "/m/0cvnqh" 68 | id: 15 69 | display_name: "bench" 70 | } 71 | item { 72 | name: "/m/015p6" 73 | id: 16 74 | display_name: "bird" 75 | } 76 | item { 77 | name: "/m/01yrx" 78 | id: 17 79 | display_name: "cat" 80 | } 81 | item { 82 | name: "/m/0bt9lr" 83 | id: 18 84 | display_name: "dog" 85 | } 86 | item { 87 | name: "/m/03k3r" 88 | id: 19 89 | display_name: "horse" 90 | } 91 | item { 92 | name: "/m/07bgp" 93 | id: 20 94 | display_name: "sheep" 95 | } 96 | item { 97 | name: "/m/01xq0k1" 98 | id: 21 99 | display_name: "cow" 100 | } 101 | item { 102 | name: "/m/0bwd_0j" 103 | id: 22 104 | display_name: "elephant" 105 | } 106 | item { 107 | name: "/m/01dws" 108 | id: 23 109 | display_name: "bear" 110 | } 111 | item { 112 | name: "/m/0898b" 113 | id: 24 114 | display_name: "zebra" 115 | } 116 | item { 117 | name: "/m/03bk1" 118 | id: 25 119 | display_name: "giraffe" 120 | } 121 | item { 122 | name: "/m/01940j" 123 | id: 27 124 | display_name: "backpack" 125 | } 126 | item { 127 | name: "/m/0hnnb" 128 | id: 28 129 | display_name: "umbrella" 130 | } 131 | item { 132 | name: "/m/080hkjn" 133 | id: 31 134 | display_name: "handbag" 135 | } 136 | item { 137 | name: "/m/01rkbr" 138 | id: 32 139 | display_name: "tie" 140 | } 141 | item { 142 | name: "/m/01s55n" 143 | id: 33 144 | display_name: "suitcase" 145 | } 146 | item { 147 | name: "/m/02wmf" 148 | id: 34 149 | display_name: "frisbee" 150 | } 151 | item { 152 | name: "/m/071p9" 153 | id: 35 154 | display_name: "skis" 155 | } 156 | item { 157 | name: "/m/06__v" 158 | id: 36 159 | display_name: "snowboard" 160 | } 161 | item { 162 | name: "/m/018xm" 163 | id: 37 164 | display_name: "sports ball" 165 | } 166 | item { 167 | name: "/m/02zt3" 168 | id: 38 169 | display_name: "kite" 170 | } 171 | item { 172 | name: "/m/03g8mr" 173 | id: 39 174 | display_name: "baseball bat" 175 | } 176 | item { 177 | name: "/m/03grzl" 178 | id: 40 179 | display_name: "baseball glove" 180 | } 181 | item { 182 | name: "/m/06_fw" 183 | id: 41 184 | display_name: "skateboard" 185 | } 186 | item { 187 | name: "/m/019w40" 188 | id: 42 189 | display_name: "surfboard" 190 | } 191 | item { 192 | name: "/m/0dv9c" 193 | id: 43 194 | display_name: "tennis racket" 195 | } 196 | item { 197 | name: "/m/04dr76w" 198 | id: 44 199 | display_name: "bottle" 200 | } 201 | item { 202 | name: "/m/09tvcd" 203 | id: 46 204 | display_name: "wine glass" 205 | } 206 | item { 207 | name: "/m/08gqpm" 208 | id: 47 209 | display_name: "cup" 210 | } 211 | item { 212 | name: "/m/0dt3t" 213 | id: 48 214 | display_name: "fork" 215 | } 216 | item { 217 | name: "/m/04ctx" 218 | id: 49 219 | display_name: "knife" 220 | } 221 | item { 222 | name: "/m/0cmx8" 223 | id: 50 224 | display_name: "spoon" 225 | } 226 | item { 227 | name: "/m/04kkgm" 228 | id: 51 229 | display_name: "bowl" 230 | } 231 | item { 232 | name: "/m/09qck" 233 | id: 52 234 | display_name: "banana" 235 | } 236 | item { 237 | name: "/m/014j1m" 238 | id: 53 239 | display_name: "apple" 240 | } 241 | item { 242 | name: "/m/0l515" 243 | id: 54 244 | display_name: "sandwich" 245 | } 246 | item { 247 | name: "/m/0cyhj_" 248 | id: 55 249 | display_name: "orange" 250 | } 251 | item { 252 | name: "/m/0hkxq" 253 | id: 56 254 | display_name: "broccoli" 255 | } 256 | item { 257 | name: "/m/0fj52s" 258 | id: 57 259 | display_name: "carrot" 260 | } 261 | item { 262 | name: "/m/01b9xk" 263 | id: 58 264 | display_name: "hot dog" 265 | } 266 | item { 267 | name: "/m/0663v" 268 | id: 59 269 | display_name: "pizza" 270 | } 271 | item { 272 | name: "/m/0jy4k" 273 | id: 60 274 | display_name: "donut" 275 | } 276 | item { 277 | name: "/m/0fszt" 278 | id: 61 279 | display_name: "cake" 280 | } 281 | item { 282 | name: "/m/01mzpv" 283 | id: 62 284 | display_name: "chair" 285 | } 286 | item { 287 | name: "/m/02crq1" 288 | id: 63 289 | display_name: "couch" 290 | } 291 | item { 292 | name: "/m/03fp41" 293 | id: 64 294 | display_name: "potted plant" 295 | } 296 | item { 297 | name: "/m/03ssj5" 298 | id: 65 299 | display_name: "bed" 300 | } 301 | item { 302 | name: "/m/04bcr3" 303 | id: 67 304 | display_name: "dining table" 305 | } 306 | item { 307 | name: "/m/09g1w" 308 | id: 70 309 | display_name: "toilet" 310 | } 311 | item { 312 | name: "/m/07c52" 313 | id: 72 314 | display_name: "tv" 315 | } 316 | item { 317 | name: "/m/01c648" 318 | id: 73 319 | display_name: "laptop" 320 | } 321 | item { 322 | name: "/m/020lf" 323 | id: 74 324 | display_name: "mouse" 325 | } 326 | item { 327 | name: "/m/0qjjc" 328 | id: 75 329 | display_name: "remote" 330 | } 331 | item { 332 | name: "/m/01m2v" 333 | id: 76 334 | display_name: "keyboard" 335 | } 336 | item { 337 | name: "/m/050k8" 338 | id: 77 339 | display_name: "cell phone" 340 | } 341 | item { 342 | name: "/m/0fx9l" 343 | id: 78 344 | display_name: "microwave" 345 | } 346 | item { 347 | name: "/m/029bxz" 348 | id: 79 349 | display_name: "oven" 350 | } 351 | item { 352 | name: "/m/01k6s3" 353 | id: 80 354 | display_name: "toaster" 355 | } 356 | item { 357 | name: "/m/0130jx" 358 | id: 81 359 | display_name: "sink" 360 | } 361 | item { 362 | name: "/m/040b_t" 363 | id: 82 364 | display_name: "refrigerator" 365 | } 366 | item { 367 | name: "/m/0bt_c3" 368 | id: 84 369 | display_name: "book" 370 | } 371 | item { 372 | name: "/m/01x3z" 373 | id: 85 374 | display_name: "clock" 375 | } 376 | item { 377 | name: "/m/02s195" 378 | id: 86 379 | display_name: "vase" 380 | } 381 | item { 382 | name: "/m/01lsmm" 383 | id: 87 384 | display_name: "scissors" 385 | } 386 | item { 387 | name: "/m/0kmg4" 388 | id: 88 389 | display_name: "teddy bear" 390 | } 391 | item { 392 | name: "/m/03wvsk" 393 | id: 89 394 | display_name: "hair drier" 395 | } 396 | item { 397 | name: "/m/012xff" 398 | id: 90 399 | display_name: "toothbrush" 400 | } -------------------------------------------------------------------------------- /objection_detection_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import argparse 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from utils.webcam import FPS, WebcamVideoStream 9 | from queue import Queue 10 | from threading import Thread 11 | from analytics.tracking import ObjectTracker 12 | from video_writer import VideoWriter 13 | from detect_object import detect_objects 14 | 15 | CWD_PATH = os.getcwd() 16 | 17 | MODEL_NAME = 'ssd_mobilenet_v1_coco_2017_11_17' 18 | PATH_TO_MODEL = os.path.join(CWD_PATH, 'detection', 'tf_models', MODEL_NAME, 'frozen_inference_graph.pb') 19 | PATH_TO_VIDEO = os.path.join(CWD_PATH, 'input.mp4') 20 | 21 | 22 | def worker(input_q, output_q): 23 | # load the frozen tensorflow model into memory 24 | detection_graph = tf.Graph() 25 | with detection_graph.as_default(): 26 | od_graph_def = tf.GraphDef() 27 | with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid: 28 | serialized_graph = fid.read() 29 | od_graph_def.ParseFromString(serialized_graph) 30 | tf.import_graph_def(od_graph_def, name='') 31 | 32 | sess = tf.Session(graph=detection_graph) 33 | 34 | fps = FPS().start() 35 | while True: 36 | fps.update() 37 | frame = input_q.get() 38 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 39 | output_q.put(detect_objects(frame, sess, detection_graph)) 40 | 41 | fps.stop() 42 | sess.close() 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('-src', '--source', dest='video_source', type=int, 48 | default=0, help='Device index of the camera.') 49 | parser.add_argument('-wd', '--width', dest='width', type=int, 50 | default=1280, help='Width of the frames in the video stream.') 51 | parser.add_argument('-ht', '--height', dest='height', type=int, 52 | default=720, help='Height of the frames in the video stream.') 53 | args = parser.parse_args() 54 | 55 | input_q = Queue(5) 56 | output_q = Queue() 57 | for i in range(1): 58 | t = Thread(target=worker, args=(input_q, output_q)) 59 | t.daemon = True 60 | t.start() 61 | 62 | video_capture = WebcamVideoStream(src=args.video_source, 63 | width=args.width, 64 | height=args.height).start() 65 | writer = VideoWriter('output.mp4', (args.width, args.height)) 66 | 67 | ''' 68 | stream = cv2.VideoCapture(0) 69 | stream.set(cv2.CAP_PROP_FRAME_WIDTH, args.width) 70 | stream.set(cv2.CAP_PROP_FRAME_HEIGHT, args.height) 71 | ''' 72 | 73 | fps = FPS().start() 74 | object_tracker = ObjectTracker(path='./', file_name='report.csv') 75 | while True: 76 | frame = video_capture.read() 77 | # (ret, frame) = stream.read() 78 | fps.update() 79 | 80 | if fps.get_numFrames() % 2 != 0: 81 | continue 82 | 83 | # put data into the input queue 84 | input_q.put(frame) 85 | 86 | t = time.time() 87 | 88 | if output_q.empty(): 89 | pass # fill up queue 90 | else: 91 | data = output_q.get() 92 | context = {'frame': frame, 'class_names': data['class_names'], 'rec_points': data['rect_points'], 'class_colors': data['class_colors'], 93 | 'width': args.width, 'height': args.height, 'frame_number': fps.get_numFrames()} 94 | new_frame = object_tracker(context) 95 | writer(new_frame) 96 | cv2.imshow('Video', new_frame) 97 | 98 | print('[INFO] elapsed time: {:.2f}'.format(time.time() - t)) 99 | 100 | if cv2.waitKey(1) & 0xFF == ord('q'): 101 | break 102 | 103 | fps.stop() 104 | print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed())) 105 | print('[INFO] approx. FPS: {:.2f}'.format(fps.fps())) 106 | 107 | video_capture.stop() 108 | writer.close() 109 | cv2.destroyAllWindows() 110 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Real time object detection application using tensorflow. 2 | 3 | ### Detecting the presence of people in a room in a live video feed 4 | 5 | Using single shot detection (SSD) deep learning model trained on the COCO image dataset. My project is inspired by this repo https://github.com/datitran/object_detector_app 6 | 7 | Please ensure you install all dependencies as per the requirements.txt file 8 | 9 | Running the app: python3 objection_detection_app.py 10 | 11 | 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==3.4.0 2 | matplotlib==2.0.0 3 | numpy==1.12.0 4 | tensorflow==1.8.0 5 | six==1.6.0 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/label_map_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/utils/__pycache__/label_map_util.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/string_int_label_map_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/utils/__pycache__/string_int_label_map_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/webcam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schumanzhang/object_detection_real_time/4ba599b530e9367d239c6f6fb989567923a05655/utils/__pycache__/webcam.cpython-36.pyc -------------------------------------------------------------------------------- /utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | """Label map utility functions.""" 2 | 3 | import logging 4 | 5 | import tensorflow as tf 6 | from google.protobuf import text_format 7 | from utils import string_int_label_map_pb2 8 | 9 | 10 | def create_category_index(categories): 11 | """Creates dictionary of COCO compatible categories keyed by category id. 12 | Args: 13 | categories: a list of dicts, each of which has the following keys: 14 | 'id': (required) an integer id uniquely identifying this category. 15 | 'name': (required) string representing category name 16 | e.g., 'cat', 'dog', 'pizza'. 17 | Returns: 18 | category_index: a dict containing the same entries as categories, but keyed 19 | by the 'id' field of each category. 20 | """ 21 | category_index = {} 22 | for cat in categories: 23 | category_index[cat['id']] = cat 24 | return category_index 25 | 26 | 27 | def convert_label_map_to_categories(label_map, 28 | max_num_classes, 29 | use_display_name=True): 30 | """Loads label map proto and returns categories list compatible with eval. 31 | This function loads a label map and returns a list of dicts, each of which 32 | has the following keys: 33 | 'id': (required) an integer id uniquely identifying this category. 34 | 'name': (required) string representing category name 35 | e.g., 'cat', 'dog', 'pizza'. 36 | We only allow class into the list if its id-label_id_offset is 37 | between 0 (inclusive) and max_num_classes (exclusive). 38 | If there are several items mapping to the same id in the label map, 39 | we will only keep the first one in the categories list. 40 | Args: 41 | label_map: a StringIntLabelMapProto or None. If None, a default categories 42 | list is created with max_num_classes categories. 43 | max_num_classes: maximum number of (consecutive) label indices to include. 44 | use_display_name: (boolean) choose whether to load 'display_name' field 45 | as category name. If False of if the display_name field does not exist, 46 | uses 'name' field as category names instead. 47 | Returns: 48 | categories: a list of dictionaries representing all possible categories. 49 | """ 50 | categories = [] 51 | list_of_ids_already_added = [] 52 | if not label_map: 53 | label_id_offset = 1 54 | for class_id in range(max_num_classes): 55 | categories.append({ 56 | 'id': class_id + label_id_offset, 57 | 'name': 'category_{}'.format(class_id + label_id_offset) 58 | }) 59 | return categories 60 | for item in label_map.item: 61 | if not 0 < item.id <= max_num_classes: 62 | logging.info('Ignore item %d since it falls outside of requested ' 63 | 'label range.', item.id) 64 | continue 65 | if use_display_name and item.HasField('display_name'): 66 | name = item.display_name 67 | else: 68 | name = item.name 69 | if item.id not in list_of_ids_already_added: 70 | list_of_ids_already_added.append(item.id) 71 | categories.append({'id': item.id, 'name': name}) 72 | return categories 73 | 74 | 75 | # TODO: double check documentaion. 76 | def load_labelmap(path): 77 | """Loads label map proto. 78 | Args: 79 | path: path to StringIntLabelMap proto text file. 80 | Returns: 81 | a StringIntLabelMapProto 82 | """ 83 | with tf.gfile.GFile(path, 'r') as fid: 84 | label_map_string = fid.read() 85 | label_map = string_int_label_map_pb2.StringIntLabelMap() 86 | try: 87 | text_format.Merge(label_map_string, label_map) 88 | except text_format.ParseError: 89 | label_map.ParseFromString(label_map_string) 90 | return label_map 91 | 92 | 93 | def get_label_map_dict(label_map_path): 94 | """Reads a label map and returns a dictionary of label names to id. 95 | Args: 96 | label_map_path: path to label_map. 97 | Returns: 98 | A dictionary mapping label names to id. 99 | """ 100 | label_map = load_labelmap(label_map_path) 101 | label_map_dict = {} 102 | for item in label_map.item: 103 | label_map_dict[item.name] = item.id 104 | return label_map_dict -------------------------------------------------------------------------------- /utils/string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | import sys 4 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: \ 5 | x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | 12 | # @@protoc_insertion_point(imports) 13 | 14 | _sym_db = _symbol_database.Default() 15 | 16 | DESCRIPTOR = \ 17 | _descriptor.FileDescriptor(name='object_detection/protos/string_int_label_map.proto' 18 | , package='object_detection.protos', 19 | serialized_pb=_b(''' 20 | 2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G 21 | \x15StringIntLabelMapItem\x12\x0c 22 | \x04name\x18\x01 \x01(\t\x12 23 | 24 | \x02id\x18\x02 \x01(\x05\x12\x14 25 | \x0c\x64isplay_name\x18\x03 \x01(\t\"Q 26 | \x11StringIntLabelMap\x12< 27 | \x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem''')) 28 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 29 | 30 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 31 | name='StringIntLabelMapItem', 32 | full_name='object_detection.protos.StringIntLabelMapItem', 33 | filename=None, 34 | file=DESCRIPTOR, 35 | containing_type=None, 36 | fields=[_descriptor.FieldDescriptor( 37 | name='name', 38 | full_name='object_detection.protos.StringIntLabelMapItem.name', 39 | index=0, 40 | number=1, 41 | type=9, 42 | cpp_type=9, 43 | label=1, 44 | has_default_value=False, 45 | default_value=_b('').decode('utf-8'), 46 | message_type=None, 47 | enum_type=None, 48 | containing_type=None, 49 | is_extension=False, 50 | extension_scope=None, 51 | options=None, 52 | ), _descriptor.FieldDescriptor( 53 | name='id', 54 | full_name='object_detection.protos.StringIntLabelMapItem.id', 55 | index=1, 56 | number=2, 57 | type=5, 58 | cpp_type=1, 59 | label=1, 60 | has_default_value=False, 61 | default_value=0, 62 | message_type=None, 63 | enum_type=None, 64 | containing_type=None, 65 | is_extension=False, 66 | extension_scope=None, 67 | options=None, 68 | ), _descriptor.FieldDescriptor( 69 | name='display_name', 70 | full_name='object_detection.protos.StringIntLabelMapItem.display_name' 71 | , 72 | index=2, 73 | number=3, 74 | type=9, 75 | cpp_type=9, 76 | label=1, 77 | has_default_value=False, 78 | default_value=_b('').decode('utf-8'), 79 | message_type=None, 80 | enum_type=None, 81 | containing_type=None, 82 | is_extension=False, 83 | extension_scope=None, 84 | options=None, 85 | )], 86 | extensions=[], 87 | nested_types=[], 88 | enum_types=[], 89 | options=None, 90 | is_extendable=False, 91 | extension_ranges=[], 92 | oneofs=[], 93 | serialized_start=79, 94 | serialized_end=150, 95 | ) 96 | 97 | _STRINGINTLABELMAP = _descriptor.Descriptor( 98 | name='StringIntLabelMap', 99 | full_name='object_detection.protos.StringIntLabelMap', 100 | filename=None, 101 | file=DESCRIPTOR, 102 | containing_type=None, 103 | fields=[_descriptor.FieldDescriptor( 104 | name='item', 105 | full_name='object_detection.protos.StringIntLabelMap.item', 106 | index=0, 107 | number=1, 108 | type=11, 109 | cpp_type=10, 110 | label=3, 111 | has_default_value=False, 112 | default_value=[], 113 | message_type=None, 114 | enum_type=None, 115 | containing_type=None, 116 | is_extension=False, 117 | extension_scope=None, 118 | options=None, 119 | )], 120 | extensions=[], 121 | nested_types=[], 122 | enum_types=[], 123 | options=None, 124 | is_extendable=False, 125 | extension_ranges=[], 126 | oneofs=[], 127 | serialized_start=152, 128 | serialized_end=233, 129 | ) 130 | 131 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = \ 132 | _STRINGINTLABELMAPITEM 133 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = \ 134 | _STRINGINTLABELMAPITEM 135 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = \ 136 | _STRINGINTLABELMAP 137 | 138 | StringIntLabelMapItem = \ 139 | _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', 140 | (_message.Message, ), dict(DESCRIPTOR=_STRINGINTLABELMAPITEM, 141 | __module__='object_detection.protos.string_int_label_map_pb2')) 142 | 143 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 144 | 145 | _sym_db.RegisterMessage(StringIntLabelMapItem) 146 | 147 | StringIntLabelMap = \ 148 | _reflection.GeneratedProtocolMessageType('StringIntLabelMap', 149 | (_message.Message, ), dict(DESCRIPTOR=_STRINGINTLABELMAP, 150 | __module__='object_detection.protos.string_int_label_map_pb2')) 151 | 152 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 153 | 154 | _sym_db.RegisterMessage(StringIntLabelMap) 155 | 156 | # @@protoc_insertion_point(module_scope)............ 157 | -------------------------------------------------------------------------------- /utils/webcam.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import six 3 | import collections 4 | import cv2 5 | import datetime 6 | from threading import Thread 7 | from matplotlib import colors 8 | 9 | class FPS: 10 | def __init__(self): 11 | # start time, end time, total number of frames 12 | self._start = None 13 | self._end = None 14 | self._numFrames = 0 15 | 16 | def start(self): 17 | # start the timer 18 | self._start = datetime.datetime.now() 19 | return self 20 | 21 | def stop(self): 22 | # stop the timer 23 | self._end = datetime.datetime.now() 24 | 25 | def update(self): 26 | # increment the number of frames analysed during the start and end interval 27 | self._numFrames += 1 28 | 29 | def elapsed(self): 30 | # return the total number of seconds 31 | return (self._end - self._start).total_seconds() 32 | 33 | def fps(self): 34 | # approximate frames per second 35 | return self._numFrames / self.elapsed() 36 | 37 | def get_numFrames(self): 38 | return self._numFrames 39 | 40 | 41 | class WebcamVideoStream: 42 | def __init__(self, src, width, height): 43 | # init the video camera stream 44 | # read the first frame from the stream 45 | 46 | # VideoCapture provides API for capturing video from cameras, reading video files and image sequences 47 | # try passing in different sources 48 | self.stream = cv2.VideoCapture(src) 49 | self.stream.set(cv2.CAP_PROP_FRAME_WIDTH, width) 50 | self.stream.set(cv2.CAP_PROP_FRAME_HEIGHT, height) 51 | 52 | # read the first frame 53 | (self.grabbed, self.frame) = self.stream.read() 54 | self.stopped = False 55 | 56 | 57 | def start(self): 58 | # start the thread to read frames from video stream 59 | Thread(target=self.update, args=()).start() 60 | return self 61 | 62 | 63 | def update(self): 64 | # keep looping until thread is stopped 65 | while True: 66 | if self.stopped: 67 | return 68 | 69 | # keep reading the next frame from the video stream 70 | (self.grabbed, self.frame) = self.stream.read() 71 | 72 | 73 | def read(self): 74 | # get the most recently read frame 75 | return self.frame 76 | 77 | 78 | def stop(self): 79 | # stopping the thread 80 | self.stopped = True 81 | self.stream.release() 82 | 83 | 84 | def standard_colors(): 85 | colors = [ 86 | 'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque', 87 | 'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite', 88 | 'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan', 89 | 'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange', 90 | 'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet', 91 | 'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite', 92 | 'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod', 93 | 'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki', 94 | 'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue', 95 | 'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey', 96 | 'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue', 97 | 'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime', 98 | 'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid', 99 | 'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen', 100 | 'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin', 101 | 'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed', 102 | 'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed', 103 | 'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple', 104 | 'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown', 105 | 'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue', 106 | 'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow', 107 | 'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White', 108 | 'WhiteSmoke', 'Yellow', 'YellowGreen' 109 | ] 110 | return colors 111 | 112 | def color_name_to_rgb(): 113 | colors_rgb = [] 114 | for key, value in colors.cnames.items(): 115 | colors_rgb.append((key, struct.unpack('BBB', bytes.fromhex(value.replace('#', ''))))) 116 | return dict(colors_rgb) 117 | 118 | 119 | def draw_boxes_and_labels( 120 | boxes, 121 | classes, 122 | scores, 123 | category_index, 124 | instance_masks=None, 125 | keypoints=None, 126 | max_boxes_to_draw=20, 127 | min_score_thresh=.5, 128 | agnostic_mode=False): 129 | """Returns boxes coordinates, class names and colors 130 | Args: 131 | boxes: a numpy array of shape [N, 4] 132 | classes: a numpy array of shape [N] 133 | scores: a numpy array of shape [N] or None. If scores=None, then 134 | this function assumes that the boxes to be plotted are groundtruth 135 | boxes and plot all boxes as black with no classes or scores. 136 | category_index: a dict containing category dictionaries (each holding 137 | category index `id` and category name `name`) keyed by category indices. 138 | instance_masks: a numpy array of shape [N, image_height, image_width], can 139 | be None 140 | keypoints: a numpy array of shape [N, num_keypoints, 2], can 141 | be None 142 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw 143 | all boxes. 144 | min_score_thresh: minimum score threshold for a box to be visualized 145 | agnostic_mode: boolean (default: False) controlling whether to evaluate in 146 | class-agnostic mode or not. This mode will display scores but ignore 147 | classes. 148 | """ 149 | # Create a display string (and color) for every box location, group any boxes 150 | # that correspond to the same location. 151 | box_to_display_str_map = collections.defaultdict(list) 152 | box_to_color_map = collections.defaultdict(str) 153 | box_to_instance_masks_map = {} 154 | box_to_keypoints_map = collections.defaultdict(list) 155 | if not max_boxes_to_draw: 156 | max_boxes_to_draw = boxes.shape[0] 157 | for i in range(min(max_boxes_to_draw, boxes.shape[0])): 158 | if scores is None or scores[i] > min_score_thresh: 159 | box = tuple(boxes[i].tolist()) 160 | if instance_masks is not None: 161 | box_to_instance_masks_map[box] = instance_masks[i] 162 | if keypoints is not None: 163 | box_to_keypoints_map[box].extend(keypoints[i]) 164 | if scores is None: 165 | box_to_color_map[box] = 'black' 166 | else: 167 | if not agnostic_mode: 168 | if classes[i] in category_index.keys(): 169 | class_name = category_index[classes[i]]['name'] 170 | else: 171 | class_name = 'N/A' 172 | display_str = '{}: {}%'.format( 173 | class_name, 174 | int(100 * scores[i])) 175 | else: 176 | display_str = 'score: {}%'.format(int(100 * scores[i])) 177 | box_to_display_str_map[box].append(display_str) 178 | if agnostic_mode: 179 | box_to_color_map[box] = 'DarkOrange' 180 | else: 181 | box_to_color_map[box] = standard_colors()[ 182 | classes[i] % len(standard_colors())] 183 | 184 | # Store all the coordinates of the boxes, class names and colors 185 | color_rgb = color_name_to_rgb() 186 | rect_points = [] 187 | class_names = [] 188 | class_colors = [] 189 | for box, color in six.iteritems(box_to_color_map): 190 | ymin, xmin, ymax, xmax = box 191 | rect_points.append(dict(ymin=ymin, xmin=xmin, ymax=ymax, xmax=xmax)) 192 | class_names.append(box_to_display_str_map[box]) 193 | class_colors.append(color_rgb[color.lower()]) 194 | return rect_points, class_names, class_colors -------------------------------------------------------------------------------- /video_writer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | class VideoWriter(object): 4 | 5 | def __init__(self, path, size): 6 | self.path = path 7 | self.size = size 8 | self.writer = cv2.VideoWriter(self.path, 9 | cv2.VideoWriter_fourcc('F','M','P','4'), 10 | 20.0, self.size, True) 11 | 12 | def __call__(self, frame): 13 | self.writer.write(frame) 14 | 15 | def close(self): 16 | self.writer.release() 17 | 18 | 19 | 20 | --------------------------------------------------------------------------------