├── .gitignore ├── README.md ├── label_utils ├── label_map_util.py ├── mscoco_label_map.pbtxt └── string_int_label_map_pb2.py ├── object_detection_tutorial.py └── test_images ├── image1.jpg ├── image2.jpg ├── image3.png └── image_info.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Object Detection Tutorial 2 | ## v0.1.0 3 | 4 | ### Update 5 | 6 | - draw bbox and label 7 | - multi images 8 | 9 | ### Usage 10 | 11 | - download model from https://github.com/tensorflow/models/blob/v1.13.0/research/object_detection/g3doc/detection_model_zoo.md 12 | - run `python object_detection_tutorial.py --model_frozen /PATH/TO/MODEL` 13 | 14 | for example: 15 | ``` 16 | python3 object_detection_tutorial.py --model_frozen ../ssd_mobilenet_v1_coco_2018_01_28/frozen_inference_graph.pb 17 | ``` 18 | 19 | ### ref: 20 | - modify from tensorflow detection api demo(v1.13.0): https://github.com/tensorflow/models/blob/v1.13.0/research/object_detection/object_detection_tutorial.ipynb 21 | 22 | ## v0.1.1 23 | - optimization 24 | - TFDetector class 25 | 26 | ### Update 27 | - new TFDetector class with session as class member 28 | - remove PIL, matplotlib dependencies 29 | - reformat the code 30 | - combine load image paths and labels in load_image_and_labels functions 31 | - remove isimage function, add types parameter to handle more image types in the future 32 | - remove main function -------------------------------------------------------------------------------- /label_utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Label map utility functions.""" 16 | 17 | import logging 18 | 19 | import tensorflow as tf 20 | from google.protobuf import text_format 21 | from label_utils import string_int_label_map_pb2 22 | 23 | 24 | def _validate_label_map(label_map): 25 | """Checks if a label map is valid. 26 | 27 | Args: 28 | label_map: StringIntLabelMap to validate. 29 | 30 | Raises: 31 | ValueError: if label map is invalid. 32 | """ 33 | for item in label_map.item: 34 | if item.id < 0: 35 | raise ValueError('Label map ids should be >= 0.') 36 | if (item.id == 0 and item.name != 'background' and 37 | item.display_name != 'background'): 38 | raise ValueError('Label map id 0 is reserved for the background label') 39 | 40 | 41 | def create_category_index(categories): 42 | """Creates dictionary of COCO compatible categories keyed by category id. 43 | 44 | Args: 45 | categories: a list of dicts, each of which has the following keys: 46 | 'id': (required) an integer id uniquely identifying this category. 47 | 'name': (required) string representing category name 48 | e.g., 'cat', 'dog', 'pizza'. 49 | 50 | Returns: 51 | category_index: a dict containing the same entries as categories, but keyed 52 | by the 'id' field of each category. 53 | """ 54 | category_index = {} 55 | for cat in categories: 56 | category_index[cat['id']] = cat 57 | return category_index 58 | 59 | 60 | def get_max_label_map_index(label_map): 61 | """Get maximum index in label map. 62 | 63 | Args: 64 | label_map: a StringIntLabelMapProto 65 | 66 | Returns: 67 | an integer 68 | """ 69 | return max([item.id for item in label_map.item]) 70 | 71 | 72 | def convert_label_map_to_categories(label_map, 73 | max_num_classes, 74 | use_display_name=True): 75 | """Given label map proto returns categories list compatible with eval. 76 | 77 | This function converts label map proto and returns a list of dicts, each of 78 | which has the following keys: 79 | 'id': (required) an integer id uniquely identifying this category. 80 | 'name': (required) string representing category name 81 | e.g., 'cat', 'dog', 'pizza'. 82 | We only allow class into the list if its id-label_id_offset is 83 | between 0 (inclusive) and max_num_classes (exclusive). 84 | If there are several items mapping to the same id in the label map, 85 | we will only keep the first one in the categories list. 86 | 87 | Args: 88 | label_map: a StringIntLabelMapProto or None. If None, a default categories 89 | list is created with max_num_classes categories. 90 | max_num_classes: maximum number of (consecutive) label indices to include. 91 | use_display_name: (boolean) choose whether to load 'display_name' field as 92 | category name. If False or if the display_name field does not exist, uses 93 | 'name' field as category names instead. 94 | 95 | Returns: 96 | categories: a list of dictionaries representing all possible categories. 97 | """ 98 | categories = [] 99 | list_of_ids_already_added = [] 100 | if not label_map: 101 | label_id_offset = 1 102 | for class_id in range(max_num_classes): 103 | categories.append({ 104 | 'id': class_id + label_id_offset, 105 | 'name': 'category_{}'.format(class_id + label_id_offset) 106 | }) 107 | return categories 108 | for item in label_map.item: 109 | if not 0 < item.id <= max_num_classes: 110 | logging.info( 111 | 'Ignore item %d since it falls outside of requested ' 112 | 'label range.', item.id) 113 | continue 114 | if use_display_name and item.HasField('display_name'): 115 | name = item.display_name 116 | else: 117 | name = item.name 118 | if item.id not in list_of_ids_already_added: 119 | list_of_ids_already_added.append(item.id) 120 | categories.append({'id': item.id, 'name': name}) 121 | return categories 122 | 123 | 124 | def load_labelmap(path): 125 | """Loads label map proto. 126 | 127 | Args: 128 | path: path to StringIntLabelMap proto text file. 129 | Returns: 130 | a StringIntLabelMapProto 131 | """ 132 | with tf.gfile.GFile(path, 'r') as fid: 133 | label_map_string = fid.read() 134 | label_map = string_int_label_map_pb2.StringIntLabelMap() 135 | try: 136 | text_format.Merge(label_map_string, label_map) 137 | except text_format.ParseError: 138 | label_map.ParseFromString(label_map_string) 139 | _validate_label_map(label_map) 140 | return label_map 141 | 142 | 143 | def get_label_map_dict(label_map_path, 144 | use_display_name=False, 145 | fill_in_gaps_and_background=False): 146 | """Reads a label map and returns a dictionary of label names to id. 147 | 148 | Args: 149 | label_map_path: path to StringIntLabelMap proto text file. 150 | use_display_name: whether to use the label map items' display names as keys. 151 | fill_in_gaps_and_background: whether to fill in gaps and background with 152 | respect to the id field in the proto. The id: 0 is reserved for the 153 | 'background' class and will be added if it is missing. All other missing 154 | ids in range(1, max(id)) will be added with a dummy class name 155 | ("class_") if they are missing. 156 | 157 | Returns: 158 | A dictionary mapping label names to id. 159 | 160 | Raises: 161 | ValueError: if fill_in_gaps_and_background and label_map has non-integer or 162 | negative values. 163 | """ 164 | label_map = load_labelmap(label_map_path) 165 | label_map_dict = {} 166 | for item in label_map.item: 167 | if use_display_name: 168 | label_map_dict[item.display_name] = item.id 169 | else: 170 | label_map_dict[item.name] = item.id 171 | 172 | if fill_in_gaps_and_background: 173 | values = set(label_map_dict.values()) 174 | 175 | if 0 not in values: 176 | label_map_dict['background'] = 0 177 | if not all(isinstance(value, int) for value in values): 178 | raise ValueError('The values in label map must be integers in order to' 179 | 'fill_in_gaps_and_background.') 180 | if not all(value >= 0 for value in values): 181 | raise ValueError('The values in the label map must be positive.') 182 | 183 | if len(values) != max(values) + 1: 184 | # there are gaps in the labels, fill in gaps. 185 | for value in range(1, max(values)): 186 | if value not in values: 187 | label_map_dict['class_' + str(value)] = value 188 | 189 | return label_map_dict 190 | 191 | 192 | def create_categories_from_labelmap(label_map_path, use_display_name=True): 193 | """Reads a label map and returns categories list compatible with eval. 194 | 195 | This function converts label map proto and returns a list of dicts, each of 196 | which has the following keys: 197 | 'id': an integer id uniquely identifying this category. 198 | 'name': string representing category name e.g., 'cat', 'dog'. 199 | 200 | Args: 201 | label_map_path: Path to `StringIntLabelMap` proto text file. 202 | use_display_name: (boolean) choose whether to load 'display_name' field 203 | as category name. If False or if the display_name field does not exist, 204 | uses 'name' field as category names instead. 205 | 206 | Returns: 207 | categories: a list of dictionaries representing all possible categories. 208 | """ 209 | label_map = load_labelmap(label_map_path) 210 | max_num_classes = max(item.id for item in label_map.item) 211 | return convert_label_map_to_categories(label_map, max_num_classes, 212 | use_display_name) 213 | 214 | 215 | def create_category_index_from_labelmap(label_map_path, use_display_name=True): 216 | """Reads a label map and returns a category index. 217 | 218 | Args: 219 | label_map_path: Path to `StringIntLabelMap` proto text file. 220 | use_display_name: (boolean) choose whether to load 'display_name' field 221 | as category name. If False or if the display_name field does not exist, 222 | uses 'name' field as category names instead. 223 | 224 | Returns: 225 | A category index, which is a dictionary that maps integer ids to dicts 226 | containing categories, e.g. 227 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 228 | """ 229 | categories = create_categories_from_labelmap(label_map_path, use_display_name) 230 | return create_category_index(categories) 231 | 232 | 233 | def create_class_agnostic_category_index(): 234 | """Creates a category index with a single `object` class.""" 235 | return {1: {'id': 1, 'name': 'object'}} 236 | -------------------------------------------------------------------------------- /label_utils/mscoco_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "/m/01g317" 3 | id: 1 4 | display_name: "person" 5 | } 6 | item { 7 | name: "/m/0199g" 8 | id: 2 9 | display_name: "bicycle" 10 | } 11 | item { 12 | name: "/m/0k4j" 13 | id: 3 14 | display_name: "car" 15 | } 16 | item { 17 | name: "/m/04_sv" 18 | id: 4 19 | display_name: "motorcycle" 20 | } 21 | item { 22 | name: "/m/05czz6l" 23 | id: 5 24 | display_name: "airplane" 25 | } 26 | item { 27 | name: "/m/01bjv" 28 | id: 6 29 | display_name: "bus" 30 | } 31 | item { 32 | name: "/m/07jdr" 33 | id: 7 34 | display_name: "train" 35 | } 36 | item { 37 | name: "/m/07r04" 38 | id: 8 39 | display_name: "truck" 40 | } 41 | item { 42 | name: "/m/019jd" 43 | id: 9 44 | display_name: "boat" 45 | } 46 | item { 47 | name: "/m/015qff" 48 | id: 10 49 | display_name: "traffic light" 50 | } 51 | item { 52 | name: "/m/01pns0" 53 | id: 11 54 | display_name: "fire hydrant" 55 | } 56 | item { 57 | name: "/m/02pv19" 58 | id: 13 59 | display_name: "stop sign" 60 | } 61 | item { 62 | name: "/m/015qbp" 63 | id: 14 64 | display_name: "parking meter" 65 | } 66 | item { 67 | name: "/m/0cvnqh" 68 | id: 15 69 | display_name: "bench" 70 | } 71 | item { 72 | name: "/m/015p6" 73 | id: 16 74 | display_name: "bird" 75 | } 76 | item { 77 | name: "/m/01yrx" 78 | id: 17 79 | display_name: "cat" 80 | } 81 | item { 82 | name: "/m/0bt9lr" 83 | id: 18 84 | display_name: "dog" 85 | } 86 | item { 87 | name: "/m/03k3r" 88 | id: 19 89 | display_name: "horse" 90 | } 91 | item { 92 | name: "/m/07bgp" 93 | id: 20 94 | display_name: "sheep" 95 | } 96 | item { 97 | name: "/m/01xq0k1" 98 | id: 21 99 | display_name: "cow" 100 | } 101 | item { 102 | name: "/m/0bwd_0j" 103 | id: 22 104 | display_name: "elephant" 105 | } 106 | item { 107 | name: "/m/01dws" 108 | id: 23 109 | display_name: "bear" 110 | } 111 | item { 112 | name: "/m/0898b" 113 | id: 24 114 | display_name: "zebra" 115 | } 116 | item { 117 | name: "/m/03bk1" 118 | id: 25 119 | display_name: "giraffe" 120 | } 121 | item { 122 | name: "/m/01940j" 123 | id: 27 124 | display_name: "backpack" 125 | } 126 | item { 127 | name: "/m/0hnnb" 128 | id: 28 129 | display_name: "umbrella" 130 | } 131 | item { 132 | name: "/m/080hkjn" 133 | id: 31 134 | display_name: "handbag" 135 | } 136 | item { 137 | name: "/m/01rkbr" 138 | id: 32 139 | display_name: "tie" 140 | } 141 | item { 142 | name: "/m/01s55n" 143 | id: 33 144 | display_name: "suitcase" 145 | } 146 | item { 147 | name: "/m/02wmf" 148 | id: 34 149 | display_name: "frisbee" 150 | } 151 | item { 152 | name: "/m/071p9" 153 | id: 35 154 | display_name: "skis" 155 | } 156 | item { 157 | name: "/m/06__v" 158 | id: 36 159 | display_name: "snowboard" 160 | } 161 | item { 162 | name: "/m/018xm" 163 | id: 37 164 | display_name: "sports ball" 165 | } 166 | item { 167 | name: "/m/02zt3" 168 | id: 38 169 | display_name: "kite" 170 | } 171 | item { 172 | name: "/m/03g8mr" 173 | id: 39 174 | display_name: "baseball bat" 175 | } 176 | item { 177 | name: "/m/03grzl" 178 | id: 40 179 | display_name: "baseball glove" 180 | } 181 | item { 182 | name: "/m/06_fw" 183 | id: 41 184 | display_name: "skateboard" 185 | } 186 | item { 187 | name: "/m/019w40" 188 | id: 42 189 | display_name: "surfboard" 190 | } 191 | item { 192 | name: "/m/0dv9c" 193 | id: 43 194 | display_name: "tennis racket" 195 | } 196 | item { 197 | name: "/m/04dr76w" 198 | id: 44 199 | display_name: "bottle" 200 | } 201 | item { 202 | name: "/m/09tvcd" 203 | id: 46 204 | display_name: "wine glass" 205 | } 206 | item { 207 | name: "/m/08gqpm" 208 | id: 47 209 | display_name: "cup" 210 | } 211 | item { 212 | name: "/m/0dt3t" 213 | id: 48 214 | display_name: "fork" 215 | } 216 | item { 217 | name: "/m/04ctx" 218 | id: 49 219 | display_name: "knife" 220 | } 221 | item { 222 | name: "/m/0cmx8" 223 | id: 50 224 | display_name: "spoon" 225 | } 226 | item { 227 | name: "/m/04kkgm" 228 | id: 51 229 | display_name: "bowl" 230 | } 231 | item { 232 | name: "/m/09qck" 233 | id: 52 234 | display_name: "banana" 235 | } 236 | item { 237 | name: "/m/014j1m" 238 | id: 53 239 | display_name: "apple" 240 | } 241 | item { 242 | name: "/m/0l515" 243 | id: 54 244 | display_name: "sandwich" 245 | } 246 | item { 247 | name: "/m/0cyhj_" 248 | id: 55 249 | display_name: "orange" 250 | } 251 | item { 252 | name: "/m/0hkxq" 253 | id: 56 254 | display_name: "broccoli" 255 | } 256 | item { 257 | name: "/m/0fj52s" 258 | id: 57 259 | display_name: "carrot" 260 | } 261 | item { 262 | name: "/m/01b9xk" 263 | id: 58 264 | display_name: "hot dog" 265 | } 266 | item { 267 | name: "/m/0663v" 268 | id: 59 269 | display_name: "pizza" 270 | } 271 | item { 272 | name: "/m/0jy4k" 273 | id: 60 274 | display_name: "donut" 275 | } 276 | item { 277 | name: "/m/0fszt" 278 | id: 61 279 | display_name: "cake" 280 | } 281 | item { 282 | name: "/m/01mzpv" 283 | id: 62 284 | display_name: "chair" 285 | } 286 | item { 287 | name: "/m/02crq1" 288 | id: 63 289 | display_name: "couch" 290 | } 291 | item { 292 | name: "/m/03fp41" 293 | id: 64 294 | display_name: "potted plant" 295 | } 296 | item { 297 | name: "/m/03ssj5" 298 | id: 65 299 | display_name: "bed" 300 | } 301 | item { 302 | name: "/m/04bcr3" 303 | id: 67 304 | display_name: "dining table" 305 | } 306 | item { 307 | name: "/m/09g1w" 308 | id: 70 309 | display_name: "toilet" 310 | } 311 | item { 312 | name: "/m/07c52" 313 | id: 72 314 | display_name: "tv" 315 | } 316 | item { 317 | name: "/m/01c648" 318 | id: 73 319 | display_name: "laptop" 320 | } 321 | item { 322 | name: "/m/020lf" 323 | id: 74 324 | display_name: "mouse" 325 | } 326 | item { 327 | name: "/m/0qjjc" 328 | id: 75 329 | display_name: "remote" 330 | } 331 | item { 332 | name: "/m/01m2v" 333 | id: 76 334 | display_name: "keyboard" 335 | } 336 | item { 337 | name: "/m/050k8" 338 | id: 77 339 | display_name: "cell phone" 340 | } 341 | item { 342 | name: "/m/0fx9l" 343 | id: 78 344 | display_name: "microwave" 345 | } 346 | item { 347 | name: "/m/029bxz" 348 | id: 79 349 | display_name: "oven" 350 | } 351 | item { 352 | name: "/m/01k6s3" 353 | id: 80 354 | display_name: "toaster" 355 | } 356 | item { 357 | name: "/m/0130jx" 358 | id: 81 359 | display_name: "sink" 360 | } 361 | item { 362 | name: "/m/040b_t" 363 | id: 82 364 | display_name: "refrigerator" 365 | } 366 | item { 367 | name: "/m/0bt_c3" 368 | id: 84 369 | display_name: "book" 370 | } 371 | item { 372 | name: "/m/01x3z" 373 | id: 85 374 | display_name: "clock" 375 | } 376 | item { 377 | name: "/m/02s195" 378 | id: 86 379 | display_name: "vase" 380 | } 381 | item { 382 | name: "/m/01lsmm" 383 | id: 87 384 | display_name: "scissors" 385 | } 386 | item { 387 | name: "/m/0kmg4" 388 | id: 88 389 | display_name: "teddy bear" 390 | } 391 | item { 392 | name: "/m/03wvsk" 393 | id: 89 394 | display_name: "hair drier" 395 | } 396 | item { 397 | name: "/m/012xff" 398 | id: 90 399 | display_name: "toothbrush" 400 | } 401 | -------------------------------------------------------------------------------- /label_utils/string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/string_int_label_map.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/string_int_label_map.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 23 | ) 24 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 25 | 26 | 27 | 28 | 29 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 30 | name='StringIntLabelMapItem', 31 | full_name='object_detection.protos.StringIntLabelMapItem', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | containing_type=None, 35 | fields=[ 36 | _descriptor.FieldDescriptor( 37 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 38 | number=1, type=9, cpp_type=9, label=1, 39 | has_default_value=False, default_value=_b("").decode('utf-8'), 40 | message_type=None, enum_type=None, containing_type=None, 41 | is_extension=False, extension_scope=None, 42 | options=None), 43 | _descriptor.FieldDescriptor( 44 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 45 | number=2, type=5, cpp_type=1, label=1, 46 | has_default_value=False, default_value=0, 47 | message_type=None, enum_type=None, containing_type=None, 48 | is_extension=False, extension_scope=None, 49 | options=None), 50 | _descriptor.FieldDescriptor( 51 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 52 | number=3, type=9, cpp_type=9, label=1, 53 | has_default_value=False, default_value=_b("").decode('utf-8'), 54 | message_type=None, enum_type=None, containing_type=None, 55 | is_extension=False, extension_scope=None, 56 | options=None), 57 | ], 58 | extensions=[ 59 | ], 60 | nested_types=[], 61 | enum_types=[ 62 | ], 63 | options=None, 64 | is_extendable=False, 65 | syntax='proto2', 66 | extension_ranges=[], 67 | oneofs=[ 68 | ], 69 | serialized_start=79, 70 | serialized_end=150, 71 | ) 72 | 73 | 74 | _STRINGINTLABELMAP = _descriptor.Descriptor( 75 | name='StringIntLabelMap', 76 | full_name='object_detection.protos.StringIntLabelMap', 77 | filename=None, 78 | file=DESCRIPTOR, 79 | containing_type=None, 80 | fields=[ 81 | _descriptor.FieldDescriptor( 82 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 83 | number=1, type=11, cpp_type=10, label=3, 84 | has_default_value=False, default_value=[], 85 | message_type=None, enum_type=None, containing_type=None, 86 | is_extension=False, extension_scope=None, 87 | options=None), 88 | ], 89 | extensions=[ 90 | ], 91 | nested_types=[], 92 | enum_types=[ 93 | ], 94 | options=None, 95 | is_extendable=False, 96 | syntax='proto2', 97 | extension_ranges=[], 98 | oneofs=[ 99 | ], 100 | serialized_start=152, 101 | serialized_end=233, 102 | ) 103 | 104 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 105 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 106 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 107 | 108 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 109 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 110 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 111 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 112 | )) 113 | _sym_db.RegisterMessage(StringIntLabelMapItem) 114 | 115 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 116 | DESCRIPTOR = _STRINGINTLABELMAP, 117 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 118 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 119 | )) 120 | _sym_db.RegisterMessage(StringIntLabelMap) 121 | 122 | 123 | # @@protoc_insertion_point(module_scope) 124 | -------------------------------------------------------------------------------- /object_detection_tutorial.py: -------------------------------------------------------------------------------- 1 | from distutils.version import StrictVersion 2 | from label_utils import label_map_util 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import argparse 7 | import os 8 | import time 9 | import cv2 10 | 11 | FLAGS = None 12 | 13 | if StrictVersion(tf.__version__) < StrictVersion('1.9.0'): 14 | raise ImportError('Please upgrade your TensorFlow installation to v1.9.* or later!') 15 | 16 | 17 | def load_image_and_labels(label_path, 18 | image_path, 19 | verbose=True, 20 | types=['.jpg', '.png', '.jpeg']): 21 | # labels 22 | category_index = label_map_util.create_category_index_from_labelmap(label_path, use_display_name=True) 23 | 24 | image_paths = [] 25 | if os.path.isfile(image_path): 26 | image_paths.append(image_path) 27 | else: 28 | for file_or_dir in os.listdir(image_path): 29 | file_path = os.path.join(image_path, file_or_dir) 30 | if os.path.isfile(file_path) and \ 31 | os.path.splitext(file_path)[1].lower() in types: 32 | image_paths.append(file_path) 33 | if verbose: 34 | print(image_paths) 35 | return category_index, image_paths 36 | 37 | 38 | class TFDetector(object): 39 | def __init__(self, model_path, category_index): 40 | self.graph = self.create_graph(model_path) 41 | self.sess = self.create_session() 42 | self.category_index = category_index 43 | 44 | def create_graph(self, model_path): 45 | detection_graph = tf.Graph() 46 | with detection_graph.as_default(): 47 | od_graph_def = tf.GraphDef() 48 | with tf.gfile.GFile(model_path, 'rb') as fid: 49 | serialized_graph = fid.read() 50 | od_graph_def.ParseFromString(serialized_graph) 51 | tf.import_graph_def(od_graph_def, name='') 52 | self.graph = detection_graph 53 | return self.graph 54 | 55 | def create_session(self): 56 | with self.graph.as_default(): 57 | self.sess = tf.Session() 58 | return self.sess 59 | 60 | def detect(self, image, mark=False): 61 | with self.graph.as_default(): 62 | # Get handles to input and output tensors 63 | ops = tf.get_default_graph().get_operations() 64 | all_tensor_names = {output.name for op in ops for output in op.outputs} 65 | tensor_dict = {} 66 | for key in [ 67 | 'num_detections', 'detection_boxes', 'detection_scores', 68 | 'detection_classes', 'detection_masks' 69 | ]: 70 | tensor_name = key + ':0' 71 | if tensor_name in all_tensor_names: 72 | tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( 73 | tensor_name) 74 | 75 | image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') 76 | # Run inference 77 | start_time = time.time() 78 | output_dict = self.sess.run(tensor_dict, feed_dict={image_tensor: np.expand_dims(image, 0)}) 79 | end_time = time.time() 80 | print("run time:", end_time - start_time) 81 | 82 | # all outputs are float32 numpy arrays, so convert types as appropriate 83 | num = int(output_dict['num_detections'][0]) 84 | output_dict['num_detections'] = num 85 | output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)[:num] 86 | output_dict['detection_boxes'] = output_dict['detection_boxes'][0][:num] 87 | output_dict['detection_scores'] = output_dict['detection_scores'][0][:num] 88 | 89 | if mark: 90 | for i in range(output_dict['num_detections']): 91 | image_height, image_width = image.shape[:2] 92 | 93 | cls_id = output_dict['detection_classes'][i] 94 | cls_name = self.category_index[cls_id]['name'] 95 | score = output_dict['detection_scores'][i] 96 | box_ymin = int(output_dict['detection_boxes'][i][0] * image_height) 97 | box_xmin = int(output_dict['detection_boxes'][i][1] * image_width) 98 | box_ymax = int(output_dict['detection_boxes'][i][2] * image_height) 99 | box_xmax = int(output_dict['detection_boxes'][i][3] * image_width) 100 | cv2.rectangle(image, (box_xmin, box_ymin), (box_xmax, box_ymax), (0, 255, 0), 3) 101 | text = "%s:%.2f" % (cls_name, score) 102 | cv2.putText(image, text, (box_xmin, box_ymin - 4), cv2.FONT_HERSHEY_COMPLEX_SMALL, 0.8, 103 | (255, 0, 0)) 104 | # show image 105 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 106 | cv2.imshow("img", image) 107 | cv2.waitKey(0) 108 | cv2.destroyAllWindows() 109 | return output_dict, image 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('--image_path', type=str, 115 | default='test_images/', 116 | help='image path') 117 | parser.add_argument('--model_frozen', type=str, 118 | default='TFMODEL/ssd_mobilenet_v1_coco_2018_01_28/frozen_inference_graph.pb', 119 | help='model path') 120 | parser.add_argument('--label_path', type=str, 121 | default='label_utils/mscoco_label_map.pbtxt', 122 | help='label path') 123 | FLAGS, unparsed = parser.parse_known_args() 124 | 125 | category_index, image_paths = load_image_and_labels(FLAGS.label_path, FLAGS.image_path) 126 | detector = TFDetector(FLAGS.model_frozen, category_index) 127 | 128 | for image_path in image_paths: 129 | # load image 130 | image = cv2.imread(image_path) 131 | # convert color space, try to remove this you'll see amazing result for image1 132 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 133 | # detection 134 | detector.detect(image, mark=True) 135 | -------------------------------------------------------------------------------- /test_images/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andylei77/object-detector/d9ddcb48bf339fc4ef604626e54be135e484d9e5/test_images/image1.jpg -------------------------------------------------------------------------------- /test_images/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andylei77/object-detector/d9ddcb48bf339fc4ef604626e54be135e484d9e5/test_images/image2.jpg -------------------------------------------------------------------------------- /test_images/image3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andylei77/object-detector/d9ddcb48bf339fc4ef604626e54be135e484d9e5/test_images/image3.png -------------------------------------------------------------------------------- /test_images/image_info.txt: -------------------------------------------------------------------------------- 1 | 2 | Image provenance: 3 | image1.jpg: https://commons.wikimedia.org/wiki/File:Baegle_dwa.jpg 4 | image2.jpg: Michael Miley, 5 | https://www.flickr.com/photos/mike_miley/4678754542/in/photolist-88rQHL-88oBVp-88oC2B-88rS6J-88rSqm-88oBLv-88oBC4 6 | 7 | --------------------------------------------------------------------------------