├── .gitignore ├── README.md ├── data ├── box_priors.txt ├── coco_labels.txt └── mscoco_label_map.pbtxt ├── detect.py ├── detect.tflite ├── detection_stream.py ├── dog.jpg ├── jetson_stream.py ├── object_detection ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc └── protos │ ├── BUILD │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── string_int_label_map_pb2.cpython-36.pyc │ ├── anchor_generator.proto │ ├── anchor_generator_pb2.py │ ├── argmax_matcher.proto │ ├── argmax_matcher_pb2.py │ ├── bipartite_matcher.proto │ ├── bipartite_matcher_pb2.py │ ├── box_coder.proto │ ├── box_coder_pb2.py │ ├── box_predictor.proto │ ├── box_predictor_pb2.py │ ├── eval.proto │ ├── eval_pb2.py │ ├── faster_rcnn.proto │ ├── faster_rcnn_box_coder.proto │ ├── faster_rcnn_box_coder_pb2.py │ ├── faster_rcnn_pb2.py │ ├── grid_anchor_generator.proto │ ├── grid_anchor_generator_pb2.py │ ├── hyperparams.proto │ ├── hyperparams_pb2.py │ ├── image_resizer.proto │ ├── image_resizer_pb2.py │ ├── input_reader.proto │ ├── input_reader_pb2.py │ ├── losses.proto │ ├── losses_pb2.py │ ├── matcher.proto │ ├── matcher_pb2.py │ ├── mean_stddev_box_coder.proto │ ├── mean_stddev_box_coder_pb2.py │ ├── model.proto │ ├── model_pb2.py │ ├── optimizer.proto │ ├── optimizer_pb2.py │ ├── pipeline.proto │ ├── pipeline_pb2.py │ ├── post_processing.proto │ ├── post_processing_pb2.py │ ├── preprocessor.proto │ ├── preprocessor_pb2.py │ ├── region_similarity_calculator.proto │ ├── region_similarity_calculator_pb2.py │ ├── square_box_coder.proto │ ├── square_box_coder_pb2.py │ ├── ssd.proto │ ├── ssd_anchor_generator.proto │ ├── ssd_anchor_generator_pb2.py │ ├── ssd_pb2.py │ ├── string_int_label_map.proto │ ├── string_int_label_map_pb2.py │ ├── train.proto │ └── train_pb2.py ├── object_detector.py ├── object_detector_detection_api.py ├── object_detector_detection_api_lite.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── label_map_util.cpython-36.pyc │ └── utils.cpython-36.pyc ├── label_map_util.py └── utils.py └── yolo_darfklow.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | results 3 | __pycache__ 4 | *.pyc 5 | *.ipynb 6 | *.pb 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Mobile detector for embedded platforms 2 | 3 | `master` branch for Raspberry Pi, `jetson-tx2` brach for NVIDIA Jetson TX2. 4 | 5 | Use `detect.py` for object detection on image. 6 | 7 | Use `detection_stream.py` for detection from Raspberry Pi Camera. 8 | 9 | Use `jetson_stream.py` for detection from Jetson Camera. 10 | 11 | `https://drive.google.com/file/d/1_DSRMQB6oTaPifqAeDz1MIHVmfIaEB4g/view?usp=sharing` - *.pb files with ssdlite model graph. -------------------------------------------------------------------------------- /data/coco_labels.txt: -------------------------------------------------------------------------------- 1 | ??? 2 | person 3 | bicycle 4 | car 5 | motorcycle 6 | airplane 7 | bus 8 | train 9 | truck 10 | boat 11 | traffic light 12 | fire hydrant 13 | ??? 14 | stop sign 15 | parking meter 16 | bench 17 | bird 18 | cat 19 | dog 20 | horse 21 | sheep 22 | cow 23 | elephant 24 | bear 25 | zebra 26 | giraffe 27 | ??? 28 | backpack 29 | umbrella 30 | ??? 31 | ??? 32 | handbag 33 | tie 34 | suitcase 35 | frisbee 36 | skis 37 | snowboard 38 | sports ball 39 | kite 40 | baseball bat 41 | baseball glove 42 | skateboard 43 | surfboard 44 | tennis racket 45 | bottle 46 | ??? 47 | wine glass 48 | cup 49 | fork 50 | knife 51 | spoon 52 | bowl 53 | banana 54 | apple 55 | sandwich 56 | orange 57 | broccoli 58 | carrot 59 | hot dog 60 | pizza 61 | donut 62 | cake 63 | chair 64 | couch 65 | potted plant 66 | bed 67 | ??? 68 | dining table 69 | ??? 70 | ??? 71 | toilet 72 | ??? 73 | tv 74 | laptop 75 | mouse 76 | remote 77 | keyboard 78 | cell phone 79 | microwave 80 | oven 81 | toaster 82 | sink 83 | refrigerator 84 | ??? 85 | book 86 | clock 87 | vase 88 | scissors 89 | teddy bear 90 | hair drier 91 | toothbrush -------------------------------------------------------------------------------- /data/mscoco_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "/m/01g317" 3 | id: 1 4 | display_name: "person" 5 | } 6 | item { 7 | name: "/m/0199g" 8 | id: 2 9 | display_name: "bicycle" 10 | } 11 | item { 12 | name: "/m/0k4j" 13 | id: 3 14 | display_name: "car" 15 | } 16 | item { 17 | name: "/m/04_sv" 18 | id: 4 19 | display_name: "motorcycle" 20 | } 21 | item { 22 | name: "/m/05czz6l" 23 | id: 5 24 | display_name: "airplane" 25 | } 26 | item { 27 | name: "/m/01bjv" 28 | id: 6 29 | display_name: "bus" 30 | } 31 | item { 32 | name: "/m/07jdr" 33 | id: 7 34 | display_name: "train" 35 | } 36 | item { 37 | name: "/m/07r04" 38 | id: 8 39 | display_name: "truck" 40 | } 41 | item { 42 | name: "/m/019jd" 43 | id: 9 44 | display_name: "boat" 45 | } 46 | item { 47 | name: "/m/015qff" 48 | id: 10 49 | display_name: "traffic light" 50 | } 51 | item { 52 | name: "/m/01pns0" 53 | id: 11 54 | display_name: "fire hydrant" 55 | } 56 | item { 57 | name: "/m/02pv19" 58 | id: 13 59 | display_name: "stop sign" 60 | } 61 | item { 62 | name: "/m/015qbp" 63 | id: 14 64 | display_name: "parking meter" 65 | } 66 | item { 67 | name: "/m/0cvnqh" 68 | id: 15 69 | display_name: "bench" 70 | } 71 | item { 72 | name: "/m/015p6" 73 | id: 16 74 | display_name: "bird" 75 | } 76 | item { 77 | name: "/m/01yrx" 78 | id: 17 79 | display_name: "cat" 80 | } 81 | item { 82 | name: "/m/0bt9lr" 83 | id: 18 84 | display_name: "dog" 85 | } 86 | item { 87 | name: "/m/03k3r" 88 | id: 19 89 | display_name: "horse" 90 | } 91 | item { 92 | name: "/m/07bgp" 93 | id: 20 94 | display_name: "sheep" 95 | } 96 | item { 97 | name: "/m/01xq0k1" 98 | id: 21 99 | display_name: "cow" 100 | } 101 | item { 102 | name: "/m/0bwd_0j" 103 | id: 22 104 | display_name: "elephant" 105 | } 106 | item { 107 | name: "/m/01dws" 108 | id: 23 109 | display_name: "bear" 110 | } 111 | item { 112 | name: "/m/0898b" 113 | id: 24 114 | display_name: "zebra" 115 | } 116 | item { 117 | name: "/m/03bk1" 118 | id: 25 119 | display_name: "giraffe" 120 | } 121 | item { 122 | name: "/m/01940j" 123 | id: 27 124 | display_name: "backpack" 125 | } 126 | item { 127 | name: "/m/0hnnb" 128 | id: 28 129 | display_name: "umbrella" 130 | } 131 | item { 132 | name: "/m/080hkjn" 133 | id: 31 134 | display_name: "handbag" 135 | } 136 | item { 137 | name: "/m/01rkbr" 138 | id: 32 139 | display_name: "tie" 140 | } 141 | item { 142 | name: "/m/01s55n" 143 | id: 33 144 | display_name: "suitcase" 145 | } 146 | item { 147 | name: "/m/02wmf" 148 | id: 34 149 | display_name: "frisbee" 150 | } 151 | item { 152 | name: "/m/071p9" 153 | id: 35 154 | display_name: "skis" 155 | } 156 | item { 157 | name: "/m/06__v" 158 | id: 36 159 | display_name: "snowboard" 160 | } 161 | item { 162 | name: "/m/018xm" 163 | id: 37 164 | display_name: "sports ball" 165 | } 166 | item { 167 | name: "/m/02zt3" 168 | id: 38 169 | display_name: "kite" 170 | } 171 | item { 172 | name: "/m/03g8mr" 173 | id: 39 174 | display_name: "baseball bat" 175 | } 176 | item { 177 | name: "/m/03grzl" 178 | id: 40 179 | display_name: "baseball glove" 180 | } 181 | item { 182 | name: "/m/06_fw" 183 | id: 41 184 | display_name: "skateboard" 185 | } 186 | item { 187 | name: "/m/019w40" 188 | id: 42 189 | display_name: "surfboard" 190 | } 191 | item { 192 | name: "/m/0dv9c" 193 | id: 43 194 | display_name: "tennis racket" 195 | } 196 | item { 197 | name: "/m/04dr76w" 198 | id: 44 199 | display_name: "bottle" 200 | } 201 | item { 202 | name: "/m/09tvcd" 203 | id: 46 204 | display_name: "wine glass" 205 | } 206 | item { 207 | name: "/m/08gqpm" 208 | id: 47 209 | display_name: "cup" 210 | } 211 | item { 212 | name: "/m/0dt3t" 213 | id: 48 214 | display_name: "fork" 215 | } 216 | item { 217 | name: "/m/04ctx" 218 | id: 49 219 | display_name: "knife" 220 | } 221 | item { 222 | name: "/m/0cmx8" 223 | id: 50 224 | display_name: "spoon" 225 | } 226 | item { 227 | name: "/m/04kkgm" 228 | id: 51 229 | display_name: "bowl" 230 | } 231 | item { 232 | name: "/m/09qck" 233 | id: 52 234 | display_name: "banana" 235 | } 236 | item { 237 | name: "/m/014j1m" 238 | id: 53 239 | display_name: "apple" 240 | } 241 | item { 242 | name: "/m/0l515" 243 | id: 54 244 | display_name: "sandwich" 245 | } 246 | item { 247 | name: "/m/0cyhj_" 248 | id: 55 249 | display_name: "orange" 250 | } 251 | item { 252 | name: "/m/0hkxq" 253 | id: 56 254 | display_name: "broccoli" 255 | } 256 | item { 257 | name: "/m/0fj52s" 258 | id: 57 259 | display_name: "carrot" 260 | } 261 | item { 262 | name: "/m/01b9xk" 263 | id: 58 264 | display_name: "hot dog" 265 | } 266 | item { 267 | name: "/m/0663v" 268 | id: 59 269 | display_name: "pizza" 270 | } 271 | item { 272 | name: "/m/0jy4k" 273 | id: 60 274 | display_name: "donut" 275 | } 276 | item { 277 | name: "/m/0fszt" 278 | id: 61 279 | display_name: "cake" 280 | } 281 | item { 282 | name: "/m/01mzpv" 283 | id: 62 284 | display_name: "chair" 285 | } 286 | item { 287 | name: "/m/02crq1" 288 | id: 63 289 | display_name: "couch" 290 | } 291 | item { 292 | name: "/m/03fp41" 293 | id: 64 294 | display_name: "potted plant" 295 | } 296 | item { 297 | name: "/m/03ssj5" 298 | id: 65 299 | display_name: "bed" 300 | } 301 | item { 302 | name: "/m/04bcr3" 303 | id: 67 304 | display_name: "dining table" 305 | } 306 | item { 307 | name: "/m/09g1w" 308 | id: 70 309 | display_name: "toilet" 310 | } 311 | item { 312 | name: "/m/07c52" 313 | id: 72 314 | display_name: "tv" 315 | } 316 | item { 317 | name: "/m/01c648" 318 | id: 73 319 | display_name: "laptop" 320 | } 321 | item { 322 | name: "/m/020lf" 323 | id: 74 324 | display_name: "mouse" 325 | } 326 | item { 327 | name: "/m/0qjjc" 328 | id: 75 329 | display_name: "remote" 330 | } 331 | item { 332 | name: "/m/01m2v" 333 | id: 76 334 | display_name: "keyboard" 335 | } 336 | item { 337 | name: "/m/050k8" 338 | id: 77 339 | display_name: "cell phone" 340 | } 341 | item { 342 | name: "/m/0fx9l" 343 | id: 78 344 | display_name: "microwave" 345 | } 346 | item { 347 | name: "/m/029bxz" 348 | id: 79 349 | display_name: "oven" 350 | } 351 | item { 352 | name: "/m/01k6s3" 353 | id: 80 354 | display_name: "toaster" 355 | } 356 | item { 357 | name: "/m/0130jx" 358 | id: 81 359 | display_name: "sink" 360 | } 361 | item { 362 | name: "/m/040b_t" 363 | id: 82 364 | display_name: "refrigerator" 365 | } 366 | item { 367 | name: "/m/0bt_c3" 368 | id: 84 369 | display_name: "book" 370 | } 371 | item { 372 | name: "/m/01x3z" 373 | id: 85 374 | display_name: "clock" 375 | } 376 | item { 377 | name: "/m/02s195" 378 | id: 86 379 | display_name: "vase" 380 | } 381 | item { 382 | name: "/m/01lsmm" 383 | id: 87 384 | display_name: "scissors" 385 | } 386 | item { 387 | name: "/m/0kmg4" 388 | id: 88 389 | display_name: "teddy bear" 390 | } 391 | item { 392 | name: "/m/03wvsk" 393 | id: 89 394 | display_name: "hair drier" 395 | } 396 | item { 397 | name: "/m/012xff" 398 | id: 90 399 | display_name: "toothbrush" 400 | } 401 | -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path 3 | import logging 4 | import sys 5 | import time 6 | 7 | import cv2 8 | 9 | from utils.utils import load_image_into_numpy_array, Models 10 | from object_detector_detection_api import ObjectDetectorDetectionAPI 11 | from yolo_darfklow import YOLODarkflowDetector 12 | from object_detector_detection_api_lite import ObjectDetectorLite 13 | 14 | 15 | logging.basicConfig( 16 | stream=sys.stdout, 17 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 18 | datefmt=' %I:%M:%S ', 19 | level="INFO" 20 | ) 21 | logger = logging.getLogger('detector') 22 | 23 | 24 | basepath = path.dirname(__file__) 25 | 26 | 27 | if __name__ == '__main__': 28 | # initiate the parser 29 | parser = argparse.ArgumentParser(prog='test_models.py') 30 | 31 | # add arguments 32 | parser.add_argument("--image_path", "-ip", type=str, required=True, 33 | help="path to image") 34 | parser.add_argument("--model_name", "-mn", type=Models.from_string, 35 | required=True, choices=list(Models), 36 | help="name of detection model: {}".format( 37 | list(Models))) 38 | parser.add_argument("--cfg_path", "-cfg", type=str, required=False, 39 | default=path.join(basepath, "tiny-yolo-voc.cfg"), 40 | help="path to yolo *.cfg file") 41 | parser.add_argument("--graph_path", "-gp", type=str, required=False, 42 | default=path.join(basepath, "frozen_inference_graph.pb"), 43 | help="path to model frozen graph *.pb file") 44 | parser.add_argument("--result_path", "-rp", type=str, required=False, 45 | default='result.jpg', help="path to result image") 46 | 47 | # read arguments from the command line 48 | args = parser.parse_args() 49 | 50 | for k, v in vars(args).items(): 51 | logger.info('Arguments. {}: {}'.format(k, v)) 52 | 53 | # initialize detector 54 | logger.info('Model loading...') 55 | if args.model_name == Models.ssd_lite: 56 | predictor = ObjectDetectorDetectionAPI(args.graph_path) 57 | elif args.model_name == Models.tiny_yolo: 58 | predictor = YOLODarkflowDetector(args.cfg_path, args.weights_path) 59 | elif args.model_name == Models.tf_lite: 60 | predictor = ObjectDetectorLite() 61 | 62 | image = load_image_into_numpy_array(args.image_path) 63 | h, w, _ = image.shape 64 | 65 | start_time = time.time() 66 | result = predictor.detect(image) 67 | finish_time = time.time() 68 | logger.info("time spent: {:.4f}".format(finish_time - start_time)) 69 | 70 | for obj in result: 71 | logger.info('coordinates: {} {}. class: "{}". confidence: {:.2f}'. 72 | format(obj[0], obj[1], obj[3], obj[2])) 73 | 74 | cv2.rectangle(image, obj[0], obj[1], (0, 255, 0), 2) 75 | cv2.putText(image, '{}: {:.2f}'.format(obj[3], obj[2]), 76 | (obj[0][0], obj[0][1] - 5), 77 | cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 0), 2) 78 | 79 | cv2.imwrite(args.result_path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 80 | -------------------------------------------------------------------------------- /detect.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/detect.tflite -------------------------------------------------------------------------------- /detection_stream.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path 3 | import time 4 | import logging 5 | import sys 6 | import numpy as np 7 | import cv2 8 | from picamera.array import PiRGBArray 9 | from picamera import PiCamera 10 | 11 | from object_detector_detection_api import ObjectDetectorDetectionAPI 12 | from yolo_darfklow import YOLODarkflowDetector 13 | from object_detector_detection_api_lite import ObjectDetectorLite 14 | from utils.utils import Models 15 | 16 | 17 | logging.basicConfig( 18 | stream=sys.stdout, 19 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 20 | datefmt=' %I:%M:%S ', 21 | level="INFO" 22 | ) 23 | logger = logging.getLogger('detector') 24 | 25 | 26 | basepath = path.dirname(__file__) 27 | 28 | if __name__ == '__main__': 29 | # initiate the parser 30 | parser = argparse.ArgumentParser(prog='test_models.py') 31 | 32 | # add arguments 33 | parser.add_argument("--model_name", "-mn", type=Models.from_string, 34 | required=True, choices=list(Models), 35 | help="name of detection model: {}".format(list(Models))) 36 | parser.add_argument("--graph_path", "-gp", type=str, required=False, 37 | default=path.join(basepath, "frozen_inference_graph.pb"), 38 | help="path to ssdlight model frozen graph *.pb file") 39 | parser.add_argument("--cfg_path", "-cfg", type=str, required=False, 40 | default=path.join(basepath, "tiny-yolo-voc.cfg"), 41 | help="path to yolo *.cfg file") 42 | parser.add_argument("--weights_path", "-w", type=str, required=False, 43 | default=path.join(basepath, "tiny-yolo-voc.weights"), 44 | help="path to yolo weights *.weights file") 45 | 46 | # read arguments from the command line 47 | args = parser.parse_args() 48 | 49 | for k, v in vars(args).items(): 50 | logger.info('Arguments. {}: {}'.format(k, v)) 51 | 52 | # initialize detector 53 | logger.info('Model loading...') 54 | if args.model_name == Models.ssd_lite: 55 | predictor = ObjectDetectorDetectionAPI(args.graph_path) 56 | elif args.model_name == Models.tiny_yolo: 57 | predictor = YOLODarkflowDetector(args.cfg_path, args.weights_path) 58 | elif args.model_name == Models.tf_lite: 59 | predictor = ObjectDetectorLite() 60 | 61 | # initialize the camera and grab a reference to the raw camera capture 62 | camera = PiCamera() 63 | camera.resolution = (640, 480) 64 | camera.framerate = 32 65 | rawCapture = PiRGBArray(camera, size=(640, 480)) 66 | 67 | # allow the camera to warmup 68 | time.sleep(0.1) 69 | 70 | frame_rate_calc = 1 71 | freq = cv2.getTickFrequency() 72 | 73 | # capture frames from the camera 74 | for frame in camera.capture_continuous(rawCapture, format="bgr", 75 | use_video_port=True): 76 | t1 = cv2.getTickCount() 77 | 78 | # grab the raw NumPy array representing the image, then initialize the timestamp 79 | # and occupied/unoccupied text 80 | image = frame.array 81 | 82 | logger.info("FPS: {0:.2f}".format(frame_rate_calc)) 83 | cv2.putText(image, "FPS: {0:.2f}".format(frame_rate_calc), (20, 20), 84 | cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 0), 2, cv2.LINE_AA) 85 | 86 | result = predictor.detect(image) 87 | 88 | for obj in result: 89 | logger.info('coordinates: {} {}. class: "{}". confidence: {:.2f}'. 90 | format(obj[0], obj[1], obj[3], obj[2])) 91 | 92 | cv2.rectangle(image, obj[0], obj[1], (0, 255, 0), 2) 93 | cv2.putText(image, '{}: {:.2f}'.format(obj[3], obj[2]), 94 | (obj[0][0], obj[0][1] - 5), 95 | cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 0), 2) 96 | 97 | 98 | # show the frame 99 | cv2.imshow("Stream", image) 100 | key = cv2.waitKey(1) & 0xFF 101 | 102 | t2 = cv2.getTickCount() 103 | time1 = (t2 - t1) / freq 104 | frame_rate_calc = 1 / time1 105 | 106 | # clear the stream in preparation for the next frame 107 | rawCapture.truncate(0) 108 | 109 | # if the `q` key was pressed, break from the loop 110 | if key == ord("q"): 111 | break -------------------------------------------------------------------------------- /dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/dog.jpg -------------------------------------------------------------------------------- /jetson_stream.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path 3 | import time 4 | import logging 5 | import sys 6 | import cv2 7 | 8 | # from object_detector_detection_api_lite import ObjectDetectorLite 9 | from object_detector_trt import ObjectDetectorTRT 10 | 11 | 12 | logging.basicConfig( 13 | stream=sys.stdout, 14 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 15 | datefmt=' %I:%M:%S ', 16 | level="INFO" 17 | ) 18 | logger = logging.getLogger('detector') 19 | 20 | 21 | basepath = path.dirname(__file__) 22 | 23 | 24 | def open_cam_onboard(width, height): 25 | # On versions of L4T prior to 28.1, add 'flip-method=2' into gst_str 26 | gst_str = ('nvcamerasrc ! ' 27 | 'video/x-raw(memory:NVMM), ' 28 | 'width=(int)2592, height=(int)1458, ' 29 | 'format=(string)I420, framerate=(fraction)30/1 ! ' 30 | 'nvvidconv ! ' 31 | 'video/x-raw, width=(int){}, height=(int){}, ' 32 | 'format=(string)BGRx ! ' 33 | 'videoconvert ! appsink').format(width, height) 34 | return cv2.VideoCapture(gst_str, cv2.CAP_GSTREAMER) 35 | 36 | 37 | if __name__ == '__main__': 38 | # initiate the parser 39 | parser = argparse.ArgumentParser(prog='test_models.py') 40 | 41 | # add arguments 42 | 43 | # read arguments from the command line 44 | args = parser.parse_args() 45 | 46 | # initialize detector 47 | logger.info('Model loading...') 48 | predictor = ObjectDetectorTRT() 49 | 50 | cap = open_cam_onboard(640, 480) 51 | 52 | if not cap.isOpened(): 53 | sys.exit('Failed to open camera!') 54 | 55 | # allow the camera to warmup 56 | time.sleep(0.1) 57 | 58 | frame_rate_calc = 1 59 | freq = cv2.getTickFrequency() 60 | 61 | while (cap.isOpened()): 62 | t1 = cv2.getTickCount() 63 | 64 | ret, frame = cap.read() 65 | 66 | logger.info("FPS: {0:.2f}".format(frame_rate_calc)) 67 | cv2.putText(frame, "FPS: {0:.2f}".format(frame_rate_calc), (20, 20), 68 | cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 0), 2, cv2.LINE_AA) 69 | 70 | result = predictor.detect(frame) 71 | 72 | for obj in result: 73 | logger.info('coordinates: {} {}. class: "{}". confidence: {:.2f}'. 74 | format(obj[0], obj[1], obj[3], obj[2])) 75 | 76 | cv2.rectangle(frame, obj[0], obj[1], (0, 255, 0), 2) 77 | cv2.putText(frame, '{}: {:.2f}'.format(obj[3], obj[2]), 78 | (obj[0][0], obj[0][1] - 5), 79 | cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 0), 2) 80 | 81 | # show the frame 82 | cv2.imshow("Stream", frame) 83 | key = cv2.waitKey(1) & 0xFF 84 | 85 | t2 = cv2.getTickCount() 86 | time1 = (t2 - t1) / freq 87 | frame_rate_calc = 1 / time1 88 | 89 | # if the `q` key was pressed, break from the loop 90 | if key == ord("q"): 91 | break 92 | -------------------------------------------------------------------------------- /object_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/object_detection/__init__.py -------------------------------------------------------------------------------- /object_detection/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/object_detection/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /object_detection/protos/BUILD: -------------------------------------------------------------------------------- 1 | # Tensorflow Object Detection API: Configuration protos. 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | licenses(["notice"]) 8 | 9 | proto_library( 10 | name = "argmax_matcher_proto", 11 | srcs = ["argmax_matcher.proto"], 12 | ) 13 | 14 | py_proto_library( 15 | name = "argmax_matcher_py_pb2", 16 | api_version = 2, 17 | deps = [":argmax_matcher_proto"], 18 | ) 19 | 20 | proto_library( 21 | name = "bipartite_matcher_proto", 22 | srcs = ["bipartite_matcher.proto"], 23 | ) 24 | 25 | py_proto_library( 26 | name = "bipartite_matcher_py_pb2", 27 | api_version = 2, 28 | deps = [":bipartite_matcher_proto"], 29 | ) 30 | 31 | proto_library( 32 | name = "matcher_proto", 33 | srcs = ["matcher.proto"], 34 | deps = [ 35 | ":argmax_matcher_proto", 36 | ":bipartite_matcher_proto", 37 | ], 38 | ) 39 | 40 | py_proto_library( 41 | name = "matcher_py_pb2", 42 | api_version = 2, 43 | deps = [":matcher_proto"], 44 | ) 45 | 46 | proto_library( 47 | name = "faster_rcnn_box_coder_proto", 48 | srcs = ["faster_rcnn_box_coder.proto"], 49 | ) 50 | 51 | py_proto_library( 52 | name = "faster_rcnn_box_coder_py_pb2", 53 | api_version = 2, 54 | deps = [":faster_rcnn_box_coder_proto"], 55 | ) 56 | 57 | proto_library( 58 | name = "mean_stddev_box_coder_proto", 59 | srcs = ["mean_stddev_box_coder.proto"], 60 | ) 61 | 62 | py_proto_library( 63 | name = "mean_stddev_box_coder_py_pb2", 64 | api_version = 2, 65 | deps = [":mean_stddev_box_coder_proto"], 66 | ) 67 | 68 | proto_library( 69 | name = "square_box_coder_proto", 70 | srcs = ["square_box_coder.proto"], 71 | ) 72 | 73 | py_proto_library( 74 | name = "square_box_coder_py_pb2", 75 | api_version = 2, 76 | deps = [":square_box_coder_proto"], 77 | ) 78 | 79 | proto_library( 80 | name = "box_coder_proto", 81 | srcs = ["box_coder.proto"], 82 | deps = [ 83 | ":faster_rcnn_box_coder_proto", 84 | ":mean_stddev_box_coder_proto", 85 | ":square_box_coder_proto", 86 | ], 87 | ) 88 | 89 | py_proto_library( 90 | name = "box_coder_py_pb2", 91 | api_version = 2, 92 | deps = [":box_coder_proto"], 93 | ) 94 | 95 | proto_library( 96 | name = "grid_anchor_generator_proto", 97 | srcs = ["grid_anchor_generator.proto"], 98 | ) 99 | 100 | py_proto_library( 101 | name = "grid_anchor_generator_py_pb2", 102 | api_version = 2, 103 | deps = [":grid_anchor_generator_proto"], 104 | ) 105 | 106 | proto_library( 107 | name = "ssd_anchor_generator_proto", 108 | srcs = ["ssd_anchor_generator.proto"], 109 | ) 110 | 111 | py_proto_library( 112 | name = "ssd_anchor_generator_py_pb2", 113 | api_version = 2, 114 | deps = [":ssd_anchor_generator_proto"], 115 | ) 116 | 117 | proto_library( 118 | name = "anchor_generator_proto", 119 | srcs = ["anchor_generator.proto"], 120 | deps = [ 121 | ":grid_anchor_generator_proto", 122 | ":ssd_anchor_generator_proto", 123 | ], 124 | ) 125 | 126 | py_proto_library( 127 | name = "anchor_generator_py_pb2", 128 | api_version = 2, 129 | deps = [":anchor_generator_proto"], 130 | ) 131 | 132 | proto_library( 133 | name = "input_reader_proto", 134 | srcs = ["input_reader.proto"], 135 | ) 136 | 137 | py_proto_library( 138 | name = "input_reader_py_pb2", 139 | api_version = 2, 140 | deps = [":input_reader_proto"], 141 | ) 142 | 143 | proto_library( 144 | name = "losses_proto", 145 | srcs = ["losses.proto"], 146 | ) 147 | 148 | py_proto_library( 149 | name = "losses_py_pb2", 150 | api_version = 2, 151 | deps = [":losses_proto"], 152 | ) 153 | 154 | proto_library( 155 | name = "optimizer_proto", 156 | srcs = ["optimizer.proto"], 157 | ) 158 | 159 | py_proto_library( 160 | name = "optimizer_py_pb2", 161 | api_version = 2, 162 | deps = [":optimizer_proto"], 163 | ) 164 | 165 | proto_library( 166 | name = "post_processing_proto", 167 | srcs = ["post_processing.proto"], 168 | ) 169 | 170 | py_proto_library( 171 | name = "post_processing_py_pb2", 172 | api_version = 2, 173 | deps = [":post_processing_proto"], 174 | ) 175 | 176 | proto_library( 177 | name = "hyperparams_proto", 178 | srcs = ["hyperparams.proto"], 179 | ) 180 | 181 | py_proto_library( 182 | name = "hyperparams_py_pb2", 183 | api_version = 2, 184 | deps = [":hyperparams_proto"], 185 | ) 186 | 187 | proto_library( 188 | name = "box_predictor_proto", 189 | srcs = ["box_predictor.proto"], 190 | deps = [":hyperparams_proto"], 191 | ) 192 | 193 | py_proto_library( 194 | name = "box_predictor_py_pb2", 195 | api_version = 2, 196 | deps = [":box_predictor_proto"], 197 | ) 198 | 199 | proto_library( 200 | name = "region_similarity_calculator_proto", 201 | srcs = ["region_similarity_calculator.proto"], 202 | deps = [], 203 | ) 204 | 205 | py_proto_library( 206 | name = "region_similarity_calculator_py_pb2", 207 | api_version = 2, 208 | deps = [":region_similarity_calculator_proto"], 209 | ) 210 | 211 | proto_library( 212 | name = "preprocessor_proto", 213 | srcs = ["preprocessor.proto"], 214 | ) 215 | 216 | py_proto_library( 217 | name = "preprocessor_py_pb2", 218 | api_version = 2, 219 | deps = [":preprocessor_proto"], 220 | ) 221 | 222 | proto_library( 223 | name = "train_proto", 224 | srcs = ["train.proto"], 225 | deps = [ 226 | ":optimizer_proto", 227 | ":preprocessor_proto", 228 | ], 229 | ) 230 | 231 | py_proto_library( 232 | name = "train_py_pb2", 233 | api_version = 2, 234 | deps = [":train_proto"], 235 | ) 236 | 237 | proto_library( 238 | name = "eval_proto", 239 | srcs = ["eval.proto"], 240 | ) 241 | 242 | py_proto_library( 243 | name = "eval_py_pb2", 244 | api_version = 2, 245 | deps = [":eval_proto"], 246 | ) 247 | 248 | proto_library( 249 | name = "image_resizer_proto", 250 | srcs = ["image_resizer.proto"], 251 | ) 252 | 253 | py_proto_library( 254 | name = "image_resizer_py_pb2", 255 | api_version = 2, 256 | deps = [":image_resizer_proto"], 257 | ) 258 | 259 | proto_library( 260 | name = "faster_rcnn_proto", 261 | srcs = ["faster_rcnn.proto"], 262 | deps = [ 263 | ":box_predictor_proto", 264 | "//object_detection/protos:anchor_generator_proto", 265 | "//object_detection/protos:hyperparams_proto", 266 | "//object_detection/protos:image_resizer_proto", 267 | "//object_detection/protos:losses_proto", 268 | "//object_detection/protos:post_processing_proto", 269 | ], 270 | ) 271 | 272 | proto_library( 273 | name = "ssd_proto", 274 | srcs = ["ssd.proto"], 275 | deps = [ 276 | ":anchor_generator_proto", 277 | ":box_coder_proto", 278 | ":box_predictor_proto", 279 | ":hyperparams_proto", 280 | ":image_resizer_proto", 281 | ":losses_proto", 282 | ":matcher_proto", 283 | ":post_processing_proto", 284 | ":region_similarity_calculator_proto", 285 | ], 286 | ) 287 | 288 | proto_library( 289 | name = "model_proto", 290 | srcs = ["model.proto"], 291 | deps = [ 292 | ":faster_rcnn_proto", 293 | ":ssd_proto", 294 | ], 295 | ) 296 | 297 | py_proto_library( 298 | name = "model_py_pb2", 299 | api_version = 2, 300 | deps = [":model_proto"], 301 | ) 302 | 303 | proto_library( 304 | name = "pipeline_proto", 305 | srcs = ["pipeline.proto"], 306 | deps = [ 307 | ":eval_proto", 308 | ":input_reader_proto", 309 | ":model_proto", 310 | ":train_proto", 311 | ], 312 | ) 313 | 314 | py_proto_library( 315 | name = "pipeline_py_pb2", 316 | api_version = 2, 317 | deps = [":pipeline_proto"], 318 | ) 319 | 320 | proto_library( 321 | name = "string_int_label_map_proto", 322 | srcs = ["string_int_label_map.proto"], 323 | ) 324 | 325 | py_proto_library( 326 | name = "string_int_label_map_py_pb2", 327 | api_version = 2, 328 | deps = [":string_int_label_map_proto"], 329 | ) 330 | -------------------------------------------------------------------------------- /object_detection/protos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/object_detection/protos/__init__.py -------------------------------------------------------------------------------- /object_detection/protos/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/object_detection/protos/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /object_detection/protos/__pycache__/string_int_label_map_pb2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/object_detection/protos/__pycache__/string_int_label_map_pb2.cpython-36.pyc -------------------------------------------------------------------------------- /object_detection/protos/anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/grid_anchor_generator.proto"; 6 | import "object_detection/protos/ssd_anchor_generator.proto"; 7 | 8 | // Configuration proto for the anchor generator to use in the object detection 9 | // pipeline. See core/anchor_generator.py for details. 10 | message AnchorGenerator { 11 | oneof anchor_generator_oneof { 12 | GridAnchorGenerator grid_anchor_generator = 1; 13 | SsdAnchorGenerator ssd_anchor_generator = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /object_detection/protos/anchor_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/anchor_generator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from object_detection.protos import grid_anchor_generator_pb2 as object__detection_dot_protos_dot_grid__anchor__generator__pb2 17 | from object_detection.protos import ssd_anchor_generator_pb2 as object__detection_dot_protos_dot_ssd__anchor__generator__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/anchor_generator.proto', 22 | package='object_detection.protos', 23 | syntax='proto2', 24 | serialized_pb=_b('\n.object_detection/protos/anchor_generator.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/grid_anchor_generator.proto\x1a\x32object_detection/protos/ssd_anchor_generator.proto\"\xc7\x01\n\x0f\x41nchorGenerator\x12M\n\x15grid_anchor_generator\x18\x01 \x01(\x0b\x32,.object_detection.protos.GridAnchorGeneratorH\x00\x12K\n\x14ssd_anchor_generator\x18\x02 \x01(\x0b\x32+.object_detection.protos.SsdAnchorGeneratorH\x00\x42\x18\n\x16\x61nchor_generator_oneof') 25 | , 26 | dependencies=[object__detection_dot_protos_dot_grid__anchor__generator__pb2.DESCRIPTOR,object__detection_dot_protos_dot_ssd__anchor__generator__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _ANCHORGENERATOR = _descriptor.Descriptor( 32 | name='AnchorGenerator', 33 | full_name='object_detection.protos.AnchorGenerator', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='grid_anchor_generator', full_name='object_detection.protos.AnchorGenerator.grid_anchor_generator', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='ssd_anchor_generator', full_name='object_detection.protos.AnchorGenerator.ssd_anchor_generator', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | syntax='proto2', 61 | extension_ranges=[], 62 | oneofs=[ 63 | _descriptor.OneofDescriptor( 64 | name='anchor_generator_oneof', full_name='object_detection.protos.AnchorGenerator.anchor_generator_oneof', 65 | index=0, containing_type=None, fields=[]), 66 | ], 67 | serialized_start=181, 68 | serialized_end=380, 69 | ) 70 | 71 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].message_type = object__detection_dot_protos_dot_grid__anchor__generator__pb2._GRIDANCHORGENERATOR 72 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].message_type = object__detection_dot_protos_dot_ssd__anchor__generator__pb2._SSDANCHORGENERATOR 73 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append( 74 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator']) 75 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'] 76 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append( 77 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator']) 78 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'] 79 | DESCRIPTOR.message_types_by_name['AnchorGenerator'] = _ANCHORGENERATOR 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | AnchorGenerator = _reflection.GeneratedProtocolMessageType('AnchorGenerator', (_message.Message,), dict( 83 | DESCRIPTOR = _ANCHORGENERATOR, 84 | __module__ = 'object_detection.protos.anchor_generator_pb2' 85 | # @@protoc_insertion_point(class_scope:object_detection.protos.AnchorGenerator) 86 | )) 87 | _sym_db.RegisterMessage(AnchorGenerator) 88 | 89 | 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /object_detection/protos/argmax_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for ArgMaxMatcher. See 6 | // matchers/argmax_matcher.py for details. 7 | message ArgMaxMatcher { 8 | // Threshold for positive matches. 9 | optional float matched_threshold = 1 [default = 0.5]; 10 | 11 | // Threshold for negative matches. 12 | optional float unmatched_threshold = 2 [default = 0.5]; 13 | 14 | // Whether to construct ArgMaxMatcher without thresholds. 15 | optional bool ignore_thresholds = 3 [default = false]; 16 | 17 | // If True then negative matches are the ones below the unmatched_threshold, 18 | // whereas ignored matches are in between the matched and umatched 19 | // threshold. If False, then negative matches are in between the matched 20 | // and unmatched threshold, and everything lower than unmatched is ignored. 21 | optional bool negatives_lower_than_unmatched = 4 [default = true]; 22 | 23 | // Whether to ensure each row is matched to at least one column. 24 | optional bool force_match_for_each_row = 5 [default = false]; 25 | } 26 | -------------------------------------------------------------------------------- /object_detection/protos/argmax_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/argmax_matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/argmax_matcher.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n,object_detection/protos/argmax_matcher.proto\x12\x17object_detection.protos\"\xca\x01\n\rArgMaxMatcher\x12\x1e\n\x11matched_threshold\x18\x01 \x01(\x02:\x03\x30.5\x12 \n\x13unmatched_threshold\x18\x02 \x01(\x02:\x03\x30.5\x12 \n\x11ignore_thresholds\x18\x03 \x01(\x08:\x05\x66\x61lse\x12,\n\x1enegatives_lower_than_unmatched\x18\x04 \x01(\x08:\x04true\x12\'\n\x18\x66orce_match_for_each_row\x18\x05 \x01(\x08:\x05\x66\x61lse') 23 | ) 24 | 25 | 26 | 27 | 28 | _ARGMAXMATCHER = _descriptor.Descriptor( 29 | name='ArgMaxMatcher', 30 | full_name='object_detection.protos.ArgMaxMatcher', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='matched_threshold', full_name='object_detection.protos.ArgMaxMatcher.matched_threshold', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=float(0.5), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='unmatched_threshold', full_name='object_detection.protos.ArgMaxMatcher.unmatched_threshold', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=float(0.5), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='ignore_thresholds', full_name='object_detection.protos.ArgMaxMatcher.ignore_thresholds', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=True, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='negatives_lower_than_unmatched', full_name='object_detection.protos.ArgMaxMatcher.negatives_lower_than_unmatched', index=3, 58 | number=4, type=8, cpp_type=7, label=1, 59 | has_default_value=True, default_value=True, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='force_match_for_each_row', full_name='object_detection.protos.ArgMaxMatcher.force_match_for_each_row', index=4, 65 | number=5, type=8, cpp_type=7, label=1, 66 | has_default_value=True, default_value=False, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | syntax='proto2', 79 | extension_ranges=[], 80 | oneofs=[ 81 | ], 82 | serialized_start=74, 83 | serialized_end=276, 84 | ) 85 | 86 | DESCRIPTOR.message_types_by_name['ArgMaxMatcher'] = _ARGMAXMATCHER 87 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 88 | 89 | ArgMaxMatcher = _reflection.GeneratedProtocolMessageType('ArgMaxMatcher', (_message.Message,), dict( 90 | DESCRIPTOR = _ARGMAXMATCHER, 91 | __module__ = 'object_detection.protos.argmax_matcher_pb2' 92 | # @@protoc_insertion_point(class_scope:object_detection.protos.ArgMaxMatcher) 93 | )) 94 | _sym_db.RegisterMessage(ArgMaxMatcher) 95 | 96 | 97 | # @@protoc_insertion_point(module_scope) 98 | -------------------------------------------------------------------------------- /object_detection/protos/bipartite_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for bipartite matcher. See 6 | // matchers/bipartite_matcher.py for details. 7 | message BipartiteMatcher { 8 | } 9 | -------------------------------------------------------------------------------- /object_detection/protos/bipartite_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/bipartite_matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/bipartite_matcher.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n/object_detection/protos/bipartite_matcher.proto\x12\x17object_detection.protos\"\x12\n\x10\x42ipartiteMatcher') 23 | ) 24 | 25 | 26 | 27 | 28 | _BIPARTITEMATCHER = _descriptor.Descriptor( 29 | name='BipartiteMatcher', 30 | full_name='object_detection.protos.BipartiteMatcher', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | ], 36 | extensions=[ 37 | ], 38 | nested_types=[], 39 | enum_types=[ 40 | ], 41 | options=None, 42 | is_extendable=False, 43 | syntax='proto2', 44 | extension_ranges=[], 45 | oneofs=[ 46 | ], 47 | serialized_start=76, 48 | serialized_end=94, 49 | ) 50 | 51 | DESCRIPTOR.message_types_by_name['BipartiteMatcher'] = _BIPARTITEMATCHER 52 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 53 | 54 | BipartiteMatcher = _reflection.GeneratedProtocolMessageType('BipartiteMatcher', (_message.Message,), dict( 55 | DESCRIPTOR = _BIPARTITEMATCHER, 56 | __module__ = 'object_detection.protos.bipartite_matcher_pb2' 57 | # @@protoc_insertion_point(class_scope:object_detection.protos.BipartiteMatcher) 58 | )) 59 | _sym_db.RegisterMessage(BipartiteMatcher) 60 | 61 | 62 | # @@protoc_insertion_point(module_scope) 63 | -------------------------------------------------------------------------------- /object_detection/protos/box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn_box_coder.proto"; 6 | import "object_detection/protos/mean_stddev_box_coder.proto"; 7 | import "object_detection/protos/square_box_coder.proto"; 8 | 9 | // Configuration proto for the box coder to be used in the object detection 10 | // pipeline. See core/box_coder.py for details. 11 | message BoxCoder { 12 | oneof box_coder_oneof { 13 | FasterRcnnBoxCoder faster_rcnn_box_coder = 1; 14 | MeanStddevBoxCoder mean_stddev_box_coder = 2; 15 | SquareBoxCoder square_box_coder = 3; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /object_detection/protos/box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from object_detection.protos import faster_rcnn_box_coder_pb2 as object__detection_dot_protos_dot_faster__rcnn__box__coder__pb2 17 | from object_detection.protos import mean_stddev_box_coder_pb2 as object__detection_dot_protos_dot_mean__stddev__box__coder__pb2 18 | from object_detection.protos import square_box_coder_pb2 as object__detection_dot_protos_dot_square__box__coder__pb2 19 | 20 | 21 | DESCRIPTOR = _descriptor.FileDescriptor( 22 | name='object_detection/protos/box_coder.proto', 23 | package='object_detection.protos', 24 | syntax='proto2', 25 | serialized_pb=_b('\n\'object_detection/protos/box_coder.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/faster_rcnn_box_coder.proto\x1a\x33object_detection/protos/mean_stddev_box_coder.proto\x1a.object_detection/protos/square_box_coder.proto\"\xfe\x01\n\x08\x42oxCoder\x12L\n\x15\x66\x61ster_rcnn_box_coder\x18\x01 \x01(\x0b\x32+.object_detection.protos.FasterRcnnBoxCoderH\x00\x12L\n\x15mean_stddev_box_coder\x18\x02 \x01(\x0b\x32+.object_detection.protos.MeanStddevBoxCoderH\x00\x12\x43\n\x10square_box_coder\x18\x03 \x01(\x0b\x32\'.object_detection.protos.SquareBoxCoderH\x00\x42\x11\n\x0f\x62ox_coder_oneof') 26 | , 27 | dependencies=[object__detection_dot_protos_dot_faster__rcnn__box__coder__pb2.DESCRIPTOR,object__detection_dot_protos_dot_mean__stddev__box__coder__pb2.DESCRIPTOR,object__detection_dot_protos_dot_square__box__coder__pb2.DESCRIPTOR,]) 28 | 29 | 30 | 31 | 32 | _BOXCODER = _descriptor.Descriptor( 33 | name='BoxCoder', 34 | full_name='object_detection.protos.BoxCoder', 35 | filename=None, 36 | file=DESCRIPTOR, 37 | containing_type=None, 38 | fields=[ 39 | _descriptor.FieldDescriptor( 40 | name='faster_rcnn_box_coder', full_name='object_detection.protos.BoxCoder.faster_rcnn_box_coder', index=0, 41 | number=1, type=11, cpp_type=10, label=1, 42 | has_default_value=False, default_value=None, 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | _descriptor.FieldDescriptor( 47 | name='mean_stddev_box_coder', full_name='object_detection.protos.BoxCoder.mean_stddev_box_coder', index=1, 48 | number=2, type=11, cpp_type=10, label=1, 49 | has_default_value=False, default_value=None, 50 | message_type=None, enum_type=None, containing_type=None, 51 | is_extension=False, extension_scope=None, 52 | options=None), 53 | _descriptor.FieldDescriptor( 54 | name='square_box_coder', full_name='object_detection.protos.BoxCoder.square_box_coder', index=2, 55 | number=3, type=11, cpp_type=10, label=1, 56 | has_default_value=False, default_value=None, 57 | message_type=None, enum_type=None, containing_type=None, 58 | is_extension=False, extension_scope=None, 59 | options=None), 60 | ], 61 | extensions=[ 62 | ], 63 | nested_types=[], 64 | enum_types=[ 65 | ], 66 | options=None, 67 | is_extendable=False, 68 | syntax='proto2', 69 | extension_ranges=[], 70 | oneofs=[ 71 | _descriptor.OneofDescriptor( 72 | name='box_coder_oneof', full_name='object_detection.protos.BoxCoder.box_coder_oneof', 73 | index=0, containing_type=None, fields=[]), 74 | ], 75 | serialized_start=223, 76 | serialized_end=477, 77 | ) 78 | 79 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].message_type = object__detection_dot_protos_dot_faster__rcnn__box__coder__pb2._FASTERRCNNBOXCODER 80 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].message_type = object__detection_dot_protos_dot_mean__stddev__box__coder__pb2._MEANSTDDEVBOXCODER 81 | _BOXCODER.fields_by_name['square_box_coder'].message_type = object__detection_dot_protos_dot_square__box__coder__pb2._SQUAREBOXCODER 82 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 83 | _BOXCODER.fields_by_name['faster_rcnn_box_coder']) 84 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 85 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 86 | _BOXCODER.fields_by_name['mean_stddev_box_coder']) 87 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 88 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 89 | _BOXCODER.fields_by_name['square_box_coder']) 90 | _BOXCODER.fields_by_name['square_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 91 | DESCRIPTOR.message_types_by_name['BoxCoder'] = _BOXCODER 92 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 93 | 94 | BoxCoder = _reflection.GeneratedProtocolMessageType('BoxCoder', (_message.Message,), dict( 95 | DESCRIPTOR = _BOXCODER, 96 | __module__ = 'object_detection.protos.box_coder_pb2' 97 | # @@protoc_insertion_point(class_scope:object_detection.protos.BoxCoder) 98 | )) 99 | _sym_db.RegisterMessage(BoxCoder) 100 | 101 | 102 | # @@protoc_insertion_point(module_scope) 103 | -------------------------------------------------------------------------------- /object_detection/protos/box_predictor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/hyperparams.proto"; 6 | 7 | 8 | // Configuration proto for box predictor. See core/box_predictor.py for details. 9 | message BoxPredictor { 10 | oneof box_predictor_oneof { 11 | ConvolutionalBoxPredictor convolutional_box_predictor = 1; 12 | MaskRCNNBoxPredictor mask_rcnn_box_predictor = 2; 13 | RfcnBoxPredictor rfcn_box_predictor = 3; 14 | } 15 | } 16 | 17 | // Configuration proto for Convolutional box predictor. 18 | message ConvolutionalBoxPredictor { 19 | // Hyperparameters for convolution ops used in the box predictor. 20 | optional Hyperparams conv_hyperparams = 1; 21 | 22 | // Minumum feature depth prior to predicting box encodings and class 23 | // predictions. 24 | optional int32 min_depth = 2 [default = 0]; 25 | 26 | // Maximum feature depth prior to predicting box encodings and class 27 | // predictions. If max_depth is set to 0, no additional feature map will be 28 | // inserted before location and class predictions. 29 | optional int32 max_depth = 3 [default = 0]; 30 | 31 | // Number of the additional conv layers before the predictor. 32 | optional int32 num_layers_before_predictor = 4 [default = 0]; 33 | 34 | // Whether to use dropout for class prediction. 35 | optional bool use_dropout = 5 [default = true]; 36 | 37 | // Keep probability for dropout 38 | optional float dropout_keep_probability = 6 [default = 0.8]; 39 | 40 | // Size of final convolution kernel. If the spatial resolution of the feature 41 | // map is smaller than the kernel size, then the kernel size is set to 42 | // min(feature_width, feature_height). 43 | optional int32 kernel_size = 7 [default = 1]; 44 | 45 | // Size of the encoding for boxes. 46 | optional int32 box_code_size = 8 [default = 4]; 47 | 48 | // Whether to apply sigmoid to the output of class predictions. 49 | // TODO: Do we need this since we have a post processing module.? 50 | optional bool apply_sigmoid_to_scores = 9 [default = false]; 51 | } 52 | 53 | message MaskRCNNBoxPredictor { 54 | // Hyperparameters for fully connected ops used in the box predictor. 55 | optional Hyperparams fc_hyperparams = 1; 56 | 57 | // Whether to use dropout op prior to the both box and class predictions. 58 | optional bool use_dropout = 2 [default= false]; 59 | 60 | // Keep probability for dropout. This is only used if use_dropout is true. 61 | optional float dropout_keep_probability = 3 [default = 0.5]; 62 | 63 | // Size of the encoding for the boxes. 64 | optional int32 box_code_size = 4 [default = 4]; 65 | 66 | // Hyperparameters for convolution ops used in the box predictor. 67 | optional Hyperparams conv_hyperparams = 5; 68 | 69 | // Whether to predict instance masks inside detection boxes. 70 | optional bool predict_instance_masks = 6 [default = false]; 71 | 72 | // The depth for the first conv2d_transpose op applied to the 73 | // image_features in the mask prediciton branch 74 | optional int32 mask_prediction_conv_depth = 7 [default = 256]; 75 | 76 | // Whether to predict keypoints inside detection boxes. 77 | optional bool predict_keypoints = 8 [default = false]; 78 | } 79 | 80 | message RfcnBoxPredictor { 81 | // Hyperparameters for convolution ops used in the box predictor. 82 | optional Hyperparams conv_hyperparams = 1; 83 | 84 | // Bin sizes for RFCN crops. 85 | optional int32 num_spatial_bins_height = 2 [default = 3]; 86 | 87 | optional int32 num_spatial_bins_width = 3 [default = 3]; 88 | 89 | // Target depth to reduce the input image features to. 90 | optional int32 depth = 4 [default=1024]; 91 | 92 | // Size of the encoding for the boxes. 93 | optional int32 box_code_size = 5 [default = 4]; 94 | 95 | // Size to resize the rfcn crops to. 96 | optional int32 crop_height = 6 [default= 12]; 97 | 98 | optional int32 crop_width = 7 [default=12]; 99 | } 100 | -------------------------------------------------------------------------------- /object_detection/protos/eval.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for configuring DetectionModel evaluation jobs (eval.py). 6 | message EvalConfig { 7 | // Number of visualization images to generate. 8 | optional uint32 num_visualizations = 1 [default=10]; 9 | 10 | // Number of examples to process of evaluation. 11 | optional uint32 num_examples = 2 [default=5000]; 12 | 13 | // How often to run evaluation. 14 | optional uint32 eval_interval_secs = 3 [default=300]; 15 | 16 | // Maximum number of times to run evaluation. If set to 0, will run forever. 17 | optional uint32 max_evals = 4 [default=0]; 18 | 19 | // Whether the TensorFlow graph used for evaluation should be saved to disk. 20 | optional bool save_graph = 5 [default=false]; 21 | 22 | // Path to directory to store visualizations in. If empty, visualization 23 | // images are not exported (only shown on Tensorboard). 24 | optional string visualization_export_dir = 6 [default=""]; 25 | 26 | // BNS name of the TensorFlow master. 27 | optional string eval_master = 7 [default=""]; 28 | 29 | // Type of metrics to use for evaluation. Currently supports only Pascal VOC 30 | // detection metrics. 31 | optional string metrics_set = 8 [default="pascal_voc_metrics"]; 32 | 33 | // Path to export detections to COCO compatible JSON format. 34 | optional string export_path = 9 [default='']; 35 | 36 | // Option to not read groundtruth labels and only export detections to 37 | // COCO-compatible JSON file. 38 | optional bool ignore_groundtruth = 10 [default=false]; 39 | 40 | // Use exponential moving averages of variables for evaluation. 41 | // TODO: When this is false make sure the model is constructed 42 | // without moving averages in restore_fn. 43 | optional bool use_moving_averages = 11 [default=false]; 44 | 45 | // Whether to evaluate instance masks. 46 | optional bool eval_instance_masks = 12 [default=false]; 47 | } 48 | -------------------------------------------------------------------------------- /object_detection/protos/eval_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/eval.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/eval.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n\"object_detection/protos/eval.proto\x12\x17object_detection.protos\"\x80\x03\n\nEvalConfig\x12\x1e\n\x12num_visualizations\x18\x01 \x01(\r:\x02\x31\x30\x12\x1a\n\x0cnum_examples\x18\x02 \x01(\r:\x04\x35\x30\x30\x30\x12\x1f\n\x12\x65val_interval_secs\x18\x03 \x01(\r:\x03\x33\x30\x30\x12\x14\n\tmax_evals\x18\x04 \x01(\r:\x01\x30\x12\x19\n\nsave_graph\x18\x05 \x01(\x08:\x05\x66\x61lse\x12\"\n\x18visualization_export_dir\x18\x06 \x01(\t:\x00\x12\x15\n\x0b\x65val_master\x18\x07 \x01(\t:\x00\x12\'\n\x0bmetrics_set\x18\x08 \x01(\t:\x12pascal_voc_metrics\x12\x15\n\x0b\x65xport_path\x18\t \x01(\t:\x00\x12!\n\x12ignore_groundtruth\x18\n \x01(\x08:\x05\x66\x61lse\x12\"\n\x13use_moving_averages\x18\x0b \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x65val_instance_masks\x18\x0c \x01(\x08:\x05\x66\x61lse') 23 | ) 24 | 25 | 26 | 27 | 28 | _EVALCONFIG = _descriptor.Descriptor( 29 | name='EvalConfig', 30 | full_name='object_detection.protos.EvalConfig', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='num_visualizations', full_name='object_detection.protos.EvalConfig.num_visualizations', index=0, 37 | number=1, type=13, cpp_type=3, label=1, 38 | has_default_value=True, default_value=10, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='num_examples', full_name='object_detection.protos.EvalConfig.num_examples', index=1, 44 | number=2, type=13, cpp_type=3, label=1, 45 | has_default_value=True, default_value=5000, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='eval_interval_secs', full_name='object_detection.protos.EvalConfig.eval_interval_secs', index=2, 51 | number=3, type=13, cpp_type=3, label=1, 52 | has_default_value=True, default_value=300, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='max_evals', full_name='object_detection.protos.EvalConfig.max_evals', index=3, 58 | number=4, type=13, cpp_type=3, label=1, 59 | has_default_value=True, default_value=0, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='save_graph', full_name='object_detection.protos.EvalConfig.save_graph', index=4, 65 | number=5, type=8, cpp_type=7, label=1, 66 | has_default_value=True, default_value=False, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='visualization_export_dir', full_name='object_detection.protos.EvalConfig.visualization_export_dir', index=5, 72 | number=6, type=9, cpp_type=9, label=1, 73 | has_default_value=True, default_value=_b("").decode('utf-8'), 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='eval_master', full_name='object_detection.protos.EvalConfig.eval_master', index=6, 79 | number=7, type=9, cpp_type=9, label=1, 80 | has_default_value=True, default_value=_b("").decode('utf-8'), 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | _descriptor.FieldDescriptor( 85 | name='metrics_set', full_name='object_detection.protos.EvalConfig.metrics_set', index=7, 86 | number=8, type=9, cpp_type=9, label=1, 87 | has_default_value=True, default_value=_b("pascal_voc_metrics").decode('utf-8'), 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None), 91 | _descriptor.FieldDescriptor( 92 | name='export_path', full_name='object_detection.protos.EvalConfig.export_path', index=8, 93 | number=9, type=9, cpp_type=9, label=1, 94 | has_default_value=True, default_value=_b("").decode('utf-8'), 95 | message_type=None, enum_type=None, containing_type=None, 96 | is_extension=False, extension_scope=None, 97 | options=None), 98 | _descriptor.FieldDescriptor( 99 | name='ignore_groundtruth', full_name='object_detection.protos.EvalConfig.ignore_groundtruth', index=9, 100 | number=10, type=8, cpp_type=7, label=1, 101 | has_default_value=True, default_value=False, 102 | message_type=None, enum_type=None, containing_type=None, 103 | is_extension=False, extension_scope=None, 104 | options=None), 105 | _descriptor.FieldDescriptor( 106 | name='use_moving_averages', full_name='object_detection.protos.EvalConfig.use_moving_averages', index=10, 107 | number=11, type=8, cpp_type=7, label=1, 108 | has_default_value=True, default_value=False, 109 | message_type=None, enum_type=None, containing_type=None, 110 | is_extension=False, extension_scope=None, 111 | options=None), 112 | _descriptor.FieldDescriptor( 113 | name='eval_instance_masks', full_name='object_detection.protos.EvalConfig.eval_instance_masks', index=11, 114 | number=12, type=8, cpp_type=7, label=1, 115 | has_default_value=True, default_value=False, 116 | message_type=None, enum_type=None, containing_type=None, 117 | is_extension=False, extension_scope=None, 118 | options=None), 119 | ], 120 | extensions=[ 121 | ], 122 | nested_types=[], 123 | enum_types=[ 124 | ], 125 | options=None, 126 | is_extendable=False, 127 | syntax='proto2', 128 | extension_ranges=[], 129 | oneofs=[ 130 | ], 131 | serialized_start=64, 132 | serialized_end=448, 133 | ) 134 | 135 | DESCRIPTOR.message_types_by_name['EvalConfig'] = _EVALCONFIG 136 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 137 | 138 | EvalConfig = _reflection.GeneratedProtocolMessageType('EvalConfig', (_message.Message,), dict( 139 | DESCRIPTOR = _EVALCONFIG, 140 | __module__ = 'object_detection.protos.eval_pb2' 141 | # @@protoc_insertion_point(class_scope:object_detection.protos.EvalConfig) 142 | )) 143 | _sym_db.RegisterMessage(EvalConfig) 144 | 145 | 146 | # @@protoc_insertion_point(module_scope) 147 | -------------------------------------------------------------------------------- /object_detection/protos/faster_rcnn.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/anchor_generator.proto"; 6 | import "object_detection/protos/box_predictor.proto"; 7 | import "object_detection/protos/hyperparams.proto"; 8 | import "object_detection/protos/image_resizer.proto"; 9 | import "object_detection/protos/losses.proto"; 10 | import "object_detection/protos/post_processing.proto"; 11 | 12 | // Configuration for Faster R-CNN models. 13 | // See meta_architectures/faster_rcnn_meta_arch.py and models/model_builder.py 14 | // 15 | // Naming conventions: 16 | // Faster R-CNN models have two stages: a first stage region proposal network 17 | // (or RPN) and a second stage box classifier. We thus use the prefixes 18 | // `first_stage_` and `second_stage_` to indicate the stage to which each 19 | // parameter pertains when relevant. 20 | message FasterRcnn { 21 | 22 | // Whether to construct only the Region Proposal Network (RPN). 23 | optional bool first_stage_only = 1 [default=false]; 24 | 25 | // Number of classes to predict. 26 | optional int32 num_classes = 3; 27 | 28 | // Image resizer for preprocessing the input image. 29 | optional ImageResizer image_resizer = 4; 30 | 31 | // Feature extractor config. 32 | optional FasterRcnnFeatureExtractor feature_extractor = 5; 33 | 34 | 35 | // (First stage) region proposal network (RPN) parameters. 36 | 37 | // Anchor generator to compute RPN anchors. 38 | optional AnchorGenerator first_stage_anchor_generator = 6; 39 | 40 | // Atrous rate for the convolution op applied to the 41 | // `first_stage_features_to_crop` tensor to obtain box predictions. 42 | optional int32 first_stage_atrous_rate = 7 [default=1]; 43 | 44 | // Hyperparameters for the convolutional RPN box predictor. 45 | optional Hyperparams first_stage_box_predictor_conv_hyperparams = 8; 46 | 47 | // Kernel size to use for the convolution op just prior to RPN box 48 | // predictions. 49 | optional int32 first_stage_box_predictor_kernel_size = 9 [default=3]; 50 | 51 | // Output depth for the convolution op just prior to RPN box predictions. 52 | optional int32 first_stage_box_predictor_depth = 10 [default=512]; 53 | 54 | // The batch size to use for computing the first stage objectness and 55 | // location losses. 56 | optional int32 first_stage_minibatch_size = 11 [default=256]; 57 | 58 | // Fraction of positive examples per image for the RPN. 59 | optional float first_stage_positive_balance_fraction = 12 [default=0.5]; 60 | 61 | // Non max suppression score threshold applied to first stage RPN proposals. 62 | optional float first_stage_nms_score_threshold = 13 [default=0.0]; 63 | 64 | // Non max suppression IOU threshold applied to first stage RPN proposals. 65 | optional float first_stage_nms_iou_threshold = 14 [default=0.7]; 66 | 67 | // Maximum number of RPN proposals retained after first stage postprocessing. 68 | optional int32 first_stage_max_proposals = 15 [default=300]; 69 | 70 | // First stage RPN localization loss weight. 71 | optional float first_stage_localization_loss_weight = 16 [default=1.0]; 72 | 73 | // First stage RPN objectness loss weight. 74 | optional float first_stage_objectness_loss_weight = 17 [default=1.0]; 75 | 76 | 77 | // Per-region cropping parameters. 78 | // Note that if a R-FCN model is constructed the per region cropping 79 | // parameters below are ignored. 80 | 81 | // Output size (width and height are set to be the same) of the initial 82 | // bilinear interpolation based cropping during ROI pooling. 83 | optional int32 initial_crop_size = 18; 84 | 85 | // Kernel size of the max pool op on the cropped feature map during 86 | // ROI pooling. 87 | optional int32 maxpool_kernel_size = 19; 88 | 89 | // Stride of the max pool op on the cropped feature map during ROI pooling. 90 | optional int32 maxpool_stride = 20; 91 | 92 | 93 | // (Second stage) box classifier parameters 94 | 95 | // Hyperparameters for the second stage box predictor. If box predictor type 96 | // is set to rfcn_box_predictor, a R-FCN model is constructed, otherwise a 97 | // Faster R-CNN model is constructed. 98 | optional BoxPredictor second_stage_box_predictor = 21; 99 | 100 | // The batch size per image used for computing the classification and refined 101 | // location loss of the box classifier. 102 | // Note that this field is ignored if `hard_example_miner` is configured. 103 | optional int32 second_stage_batch_size = 22 [default=64]; 104 | 105 | // Fraction of positive examples to use per image for the box classifier. 106 | optional float second_stage_balance_fraction = 23 [default=0.25]; 107 | 108 | // Post processing to apply on the second stage box classifier predictions. 109 | // Note: the `score_converter` provided to the FasterRCNNMetaArch constructor 110 | // is taken from this `second_stage_post_processing` proto. 111 | optional PostProcessing second_stage_post_processing = 24; 112 | 113 | // Second stage refined localization loss weight. 114 | optional float second_stage_localization_loss_weight = 25 [default=1.0]; 115 | 116 | // Second stage classification loss weight 117 | optional float second_stage_classification_loss_weight = 26 [default=1.0]; 118 | 119 | // If not left to default, applies hard example mining. 120 | optional HardExampleMiner hard_example_miner = 27; 121 | } 122 | 123 | 124 | message FasterRcnnFeatureExtractor { 125 | // Type of Faster R-CNN model (e.g., 'faster_rcnn_resnet101'; 126 | // See models/model_builder.py for expected types). 127 | optional string type = 1; 128 | 129 | // Output stride of extracted RPN feature map. 130 | optional int32 first_stage_features_stride = 2 [default=16]; 131 | } 132 | -------------------------------------------------------------------------------- /object_detection/protos/faster_rcnn_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for FasterRCNNBoxCoder. See 6 | // box_coders/faster_rcnn_box_coder.py for details. 7 | message FasterRcnnBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box height. 13 | optional float height_scale = 3 [default = 5.0]; 14 | 15 | // Scale factor for anchor encoded box width. 16 | optional float width_scale = 4 [default = 5.0]; 17 | } 18 | -------------------------------------------------------------------------------- /object_detection/protos/faster_rcnn_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/faster_rcnn_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/faster_rcnn_box_coder.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n3object_detection/protos/faster_rcnn_box_coder.proto\x12\x17object_detection.protos\"o\n\x12\x46\x61sterRcnnBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0cheight_scale\x18\x03 \x01(\x02:\x01\x35\x12\x16\n\x0bwidth_scale\x18\x04 \x01(\x02:\x01\x35') 23 | ) 24 | 25 | 26 | 27 | 28 | _FASTERRCNNBOXCODER = _descriptor.Descriptor( 29 | name='FasterRcnnBoxCoder', 30 | full_name='object_detection.protos.FasterRcnnBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='y_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.y_scale', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=float(10), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='x_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.x_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=float(10), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='height_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.height_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=float(5), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='width_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.width_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=True, default_value=float(5), 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | ], 64 | extensions=[ 65 | ], 66 | nested_types=[], 67 | enum_types=[ 68 | ], 69 | options=None, 70 | is_extendable=False, 71 | syntax='proto2', 72 | extension_ranges=[], 73 | oneofs=[ 74 | ], 75 | serialized_start=80, 76 | serialized_end=191, 77 | ) 78 | 79 | DESCRIPTOR.message_types_by_name['FasterRcnnBoxCoder'] = _FASTERRCNNBOXCODER 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | FasterRcnnBoxCoder = _reflection.GeneratedProtocolMessageType('FasterRcnnBoxCoder', (_message.Message,), dict( 83 | DESCRIPTOR = _FASTERRCNNBOXCODER, 84 | __module__ = 'object_detection.protos.faster_rcnn_box_coder_pb2' 85 | # @@protoc_insertion_point(class_scope:object_detection.protos.FasterRcnnBoxCoder) 86 | )) 87 | _sym_db.RegisterMessage(FasterRcnnBoxCoder) 88 | 89 | 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /object_detection/protos/grid_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for GridAnchorGenerator. See 6 | // anchor_generators/grid_anchor_generator.py for details. 7 | message GridAnchorGenerator { 8 | // Anchor height in pixels. 9 | optional int32 height = 1 [default = 256]; 10 | 11 | // Anchor width in pixels. 12 | optional int32 width = 2 [default = 256]; 13 | 14 | // Anchor stride in height dimension in pixels. 15 | optional int32 height_stride = 3 [default = 16]; 16 | 17 | // Anchor stride in width dimension in pixels. 18 | optional int32 width_stride = 4 [default = 16]; 19 | 20 | // Anchor height offset in pixels. 21 | optional int32 height_offset = 5 [default = 0]; 22 | 23 | // Anchor width offset in pixels. 24 | optional int32 width_offset = 6 [default = 0]; 25 | 26 | // At any given location, len(scales) * len(aspect_ratios) anchors are 27 | // generated with all possible combinations of scales and aspect ratios. 28 | 29 | // List of scales for the anchors. 30 | repeated float scales = 7; 31 | 32 | // List of aspect ratios for the anchors. 33 | repeated float aspect_ratios = 8; 34 | } 35 | -------------------------------------------------------------------------------- /object_detection/protos/grid_anchor_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/grid_anchor_generator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/grid_anchor_generator.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n3object_detection/protos/grid_anchor_generator.proto\x12\x17object_detection.protos\"\xcd\x01\n\x13GridAnchorGenerator\x12\x13\n\x06height\x18\x01 \x01(\x05:\x03\x32\x35\x36\x12\x12\n\x05width\x18\x02 \x01(\x05:\x03\x32\x35\x36\x12\x19\n\rheight_stride\x18\x03 \x01(\x05:\x02\x31\x36\x12\x18\n\x0cwidth_stride\x18\x04 \x01(\x05:\x02\x31\x36\x12\x18\n\rheight_offset\x18\x05 \x01(\x05:\x01\x30\x12\x17\n\x0cwidth_offset\x18\x06 \x01(\x05:\x01\x30\x12\x0e\n\x06scales\x18\x07 \x03(\x02\x12\x15\n\raspect_ratios\x18\x08 \x03(\x02') 23 | ) 24 | 25 | 26 | 27 | 28 | _GRIDANCHORGENERATOR = _descriptor.Descriptor( 29 | name='GridAnchorGenerator', 30 | full_name='object_detection.protos.GridAnchorGenerator', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='height', full_name='object_detection.protos.GridAnchorGenerator.height', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=True, default_value=256, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='width', full_name='object_detection.protos.GridAnchorGenerator.width', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=True, default_value=256, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='height_stride', full_name='object_detection.protos.GridAnchorGenerator.height_stride', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=True, default_value=16, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='width_stride', full_name='object_detection.protos.GridAnchorGenerator.width_stride', index=3, 58 | number=4, type=5, cpp_type=1, label=1, 59 | has_default_value=True, default_value=16, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='height_offset', full_name='object_detection.protos.GridAnchorGenerator.height_offset', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=True, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='width_offset', full_name='object_detection.protos.GridAnchorGenerator.width_offset', index=5, 72 | number=6, type=5, cpp_type=1, label=1, 73 | has_default_value=True, default_value=0, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='scales', full_name='object_detection.protos.GridAnchorGenerator.scales', index=6, 79 | number=7, type=2, cpp_type=6, label=3, 80 | has_default_value=False, default_value=[], 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | _descriptor.FieldDescriptor( 85 | name='aspect_ratios', full_name='object_detection.protos.GridAnchorGenerator.aspect_ratios', index=7, 86 | number=8, type=2, cpp_type=6, label=3, 87 | has_default_value=False, default_value=[], 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None), 91 | ], 92 | extensions=[ 93 | ], 94 | nested_types=[], 95 | enum_types=[ 96 | ], 97 | options=None, 98 | is_extendable=False, 99 | syntax='proto2', 100 | extension_ranges=[], 101 | oneofs=[ 102 | ], 103 | serialized_start=81, 104 | serialized_end=286, 105 | ) 106 | 107 | DESCRIPTOR.message_types_by_name['GridAnchorGenerator'] = _GRIDANCHORGENERATOR 108 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 109 | 110 | GridAnchorGenerator = _reflection.GeneratedProtocolMessageType('GridAnchorGenerator', (_message.Message,), dict( 111 | DESCRIPTOR = _GRIDANCHORGENERATOR, 112 | __module__ = 'object_detection.protos.grid_anchor_generator_pb2' 113 | # @@protoc_insertion_point(class_scope:object_detection.protos.GridAnchorGenerator) 114 | )) 115 | _sym_db.RegisterMessage(GridAnchorGenerator) 116 | 117 | 118 | # @@protoc_insertion_point(module_scope) 119 | -------------------------------------------------------------------------------- /object_detection/protos/hyperparams.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for the convolution op hyperparameters to use in the 6 | // object detection pipeline. 7 | message Hyperparams { 8 | 9 | // Operations affected by hyperparameters. 10 | enum Op { 11 | // Convolution, Separable Convolution, Convolution transpose. 12 | CONV = 1; 13 | 14 | // Fully connected 15 | FC = 2; 16 | } 17 | optional Op op = 1 [default = CONV]; 18 | 19 | // Regularizer for the weights of the convolution op. 20 | optional Regularizer regularizer = 2; 21 | 22 | // Initializer for the weights of the convolution op. 23 | optional Initializer initializer = 3; 24 | 25 | // Type of activation to apply after convolution. 26 | enum Activation { 27 | // Use None (no activation) 28 | NONE = 0; 29 | 30 | // Use tf.nn.relu 31 | RELU = 1; 32 | 33 | // Use tf.nn.relu6 34 | RELU_6 = 2; 35 | } 36 | optional Activation activation = 4 [default = RELU]; 37 | 38 | // BatchNorm hyperparameters. If this parameter is NOT set then BatchNorm is 39 | // not applied! 40 | optional BatchNorm batch_norm = 5; 41 | } 42 | 43 | // Proto with one-of field for regularizers. 44 | message Regularizer { 45 | oneof regularizer_oneof { 46 | L1Regularizer l1_regularizer = 1; 47 | L2Regularizer l2_regularizer = 2; 48 | } 49 | } 50 | 51 | // Configuration proto for L1 Regularizer. 52 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l1_regularizer 53 | message L1Regularizer { 54 | optional float weight = 1 [default = 1.0]; 55 | } 56 | 57 | // Configuration proto for L2 Regularizer. 58 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l2_regularizer 59 | message L2Regularizer { 60 | optional float weight = 1 [default = 1.0]; 61 | } 62 | 63 | // Proto with one-of field for initializers. 64 | message Initializer { 65 | oneof initializer_oneof { 66 | TruncatedNormalInitializer truncated_normal_initializer = 1; 67 | VarianceScalingInitializer variance_scaling_initializer = 2; 68 | } 69 | } 70 | 71 | // Configuration proto for truncated normal initializer. See 72 | // https://www.tensorflow.org/api_docs/python/tf/truncated_normal_initializer 73 | message TruncatedNormalInitializer { 74 | optional float mean = 1 [default = 0.0]; 75 | optional float stddev = 2 [default = 1.0]; 76 | } 77 | 78 | // Configuration proto for variance scaling initializer. See 79 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/ 80 | // variance_scaling_initializer 81 | message VarianceScalingInitializer { 82 | optional float factor = 1 [default = 2.0]; 83 | optional bool uniform = 2 [default = false]; 84 | enum Mode { 85 | FAN_IN = 0; 86 | FAN_OUT = 1; 87 | FAN_AVG = 2; 88 | } 89 | optional Mode mode = 3 [default = FAN_IN]; 90 | } 91 | 92 | // Configuration proto for batch norm to apply after convolution op. See 93 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm 94 | message BatchNorm { 95 | optional float decay = 1 [default = 0.999]; 96 | optional bool center = 2 [default = true]; 97 | optional bool scale = 3 [default = false]; 98 | optional float epsilon = 4 [default = 0.001]; 99 | // Whether to train the batch norm variables. If this is set to false during 100 | // training, the current value of the batch_norm variables are used for 101 | // forward pass but they are never updated. 102 | optional bool train = 5 [default = true]; 103 | } 104 | -------------------------------------------------------------------------------- /object_detection/protos/image_resizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for image resizing operations. 6 | // See builders/image_resizer_builder.py for details. 7 | message ImageResizer { 8 | oneof image_resizer_oneof { 9 | KeepAspectRatioResizer keep_aspect_ratio_resizer = 1; 10 | FixedShapeResizer fixed_shape_resizer = 2; 11 | } 12 | } 13 | 14 | 15 | // Configuration proto for image resizer that keeps aspect ratio. 16 | message KeepAspectRatioResizer { 17 | // Desired size of the smaller image dimension in pixels. 18 | optional int32 min_dimension = 1 [default = 600]; 19 | 20 | // Desired size of the larger image dimension in pixels. 21 | optional int32 max_dimension = 2 [default = 1024]; 22 | } 23 | 24 | 25 | // Configuration proto for image resizer that resizes to a fixed shape. 26 | message FixedShapeResizer { 27 | // Desired height of image in pixels. 28 | optional int32 height = 1 [default = 300]; 29 | 30 | // Desired width of image in pixels. 31 | optional int32 width = 2 [default = 300]; 32 | } 33 | -------------------------------------------------------------------------------- /object_detection/protos/image_resizer_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/image_resizer.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/image_resizer.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n+object_detection/protos/image_resizer.proto\x12\x17object_detection.protos\"\xc6\x01\n\x0cImageResizer\x12T\n\x19keep_aspect_ratio_resizer\x18\x01 \x01(\x0b\x32/.object_detection.protos.KeepAspectRatioResizerH\x00\x12I\n\x13\x66ixed_shape_resizer\x18\x02 \x01(\x0b\x32*.object_detection.protos.FixedShapeResizerH\x00\x42\x15\n\x13image_resizer_oneof\"Q\n\x16KeepAspectRatioResizer\x12\x1a\n\rmin_dimension\x18\x01 \x01(\x05:\x03\x36\x30\x30\x12\x1b\n\rmax_dimension\x18\x02 \x01(\x05:\x04\x31\x30\x32\x34\"<\n\x11\x46ixedShapeResizer\x12\x13\n\x06height\x18\x01 \x01(\x05:\x03\x33\x30\x30\x12\x12\n\x05width\x18\x02 \x01(\x05:\x03\x33\x30\x30') 23 | ) 24 | 25 | 26 | 27 | 28 | _IMAGERESIZER = _descriptor.Descriptor( 29 | name='ImageResizer', 30 | full_name='object_detection.protos.ImageResizer', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='keep_aspect_ratio_resizer', full_name='object_detection.protos.ImageResizer.keep_aspect_ratio_resizer', index=0, 37 | number=1, type=11, cpp_type=10, label=1, 38 | has_default_value=False, default_value=None, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='fixed_shape_resizer', full_name='object_detection.protos.ImageResizer.fixed_shape_resizer', index=1, 44 | number=2, type=11, cpp_type=10, label=1, 45 | has_default_value=False, default_value=None, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | ], 50 | extensions=[ 51 | ], 52 | nested_types=[], 53 | enum_types=[ 54 | ], 55 | options=None, 56 | is_extendable=False, 57 | syntax='proto2', 58 | extension_ranges=[], 59 | oneofs=[ 60 | _descriptor.OneofDescriptor( 61 | name='image_resizer_oneof', full_name='object_detection.protos.ImageResizer.image_resizer_oneof', 62 | index=0, containing_type=None, fields=[]), 63 | ], 64 | serialized_start=73, 65 | serialized_end=271, 66 | ) 67 | 68 | 69 | _KEEPASPECTRATIORESIZER = _descriptor.Descriptor( 70 | name='KeepAspectRatioResizer', 71 | full_name='object_detection.protos.KeepAspectRatioResizer', 72 | filename=None, 73 | file=DESCRIPTOR, 74 | containing_type=None, 75 | fields=[ 76 | _descriptor.FieldDescriptor( 77 | name='min_dimension', full_name='object_detection.protos.KeepAspectRatioResizer.min_dimension', index=0, 78 | number=1, type=5, cpp_type=1, label=1, 79 | has_default_value=True, default_value=600, 80 | message_type=None, enum_type=None, containing_type=None, 81 | is_extension=False, extension_scope=None, 82 | options=None), 83 | _descriptor.FieldDescriptor( 84 | name='max_dimension', full_name='object_detection.protos.KeepAspectRatioResizer.max_dimension', index=1, 85 | number=2, type=5, cpp_type=1, label=1, 86 | has_default_value=True, default_value=1024, 87 | message_type=None, enum_type=None, containing_type=None, 88 | is_extension=False, extension_scope=None, 89 | options=None), 90 | ], 91 | extensions=[ 92 | ], 93 | nested_types=[], 94 | enum_types=[ 95 | ], 96 | options=None, 97 | is_extendable=False, 98 | syntax='proto2', 99 | extension_ranges=[], 100 | oneofs=[ 101 | ], 102 | serialized_start=273, 103 | serialized_end=354, 104 | ) 105 | 106 | 107 | _FIXEDSHAPERESIZER = _descriptor.Descriptor( 108 | name='FixedShapeResizer', 109 | full_name='object_detection.protos.FixedShapeResizer', 110 | filename=None, 111 | file=DESCRIPTOR, 112 | containing_type=None, 113 | fields=[ 114 | _descriptor.FieldDescriptor( 115 | name='height', full_name='object_detection.protos.FixedShapeResizer.height', index=0, 116 | number=1, type=5, cpp_type=1, label=1, 117 | has_default_value=True, default_value=300, 118 | message_type=None, enum_type=None, containing_type=None, 119 | is_extension=False, extension_scope=None, 120 | options=None), 121 | _descriptor.FieldDescriptor( 122 | name='width', full_name='object_detection.protos.FixedShapeResizer.width', index=1, 123 | number=2, type=5, cpp_type=1, label=1, 124 | has_default_value=True, default_value=300, 125 | message_type=None, enum_type=None, containing_type=None, 126 | is_extension=False, extension_scope=None, 127 | options=None), 128 | ], 129 | extensions=[ 130 | ], 131 | nested_types=[], 132 | enum_types=[ 133 | ], 134 | options=None, 135 | is_extendable=False, 136 | syntax='proto2', 137 | extension_ranges=[], 138 | oneofs=[ 139 | ], 140 | serialized_start=356, 141 | serialized_end=416, 142 | ) 143 | 144 | _IMAGERESIZER.fields_by_name['keep_aspect_ratio_resizer'].message_type = _KEEPASPECTRATIORESIZER 145 | _IMAGERESIZER.fields_by_name['fixed_shape_resizer'].message_type = _FIXEDSHAPERESIZER 146 | _IMAGERESIZER.oneofs_by_name['image_resizer_oneof'].fields.append( 147 | _IMAGERESIZER.fields_by_name['keep_aspect_ratio_resizer']) 148 | _IMAGERESIZER.fields_by_name['keep_aspect_ratio_resizer'].containing_oneof = _IMAGERESIZER.oneofs_by_name['image_resizer_oneof'] 149 | _IMAGERESIZER.oneofs_by_name['image_resizer_oneof'].fields.append( 150 | _IMAGERESIZER.fields_by_name['fixed_shape_resizer']) 151 | _IMAGERESIZER.fields_by_name['fixed_shape_resizer'].containing_oneof = _IMAGERESIZER.oneofs_by_name['image_resizer_oneof'] 152 | DESCRIPTOR.message_types_by_name['ImageResizer'] = _IMAGERESIZER 153 | DESCRIPTOR.message_types_by_name['KeepAspectRatioResizer'] = _KEEPASPECTRATIORESIZER 154 | DESCRIPTOR.message_types_by_name['FixedShapeResizer'] = _FIXEDSHAPERESIZER 155 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 156 | 157 | ImageResizer = _reflection.GeneratedProtocolMessageType('ImageResizer', (_message.Message,), dict( 158 | DESCRIPTOR = _IMAGERESIZER, 159 | __module__ = 'object_detection.protos.image_resizer_pb2' 160 | # @@protoc_insertion_point(class_scope:object_detection.protos.ImageResizer) 161 | )) 162 | _sym_db.RegisterMessage(ImageResizer) 163 | 164 | KeepAspectRatioResizer = _reflection.GeneratedProtocolMessageType('KeepAspectRatioResizer', (_message.Message,), dict( 165 | DESCRIPTOR = _KEEPASPECTRATIORESIZER, 166 | __module__ = 'object_detection.protos.image_resizer_pb2' 167 | # @@protoc_insertion_point(class_scope:object_detection.protos.KeepAspectRatioResizer) 168 | )) 169 | _sym_db.RegisterMessage(KeepAspectRatioResizer) 170 | 171 | FixedShapeResizer = _reflection.GeneratedProtocolMessageType('FixedShapeResizer', (_message.Message,), dict( 172 | DESCRIPTOR = _FIXEDSHAPERESIZER, 173 | __module__ = 'object_detection.protos.image_resizer_pb2' 174 | # @@protoc_insertion_point(class_scope:object_detection.protos.FixedShapeResizer) 175 | )) 176 | _sym_db.RegisterMessage(FixedShapeResizer) 177 | 178 | 179 | # @@protoc_insertion_point(module_scope) 180 | -------------------------------------------------------------------------------- /object_detection/protos/input_reader.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for defining input readers that generate Object Detection 6 | // Examples from input sources. Input readers are expected to generate a 7 | // dictionary of tensors, with the following fields populated: 8 | // 9 | // 'image': an [image_height, image_width, channels] image tensor that detection 10 | // will be run on. 11 | // 'groundtruth_classes': a [num_boxes] int32 tensor storing the class 12 | // labels of detected boxes in the image. 13 | // 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of 14 | // detected boxes in the image. 15 | // 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height, 16 | // image_width] float tensor storing binary mask of the objects in boxes. 17 | 18 | message InputReader { 19 | // Path to StringIntLabelMap pbtxt file specifying the mapping from string 20 | // labels to integer ids. 21 | optional string label_map_path = 1 [default=""]; 22 | 23 | // Whether data should be processed in the order they are read in, or 24 | // shuffled randomly. 25 | optional bool shuffle = 2 [default=true]; 26 | 27 | // Maximum number of records to keep in reader queue. 28 | optional uint32 queue_capacity = 3 [default=2000]; 29 | 30 | // Minimum number of records to keep in reader queue. A large value is needed 31 | // to generate a good random shuffle. 32 | optional uint32 min_after_dequeue = 4 [default=1000]; 33 | 34 | // The number of times a data source is read. If set to zero, the data source 35 | // will be reused indefinitely. 36 | optional uint32 num_epochs = 5 [default=0]; 37 | 38 | // Number of reader instances to create. 39 | optional uint32 num_readers = 6 [default=8]; 40 | 41 | // Whether to load groundtruth instance masks. 42 | optional bool load_instance_masks = 7 [default = false]; 43 | 44 | oneof input_reader { 45 | TFRecordInputReader tf_record_input_reader = 8; 46 | ExternalInputReader external_input_reader = 9; 47 | } 48 | } 49 | 50 | // An input reader that reads TF Example protos from local TFRecord files. 51 | message TFRecordInputReader { 52 | // Path to TFRecordFile. 53 | optional string input_path = 1 [default=""]; 54 | } 55 | 56 | // An externally defined input reader. Users may define an extension to this 57 | // proto to interface their own input readers. 58 | message ExternalInputReader { 59 | extensions 1 to 999; 60 | } 61 | -------------------------------------------------------------------------------- /object_detection/protos/input_reader_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/input_reader.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/input_reader.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n*object_detection/protos/input_reader.proto\x12\x17object_detection.protos\"\xff\x02\n\x0bInputReader\x12\x18\n\x0elabel_map_path\x18\x01 \x01(\t:\x00\x12\x15\n\x07shuffle\x18\x02 \x01(\x08:\x04true\x12\x1c\n\x0equeue_capacity\x18\x03 \x01(\r:\x04\x32\x30\x30\x30\x12\x1f\n\x11min_after_dequeue\x18\x04 \x01(\r:\x04\x31\x30\x30\x30\x12\x15\n\nnum_epochs\x18\x05 \x01(\r:\x01\x30\x12\x16\n\x0bnum_readers\x18\x06 \x01(\r:\x01\x38\x12\"\n\x13load_instance_masks\x18\x07 \x01(\x08:\x05\x66\x61lse\x12N\n\x16tf_record_input_reader\x18\x08 \x01(\x0b\x32,.object_detection.protos.TFRecordInputReaderH\x00\x12M\n\x15\x65xternal_input_reader\x18\t \x01(\x0b\x32,.object_detection.protos.ExternalInputReaderH\x00\x42\x0e\n\x0cinput_reader\"+\n\x13TFRecordInputReader\x12\x14\n\ninput_path\x18\x01 \x01(\t:\x00\"\x1c\n\x13\x45xternalInputReader*\x05\x08\x01\x10\xe8\x07') 23 | ) 24 | 25 | 26 | 27 | 28 | _INPUTREADER = _descriptor.Descriptor( 29 | name='InputReader', 30 | full_name='object_detection.protos.InputReader', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='label_map_path', full_name='object_detection.protos.InputReader.label_map_path', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=True, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='shuffle', full_name='object_detection.protos.InputReader.shuffle', index=1, 44 | number=2, type=8, cpp_type=7, label=1, 45 | has_default_value=True, default_value=True, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='queue_capacity', full_name='object_detection.protos.InputReader.queue_capacity', index=2, 51 | number=3, type=13, cpp_type=3, label=1, 52 | has_default_value=True, default_value=2000, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='min_after_dequeue', full_name='object_detection.protos.InputReader.min_after_dequeue', index=3, 58 | number=4, type=13, cpp_type=3, label=1, 59 | has_default_value=True, default_value=1000, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='num_epochs', full_name='object_detection.protos.InputReader.num_epochs', index=4, 65 | number=5, type=13, cpp_type=3, label=1, 66 | has_default_value=True, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='num_readers', full_name='object_detection.protos.InputReader.num_readers', index=5, 72 | number=6, type=13, cpp_type=3, label=1, 73 | has_default_value=True, default_value=8, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='load_instance_masks', full_name='object_detection.protos.InputReader.load_instance_masks', index=6, 79 | number=7, type=8, cpp_type=7, label=1, 80 | has_default_value=True, default_value=False, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | _descriptor.FieldDescriptor( 85 | name='tf_record_input_reader', full_name='object_detection.protos.InputReader.tf_record_input_reader', index=7, 86 | number=8, type=11, cpp_type=10, label=1, 87 | has_default_value=False, default_value=None, 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None), 91 | _descriptor.FieldDescriptor( 92 | name='external_input_reader', full_name='object_detection.protos.InputReader.external_input_reader', index=8, 93 | number=9, type=11, cpp_type=10, label=1, 94 | has_default_value=False, default_value=None, 95 | message_type=None, enum_type=None, containing_type=None, 96 | is_extension=False, extension_scope=None, 97 | options=None), 98 | ], 99 | extensions=[ 100 | ], 101 | nested_types=[], 102 | enum_types=[ 103 | ], 104 | options=None, 105 | is_extendable=False, 106 | syntax='proto2', 107 | extension_ranges=[], 108 | oneofs=[ 109 | _descriptor.OneofDescriptor( 110 | name='input_reader', full_name='object_detection.protos.InputReader.input_reader', 111 | index=0, containing_type=None, fields=[]), 112 | ], 113 | serialized_start=72, 114 | serialized_end=455, 115 | ) 116 | 117 | 118 | _TFRECORDINPUTREADER = _descriptor.Descriptor( 119 | name='TFRecordInputReader', 120 | full_name='object_detection.protos.TFRecordInputReader', 121 | filename=None, 122 | file=DESCRIPTOR, 123 | containing_type=None, 124 | fields=[ 125 | _descriptor.FieldDescriptor( 126 | name='input_path', full_name='object_detection.protos.TFRecordInputReader.input_path', index=0, 127 | number=1, type=9, cpp_type=9, label=1, 128 | has_default_value=True, default_value=_b("").decode('utf-8'), 129 | message_type=None, enum_type=None, containing_type=None, 130 | is_extension=False, extension_scope=None, 131 | options=None), 132 | ], 133 | extensions=[ 134 | ], 135 | nested_types=[], 136 | enum_types=[ 137 | ], 138 | options=None, 139 | is_extendable=False, 140 | syntax='proto2', 141 | extension_ranges=[], 142 | oneofs=[ 143 | ], 144 | serialized_start=457, 145 | serialized_end=500, 146 | ) 147 | 148 | 149 | _EXTERNALINPUTREADER = _descriptor.Descriptor( 150 | name='ExternalInputReader', 151 | full_name='object_detection.protos.ExternalInputReader', 152 | filename=None, 153 | file=DESCRIPTOR, 154 | containing_type=None, 155 | fields=[ 156 | ], 157 | extensions=[ 158 | ], 159 | nested_types=[], 160 | enum_types=[ 161 | ], 162 | options=None, 163 | is_extendable=True, 164 | syntax='proto2', 165 | extension_ranges=[(1, 1000), ], 166 | oneofs=[ 167 | ], 168 | serialized_start=502, 169 | serialized_end=530, 170 | ) 171 | 172 | _INPUTREADER.fields_by_name['tf_record_input_reader'].message_type = _TFRECORDINPUTREADER 173 | _INPUTREADER.fields_by_name['external_input_reader'].message_type = _EXTERNALINPUTREADER 174 | _INPUTREADER.oneofs_by_name['input_reader'].fields.append( 175 | _INPUTREADER.fields_by_name['tf_record_input_reader']) 176 | _INPUTREADER.fields_by_name['tf_record_input_reader'].containing_oneof = _INPUTREADER.oneofs_by_name['input_reader'] 177 | _INPUTREADER.oneofs_by_name['input_reader'].fields.append( 178 | _INPUTREADER.fields_by_name['external_input_reader']) 179 | _INPUTREADER.fields_by_name['external_input_reader'].containing_oneof = _INPUTREADER.oneofs_by_name['input_reader'] 180 | DESCRIPTOR.message_types_by_name['InputReader'] = _INPUTREADER 181 | DESCRIPTOR.message_types_by_name['TFRecordInputReader'] = _TFRECORDINPUTREADER 182 | DESCRIPTOR.message_types_by_name['ExternalInputReader'] = _EXTERNALINPUTREADER 183 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 184 | 185 | InputReader = _reflection.GeneratedProtocolMessageType('InputReader', (_message.Message,), dict( 186 | DESCRIPTOR = _INPUTREADER, 187 | __module__ = 'object_detection.protos.input_reader_pb2' 188 | # @@protoc_insertion_point(class_scope:object_detection.protos.InputReader) 189 | )) 190 | _sym_db.RegisterMessage(InputReader) 191 | 192 | TFRecordInputReader = _reflection.GeneratedProtocolMessageType('TFRecordInputReader', (_message.Message,), dict( 193 | DESCRIPTOR = _TFRECORDINPUTREADER, 194 | __module__ = 'object_detection.protos.input_reader_pb2' 195 | # @@protoc_insertion_point(class_scope:object_detection.protos.TFRecordInputReader) 196 | )) 197 | _sym_db.RegisterMessage(TFRecordInputReader) 198 | 199 | ExternalInputReader = _reflection.GeneratedProtocolMessageType('ExternalInputReader', (_message.Message,), dict( 200 | DESCRIPTOR = _EXTERNALINPUTREADER, 201 | __module__ = 'object_detection.protos.input_reader_pb2' 202 | # @@protoc_insertion_point(class_scope:object_detection.protos.ExternalInputReader) 203 | )) 204 | _sym_db.RegisterMessage(ExternalInputReader) 205 | 206 | 207 | # @@protoc_insertion_point(module_scope) 208 | -------------------------------------------------------------------------------- /object_detection/protos/losses.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for configuring the localization loss, classification loss and hard 6 | // example miner used for training object detection models. See core/losses.py 7 | // for details 8 | message Loss { 9 | // Localization loss to use. 10 | optional LocalizationLoss localization_loss = 1; 11 | 12 | // Classification loss to use. 13 | optional ClassificationLoss classification_loss = 2; 14 | 15 | // If not left to default, applies hard example mining. 16 | optional HardExampleMiner hard_example_miner = 3; 17 | 18 | // Classification loss weight. 19 | optional float classification_weight = 4 [default=1.0]; 20 | 21 | // Localization loss weight. 22 | optional float localization_weight = 5 [default=1.0]; 23 | } 24 | 25 | // Configuration for bounding box localization loss function. 26 | message LocalizationLoss { 27 | oneof localization_loss { 28 | WeightedL2LocalizationLoss weighted_l2 = 1; 29 | WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2; 30 | WeightedIOULocalizationLoss weighted_iou = 3; 31 | } 32 | } 33 | 34 | // L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2 35 | message WeightedL2LocalizationLoss { 36 | // Output loss per anchor. 37 | optional bool anchorwise_output = 1 [default=false]; 38 | } 39 | 40 | // SmoothL1 (Huber) location loss: .5 * x ^ 2 if |x| < 1 else |x| - .5 41 | message WeightedSmoothL1LocalizationLoss { 42 | // Output loss per anchor. 43 | optional bool anchorwise_output = 1 [default=false]; 44 | } 45 | 46 | // Intersection over union location loss: 1 - IOU 47 | message WeightedIOULocalizationLoss { 48 | } 49 | 50 | // Configuration for class prediction loss function. 51 | message ClassificationLoss { 52 | oneof classification_loss { 53 | WeightedSigmoidClassificationLoss weighted_sigmoid = 1; 54 | WeightedSoftmaxClassificationLoss weighted_softmax = 2; 55 | BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3; 56 | } 57 | } 58 | 59 | // Classification loss using a sigmoid function over class predictions. 60 | message WeightedSigmoidClassificationLoss { 61 | // Output loss per anchor. 62 | optional bool anchorwise_output = 1 [default=false]; 63 | } 64 | 65 | // Classification loss using a softmax function over class predictions. 66 | message WeightedSoftmaxClassificationLoss { 67 | // Output loss per anchor. 68 | optional bool anchorwise_output = 1 [default=false]; 69 | } 70 | 71 | // Classification loss using a sigmoid function over the class prediction with 72 | // the highest prediction score. 73 | message BootstrappedSigmoidClassificationLoss { 74 | // Interpolation weight between 0 and 1. 75 | optional float alpha = 1; 76 | 77 | // Whether hard boot strapping should be used or not. If true, will only use 78 | // one class favored by model. Othewise, will use all predicted class 79 | // probabilities. 80 | optional bool hard_bootstrap = 2 [default=false]; 81 | 82 | // Output loss per anchor. 83 | optional bool anchorwise_output = 3 [default=false]; 84 | } 85 | 86 | // Configuation for hard example miner. 87 | message HardExampleMiner { 88 | // Maximum number of hard examples to be selected per image (prior to 89 | // enforcing max negative to positive ratio constraint). If set to 0, 90 | // all examples obtained after NMS are considered. 91 | optional int32 num_hard_examples = 1 [default=64]; 92 | 93 | // Minimum intersection over union for an example to be discarded during NMS. 94 | optional float iou_threshold = 2 [default=0.7]; 95 | 96 | // Whether to use classification losses ('cls', default), localization losses 97 | // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and 98 | // loc_loss_weight are used to compute weighted sum of the two losses. 99 | enum LossType { 100 | BOTH = 0; 101 | CLASSIFICATION = 1; 102 | LOCALIZATION = 2; 103 | } 104 | optional LossType loss_type = 3 [default=BOTH]; 105 | 106 | // Maximum number of negatives to retain for each positive anchor. If 107 | // num_negatives_per_positive is 0 no prespecified negative:positive ratio is 108 | // enforced. 109 | optional int32 max_negatives_per_positive = 4 [default=0]; 110 | 111 | // Minimum number of negative anchors to sample for a given image. Setting 112 | // this to a positive number samples negatives in an image without any 113 | // positive anchors and thus not bias the model towards having at least one 114 | // detection per image. 115 | optional int32 min_negatives_per_image = 5 [default=0]; 116 | } 117 | -------------------------------------------------------------------------------- /object_detection/protos/matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/argmax_matcher.proto"; 6 | import "object_detection/protos/bipartite_matcher.proto"; 7 | 8 | // Configuration proto for the matcher to be used in the object detection 9 | // pipeline. See core/matcher.py for details. 10 | message Matcher { 11 | oneof matcher_oneof { 12 | ArgMaxMatcher argmax_matcher = 1; 13 | BipartiteMatcher bipartite_matcher = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /object_detection/protos/matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from object_detection.protos import argmax_matcher_pb2 as object__detection_dot_protos_dot_argmax__matcher__pb2 17 | from object_detection.protos import bipartite_matcher_pb2 as object__detection_dot_protos_dot_bipartite__matcher__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/matcher.proto', 22 | package='object_detection.protos', 23 | syntax='proto2', 24 | serialized_pb=_b('\n%object_detection/protos/matcher.proto\x12\x17object_detection.protos\x1a,object_detection/protos/argmax_matcher.proto\x1a/object_detection/protos/bipartite_matcher.proto\"\xa4\x01\n\x07Matcher\x12@\n\x0e\x61rgmax_matcher\x18\x01 \x01(\x0b\x32&.object_detection.protos.ArgMaxMatcherH\x00\x12\x46\n\x11\x62ipartite_matcher\x18\x02 \x01(\x0b\x32).object_detection.protos.BipartiteMatcherH\x00\x42\x0f\n\rmatcher_oneof') 25 | , 26 | dependencies=[object__detection_dot_protos_dot_argmax__matcher__pb2.DESCRIPTOR,object__detection_dot_protos_dot_bipartite__matcher__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _MATCHER = _descriptor.Descriptor( 32 | name='Matcher', 33 | full_name='object_detection.protos.Matcher', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='argmax_matcher', full_name='object_detection.protos.Matcher.argmax_matcher', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='bipartite_matcher', full_name='object_detection.protos.Matcher.bipartite_matcher', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | syntax='proto2', 61 | extension_ranges=[], 62 | oneofs=[ 63 | _descriptor.OneofDescriptor( 64 | name='matcher_oneof', full_name='object_detection.protos.Matcher.matcher_oneof', 65 | index=0, containing_type=None, fields=[]), 66 | ], 67 | serialized_start=162, 68 | serialized_end=326, 69 | ) 70 | 71 | _MATCHER.fields_by_name['argmax_matcher'].message_type = object__detection_dot_protos_dot_argmax__matcher__pb2._ARGMAXMATCHER 72 | _MATCHER.fields_by_name['bipartite_matcher'].message_type = object__detection_dot_protos_dot_bipartite__matcher__pb2._BIPARTITEMATCHER 73 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append( 74 | _MATCHER.fields_by_name['argmax_matcher']) 75 | _MATCHER.fields_by_name['argmax_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof'] 76 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append( 77 | _MATCHER.fields_by_name['bipartite_matcher']) 78 | _MATCHER.fields_by_name['bipartite_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof'] 79 | DESCRIPTOR.message_types_by_name['Matcher'] = _MATCHER 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | Matcher = _reflection.GeneratedProtocolMessageType('Matcher', (_message.Message,), dict( 83 | DESCRIPTOR = _MATCHER, 84 | __module__ = 'object_detection.protos.matcher_pb2' 85 | # @@protoc_insertion_point(class_scope:object_detection.protos.Matcher) 86 | )) 87 | _sym_db.RegisterMessage(Matcher) 88 | 89 | 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /object_detection/protos/mean_stddev_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for MeanStddevBoxCoder. See 6 | // box_coders/mean_stddev_box_coder.py for details. 7 | message MeanStddevBoxCoder { 8 | } 9 | -------------------------------------------------------------------------------- /object_detection/protos/mean_stddev_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/mean_stddev_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/mean_stddev_box_coder.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n3object_detection/protos/mean_stddev_box_coder.proto\x12\x17object_detection.protos\"\x14\n\x12MeanStddevBoxCoder') 23 | ) 24 | 25 | 26 | 27 | 28 | _MEANSTDDEVBOXCODER = _descriptor.Descriptor( 29 | name='MeanStddevBoxCoder', 30 | full_name='object_detection.protos.MeanStddevBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | ], 36 | extensions=[ 37 | ], 38 | nested_types=[], 39 | enum_types=[ 40 | ], 41 | options=None, 42 | is_extendable=False, 43 | syntax='proto2', 44 | extension_ranges=[], 45 | oneofs=[ 46 | ], 47 | serialized_start=80, 48 | serialized_end=100, 49 | ) 50 | 51 | DESCRIPTOR.message_types_by_name['MeanStddevBoxCoder'] = _MEANSTDDEVBOXCODER 52 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 53 | 54 | MeanStddevBoxCoder = _reflection.GeneratedProtocolMessageType('MeanStddevBoxCoder', (_message.Message,), dict( 55 | DESCRIPTOR = _MEANSTDDEVBOXCODER, 56 | __module__ = 'object_detection.protos.mean_stddev_box_coder_pb2' 57 | # @@protoc_insertion_point(class_scope:object_detection.protos.MeanStddevBoxCoder) 58 | )) 59 | _sym_db.RegisterMessage(MeanStddevBoxCoder) 60 | 61 | 62 | # @@protoc_insertion_point(module_scope) 63 | -------------------------------------------------------------------------------- /object_detection/protos/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn.proto"; 6 | import "object_detection/protos/ssd.proto"; 7 | 8 | // Top level configuration for DetectionModels. 9 | message DetectionModel { 10 | oneof model { 11 | FasterRcnn faster_rcnn = 1; 12 | Ssd ssd = 2; 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /object_detection/protos/model_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/model.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from object_detection.protos import faster_rcnn_pb2 as object__detection_dot_protos_dot_faster__rcnn__pb2 17 | from object_detection.protos import ssd_pb2 as object__detection_dot_protos_dot_ssd__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/model.proto', 22 | package='object_detection.protos', 23 | syntax='proto2', 24 | serialized_pb=_b('\n#object_detection/protos/model.proto\x12\x17object_detection.protos\x1a)object_detection/protos/faster_rcnn.proto\x1a!object_detection/protos/ssd.proto\"\x82\x01\n\x0e\x44\x65tectionModel\x12:\n\x0b\x66\x61ster_rcnn\x18\x01 \x01(\x0b\x32#.object_detection.protos.FasterRcnnH\x00\x12+\n\x03ssd\x18\x02 \x01(\x0b\x32\x1c.object_detection.protos.SsdH\x00\x42\x07\n\x05model') 25 | , 26 | dependencies=[object__detection_dot_protos_dot_faster__rcnn__pb2.DESCRIPTOR,object__detection_dot_protos_dot_ssd__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _DETECTIONMODEL = _descriptor.Descriptor( 32 | name='DetectionModel', 33 | full_name='object_detection.protos.DetectionModel', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='faster_rcnn', full_name='object_detection.protos.DetectionModel.faster_rcnn', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='ssd', full_name='object_detection.protos.DetectionModel.ssd', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | syntax='proto2', 61 | extension_ranges=[], 62 | oneofs=[ 63 | _descriptor.OneofDescriptor( 64 | name='model', full_name='object_detection.protos.DetectionModel.model', 65 | index=0, containing_type=None, fields=[]), 66 | ], 67 | serialized_start=143, 68 | serialized_end=273, 69 | ) 70 | 71 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].message_type = object__detection_dot_protos_dot_faster__rcnn__pb2._FASTERRCNN 72 | _DETECTIONMODEL.fields_by_name['ssd'].message_type = object__detection_dot_protos_dot_ssd__pb2._SSD 73 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append( 74 | _DETECTIONMODEL.fields_by_name['faster_rcnn']) 75 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model'] 76 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append( 77 | _DETECTIONMODEL.fields_by_name['ssd']) 78 | _DETECTIONMODEL.fields_by_name['ssd'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model'] 79 | DESCRIPTOR.message_types_by_name['DetectionModel'] = _DETECTIONMODEL 80 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 81 | 82 | DetectionModel = _reflection.GeneratedProtocolMessageType('DetectionModel', (_message.Message,), dict( 83 | DESCRIPTOR = _DETECTIONMODEL, 84 | __module__ = 'object_detection.protos.model_pb2' 85 | # @@protoc_insertion_point(class_scope:object_detection.protos.DetectionModel) 86 | )) 87 | _sym_db.RegisterMessage(DetectionModel) 88 | 89 | 90 | # @@protoc_insertion_point(module_scope) 91 | -------------------------------------------------------------------------------- /object_detection/protos/optimizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Messages for configuring the optimizing strategy for training object 6 | // detection models. 7 | 8 | // Top level optimizer message. 9 | message Optimizer { 10 | oneof optimizer { 11 | RMSPropOptimizer rms_prop_optimizer = 1; 12 | MomentumOptimizer momentum_optimizer = 2; 13 | AdamOptimizer adam_optimizer = 3; 14 | } 15 | optional bool use_moving_average = 4 [default=true]; 16 | optional float moving_average_decay = 5 [default=0.9999]; 17 | } 18 | 19 | // Configuration message for the RMSPropOptimizer 20 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 21 | message RMSPropOptimizer { 22 | optional LearningRate learning_rate = 1; 23 | optional float momentum_optimizer_value = 2 [default=0.9]; 24 | optional float decay = 3 [default=0.9]; 25 | optional float epsilon = 4 [default=1.0]; 26 | } 27 | 28 | // Configuration message for the MomentumOptimizer 29 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer 30 | message MomentumOptimizer { 31 | optional LearningRate learning_rate = 1; 32 | optional float momentum_optimizer_value = 2 [default=0.9]; 33 | } 34 | 35 | // Configuration message for the AdamOptimizer 36 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 37 | message AdamOptimizer { 38 | optional LearningRate learning_rate = 1; 39 | } 40 | 41 | // Configuration message for optimizer learning rate. 42 | message LearningRate { 43 | oneof learning_rate { 44 | ConstantLearningRate constant_learning_rate = 1; 45 | ExponentialDecayLearningRate exponential_decay_learning_rate = 2; 46 | ManualStepLearningRate manual_step_learning_rate = 3; 47 | } 48 | } 49 | 50 | // Configuration message for a constant learning rate. 51 | message ConstantLearningRate { 52 | optional float learning_rate = 1 [default=0.002]; 53 | } 54 | 55 | // Configuration message for an exponentially decaying learning rate. 56 | // See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ 57 | // decaying_the_learning_rate#exponential_decay 58 | message ExponentialDecayLearningRate { 59 | optional float initial_learning_rate = 1 [default=0.002]; 60 | optional uint32 decay_steps = 2 [default=4000000]; 61 | optional float decay_factor = 3 [default=0.95]; 62 | optional bool staircase = 4 [default=true]; 63 | } 64 | 65 | // Configuration message for a manually defined learning rate schedule. 66 | message ManualStepLearningRate { 67 | optional float initial_learning_rate = 1 [default=0.002]; 68 | message LearningRateSchedule { 69 | optional uint32 step = 1; 70 | optional float learning_rate = 2 [default=0.002]; 71 | } 72 | repeated LearningRateSchedule schedule = 2; 73 | } 74 | -------------------------------------------------------------------------------- /object_detection/protos/pipeline.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/eval.proto"; 6 | import "object_detection/protos/input_reader.proto"; 7 | import "object_detection/protos/model.proto"; 8 | import "object_detection/protos/train.proto"; 9 | 10 | // Convenience message for configuring a training and eval pipeline. Allows all 11 | // of the pipeline parameters to be configured from one file. 12 | message TrainEvalPipelineConfig { 13 | optional DetectionModel model = 1; 14 | optional TrainConfig train_config = 2; 15 | optional InputReader train_input_reader = 3; 16 | optional EvalConfig eval_config = 4; 17 | optional InputReader eval_input_reader = 5; 18 | } 19 | -------------------------------------------------------------------------------- /object_detection/protos/pipeline_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/pipeline.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from object_detection.protos import eval_pb2 as object__detection_dot_protos_dot_eval__pb2 17 | from object_detection.protos import input_reader_pb2 as object__detection_dot_protos_dot_input__reader__pb2 18 | from object_detection.protos import model_pb2 as object__detection_dot_protos_dot_model__pb2 19 | from object_detection.protos import train_pb2 as object__detection_dot_protos_dot_train__pb2 20 | 21 | 22 | DESCRIPTOR = _descriptor.FileDescriptor( 23 | name='object_detection/protos/pipeline.proto', 24 | package='object_detection.protos', 25 | syntax='proto2', 26 | serialized_pb=_b('\n&object_detection/protos/pipeline.proto\x12\x17object_detection.protos\x1a\"object_detection/protos/eval.proto\x1a*object_detection/protos/input_reader.proto\x1a#object_detection/protos/model.proto\x1a#object_detection/protos/train.proto\"\xca\x02\n\x17TrainEvalPipelineConfig\x12\x36\n\x05model\x18\x01 \x01(\x0b\x32\'.object_detection.protos.DetectionModel\x12:\n\x0ctrain_config\x18\x02 \x01(\x0b\x32$.object_detection.protos.TrainConfig\x12@\n\x12train_input_reader\x18\x03 \x01(\x0b\x32$.object_detection.protos.InputReader\x12\x38\n\x0b\x65val_config\x18\x04 \x01(\x0b\x32#.object_detection.protos.EvalConfig\x12?\n\x11\x65val_input_reader\x18\x05 \x01(\x0b\x32$.object_detection.protos.InputReader') 27 | , 28 | dependencies=[object__detection_dot_protos_dot_eval__pb2.DESCRIPTOR,object__detection_dot_protos_dot_input__reader__pb2.DESCRIPTOR,object__detection_dot_protos_dot_model__pb2.DESCRIPTOR,object__detection_dot_protos_dot_train__pb2.DESCRIPTOR,]) 29 | 30 | 31 | 32 | 33 | _TRAINEVALPIPELINECONFIG = _descriptor.Descriptor( 34 | name='TrainEvalPipelineConfig', 35 | full_name='object_detection.protos.TrainEvalPipelineConfig', 36 | filename=None, 37 | file=DESCRIPTOR, 38 | containing_type=None, 39 | fields=[ 40 | _descriptor.FieldDescriptor( 41 | name='model', full_name='object_detection.protos.TrainEvalPipelineConfig.model', index=0, 42 | number=1, type=11, cpp_type=10, label=1, 43 | has_default_value=False, default_value=None, 44 | message_type=None, enum_type=None, containing_type=None, 45 | is_extension=False, extension_scope=None, 46 | options=None), 47 | _descriptor.FieldDescriptor( 48 | name='train_config', full_name='object_detection.protos.TrainEvalPipelineConfig.train_config', index=1, 49 | number=2, type=11, cpp_type=10, label=1, 50 | has_default_value=False, default_value=None, 51 | message_type=None, enum_type=None, containing_type=None, 52 | is_extension=False, extension_scope=None, 53 | options=None), 54 | _descriptor.FieldDescriptor( 55 | name='train_input_reader', full_name='object_detection.protos.TrainEvalPipelineConfig.train_input_reader', index=2, 56 | number=3, type=11, cpp_type=10, label=1, 57 | has_default_value=False, default_value=None, 58 | message_type=None, enum_type=None, containing_type=None, 59 | is_extension=False, extension_scope=None, 60 | options=None), 61 | _descriptor.FieldDescriptor( 62 | name='eval_config', full_name='object_detection.protos.TrainEvalPipelineConfig.eval_config', index=3, 63 | number=4, type=11, cpp_type=10, label=1, 64 | has_default_value=False, default_value=None, 65 | message_type=None, enum_type=None, containing_type=None, 66 | is_extension=False, extension_scope=None, 67 | options=None), 68 | _descriptor.FieldDescriptor( 69 | name='eval_input_reader', full_name='object_detection.protos.TrainEvalPipelineConfig.eval_input_reader', index=4, 70 | number=5, type=11, cpp_type=10, label=1, 71 | has_default_value=False, default_value=None, 72 | message_type=None, enum_type=None, containing_type=None, 73 | is_extension=False, extension_scope=None, 74 | options=None), 75 | ], 76 | extensions=[ 77 | ], 78 | nested_types=[], 79 | enum_types=[ 80 | ], 81 | options=None, 82 | is_extendable=False, 83 | syntax='proto2', 84 | extension_ranges=[], 85 | oneofs=[ 86 | ], 87 | serialized_start=222, 88 | serialized_end=552, 89 | ) 90 | 91 | _TRAINEVALPIPELINECONFIG.fields_by_name['model'].message_type = object__detection_dot_protos_dot_model__pb2._DETECTIONMODEL 92 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_config'].message_type = object__detection_dot_protos_dot_train__pb2._TRAINCONFIG 93 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_input_reader'].message_type = object__detection_dot_protos_dot_input__reader__pb2._INPUTREADER 94 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_config'].message_type = object__detection_dot_protos_dot_eval__pb2._EVALCONFIG 95 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_input_reader'].message_type = object__detection_dot_protos_dot_input__reader__pb2._INPUTREADER 96 | DESCRIPTOR.message_types_by_name['TrainEvalPipelineConfig'] = _TRAINEVALPIPELINECONFIG 97 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 98 | 99 | TrainEvalPipelineConfig = _reflection.GeneratedProtocolMessageType('TrainEvalPipelineConfig', (_message.Message,), dict( 100 | DESCRIPTOR = _TRAINEVALPIPELINECONFIG, 101 | __module__ = 'object_detection.protos.pipeline_pb2' 102 | # @@protoc_insertion_point(class_scope:object_detection.protos.TrainEvalPipelineConfig) 103 | )) 104 | _sym_db.RegisterMessage(TrainEvalPipelineConfig) 105 | 106 | 107 | # @@protoc_insertion_point(module_scope) 108 | -------------------------------------------------------------------------------- /object_detection/protos/post_processing.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for non-max-suppression operation on a batch of 6 | // detections. 7 | message BatchNonMaxSuppression { 8 | // Scalar threshold for score (low scoring boxes are removed). 9 | optional float score_threshold = 1 [default = 0.0]; 10 | 11 | // Scalar threshold for IOU (boxes that have high IOU overlap 12 | // with previously selected boxes are removed). 13 | optional float iou_threshold = 2 [default = 0.6]; 14 | 15 | // Maximum number of detections to retain per class. 16 | optional int32 max_detections_per_class = 3 [default = 100]; 17 | 18 | // Maximum number of detections to retain across all classes. 19 | optional int32 max_total_detections = 5 [default = 100]; 20 | } 21 | 22 | // Configuration proto for post-processing predicted boxes and 23 | // scores. 24 | message PostProcessing { 25 | // Non max suppression parameters. 26 | optional BatchNonMaxSuppression batch_non_max_suppression = 1; 27 | 28 | // Enum to specify how to convert the detection scores. 29 | enum ScoreConverter { 30 | // Input scores equals output scores. 31 | IDENTITY = 0; 32 | 33 | // Applies a sigmoid on input scores. 34 | SIGMOID = 1; 35 | 36 | // Applies a softmax on input scores 37 | SOFTMAX = 2; 38 | } 39 | 40 | // Score converter to use. 41 | optional ScoreConverter score_converter = 2 [default = IDENTITY]; 42 | } 43 | -------------------------------------------------------------------------------- /object_detection/protos/post_processing_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/post_processing.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/post_processing.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n-object_detection/protos/post_processing.proto\x12\x17object_detection.protos\"\x9a\x01\n\x16\x42\x61tchNonMaxSuppression\x12\x1a\n\x0fscore_threshold\x18\x01 \x01(\x02:\x01\x30\x12\x1a\n\riou_threshold\x18\x02 \x01(\x02:\x03\x30.6\x12%\n\x18max_detections_per_class\x18\x03 \x01(\x05:\x03\x31\x30\x30\x12!\n\x14max_total_detections\x18\x05 \x01(\x05:\x03\x31\x30\x30\"\xf9\x01\n\x0ePostProcessing\x12R\n\x19\x62\x61tch_non_max_suppression\x18\x01 \x01(\x0b\x32/.object_detection.protos.BatchNonMaxSuppression\x12Y\n\x0fscore_converter\x18\x02 \x01(\x0e\x32\x36.object_detection.protos.PostProcessing.ScoreConverter:\x08IDENTITY\"8\n\x0eScoreConverter\x12\x0c\n\x08IDENTITY\x10\x00\x12\x0b\n\x07SIGMOID\x10\x01\x12\x0b\n\x07SOFTMAX\x10\x02') 23 | ) 24 | 25 | 26 | 27 | _POSTPROCESSING_SCORECONVERTER = _descriptor.EnumDescriptor( 28 | name='ScoreConverter', 29 | full_name='object_detection.protos.PostProcessing.ScoreConverter', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | values=[ 33 | _descriptor.EnumValueDescriptor( 34 | name='IDENTITY', index=0, number=0, 35 | options=None, 36 | type=None), 37 | _descriptor.EnumValueDescriptor( 38 | name='SIGMOID', index=1, number=1, 39 | options=None, 40 | type=None), 41 | _descriptor.EnumValueDescriptor( 42 | name='SOFTMAX', index=2, number=2, 43 | options=None, 44 | type=None), 45 | ], 46 | containing_type=None, 47 | options=None, 48 | serialized_start=425, 49 | serialized_end=481, 50 | ) 51 | _sym_db.RegisterEnumDescriptor(_POSTPROCESSING_SCORECONVERTER) 52 | 53 | 54 | _BATCHNONMAXSUPPRESSION = _descriptor.Descriptor( 55 | name='BatchNonMaxSuppression', 56 | full_name='object_detection.protos.BatchNonMaxSuppression', 57 | filename=None, 58 | file=DESCRIPTOR, 59 | containing_type=None, 60 | fields=[ 61 | _descriptor.FieldDescriptor( 62 | name='score_threshold', full_name='object_detection.protos.BatchNonMaxSuppression.score_threshold', index=0, 63 | number=1, type=2, cpp_type=6, label=1, 64 | has_default_value=True, default_value=float(0), 65 | message_type=None, enum_type=None, containing_type=None, 66 | is_extension=False, extension_scope=None, 67 | options=None), 68 | _descriptor.FieldDescriptor( 69 | name='iou_threshold', full_name='object_detection.protos.BatchNonMaxSuppression.iou_threshold', index=1, 70 | number=2, type=2, cpp_type=6, label=1, 71 | has_default_value=True, default_value=float(0.6), 72 | message_type=None, enum_type=None, containing_type=None, 73 | is_extension=False, extension_scope=None, 74 | options=None), 75 | _descriptor.FieldDescriptor( 76 | name='max_detections_per_class', full_name='object_detection.protos.BatchNonMaxSuppression.max_detections_per_class', index=2, 77 | number=3, type=5, cpp_type=1, label=1, 78 | has_default_value=True, default_value=100, 79 | message_type=None, enum_type=None, containing_type=None, 80 | is_extension=False, extension_scope=None, 81 | options=None), 82 | _descriptor.FieldDescriptor( 83 | name='max_total_detections', full_name='object_detection.protos.BatchNonMaxSuppression.max_total_detections', index=3, 84 | number=5, type=5, cpp_type=1, label=1, 85 | has_default_value=True, default_value=100, 86 | message_type=None, enum_type=None, containing_type=None, 87 | is_extension=False, extension_scope=None, 88 | options=None), 89 | ], 90 | extensions=[ 91 | ], 92 | nested_types=[], 93 | enum_types=[ 94 | ], 95 | options=None, 96 | is_extendable=False, 97 | syntax='proto2', 98 | extension_ranges=[], 99 | oneofs=[ 100 | ], 101 | serialized_start=75, 102 | serialized_end=229, 103 | ) 104 | 105 | 106 | _POSTPROCESSING = _descriptor.Descriptor( 107 | name='PostProcessing', 108 | full_name='object_detection.protos.PostProcessing', 109 | filename=None, 110 | file=DESCRIPTOR, 111 | containing_type=None, 112 | fields=[ 113 | _descriptor.FieldDescriptor( 114 | name='batch_non_max_suppression', full_name='object_detection.protos.PostProcessing.batch_non_max_suppression', index=0, 115 | number=1, type=11, cpp_type=10, label=1, 116 | has_default_value=False, default_value=None, 117 | message_type=None, enum_type=None, containing_type=None, 118 | is_extension=False, extension_scope=None, 119 | options=None), 120 | _descriptor.FieldDescriptor( 121 | name='score_converter', full_name='object_detection.protos.PostProcessing.score_converter', index=1, 122 | number=2, type=14, cpp_type=8, label=1, 123 | has_default_value=True, default_value=0, 124 | message_type=None, enum_type=None, containing_type=None, 125 | is_extension=False, extension_scope=None, 126 | options=None), 127 | ], 128 | extensions=[ 129 | ], 130 | nested_types=[], 131 | enum_types=[ 132 | _POSTPROCESSING_SCORECONVERTER, 133 | ], 134 | options=None, 135 | is_extendable=False, 136 | syntax='proto2', 137 | extension_ranges=[], 138 | oneofs=[ 139 | ], 140 | serialized_start=232, 141 | serialized_end=481, 142 | ) 143 | 144 | _POSTPROCESSING.fields_by_name['batch_non_max_suppression'].message_type = _BATCHNONMAXSUPPRESSION 145 | _POSTPROCESSING.fields_by_name['score_converter'].enum_type = _POSTPROCESSING_SCORECONVERTER 146 | _POSTPROCESSING_SCORECONVERTER.containing_type = _POSTPROCESSING 147 | DESCRIPTOR.message_types_by_name['BatchNonMaxSuppression'] = _BATCHNONMAXSUPPRESSION 148 | DESCRIPTOR.message_types_by_name['PostProcessing'] = _POSTPROCESSING 149 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 150 | 151 | BatchNonMaxSuppression = _reflection.GeneratedProtocolMessageType('BatchNonMaxSuppression', (_message.Message,), dict( 152 | DESCRIPTOR = _BATCHNONMAXSUPPRESSION, 153 | __module__ = 'object_detection.protos.post_processing_pb2' 154 | # @@protoc_insertion_point(class_scope:object_detection.protos.BatchNonMaxSuppression) 155 | )) 156 | _sym_db.RegisterMessage(BatchNonMaxSuppression) 157 | 158 | PostProcessing = _reflection.GeneratedProtocolMessageType('PostProcessing', (_message.Message,), dict( 159 | DESCRIPTOR = _POSTPROCESSING, 160 | __module__ = 'object_detection.protos.post_processing_pb2' 161 | # @@protoc_insertion_point(class_scope:object_detection.protos.PostProcessing) 162 | )) 163 | _sym_db.RegisterMessage(PostProcessing) 164 | 165 | 166 | # @@protoc_insertion_point(module_scope) 167 | -------------------------------------------------------------------------------- /object_detection/protos/preprocessor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for defining a preprocessing operation on input data. 6 | // See: //object_detection/core/preprocessor.py 7 | message PreprocessingStep { 8 | oneof preprocessing_step { 9 | NormalizeImage normalize_image = 1; 10 | RandomHorizontalFlip random_horizontal_flip = 2; 11 | RandomPixelValueScale random_pixel_value_scale = 3; 12 | RandomImageScale random_image_scale = 4; 13 | RandomRGBtoGray random_rgb_to_gray = 5; 14 | RandomAdjustBrightness random_adjust_brightness = 6; 15 | RandomAdjustContrast random_adjust_contrast = 7; 16 | RandomAdjustHue random_adjust_hue = 8; 17 | RandomAdjustSaturation random_adjust_saturation = 9; 18 | RandomDistortColor random_distort_color = 10; 19 | RandomJitterBoxes random_jitter_boxes = 11; 20 | RandomCropImage random_crop_image = 12; 21 | RandomPadImage random_pad_image = 13; 22 | RandomCropPadImage random_crop_pad_image = 14; 23 | RandomCropToAspectRatio random_crop_to_aspect_ratio = 15; 24 | RandomBlackPatches random_black_patches = 16; 25 | RandomResizeMethod random_resize_method = 17; 26 | ScaleBoxesToPixelCoordinates scale_boxes_to_pixel_coordinates = 18; 27 | ResizeImage resize_image = 19; 28 | SubtractChannelMean subtract_channel_mean = 20; 29 | SSDRandomCrop ssd_random_crop = 21; 30 | SSDRandomCropPad ssd_random_crop_pad = 22; 31 | SSDRandomCropFixedAspectRatio ssd_random_crop_fixed_aspect_ratio = 23; 32 | } 33 | } 34 | 35 | // Normalizes pixel values in an image. 36 | // For every channel in the image, moves the pixel values from the range 37 | // [original_minval, original_maxval] to [target_minval, target_maxval]. 38 | message NormalizeImage { 39 | optional float original_minval = 1; 40 | optional float original_maxval = 2; 41 | optional float target_minval = 3 [default=0]; 42 | optional float target_maxval = 4 [default=1]; 43 | } 44 | 45 | // Randomly horizontally mirrors the image and detections 50% of the time. 46 | message RandomHorizontalFlip { 47 | } 48 | 49 | // Randomly scales the values of all pixels in the image by some constant value 50 | // between [minval, maxval], then clip the value to a range between [0, 1.0]. 51 | message RandomPixelValueScale { 52 | optional float minval = 1 [default=0.9]; 53 | optional float maxval = 2 [default=1.1]; 54 | } 55 | 56 | // Randomly enlarges or shrinks image (keeping aspect ratio). 57 | message RandomImageScale { 58 | optional float min_scale_ratio = 1 [default=0.5]; 59 | optional float max_scale_ratio = 2 [default=2.0]; 60 | } 61 | 62 | // Randomly convert entire image to grey scale. 63 | message RandomRGBtoGray { 64 | optional float probability = 1 [default=0.1]; 65 | } 66 | 67 | // Randomly changes image brightness by up to max_delta. Image outputs will be 68 | // saturated between 0 and 1. 69 | message RandomAdjustBrightness { 70 | optional float max_delta=1 [default=0.2]; 71 | } 72 | 73 | // Randomly scales contract by a value between [min_delta, max_delta]. 74 | message RandomAdjustContrast { 75 | optional float min_delta = 1 [default=0.8]; 76 | optional float max_delta = 2 [default=1.25]; 77 | } 78 | 79 | // Randomly alters hue by a value of up to max_delta. 80 | message RandomAdjustHue { 81 | optional float max_delta = 1 [default=0.02]; 82 | } 83 | 84 | // Randomly changes saturation by a value between [min_delta, max_delta]. 85 | message RandomAdjustSaturation { 86 | optional float min_delta = 1 [default=0.8]; 87 | optional float max_delta = 2 [default=1.25]; 88 | } 89 | 90 | // Performs a random color distortion. color_orderings should either be 0 or 1. 91 | message RandomDistortColor { 92 | optional int32 color_ordering = 1; 93 | } 94 | 95 | // Randomly jitters corners of boxes in the image determined by ratio. 96 | // ie. If a box is [100, 200] and ratio is 0.02, the corners can move by [1, 4]. 97 | message RandomJitterBoxes { 98 | optional float ratio = 1 [default=0.05]; 99 | } 100 | 101 | // Randomly crops the image and bounding boxes. 102 | message RandomCropImage { 103 | // Cropped image must cover at least one box by this fraction. 104 | optional float min_object_covered = 1 [default=1.0]; 105 | 106 | // Aspect ratio bounds of cropped image. 107 | optional float min_aspect_ratio = 2 [default=0.75]; 108 | optional float max_aspect_ratio = 3 [default=1.33]; 109 | 110 | // Allowed area ratio of cropped image to original image. 111 | optional float min_area = 4 [default=0.1]; 112 | optional float max_area = 5 [default=1.0]; 113 | 114 | // Minimum overlap threshold of cropped boxes to keep in new image. If the 115 | // ratio between a cropped bounding box and the original is less than this 116 | // value, it is removed from the new image. 117 | optional float overlap_thresh = 6 [default=0.3]; 118 | 119 | // Probability of keeping the original image. 120 | optional float random_coef = 7 [default=0.0]; 121 | } 122 | 123 | // Randomly adds padding to the image. 124 | message RandomPadImage { 125 | // Minimum dimensions for padded image. If unset, will use original image 126 | // dimension as a lower bound. 127 | optional float min_image_height = 1; 128 | optional float min_image_width = 2; 129 | 130 | // Maximum dimensions for padded image. If unset, will use double the original 131 | // image dimension as a lower bound. 132 | optional float max_image_height = 3; 133 | optional float max_image_width = 4; 134 | 135 | // Color of the padding. If unset, will pad using average color of the input 136 | // image. 137 | repeated float pad_color = 5; 138 | } 139 | 140 | // Randomly crops an image followed by a random pad. 141 | message RandomCropPadImage { 142 | // Cropping operation must cover at least one box by this fraction. 143 | optional float min_object_covered = 1 [default=1.0]; 144 | 145 | // Aspect ratio bounds of image after cropping operation. 146 | optional float min_aspect_ratio = 2 [default=0.75]; 147 | optional float max_aspect_ratio = 3 [default=1.33]; 148 | 149 | // Allowed area ratio of image after cropping operation. 150 | optional float min_area = 4 [default=0.1]; 151 | optional float max_area = 5 [default=1.0]; 152 | 153 | // Minimum overlap threshold of cropped boxes to keep in new image. If the 154 | // ratio between a cropped bounding box and the original is less than this 155 | // value, it is removed from the new image. 156 | optional float overlap_thresh = 6 [default=0.3]; 157 | 158 | // Probability of keeping the original image during the crop operation. 159 | optional float random_coef = 7 [default=0.0]; 160 | 161 | // Maximum dimensions for padded image. If unset, will use double the original 162 | // image dimension as a lower bound. Both of the following fields should be 163 | // length 2. 164 | repeated float min_padded_size_ratio = 8; 165 | repeated float max_padded_size_ratio = 9; 166 | 167 | // Color of the padding. If unset, will pad using average color of the input 168 | // image. 169 | repeated float pad_color = 10; 170 | } 171 | 172 | // Randomly crops an iamge to a given aspect ratio. 173 | message RandomCropToAspectRatio { 174 | // Aspect ratio. 175 | optional float aspect_ratio = 1 [default=1.0]; 176 | 177 | // Minimum overlap threshold of cropped boxes to keep in new image. If the 178 | // ratio between a cropped bounding box and the original is less than this 179 | // value, it is removed from the new image. 180 | optional float overlap_thresh = 2 [default=0.3]; 181 | } 182 | 183 | // Randomly adds black square patches to an image. 184 | message RandomBlackPatches { 185 | // The maximum number of black patches to add. 186 | optional int32 max_black_patches = 1 [default=10]; 187 | 188 | // The probability of a black patch being added to an image. 189 | optional float probability = 2 [default=0.5]; 190 | 191 | // Ratio between the dimension of the black patch to the minimum dimension of 192 | // the image (patch_width = patch_height = min(image_height, image_width)). 193 | optional float size_to_image_ratio = 3 [default=0.1]; 194 | } 195 | 196 | // Randomly resizes the image up to [target_height, target_width]. 197 | message RandomResizeMethod { 198 | optional float target_height = 1; 199 | optional float target_width = 2; 200 | } 201 | 202 | // Scales boxes from normalized coordinates to pixel coordinates. 203 | message ScaleBoxesToPixelCoordinates { 204 | } 205 | 206 | // Resizes images to [new_height, new_width]. 207 | message ResizeImage { 208 | optional int32 new_height = 1; 209 | optional int32 new_width = 2; 210 | enum Method { 211 | AREA=1; 212 | BICUBIC=2; 213 | BILINEAR=3; 214 | NEAREST_NEIGHBOR=4; 215 | } 216 | optional Method method = 3 [default=BILINEAR]; 217 | } 218 | 219 | // Normalizes an image by subtracting a mean from each channel. 220 | message SubtractChannelMean { 221 | // The mean to subtract from each channel. Should be of same dimension of 222 | // channels in the input image. 223 | repeated float means = 1; 224 | } 225 | 226 | message SSDRandomCropOperation { 227 | // Cropped image must cover at least this fraction of one original bounding 228 | // box. 229 | optional float min_object_covered = 1; 230 | 231 | // The aspect ratio of the cropped image must be within the range of 232 | // [min_aspect_ratio, max_aspect_ratio]. 233 | optional float min_aspect_ratio = 2; 234 | optional float max_aspect_ratio = 3; 235 | 236 | // The area of the cropped image must be within the range of 237 | // [min_area, max_area]. 238 | optional float min_area = 4; 239 | optional float max_area = 5; 240 | 241 | // Cropped box area ratio must be above this threhold to be kept. 242 | optional float overlap_thresh = 6; 243 | 244 | // Probability a crop operation is skipped. 245 | optional float random_coef = 7; 246 | } 247 | 248 | // Randomly crops a image according to: 249 | // Liu et al., SSD: Single shot multibox detector. 250 | // This preprocessing step defines multiple SSDRandomCropOperations. Only one 251 | // operation (chosen at random) is actually performed on an image. 252 | message SSDRandomCrop { 253 | repeated SSDRandomCropOperation operations = 1; 254 | } 255 | 256 | message SSDRandomCropPadOperation { 257 | // Cropped image must cover at least this fraction of one original bounding 258 | // box. 259 | optional float min_object_covered = 1; 260 | 261 | // The aspect ratio of the cropped image must be within the range of 262 | // [min_aspect_ratio, max_aspect_ratio]. 263 | optional float min_aspect_ratio = 2; 264 | optional float max_aspect_ratio = 3; 265 | 266 | // The area of the cropped image must be within the range of 267 | // [min_area, max_area]. 268 | optional float min_area = 4; 269 | optional float max_area = 5; 270 | 271 | // Cropped box area ratio must be above this threhold to be kept. 272 | optional float overlap_thresh = 6; 273 | 274 | // Probability a crop operation is skipped. 275 | optional float random_coef = 7; 276 | 277 | // Min ratio of padded image height and width to the input image's height and 278 | // width. Two entries per operation. 279 | repeated float min_padded_size_ratio = 8; 280 | 281 | // Max ratio of padded image height and width to the input image's height and 282 | // width. Two entries per operation. 283 | repeated float max_padded_size_ratio = 9; 284 | 285 | // Padding color. 286 | optional float pad_color_r = 10; 287 | optional float pad_color_g = 11; 288 | optional float pad_color_b = 12; 289 | } 290 | 291 | // Randomly crops and pads an image according to: 292 | // Liu et al., SSD: Single shot multibox detector. 293 | // This preprocessing step defines multiple SSDRandomCropPadOperations. Only one 294 | // operation (chosen at random) is actually performed on an image. 295 | message SSDRandomCropPad { 296 | repeated SSDRandomCropPadOperation operations = 1; 297 | } 298 | 299 | message SSDRandomCropFixedAspectRatioOperation { 300 | // Cropped image must cover at least this fraction of one original bounding 301 | // box. 302 | optional float min_object_covered = 1; 303 | 304 | // The area of the cropped image must be within the range of 305 | // [min_area, max_area]. 306 | optional float min_area = 4; 307 | optional float max_area = 5; 308 | 309 | // Cropped box area ratio must be above this threhold to be kept. 310 | optional float overlap_thresh = 6; 311 | 312 | // Probability a crop operation is skipped. 313 | optional float random_coef = 7; 314 | } 315 | 316 | // Randomly crops a image to a fixed aspect ratio according to: 317 | // Liu et al., SSD: Single shot multibox detector. 318 | // Multiple SSDRandomCropFixedAspectRatioOperations are defined by this 319 | // preprocessing step. Only one operation (chosen at random) is actually 320 | // performed on an image. 321 | message SSDRandomCropFixedAspectRatio { 322 | repeated SSDRandomCropFixedAspectRatioOperation operations = 1; 323 | 324 | // Aspect ratio to crop to. This value is used for all crop operations. 325 | optional float aspect_ratio = 2 [default=1.0]; 326 | } 327 | -------------------------------------------------------------------------------- /object_detection/protos/region_similarity_calculator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for region similarity calculators. See 6 | // core/region_similarity_calculator.py for details. 7 | message RegionSimilarityCalculator { 8 | oneof region_similarity { 9 | NegSqDistSimilarity neg_sq_dist_similarity = 1; 10 | IouSimilarity iou_similarity = 2; 11 | IoaSimilarity ioa_similarity = 3; 12 | } 13 | } 14 | 15 | // Configuration for negative squared distance similarity calculator. 16 | message NegSqDistSimilarity { 17 | } 18 | 19 | // Configuration for intersection-over-union (IOU) similarity calculator. 20 | message IouSimilarity { 21 | } 22 | 23 | // Configuration for intersection-over-area (IOA) similarity calculator. 24 | message IoaSimilarity { 25 | } 26 | -------------------------------------------------------------------------------- /object_detection/protos/region_similarity_calculator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/region_similarity_calculator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/region_similarity_calculator.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n:object_detection/protos/region_similarity_calculator.proto\x12\x17object_detection.protos\"\x85\x02\n\x1aRegionSimilarityCalculator\x12N\n\x16neg_sq_dist_similarity\x18\x01 \x01(\x0b\x32,.object_detection.protos.NegSqDistSimilarityH\x00\x12@\n\x0eiou_similarity\x18\x02 \x01(\x0b\x32&.object_detection.protos.IouSimilarityH\x00\x12@\n\x0eioa_similarity\x18\x03 \x01(\x0b\x32&.object_detection.protos.IoaSimilarityH\x00\x42\x13\n\x11region_similarity\"\x15\n\x13NegSqDistSimilarity\"\x0f\n\rIouSimilarity\"\x0f\n\rIoaSimilarity') 23 | ) 24 | 25 | 26 | 27 | 28 | _REGIONSIMILARITYCALCULATOR = _descriptor.Descriptor( 29 | name='RegionSimilarityCalculator', 30 | full_name='object_detection.protos.RegionSimilarityCalculator', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='neg_sq_dist_similarity', full_name='object_detection.protos.RegionSimilarityCalculator.neg_sq_dist_similarity', index=0, 37 | number=1, type=11, cpp_type=10, label=1, 38 | has_default_value=False, default_value=None, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='iou_similarity', full_name='object_detection.protos.RegionSimilarityCalculator.iou_similarity', index=1, 44 | number=2, type=11, cpp_type=10, label=1, 45 | has_default_value=False, default_value=None, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='ioa_similarity', full_name='object_detection.protos.RegionSimilarityCalculator.ioa_similarity', index=2, 51 | number=3, type=11, cpp_type=10, label=1, 52 | has_default_value=False, default_value=None, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto2', 65 | extension_ranges=[], 66 | oneofs=[ 67 | _descriptor.OneofDescriptor( 68 | name='region_similarity', full_name='object_detection.protos.RegionSimilarityCalculator.region_similarity', 69 | index=0, containing_type=None, fields=[]), 70 | ], 71 | serialized_start=88, 72 | serialized_end=349, 73 | ) 74 | 75 | 76 | _NEGSQDISTSIMILARITY = _descriptor.Descriptor( 77 | name='NegSqDistSimilarity', 78 | full_name='object_detection.protos.NegSqDistSimilarity', 79 | filename=None, 80 | file=DESCRIPTOR, 81 | containing_type=None, 82 | fields=[ 83 | ], 84 | extensions=[ 85 | ], 86 | nested_types=[], 87 | enum_types=[ 88 | ], 89 | options=None, 90 | is_extendable=False, 91 | syntax='proto2', 92 | extension_ranges=[], 93 | oneofs=[ 94 | ], 95 | serialized_start=351, 96 | serialized_end=372, 97 | ) 98 | 99 | 100 | _IOUSIMILARITY = _descriptor.Descriptor( 101 | name='IouSimilarity', 102 | full_name='object_detection.protos.IouSimilarity', 103 | filename=None, 104 | file=DESCRIPTOR, 105 | containing_type=None, 106 | fields=[ 107 | ], 108 | extensions=[ 109 | ], 110 | nested_types=[], 111 | enum_types=[ 112 | ], 113 | options=None, 114 | is_extendable=False, 115 | syntax='proto2', 116 | extension_ranges=[], 117 | oneofs=[ 118 | ], 119 | serialized_start=374, 120 | serialized_end=389, 121 | ) 122 | 123 | 124 | _IOASIMILARITY = _descriptor.Descriptor( 125 | name='IoaSimilarity', 126 | full_name='object_detection.protos.IoaSimilarity', 127 | filename=None, 128 | file=DESCRIPTOR, 129 | containing_type=None, 130 | fields=[ 131 | ], 132 | extensions=[ 133 | ], 134 | nested_types=[], 135 | enum_types=[ 136 | ], 137 | options=None, 138 | is_extendable=False, 139 | syntax='proto2', 140 | extension_ranges=[], 141 | oneofs=[ 142 | ], 143 | serialized_start=391, 144 | serialized_end=406, 145 | ) 146 | 147 | _REGIONSIMILARITYCALCULATOR.fields_by_name['neg_sq_dist_similarity'].message_type = _NEGSQDISTSIMILARITY 148 | _REGIONSIMILARITYCALCULATOR.fields_by_name['iou_similarity'].message_type = _IOUSIMILARITY 149 | _REGIONSIMILARITYCALCULATOR.fields_by_name['ioa_similarity'].message_type = _IOASIMILARITY 150 | _REGIONSIMILARITYCALCULATOR.oneofs_by_name['region_similarity'].fields.append( 151 | _REGIONSIMILARITYCALCULATOR.fields_by_name['neg_sq_dist_similarity']) 152 | _REGIONSIMILARITYCALCULATOR.fields_by_name['neg_sq_dist_similarity'].containing_oneof = _REGIONSIMILARITYCALCULATOR.oneofs_by_name['region_similarity'] 153 | _REGIONSIMILARITYCALCULATOR.oneofs_by_name['region_similarity'].fields.append( 154 | _REGIONSIMILARITYCALCULATOR.fields_by_name['iou_similarity']) 155 | _REGIONSIMILARITYCALCULATOR.fields_by_name['iou_similarity'].containing_oneof = _REGIONSIMILARITYCALCULATOR.oneofs_by_name['region_similarity'] 156 | _REGIONSIMILARITYCALCULATOR.oneofs_by_name['region_similarity'].fields.append( 157 | _REGIONSIMILARITYCALCULATOR.fields_by_name['ioa_similarity']) 158 | _REGIONSIMILARITYCALCULATOR.fields_by_name['ioa_similarity'].containing_oneof = _REGIONSIMILARITYCALCULATOR.oneofs_by_name['region_similarity'] 159 | DESCRIPTOR.message_types_by_name['RegionSimilarityCalculator'] = _REGIONSIMILARITYCALCULATOR 160 | DESCRIPTOR.message_types_by_name['NegSqDistSimilarity'] = _NEGSQDISTSIMILARITY 161 | DESCRIPTOR.message_types_by_name['IouSimilarity'] = _IOUSIMILARITY 162 | DESCRIPTOR.message_types_by_name['IoaSimilarity'] = _IOASIMILARITY 163 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 164 | 165 | RegionSimilarityCalculator = _reflection.GeneratedProtocolMessageType('RegionSimilarityCalculator', (_message.Message,), dict( 166 | DESCRIPTOR = _REGIONSIMILARITYCALCULATOR, 167 | __module__ = 'object_detection.protos.region_similarity_calculator_pb2' 168 | # @@protoc_insertion_point(class_scope:object_detection.protos.RegionSimilarityCalculator) 169 | )) 170 | _sym_db.RegisterMessage(RegionSimilarityCalculator) 171 | 172 | NegSqDistSimilarity = _reflection.GeneratedProtocolMessageType('NegSqDistSimilarity', (_message.Message,), dict( 173 | DESCRIPTOR = _NEGSQDISTSIMILARITY, 174 | __module__ = 'object_detection.protos.region_similarity_calculator_pb2' 175 | # @@protoc_insertion_point(class_scope:object_detection.protos.NegSqDistSimilarity) 176 | )) 177 | _sym_db.RegisterMessage(NegSqDistSimilarity) 178 | 179 | IouSimilarity = _reflection.GeneratedProtocolMessageType('IouSimilarity', (_message.Message,), dict( 180 | DESCRIPTOR = _IOUSIMILARITY, 181 | __module__ = 'object_detection.protos.region_similarity_calculator_pb2' 182 | # @@protoc_insertion_point(class_scope:object_detection.protos.IouSimilarity) 183 | )) 184 | _sym_db.RegisterMessage(IouSimilarity) 185 | 186 | IoaSimilarity = _reflection.GeneratedProtocolMessageType('IoaSimilarity', (_message.Message,), dict( 187 | DESCRIPTOR = _IOASIMILARITY, 188 | __module__ = 'object_detection.protos.region_similarity_calculator_pb2' 189 | # @@protoc_insertion_point(class_scope:object_detection.protos.IoaSimilarity) 190 | )) 191 | _sym_db.RegisterMessage(IoaSimilarity) 192 | 193 | 194 | # @@protoc_insertion_point(module_scope) 195 | -------------------------------------------------------------------------------- /object_detection/protos/square_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SquareBoxCoder. See 6 | // box_coders/square_box_coder.py for details. 7 | message SquareBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box length. 13 | optional float length_scale = 3 [default = 5.0]; 14 | } 15 | -------------------------------------------------------------------------------- /object_detection/protos/square_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/square_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/square_box_coder.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n.object_detection/protos/square_box_coder.proto\x12\x17object_detection.protos\"S\n\x0eSquareBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0clength_scale\x18\x03 \x01(\x02:\x01\x35') 23 | ) 24 | 25 | 26 | 27 | 28 | _SQUAREBOXCODER = _descriptor.Descriptor( 29 | name='SquareBoxCoder', 30 | full_name='object_detection.protos.SquareBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='y_scale', full_name='object_detection.protos.SquareBoxCoder.y_scale', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=float(10), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='x_scale', full_name='object_detection.protos.SquareBoxCoder.x_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=float(10), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='length_scale', full_name='object_detection.protos.SquareBoxCoder.length_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=float(5), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto2', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=75, 69 | serialized_end=158, 70 | ) 71 | 72 | DESCRIPTOR.message_types_by_name['SquareBoxCoder'] = _SQUAREBOXCODER 73 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 74 | 75 | SquareBoxCoder = _reflection.GeneratedProtocolMessageType('SquareBoxCoder', (_message.Message,), dict( 76 | DESCRIPTOR = _SQUAREBOXCODER, 77 | __module__ = 'object_detection.protos.square_box_coder_pb2' 78 | # @@protoc_insertion_point(class_scope:object_detection.protos.SquareBoxCoder) 79 | )) 80 | _sym_db.RegisterMessage(SquareBoxCoder) 81 | 82 | 83 | # @@protoc_insertion_point(module_scope) 84 | -------------------------------------------------------------------------------- /object_detection/protos/ssd.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package object_detection.protos; 3 | 4 | import "object_detection/protos/anchor_generator.proto"; 5 | import "object_detection/protos/box_coder.proto"; 6 | import "object_detection/protos/box_predictor.proto"; 7 | import "object_detection/protos/hyperparams.proto"; 8 | import "object_detection/protos/image_resizer.proto"; 9 | import "object_detection/protos/matcher.proto"; 10 | import "object_detection/protos/losses.proto"; 11 | import "object_detection/protos/post_processing.proto"; 12 | import "object_detection/protos/region_similarity_calculator.proto"; 13 | 14 | // Configuration for Single Shot Detection (SSD) models. 15 | message Ssd { 16 | 17 | // Number of classes to predict. 18 | optional int32 num_classes = 1; 19 | 20 | // Image resizer for preprocessing the input image. 21 | optional ImageResizer image_resizer = 2; 22 | 23 | // Feature extractor config. 24 | optional SsdFeatureExtractor feature_extractor = 3; 25 | 26 | // Box coder to encode the boxes. 27 | optional BoxCoder box_coder = 4; 28 | 29 | // Matcher to match groundtruth with anchors. 30 | optional Matcher matcher = 5; 31 | 32 | // Region similarity calculator to compute similarity of boxes. 33 | optional RegionSimilarityCalculator similarity_calculator = 6; 34 | 35 | // Box predictor to attach to the features. 36 | optional BoxPredictor box_predictor = 7; 37 | 38 | // Anchor generator to compute anchors. 39 | optional AnchorGenerator anchor_generator = 8; 40 | 41 | // Post processing to apply on the predictions. 42 | optional PostProcessing post_processing = 9; 43 | 44 | // Whether to normalize the loss by number of groundtruth boxes that match to 45 | // the anchors. 46 | optional bool normalize_loss_by_num_matches = 10 [default=true]; 47 | 48 | // Loss configuration for training. 49 | optional Loss loss = 11; 50 | } 51 | 52 | 53 | message SsdFeatureExtractor { 54 | // Type of ssd feature extractor. 55 | optional string type = 1; 56 | 57 | // The factor to alter the depth of the channels in the feature extractor. 58 | optional float depth_multiplier = 2 [default=1.0]; 59 | 60 | // Minimum number of the channels in the feature extractor. 61 | optional int32 min_depth = 3 [default=16]; 62 | 63 | // Hyperparameters for the feature extractor. 64 | optional Hyperparams conv_hyperparams = 4; 65 | } 66 | -------------------------------------------------------------------------------- /object_detection/protos/ssd_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SSD anchor generator described in 6 | // https://arxiv.org/abs/1512.02325. See 7 | // anchor_generators/multiple_grid_anchor_generator.py for details. 8 | message SsdAnchorGenerator { 9 | // Number of grid layers to create anchors for. 10 | optional int32 num_layers = 1 [default = 6]; 11 | 12 | // Scale of anchors corresponding to finest resolution. 13 | optional float min_scale = 2 [default = 0.2]; 14 | 15 | // Scale of anchors corresponding to coarsest resolution 16 | optional float max_scale = 3 [default = 0.95]; 17 | 18 | // Aspect ratios for anchors at each grid point. 19 | repeated float aspect_ratios = 4; 20 | 21 | // Whether to use the following aspect ratio and scale combination for the 22 | // layer with the finest resolution : (scale=0.1, aspect_ratio=1.0), 23 | // (scale=min_scale, aspect_ration=2.0), (scale=min_scale, aspect_ratio=0.5). 24 | optional bool reduce_boxes_in_lowest_layer = 5 [default = true]; 25 | } 26 | -------------------------------------------------------------------------------- /object_detection/protos/ssd_anchor_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/ssd_anchor_generator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/ssd_anchor_generator.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n2object_detection/protos/ssd_anchor_generator.proto\x12\x17object_detection.protos\"\x9f\x01\n\x12SsdAnchorGenerator\x12\x15\n\nnum_layers\x18\x01 \x01(\x05:\x01\x36\x12\x16\n\tmin_scale\x18\x02 \x01(\x02:\x03\x30.2\x12\x17\n\tmax_scale\x18\x03 \x01(\x02:\x04\x30.95\x12\x15\n\raspect_ratios\x18\x04 \x03(\x02\x12*\n\x1creduce_boxes_in_lowest_layer\x18\x05 \x01(\x08:\x04true') 23 | ) 24 | 25 | 26 | 27 | 28 | _SSDANCHORGENERATOR = _descriptor.Descriptor( 29 | name='SsdAnchorGenerator', 30 | full_name='object_detection.protos.SsdAnchorGenerator', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='num_layers', full_name='object_detection.protos.SsdAnchorGenerator.num_layers', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=True, default_value=6, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='min_scale', full_name='object_detection.protos.SsdAnchorGenerator.min_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=float(0.2), 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='max_scale', full_name='object_detection.protos.SsdAnchorGenerator.max_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=float(0.95), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='aspect_ratios', full_name='object_detection.protos.SsdAnchorGenerator.aspect_ratios', index=3, 58 | number=4, type=2, cpp_type=6, label=3, 59 | has_default_value=False, default_value=[], 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='reduce_boxes_in_lowest_layer', full_name='object_detection.protos.SsdAnchorGenerator.reduce_boxes_in_lowest_layer', index=4, 65 | number=5, type=8, cpp_type=7, label=1, 66 | has_default_value=True, default_value=True, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | syntax='proto2', 79 | extension_ranges=[], 80 | oneofs=[ 81 | ], 82 | serialized_start=80, 83 | serialized_end=239, 84 | ) 85 | 86 | DESCRIPTOR.message_types_by_name['SsdAnchorGenerator'] = _SSDANCHORGENERATOR 87 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 88 | 89 | SsdAnchorGenerator = _reflection.GeneratedProtocolMessageType('SsdAnchorGenerator', (_message.Message,), dict( 90 | DESCRIPTOR = _SSDANCHORGENERATOR, 91 | __module__ = 'object_detection.protos.ssd_anchor_generator_pb2' 92 | # @@protoc_insertion_point(class_scope:object_detection.protos.SsdAnchorGenerator) 93 | )) 94 | _sym_db.RegisterMessage(SsdAnchorGenerator) 95 | 96 | 97 | # @@protoc_insertion_point(module_scope) 98 | -------------------------------------------------------------------------------- /object_detection/protos/ssd_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/ssd.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from object_detection.protos import anchor_generator_pb2 as object__detection_dot_protos_dot_anchor__generator__pb2 17 | from object_detection.protos import box_coder_pb2 as object__detection_dot_protos_dot_box__coder__pb2 18 | from object_detection.protos import box_predictor_pb2 as object__detection_dot_protos_dot_box__predictor__pb2 19 | from object_detection.protos import hyperparams_pb2 as object__detection_dot_protos_dot_hyperparams__pb2 20 | from object_detection.protos import image_resizer_pb2 as object__detection_dot_protos_dot_image__resizer__pb2 21 | from object_detection.protos import matcher_pb2 as object__detection_dot_protos_dot_matcher__pb2 22 | from object_detection.protos import losses_pb2 as object__detection_dot_protos_dot_losses__pb2 23 | from object_detection.protos import post_processing_pb2 as object__detection_dot_protos_dot_post__processing__pb2 24 | from object_detection.protos import region_similarity_calculator_pb2 as object__detection_dot_protos_dot_region__similarity__calculator__pb2 25 | 26 | 27 | DESCRIPTOR = _descriptor.FileDescriptor( 28 | name='object_detection/protos/ssd.proto', 29 | package='object_detection.protos', 30 | syntax='proto2', 31 | serialized_pb=_b('\n!object_detection/protos/ssd.proto\x12\x17object_detection.protos\x1a.object_detection/protos/anchor_generator.proto\x1a\'object_detection/protos/box_coder.proto\x1a+object_detection/protos/box_predictor.proto\x1a)object_detection/protos/hyperparams.proto\x1a+object_detection/protos/image_resizer.proto\x1a%object_detection/protos/matcher.proto\x1a$object_detection/protos/losses.proto\x1a-object_detection/protos/post_processing.proto\x1a:object_detection/protos/region_similarity_calculator.proto\"\xfc\x04\n\x03Ssd\x12\x13\n\x0bnum_classes\x18\x01 \x01(\x05\x12<\n\rimage_resizer\x18\x02 \x01(\x0b\x32%.object_detection.protos.ImageResizer\x12G\n\x11\x66\x65\x61ture_extractor\x18\x03 \x01(\x0b\x32,.object_detection.protos.SsdFeatureExtractor\x12\x34\n\tbox_coder\x18\x04 \x01(\x0b\x32!.object_detection.protos.BoxCoder\x12\x31\n\x07matcher\x18\x05 \x01(\x0b\x32 .object_detection.protos.Matcher\x12R\n\x15similarity_calculator\x18\x06 \x01(\x0b\x32\x33.object_detection.protos.RegionSimilarityCalculator\x12<\n\rbox_predictor\x18\x07 \x01(\x0b\x32%.object_detection.protos.BoxPredictor\x12\x42\n\x10\x61nchor_generator\x18\x08 \x01(\x0b\x32(.object_detection.protos.AnchorGenerator\x12@\n\x0fpost_processing\x18\t \x01(\x0b\x32\'.object_detection.protos.PostProcessing\x12+\n\x1dnormalize_loss_by_num_matches\x18\n \x01(\x08:\x04true\x12+\n\x04loss\x18\x0b \x01(\x0b\x32\x1d.object_detection.protos.Loss\"\x97\x01\n\x13SsdFeatureExtractor\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x1b\n\x10\x64\x65pth_multiplier\x18\x02 \x01(\x02:\x01\x31\x12\x15\n\tmin_depth\x18\x03 \x01(\x05:\x02\x31\x36\x12>\n\x10\x63onv_hyperparams\x18\x04 \x01(\x0b\x32$.object_detection.protos.Hyperparams') 32 | , 33 | dependencies=[object__detection_dot_protos_dot_anchor__generator__pb2.DESCRIPTOR,object__detection_dot_protos_dot_box__coder__pb2.DESCRIPTOR,object__detection_dot_protos_dot_box__predictor__pb2.DESCRIPTOR,object__detection_dot_protos_dot_hyperparams__pb2.DESCRIPTOR,object__detection_dot_protos_dot_image__resizer__pb2.DESCRIPTOR,object__detection_dot_protos_dot_matcher__pb2.DESCRIPTOR,object__detection_dot_protos_dot_losses__pb2.DESCRIPTOR,object__detection_dot_protos_dot_post__processing__pb2.DESCRIPTOR,object__detection_dot_protos_dot_region__similarity__calculator__pb2.DESCRIPTOR,]) 34 | 35 | 36 | 37 | 38 | _SSD = _descriptor.Descriptor( 39 | name='Ssd', 40 | full_name='object_detection.protos.Ssd', 41 | filename=None, 42 | file=DESCRIPTOR, 43 | containing_type=None, 44 | fields=[ 45 | _descriptor.FieldDescriptor( 46 | name='num_classes', full_name='object_detection.protos.Ssd.num_classes', index=0, 47 | number=1, type=5, cpp_type=1, label=1, 48 | has_default_value=False, default_value=0, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | _descriptor.FieldDescriptor( 53 | name='image_resizer', full_name='object_detection.protos.Ssd.image_resizer', index=1, 54 | number=2, type=11, cpp_type=10, label=1, 55 | has_default_value=False, default_value=None, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None), 59 | _descriptor.FieldDescriptor( 60 | name='feature_extractor', full_name='object_detection.protos.Ssd.feature_extractor', index=2, 61 | number=3, type=11, cpp_type=10, label=1, 62 | has_default_value=False, default_value=None, 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | options=None), 66 | _descriptor.FieldDescriptor( 67 | name='box_coder', full_name='object_detection.protos.Ssd.box_coder', index=3, 68 | number=4, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | options=None), 73 | _descriptor.FieldDescriptor( 74 | name='matcher', full_name='object_detection.protos.Ssd.matcher', index=4, 75 | number=5, type=11, cpp_type=10, label=1, 76 | has_default_value=False, default_value=None, 77 | message_type=None, enum_type=None, containing_type=None, 78 | is_extension=False, extension_scope=None, 79 | options=None), 80 | _descriptor.FieldDescriptor( 81 | name='similarity_calculator', full_name='object_detection.protos.Ssd.similarity_calculator', index=5, 82 | number=6, type=11, cpp_type=10, label=1, 83 | has_default_value=False, default_value=None, 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | options=None), 87 | _descriptor.FieldDescriptor( 88 | name='box_predictor', full_name='object_detection.protos.Ssd.box_predictor', index=6, 89 | number=7, type=11, cpp_type=10, label=1, 90 | has_default_value=False, default_value=None, 91 | message_type=None, enum_type=None, containing_type=None, 92 | is_extension=False, extension_scope=None, 93 | options=None), 94 | _descriptor.FieldDescriptor( 95 | name='anchor_generator', full_name='object_detection.protos.Ssd.anchor_generator', index=7, 96 | number=8, type=11, cpp_type=10, label=1, 97 | has_default_value=False, default_value=None, 98 | message_type=None, enum_type=None, containing_type=None, 99 | is_extension=False, extension_scope=None, 100 | options=None), 101 | _descriptor.FieldDescriptor( 102 | name='post_processing', full_name='object_detection.protos.Ssd.post_processing', index=8, 103 | number=9, type=11, cpp_type=10, label=1, 104 | has_default_value=False, default_value=None, 105 | message_type=None, enum_type=None, containing_type=None, 106 | is_extension=False, extension_scope=None, 107 | options=None), 108 | _descriptor.FieldDescriptor( 109 | name='normalize_loss_by_num_matches', full_name='object_detection.protos.Ssd.normalize_loss_by_num_matches', index=9, 110 | number=10, type=8, cpp_type=7, label=1, 111 | has_default_value=True, default_value=True, 112 | message_type=None, enum_type=None, containing_type=None, 113 | is_extension=False, extension_scope=None, 114 | options=None), 115 | _descriptor.FieldDescriptor( 116 | name='loss', full_name='object_detection.protos.Ssd.loss', index=10, 117 | number=11, type=11, cpp_type=10, label=1, 118 | has_default_value=False, default_value=None, 119 | message_type=None, enum_type=None, containing_type=None, 120 | is_extension=False, extension_scope=None, 121 | options=None), 122 | ], 123 | extensions=[ 124 | ], 125 | nested_types=[], 126 | enum_types=[ 127 | ], 128 | options=None, 129 | is_extendable=False, 130 | syntax='proto2', 131 | extension_ranges=[], 132 | oneofs=[ 133 | ], 134 | serialized_start=469, 135 | serialized_end=1105, 136 | ) 137 | 138 | 139 | _SSDFEATUREEXTRACTOR = _descriptor.Descriptor( 140 | name='SsdFeatureExtractor', 141 | full_name='object_detection.protos.SsdFeatureExtractor', 142 | filename=None, 143 | file=DESCRIPTOR, 144 | containing_type=None, 145 | fields=[ 146 | _descriptor.FieldDescriptor( 147 | name='type', full_name='object_detection.protos.SsdFeatureExtractor.type', index=0, 148 | number=1, type=9, cpp_type=9, label=1, 149 | has_default_value=False, default_value=_b("").decode('utf-8'), 150 | message_type=None, enum_type=None, containing_type=None, 151 | is_extension=False, extension_scope=None, 152 | options=None), 153 | _descriptor.FieldDescriptor( 154 | name='depth_multiplier', full_name='object_detection.protos.SsdFeatureExtractor.depth_multiplier', index=1, 155 | number=2, type=2, cpp_type=6, label=1, 156 | has_default_value=True, default_value=float(1), 157 | message_type=None, enum_type=None, containing_type=None, 158 | is_extension=False, extension_scope=None, 159 | options=None), 160 | _descriptor.FieldDescriptor( 161 | name='min_depth', full_name='object_detection.protos.SsdFeatureExtractor.min_depth', index=2, 162 | number=3, type=5, cpp_type=1, label=1, 163 | has_default_value=True, default_value=16, 164 | message_type=None, enum_type=None, containing_type=None, 165 | is_extension=False, extension_scope=None, 166 | options=None), 167 | _descriptor.FieldDescriptor( 168 | name='conv_hyperparams', full_name='object_detection.protos.SsdFeatureExtractor.conv_hyperparams', index=3, 169 | number=4, type=11, cpp_type=10, label=1, 170 | has_default_value=False, default_value=None, 171 | message_type=None, enum_type=None, containing_type=None, 172 | is_extension=False, extension_scope=None, 173 | options=None), 174 | ], 175 | extensions=[ 176 | ], 177 | nested_types=[], 178 | enum_types=[ 179 | ], 180 | options=None, 181 | is_extendable=False, 182 | syntax='proto2', 183 | extension_ranges=[], 184 | oneofs=[ 185 | ], 186 | serialized_start=1108, 187 | serialized_end=1259, 188 | ) 189 | 190 | _SSD.fields_by_name['image_resizer'].message_type = object__detection_dot_protos_dot_image__resizer__pb2._IMAGERESIZER 191 | _SSD.fields_by_name['feature_extractor'].message_type = _SSDFEATUREEXTRACTOR 192 | _SSD.fields_by_name['box_coder'].message_type = object__detection_dot_protos_dot_box__coder__pb2._BOXCODER 193 | _SSD.fields_by_name['matcher'].message_type = object__detection_dot_protos_dot_matcher__pb2._MATCHER 194 | _SSD.fields_by_name['similarity_calculator'].message_type = object__detection_dot_protos_dot_region__similarity__calculator__pb2._REGIONSIMILARITYCALCULATOR 195 | _SSD.fields_by_name['box_predictor'].message_type = object__detection_dot_protos_dot_box__predictor__pb2._BOXPREDICTOR 196 | _SSD.fields_by_name['anchor_generator'].message_type = object__detection_dot_protos_dot_anchor__generator__pb2._ANCHORGENERATOR 197 | _SSD.fields_by_name['post_processing'].message_type = object__detection_dot_protos_dot_post__processing__pb2._POSTPROCESSING 198 | _SSD.fields_by_name['loss'].message_type = object__detection_dot_protos_dot_losses__pb2._LOSS 199 | _SSDFEATUREEXTRACTOR.fields_by_name['conv_hyperparams'].message_type = object__detection_dot_protos_dot_hyperparams__pb2._HYPERPARAMS 200 | DESCRIPTOR.message_types_by_name['Ssd'] = _SSD 201 | DESCRIPTOR.message_types_by_name['SsdFeatureExtractor'] = _SSDFEATUREEXTRACTOR 202 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 203 | 204 | Ssd = _reflection.GeneratedProtocolMessageType('Ssd', (_message.Message,), dict( 205 | DESCRIPTOR = _SSD, 206 | __module__ = 'object_detection.protos.ssd_pb2' 207 | # @@protoc_insertion_point(class_scope:object_detection.protos.Ssd) 208 | )) 209 | _sym_db.RegisterMessage(Ssd) 210 | 211 | SsdFeatureExtractor = _reflection.GeneratedProtocolMessageType('SsdFeatureExtractor', (_message.Message,), dict( 212 | DESCRIPTOR = _SSDFEATUREEXTRACTOR, 213 | __module__ = 'object_detection.protos.ssd_pb2' 214 | # @@protoc_insertion_point(class_scope:object_detection.protos.SsdFeatureExtractor) 215 | )) 216 | _sym_db.RegisterMessage(SsdFeatureExtractor) 217 | 218 | 219 | # @@protoc_insertion_point(module_scope) 220 | -------------------------------------------------------------------------------- /object_detection/protos/string_int_label_map.proto: -------------------------------------------------------------------------------- 1 | // Message to store the mapping from class label strings to class id. Datasets 2 | // use string labels to represent classes while the object detection framework 3 | // works with class ids. This message maps them so they can be converted back 4 | // and forth as needed. 5 | syntax = "proto2"; 6 | 7 | package object_detection.protos; 8 | 9 | message StringIntLabelMapItem { 10 | // String name. The most common practice is to set this to a MID or synsets 11 | // id. 12 | optional string name = 1; 13 | 14 | // Integer id that maps to the string name above. Label ids should start from 15 | // 1. 16 | optional int32 id = 2; 17 | 18 | // Human readable string label. 19 | optional string display_name = 3; 20 | }; 21 | 22 | message StringIntLabelMap { 23 | repeated StringIntLabelMapItem item = 1; 24 | }; 25 | -------------------------------------------------------------------------------- /object_detection/protos/string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/string_int_label_map.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/string_int_label_map.proto', 20 | package='object_detection.protos', 21 | syntax='proto2', 22 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 23 | ) 24 | 25 | 26 | 27 | 28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 29 | name='StringIntLabelMapItem', 30 | full_name='object_detection.protos.StringIntLabelMapItem', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | syntax='proto2', 65 | extension_ranges=[], 66 | oneofs=[ 67 | ], 68 | serialized_start=79, 69 | serialized_end=150, 70 | ) 71 | 72 | 73 | _STRINGINTLABELMAP = _descriptor.Descriptor( 74 | name='StringIntLabelMap', 75 | full_name='object_detection.protos.StringIntLabelMap', 76 | filename=None, 77 | file=DESCRIPTOR, 78 | containing_type=None, 79 | fields=[ 80 | _descriptor.FieldDescriptor( 81 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 82 | number=1, type=11, cpp_type=10, label=3, 83 | has_default_value=False, default_value=[], 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | options=None), 87 | ], 88 | extensions=[ 89 | ], 90 | nested_types=[], 91 | enum_types=[ 92 | ], 93 | options=None, 94 | is_extendable=False, 95 | syntax='proto2', 96 | extension_ranges=[], 97 | oneofs=[ 98 | ], 99 | serialized_start=152, 100 | serialized_end=233, 101 | ) 102 | 103 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 104 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 105 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 106 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 107 | 108 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 109 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 110 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 111 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 112 | )) 113 | _sym_db.RegisterMessage(StringIntLabelMapItem) 114 | 115 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 116 | DESCRIPTOR = _STRINGINTLABELMAP, 117 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 118 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 119 | )) 120 | _sym_db.RegisterMessage(StringIntLabelMap) 121 | 122 | 123 | # @@protoc_insertion_point(module_scope) 124 | -------------------------------------------------------------------------------- /object_detection/protos/train.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/optimizer.proto"; 6 | import "object_detection/protos/preprocessor.proto"; 7 | 8 | // Message for configuring DetectionModel training jobs (train.py). 9 | message TrainConfig { 10 | // Input queue batch size. 11 | optional uint32 batch_size = 1 [default=32]; 12 | 13 | // Data augmentation options. 14 | repeated PreprocessingStep data_augmentation_options = 2; 15 | 16 | // Whether to synchronize replicas during training. 17 | optional bool sync_replicas = 3 [default=false]; 18 | 19 | // How frequently to keep checkpoints. 20 | optional uint32 keep_checkpoint_every_n_hours = 4 [default=1000]; 21 | 22 | // Optimizer used to train the DetectionModel. 23 | optional Optimizer optimizer = 5; 24 | 25 | // If greater than 0, clips gradients by this value. 26 | optional float gradient_clipping_by_norm = 6 [default=0.0]; 27 | 28 | // Checkpoint to restore variables from. Typically used to load feature 29 | // extractor variables trained outside of object detection. 30 | optional string fine_tune_checkpoint = 7 [default=""]; 31 | 32 | // Specifies if the finetune checkpoint is from an object detection model. 33 | // If from an object detection model, the model being trained should have 34 | // the same parameters with the exception of the num_classes parameter. 35 | // If false, it assumes the checkpoint was a object classification model. 36 | optional bool from_detection_checkpoint = 8 [default=false]; 37 | 38 | // Number of steps to train the DetectionModel for. If 0, will train the model 39 | // indefinitely. 40 | optional uint32 num_steps = 9 [default=0]; 41 | 42 | // Number of training steps between replica startup. 43 | // This flag must be set to 0 if sync_replicas is set to true. 44 | optional float startup_delay_steps = 10 [default=15]; 45 | 46 | // If greater than 0, multiplies the gradient of bias variables by this 47 | // amount. 48 | optional float bias_grad_multiplier = 11 [default=0]; 49 | 50 | // Variables that should not be updated during training. 51 | repeated string freeze_variables = 12; 52 | 53 | // Number of replicas to aggregate before making parameter updates. 54 | optional int32 replicas_to_aggregate = 13 [default=1]; 55 | 56 | // Maximum number of elements to store within a queue. 57 | optional int32 batch_queue_capacity = 14 [default=600]; 58 | 59 | // Number of threads to use for batching. 60 | optional int32 num_batch_queue_threads = 15 [default=8]; 61 | 62 | // Maximum capacity of the queue used to prefetch assembled batches. 63 | optional int32 prefetch_queue_capacity = 16 [default=10]; 64 | } 65 | -------------------------------------------------------------------------------- /object_detection/protos/train_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/train.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | from object_detection.protos import optimizer_pb2 as object__detection_dot_protos_dot_optimizer__pb2 17 | from object_detection.protos import preprocessor_pb2 as object__detection_dot_protos_dot_preprocessor__pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/train.proto', 22 | package='object_detection.protos', 23 | syntax='proto2', 24 | serialized_pb=_b('\n#object_detection/protos/train.proto\x12\x17object_detection.protos\x1a\'object_detection/protos/optimizer.proto\x1a*object_detection/protos/preprocessor.proto\"\xe6\x04\n\x0bTrainConfig\x12\x16\n\nbatch_size\x18\x01 \x01(\r:\x02\x33\x32\x12M\n\x19\x64\x61ta_augmentation_options\x18\x02 \x03(\x0b\x32*.object_detection.protos.PreprocessingStep\x12\x1c\n\rsync_replicas\x18\x03 \x01(\x08:\x05\x66\x61lse\x12+\n\x1dkeep_checkpoint_every_n_hours\x18\x04 \x01(\r:\x04\x31\x30\x30\x30\x12\x35\n\toptimizer\x18\x05 \x01(\x0b\x32\".object_detection.protos.Optimizer\x12$\n\x19gradient_clipping_by_norm\x18\x06 \x01(\x02:\x01\x30\x12\x1e\n\x14\x66ine_tune_checkpoint\x18\x07 \x01(\t:\x00\x12(\n\x19\x66rom_detection_checkpoint\x18\x08 \x01(\x08:\x05\x66\x61lse\x12\x14\n\tnum_steps\x18\t \x01(\r:\x01\x30\x12\x1f\n\x13startup_delay_steps\x18\n \x01(\x02:\x02\x31\x35\x12\x1f\n\x14\x62ias_grad_multiplier\x18\x0b \x01(\x02:\x01\x30\x12\x18\n\x10\x66reeze_variables\x18\x0c \x03(\t\x12 \n\x15replicas_to_aggregate\x18\r \x01(\x05:\x01\x31\x12!\n\x14\x62\x61tch_queue_capacity\x18\x0e \x01(\x05:\x03\x36\x30\x30\x12\"\n\x17num_batch_queue_threads\x18\x0f \x01(\x05:\x01\x38\x12#\n\x17prefetch_queue_capacity\x18\x10 \x01(\x05:\x02\x31\x30') 25 | , 26 | dependencies=[object__detection_dot_protos_dot_optimizer__pb2.DESCRIPTOR,object__detection_dot_protos_dot_preprocessor__pb2.DESCRIPTOR,]) 27 | 28 | 29 | 30 | 31 | _TRAINCONFIG = _descriptor.Descriptor( 32 | name='TrainConfig', 33 | full_name='object_detection.protos.TrainConfig', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='batch_size', full_name='object_detection.protos.TrainConfig.batch_size', index=0, 40 | number=1, type=13, cpp_type=3, label=1, 41 | has_default_value=True, default_value=32, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='data_augmentation_options', full_name='object_detection.protos.TrainConfig.data_augmentation_options', index=1, 47 | number=2, type=11, cpp_type=10, label=3, 48 | has_default_value=False, default_value=[], 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | _descriptor.FieldDescriptor( 53 | name='sync_replicas', full_name='object_detection.protos.TrainConfig.sync_replicas', index=2, 54 | number=3, type=8, cpp_type=7, label=1, 55 | has_default_value=True, default_value=False, 56 | message_type=None, enum_type=None, containing_type=None, 57 | is_extension=False, extension_scope=None, 58 | options=None), 59 | _descriptor.FieldDescriptor( 60 | name='keep_checkpoint_every_n_hours', full_name='object_detection.protos.TrainConfig.keep_checkpoint_every_n_hours', index=3, 61 | number=4, type=13, cpp_type=3, label=1, 62 | has_default_value=True, default_value=1000, 63 | message_type=None, enum_type=None, containing_type=None, 64 | is_extension=False, extension_scope=None, 65 | options=None), 66 | _descriptor.FieldDescriptor( 67 | name='optimizer', full_name='object_detection.protos.TrainConfig.optimizer', index=4, 68 | number=5, type=11, cpp_type=10, label=1, 69 | has_default_value=False, default_value=None, 70 | message_type=None, enum_type=None, containing_type=None, 71 | is_extension=False, extension_scope=None, 72 | options=None), 73 | _descriptor.FieldDescriptor( 74 | name='gradient_clipping_by_norm', full_name='object_detection.protos.TrainConfig.gradient_clipping_by_norm', index=5, 75 | number=6, type=2, cpp_type=6, label=1, 76 | has_default_value=True, default_value=float(0), 77 | message_type=None, enum_type=None, containing_type=None, 78 | is_extension=False, extension_scope=None, 79 | options=None), 80 | _descriptor.FieldDescriptor( 81 | name='fine_tune_checkpoint', full_name='object_detection.protos.TrainConfig.fine_tune_checkpoint', index=6, 82 | number=7, type=9, cpp_type=9, label=1, 83 | has_default_value=True, default_value=_b("").decode('utf-8'), 84 | message_type=None, enum_type=None, containing_type=None, 85 | is_extension=False, extension_scope=None, 86 | options=None), 87 | _descriptor.FieldDescriptor( 88 | name='from_detection_checkpoint', full_name='object_detection.protos.TrainConfig.from_detection_checkpoint', index=7, 89 | number=8, type=8, cpp_type=7, label=1, 90 | has_default_value=True, default_value=False, 91 | message_type=None, enum_type=None, containing_type=None, 92 | is_extension=False, extension_scope=None, 93 | options=None), 94 | _descriptor.FieldDescriptor( 95 | name='num_steps', full_name='object_detection.protos.TrainConfig.num_steps', index=8, 96 | number=9, type=13, cpp_type=3, label=1, 97 | has_default_value=True, default_value=0, 98 | message_type=None, enum_type=None, containing_type=None, 99 | is_extension=False, extension_scope=None, 100 | options=None), 101 | _descriptor.FieldDescriptor( 102 | name='startup_delay_steps', full_name='object_detection.protos.TrainConfig.startup_delay_steps', index=9, 103 | number=10, type=2, cpp_type=6, label=1, 104 | has_default_value=True, default_value=float(15), 105 | message_type=None, enum_type=None, containing_type=None, 106 | is_extension=False, extension_scope=None, 107 | options=None), 108 | _descriptor.FieldDescriptor( 109 | name='bias_grad_multiplier', full_name='object_detection.protos.TrainConfig.bias_grad_multiplier', index=10, 110 | number=11, type=2, cpp_type=6, label=1, 111 | has_default_value=True, default_value=float(0), 112 | message_type=None, enum_type=None, containing_type=None, 113 | is_extension=False, extension_scope=None, 114 | options=None), 115 | _descriptor.FieldDescriptor( 116 | name='freeze_variables', full_name='object_detection.protos.TrainConfig.freeze_variables', index=11, 117 | number=12, type=9, cpp_type=9, label=3, 118 | has_default_value=False, default_value=[], 119 | message_type=None, enum_type=None, containing_type=None, 120 | is_extension=False, extension_scope=None, 121 | options=None), 122 | _descriptor.FieldDescriptor( 123 | name='replicas_to_aggregate', full_name='object_detection.protos.TrainConfig.replicas_to_aggregate', index=12, 124 | number=13, type=5, cpp_type=1, label=1, 125 | has_default_value=True, default_value=1, 126 | message_type=None, enum_type=None, containing_type=None, 127 | is_extension=False, extension_scope=None, 128 | options=None), 129 | _descriptor.FieldDescriptor( 130 | name='batch_queue_capacity', full_name='object_detection.protos.TrainConfig.batch_queue_capacity', index=13, 131 | number=14, type=5, cpp_type=1, label=1, 132 | has_default_value=True, default_value=600, 133 | message_type=None, enum_type=None, containing_type=None, 134 | is_extension=False, extension_scope=None, 135 | options=None), 136 | _descriptor.FieldDescriptor( 137 | name='num_batch_queue_threads', full_name='object_detection.protos.TrainConfig.num_batch_queue_threads', index=14, 138 | number=15, type=5, cpp_type=1, label=1, 139 | has_default_value=True, default_value=8, 140 | message_type=None, enum_type=None, containing_type=None, 141 | is_extension=False, extension_scope=None, 142 | options=None), 143 | _descriptor.FieldDescriptor( 144 | name='prefetch_queue_capacity', full_name='object_detection.protos.TrainConfig.prefetch_queue_capacity', index=15, 145 | number=16, type=5, cpp_type=1, label=1, 146 | has_default_value=True, default_value=10, 147 | message_type=None, enum_type=None, containing_type=None, 148 | is_extension=False, extension_scope=None, 149 | options=None), 150 | ], 151 | extensions=[ 152 | ], 153 | nested_types=[], 154 | enum_types=[ 155 | ], 156 | options=None, 157 | is_extendable=False, 158 | syntax='proto2', 159 | extension_ranges=[], 160 | oneofs=[ 161 | ], 162 | serialized_start=150, 163 | serialized_end=764, 164 | ) 165 | 166 | _TRAINCONFIG.fields_by_name['data_augmentation_options'].message_type = object__detection_dot_protos_dot_preprocessor__pb2._PREPROCESSINGSTEP 167 | _TRAINCONFIG.fields_by_name['optimizer'].message_type = object__detection_dot_protos_dot_optimizer__pb2._OPTIMIZER 168 | DESCRIPTOR.message_types_by_name['TrainConfig'] = _TRAINCONFIG 169 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 170 | 171 | TrainConfig = _reflection.GeneratedProtocolMessageType('TrainConfig', (_message.Message,), dict( 172 | DESCRIPTOR = _TRAINCONFIG, 173 | __module__ = 'object_detection.protos.train_pb2' 174 | # @@protoc_insertion_point(class_scope:object_detection.protos.TrainConfig) 175 | )) 176 | _sym_db.RegisterMessage(TrainConfig) 177 | 178 | 179 | # @@protoc_insertion_point(module_scope) 180 | -------------------------------------------------------------------------------- /object_detector.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class ObjectDetector(ABC): 5 | @abstractmethod 6 | def detect(self, frame, threshold=0.0): 7 | pass -------------------------------------------------------------------------------- /object_detector_detection_api.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from utils import label_map_util 7 | from object_detector import ObjectDetector 8 | 9 | 10 | basepath = path.dirname(__file__) 11 | 12 | # List of the strings that is used to add correct label for each box. 13 | PATH_TO_LABELS = path.join(basepath, 'data', 'mscoco_label_map.pbtxt') 14 | 15 | NUM_CLASSES = 90 16 | 17 | 18 | class ObjectDetectorDetectionAPI(ObjectDetector): 19 | def __init__(self, graph_path='frozen_inference_graph.pb'): 20 | """ 21 | Builds Tensorflow graph, load model and labels 22 | """ 23 | 24 | # model_path = path.join(basepath, graph_path) 25 | 26 | # Load Tensorflow model into memory 27 | self.detection_graph = tf.Graph() 28 | with self.detection_graph.as_default(): 29 | od_graph_def = tf.GraphDef() 30 | with tf.gfile.GFile(graph_path, 'rb') as fid: 31 | serialized_graph = fid.read() 32 | od_graph_def.ParseFromString(serialized_graph) 33 | tf.import_graph_def(od_graph_def, name='') 34 | 35 | # Load lebel_map 36 | self._load_label(PATH_TO_LABELS, NUM_CLASSES, use_disp_name=True) 37 | 38 | with self.detection_graph.as_default(): 39 | self.sess = tf.Session(graph=self.detection_graph) 40 | # Definite input and output Tensors for detection_graph 41 | self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0') 42 | # Each box represents a part of the image where a particular object was detected. 43 | self.detection_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0') 44 | # Each score represent how level of confidence for each of the objects. 45 | # Score is shown on the result image, together with the class label. 46 | self.detection_scores = self.detection_graph.get_tensor_by_name('detection_scores:0') 47 | self.detection_classes = self.detection_graph.get_tensor_by_name('detection_classes:0') 48 | self.num_detections = self.detection_graph.get_tensor_by_name('num_detections:0') 49 | 50 | def close(self): 51 | tf.reset_default_graph() 52 | self.sess = tf.InteractiveSession() 53 | 54 | def detect(self, frame, threshold=0.1): 55 | """ 56 | Predicts person in frame with threshold level of confidence 57 | Returns list with top-left, bottom-right coordinates and list with labels, confidence in % 58 | """ 59 | frames = np.expand_dims(frame, axis=0) 60 | # Actual detection. 61 | (boxes, scores, classes, num) = self.sess.run( 62 | [self.detection_boxes, self.detection_scores, 63 | self.detection_classes, self.num_detections], 64 | feed_dict={self.image_tensor: frames}) 65 | 66 | # Find detected boxes coordinates 67 | return [self._boxes_coordinates(frame, 68 | np.squeeze(boxes[0]), 69 | np.squeeze(i[2]).astype(np.int32), 70 | np.squeeze(i[3]), 71 | min_score_thresh=threshold, 72 | ) for i in zip(frames, boxes, classes, scores)][0] 73 | 74 | 75 | def _boxes_coordinates(self, 76 | image, 77 | boxes, 78 | classes, 79 | scores, 80 | max_boxes_to_draw=20, 81 | min_score_thresh=.5): 82 | """ 83 | This function groups boxes that correspond to the same location 84 | and creates a display string for each detection and overlays these 85 | on the image. 86 | 87 | Args: 88 | image: uint8 numpy array with shape (img_height, img_width, 3) 89 | boxes: a numpy array of shape [N, 4] 90 | classes: a numpy array of shape [N] 91 | scores: a numpy array of shape [N] or None. If scores=None, then 92 | this function assumes that the boxes to be plotted are groundtruth 93 | boxes and plot all boxes as black with no classes or scores. 94 | category_index: a dict containing category dictionaries (each holding 95 | category index `id` and category name `name`) keyed by category indices. 96 | use_normalized_coordinates: whether boxes is to be interpreted as 97 | normalized coordinates or not. 98 | max_boxes_to_draw: maximum number of boxes to visualize. If None, draw 99 | all boxes. 100 | min_score_thresh: minimum score threshold for a box to be visualized 101 | """ 102 | 103 | if not max_boxes_to_draw: 104 | max_boxes_to_draw = boxes.shape[0] 105 | number_boxes = min(max_boxes_to_draw, boxes.shape[0]) 106 | person_boxes = [] 107 | # person_labels = [] 108 | for i in range(number_boxes): 109 | if scores is None or scores[i] > min_score_thresh: 110 | box = tuple(boxes[i].tolist()) 111 | ymin, xmin, ymax, xmax = box 112 | 113 | im_height, im_width, _ = image.shape 114 | left, right, top, bottom = [int(z) for z in (xmin * im_width, xmax * im_width, 115 | ymin * im_height, ymax * im_height)] 116 | 117 | person_boxes.append([(left, top), (right, bottom), scores[i], 118 | self.category_index[classes[i]]['name']]) 119 | return person_boxes 120 | 121 | def _load_label(self, path, num_c, use_disp_name=True): 122 | """ 123 | Loads labels 124 | """ 125 | label_map = label_map_util.load_labelmap(path) 126 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=num_c, 127 | use_display_name=use_disp_name) 128 | self.category_index = label_map_util.create_category_index(categories) 129 | -------------------------------------------------------------------------------- /object_detector_detection_api_lite.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import cv2 4 | 5 | from object_detector_detection_api import ObjectDetectorDetectionAPI, \ 6 | PATH_TO_LABELS, NUM_CLASSES 7 | 8 | 9 | class ObjectDetectorLite(ObjectDetectorDetectionAPI): 10 | def __init__(self, model_path='detect.tflite'): 11 | """ 12 | Builds Tensorflow graph, load model and labels 13 | """ 14 | 15 | # Load lebel_map 16 | self._load_label(PATH_TO_LABELS, NUM_CLASSES, use_disp_name=True) 17 | 18 | # Define lite graph and Load Tensorflow Lite model into memory 19 | self.interpreter = tf.contrib.lite.Interpreter( 20 | model_path=model_path) 21 | self.interpreter.allocate_tensors() 22 | self.input_details = self.interpreter.get_input_details() 23 | self.output_details = self.interpreter.get_output_details() 24 | 25 | def detect(self, image, threshold=0.1): 26 | """ 27 | Predicts person in frame with threshold level of confidence 28 | Returns list with top-left, bottom-right coordinates and list with labels, confidence in % 29 | """ 30 | 31 | # Resize and normalize image for network input 32 | frame = cv2.resize(image, (300, 300)) 33 | frame = np.expand_dims(frame, axis=0) 34 | frame = (2.0 / 255.0) * frame - 1.0 35 | frame = frame.astype('float32') 36 | 37 | # run model 38 | self.interpreter.set_tensor(self.input_details[0]['index'], frame) 39 | self.interpreter.invoke() 40 | 41 | # get results 42 | boxes = self.interpreter.get_tensor( 43 | self.output_details[0]['index']) 44 | classes = self.interpreter.get_tensor( 45 | self.output_details[1]['index']) 46 | scores = self.interpreter.get_tensor( 47 | self.output_details[2]['index']) 48 | num = self.interpreter.get_tensor( 49 | self.output_details[3]['index']) 50 | 51 | # Find detected boxes coordinates 52 | return self._boxes_coordinates(image, 53 | np.squeeze(boxes[0]), 54 | np.squeeze(classes[0]+1).astype(np.int32), 55 | np.squeeze(scores[0]), 56 | min_score_thresh=threshold) 57 | 58 | def close(self): 59 | pass 60 | 61 | 62 | if __name__ == '__main__': 63 | detector = ObjectDetectorLite() 64 | 65 | image = cv2.cvtColor(cv2.imread('dog.jpg'), cv2.COLOR_BGR2RGB) 66 | 67 | result = detector.detect(image, 0.4) 68 | print(result) 69 | 70 | for obj in result: 71 | print('coordinates: {} {}. class: "{}". confidence: {:.2f}'. 72 | format(obj[0], obj[1], obj[3], obj[2])) 73 | 74 | cv2.rectangle(image, obj[0], obj[1], (0, 255, 0), 2) 75 | cv2.putText(image, '{}: {:.2f}'.format(obj[3], obj[2]), 76 | (obj[0][0], obj[0][1] - 5), 77 | cv2.FONT_HERSHEY_PLAIN, 1, (0, 255, 0), 2) 78 | 79 | cv2.imwrite('r1.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 80 | 81 | detector.close() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/label_map_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/utils/__pycache__/label_map_util.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuantuMobileSoftware/mobile_detector/b8445328acfb8b2b7d3560135714142bf3f21351/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Label map utility functions.""" 17 | 18 | import logging 19 | 20 | import tensorflow as tf 21 | from google.protobuf import text_format 22 | from object_detection.protos import string_int_label_map_pb2 23 | 24 | 25 | def _validate_label_map(label_map): 26 | """Checks if a label map is valid. 27 | 28 | Args: 29 | label_map: StringIntLabelMap to validate. 30 | 31 | Raises: 32 | ValueError: if label map is invalid. 33 | """ 34 | for item in label_map.item: 35 | if item.id < 1: 36 | raise ValueError('Label map ids should be >= 1.') 37 | 38 | 39 | def create_category_index(categories): 40 | """Creates dictionary of COCO compatible categories keyed by category id. 41 | 42 | Args: 43 | categories: a list of dicts, each of which has the following keys: 44 | 'id': (required) an integer id uniquely identifying this category. 45 | 'name': (required) string representing category name 46 | e.g., 'cat', 'dog', 'pizza'. 47 | 48 | Returns: 49 | category_index: a dict containing the same entries as categories, but keyed 50 | by the 'id' field of each category. 51 | """ 52 | category_index = {} 53 | for cat in categories: 54 | category_index[cat['id']] = cat 55 | return category_index 56 | 57 | 58 | def convert_label_map_to_categories(label_map, 59 | max_num_classes, 60 | use_display_name=True): 61 | """Loads label map proto and returns categories list compatible with eval. 62 | 63 | This function loads a label map and returns a list of dicts, each of which 64 | has the following keys: 65 | 'id': (required) an integer id uniquely identifying this category. 66 | 'name': (required) string representing category name 67 | e.g., 'cat', 'dog', 'pizza'. 68 | We only allow class into the list if its id-label_id_offset is 69 | between 0 (inclusive) and max_num_classes (exclusive). 70 | If there are several items mapping to the same id in the label map, 71 | we will only keep the first one in the categories list. 72 | 73 | Args: 74 | label_map: a StringIntLabelMapProto or None. If None, a default categories 75 | list is created with max_num_classes categories. 76 | max_num_classes: maximum number of (consecutive) label indices to include. 77 | use_display_name: (boolean) choose whether to load 'display_name' field 78 | as category name. If False or if the display_name field does not exist, 79 | uses 'name' field as category names instead. 80 | Returns: 81 | categories: a list of dictionaries representing all possible categories. 82 | """ 83 | categories = [] 84 | list_of_ids_already_added = [] 85 | if not label_map: 86 | label_id_offset = 1 87 | for class_id in range(max_num_classes): 88 | categories.append({ 89 | 'id': class_id + label_id_offset, 90 | 'name': 'category_{}'.format(class_id + label_id_offset) 91 | }) 92 | return categories 93 | for item in label_map.item: 94 | if not 0 < item.id <= max_num_classes: 95 | logging.info('Ignore item %d since it falls outside of requested ' 96 | 'label range.', item.id) 97 | continue 98 | if use_display_name and item.HasField('display_name'): 99 | name = item.display_name 100 | else: 101 | name = item.name 102 | if item.id not in list_of_ids_already_added: 103 | list_of_ids_already_added.append(item.id) 104 | categories.append({'id': item.id, 'name': name}) 105 | return categories 106 | 107 | 108 | def load_labelmap(path): 109 | """Loads label map proto. 110 | 111 | Args: 112 | path: path to StringIntLabelMap proto text file. 113 | Returns: 114 | a StringIntLabelMapProto 115 | """ 116 | with tf.gfile.GFile(path, 'r') as fid: 117 | label_map_string = fid.read() 118 | label_map = string_int_label_map_pb2.StringIntLabelMap() 119 | try: 120 | text_format.Merge(label_map_string, label_map) 121 | except text_format.ParseError: 122 | label_map.ParseFromString(label_map_string) 123 | _validate_label_map(label_map) 124 | return label_map 125 | 126 | 127 | def get_label_map_dict(label_map_path): 128 | """Reads a label map and returns a dictionary of label names to id. 129 | 130 | Args: 131 | label_map_path: path to label_map. 132 | 133 | Returns: 134 | A dictionary mapping label names to id. 135 | """ 136 | label_map = load_labelmap(label_map_path) 137 | label_map_dict = {} 138 | for item in label_map.item: 139 | label_map_dict[item.name] = item.id 140 | return label_map_dict 141 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from enum import Enum 4 | 5 | 6 | class Models(Enum): 7 | ssd_lite = 'ssd_lite' 8 | tiny_yolo = 'tiny_yolo' 9 | tf_lite = 'tf_lite' 10 | 11 | def __str__(self): 12 | return self.value 13 | 14 | @staticmethod 15 | def from_string(s): 16 | try: 17 | return Models[s] 18 | except KeyError: 19 | raise ValueError() 20 | 21 | 22 | MAX_AREA = 0.019 # max area from train set 23 | RATIO_MEAN = 4.17 24 | RATIO_STD = 1.06 25 | 26 | 27 | def load_image_into_numpy_array(image_path): 28 | image = cv2.imread(image_path) 29 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 30 | 31 | 32 | def affine_tile_corners(x0, y0, theta, wp, hp): 33 | """ 34 | Find corners of tile defined by affine transformation. 35 | 36 | Find corners in original image for tile defined by affine transformation, 37 | i.e. a rotation and translation, given (x0, y0) the upper left corner of 38 | the tile, theta, the rotation angle of the tile in degrees, and the tile 39 | width wp, and height hp. 40 | 41 | Args: 42 | x0 Horizontal coordinate of tile upper left corner (pixels) 43 | y0 Vertical coordinate of tile upper left corner (pixels) 44 | theta Rotation angle (degrees clockwise from vertical) 45 | wp Tile width (pixels) 46 | hp Tile height (pixels) 47 | Returns: 48 | corners Corner points, in clockwise order starting from upper left 49 | corner, ndarray size (4, 2) 50 | """ 51 | rot_angle = np.radians(theta) 52 | corners = np.array( 53 | [[x0, y0], 54 | [x0 + wp * np.cos(rot_angle), y0 + wp * np.sin(rot_angle)], 55 | [x0 + wp * np.cos(rot_angle) - hp * np.sin(rot_angle), 56 | y0 + wp * np.sin(rot_angle) + hp * np.cos(rot_angle)], 57 | [x0 - hp * np.sin(rot_angle), y0 + hp * np.cos(rot_angle)]]) 58 | return corners 59 | 60 | 61 | def tile_images(tiling_params, img): 62 | res = [] 63 | original_sizes = [] 64 | offset = [] 65 | 66 | for cur_pt, cur_theta, cur_multiplier in zip( 67 | tiling_params["upper_left_pts"], 68 | tiling_params["thetas"], 69 | tiling_params["multipliers"]): 70 | cur_x0, cur_y0 = cur_pt 71 | corners = affine_tile_corners( 72 | cur_x0, cur_y0, cur_theta, 73 | int(cur_multiplier * tiling_params["wp"]), 74 | int(cur_multiplier * tiling_params["hp"])).astype(int) 75 | 76 | top = min(corners[:, 1]) 77 | left = min(corners[:, 0]) 78 | bottom = max(corners[:, 1]) 79 | right = max(corners[:, 0]) 80 | h = bottom - top 81 | w = right - left 82 | tile = np.zeros((h, w, 3)).astype(np.uint8) 83 | 84 | # crop tile from image 85 | tmp = img[top: bottom, left: right] 86 | tile[:tmp.shape[0], :tmp.shape[1], :3] = tmp 87 | 88 | # resize the tile 89 | tile = cv2.resize(tile, (tiling_params["wp"], tiling_params["hp"]), 90 | interpolation=cv2.INTER_NEAREST) 91 | 92 | # rotate the tile 93 | image_center = tuple(np.array(tile.shape[1::-1]) / 2) 94 | rot_mat = cv2.getRotationMatrix2D(image_center, cur_theta, 1.0) 95 | tmp = cv2.warpAffine(tile, rot_mat, (tile.shape[1::-1]), 96 | flags=cv2.INTER_LINEAR) 97 | 98 | original_sizes.append((bottom - top, right - left)) 99 | offset.append((top, left)) 100 | res.append(tmp) 101 | 102 | return res, original_sizes, offset 103 | 104 | 105 | def rotate_points(points, rotation_matrix): 106 | # add ones 107 | points_ones = np.append(points, 1) 108 | 109 | # transform points 110 | transformed_points = rotation_matrix.dot(points_ones) 111 | return transformed_points# [:,::-1] 112 | 113 | 114 | def split_img(img, m, n): 115 | h, w, _ = img.shape 116 | tile_h = h // m 117 | tile_w = w // n 118 | padding_h = tile_h // 10 119 | padding_w = int(tile_w * 0.15) 120 | 121 | res = [] 122 | original_sizes = [] 123 | offset = [] 124 | for i in range(0, m): 125 | top = i * tile_h 126 | bottom = min(h, (i + 1) * tile_h + padding_h) 127 | for j in range(0, n): 128 | left = j * tile_w 129 | right = min(w, (j + 1) * tile_w + padding_w) 130 | original_sizes.append((bottom - top, right - left)) 131 | offset.append((top, left)) 132 | res.append(cv2.resize(img[top: bottom, left: right, :], 133 | (tile_w, tile_h), 134 | interpolation=cv2.INTER_NEAREST)) 135 | 136 | return res, original_sizes, offset 137 | 138 | 139 | def get_global_coord(point, img_size, original_size, offset): 140 | return [int(point[0] / img_size[1] * original_size[1] + offset[1]), \ 141 | int(point[1] / img_size[0] * original_size[0] + offset[0])] 142 | 143 | 144 | def non_max_suppression_fast(boxes, labels, overlap_thresh=0.5): 145 | # if there are no boxes, return an empty list 146 | boxes = np.array(boxes) 147 | if len(boxes) == 0: 148 | return [], [] 149 | 150 | # initialize the list of picked indexes 151 | pick = [] 152 | 153 | # grab the coordinates of the bounding boxes 154 | x1 = boxes[:, 0] 155 | y1 = boxes[:, 1] 156 | x2 = boxes[:, 2] 157 | y2 = boxes[:, 3] 158 | 159 | # compute the area of the bounding boxes and sort the bounding 160 | # boxes by the bottom-right y-coordinate of the bounding box 161 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 162 | idxs = np.argsort(y2) 163 | 164 | # keep looping while some indexes still remain in the indexes 165 | # list 166 | while len(idxs) > 0: 167 | # grab the last index in the indexes list and add the 168 | # index value to the list of picked indexes 169 | last = len(idxs) - 1 170 | i = idxs[last] 171 | pick.append(i) 172 | 173 | # find the largest (x, y) coordinates for the start of 174 | # the bounding box and the smallest (x, y) coordinates 175 | # for the end of the bounding box 176 | xx1 = np.maximum(x1[i], x1[idxs[:last]]) 177 | yy1 = np.maximum(y1[i], y1[idxs[:last]]) 178 | xx2 = np.minimum(x2[i], x2[idxs[:last]]) 179 | yy2 = np.minimum(y2[i], y2[idxs[:last]]) 180 | 181 | # compute the width and height of the bounding box 182 | w = np.maximum(0, xx2 - xx1 + 1) 183 | h = np.maximum(0, yy2 - yy1 + 1) 184 | 185 | # compute the ratio of overlap 186 | overlap = 1. * (w * h) / area[idxs[:last]] 187 | 188 | # delete all indexes from the index list that have 189 | idxs = np.delete(idxs, np.concatenate( 190 | ([last], np.where(overlap > overlap_thresh)[0]))) 191 | 192 | # return only the bounding boxes that were picked using the 193 | # integer data type 194 | return boxes[pick], [labels[i] for i in pick] 195 | 196 | 197 | def filter_bb_by_size(bbs, labels, img_area): 198 | res_bbs = [] 199 | res_labels = [] 200 | for bb, l in zip(bbs, labels): 201 | s = (bb[2] - bb[0]) * (bb[3] - bb[1]) / img_area 202 | r = (bb[3] - bb[1]) / (bb[2] - bb[0]) 203 | if s < MAX_AREA * 1.1 and RATIO_MEAN - 3 * RATIO_MEAN < r < RATIO_MEAN + 3 * RATIO_MEAN: 204 | res_bbs.append(bb) 205 | res_labels.append(l) 206 | 207 | return res_bbs, res_labels 208 | -------------------------------------------------------------------------------- /yolo_darfklow.py: -------------------------------------------------------------------------------- 1 | from darkflow.net.build import TFNet 2 | 3 | from object_detector import ObjectDetector 4 | 5 | 6 | class YOLODarkflowDetector(ObjectDetector): 7 | def __init__(self, cfg_path, weights_path): 8 | options = {"model": cfg_path, 9 | "load": weights_path, "threshold": 0.01} 10 | 11 | self.tfnet = TFNet(options) 12 | 13 | def detect(self, frame, threshold=0.1): 14 | results = self.tfnet.return_predict(frame) 15 | return self.__boxes_coordinates(results, threshold) 16 | 17 | def __boxes_coordinates(self, results, threshold): 18 | boxes = [] 19 | for i in results: 20 | if i['confidence'] <= threshold: continue 21 | boxes.append([ 22 | (i['topleft']['x'], i['topleft']['y']), 23 | (i['bottomright']['x'], i['bottomright']['y']), 24 | i['confidence'], 25 | i['label'] 26 | ]) 27 | 28 | return boxes --------------------------------------------------------------------------------