├── protos ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── eval_pb2.cpython-35.pyc │ ├── model_pb2.cpython-35.pyc │ ├── ssd_pb2.cpython-35.pyc │ ├── train_pb2.cpython-35.pyc │ ├── losses_pb2.cpython-35.pyc │ ├── matcher_pb2.cpython-35.pyc │ ├── box_coder_pb2.cpython-35.pyc │ ├── optimizer_pb2.cpython-35.pyc │ ├── pipeline_pb2.cpython-35.pyc │ ├── faster_rcnn_pb2.cpython-35.pyc │ ├── hyperparams_pb2.cpython-35.pyc │ ├── input_reader_pb2.cpython-35.pyc │ ├── preprocessor_pb2.cpython-35.pyc │ ├── argmax_matcher_pb2.cpython-35.pyc │ ├── box_predictor_pb2.cpython-35.pyc │ ├── image_resizer_pb2.cpython-35.pyc │ ├── post_processing_pb2.cpython-35.pyc │ ├── anchor_generator_pb2.cpython-35.pyc │ ├── bipartite_matcher_pb2.cpython-35.pyc │ ├── square_box_coder_pb2.cpython-35.pyc │ ├── keypoint_box_coder_pb2.cpython-35.pyc │ ├── ssd_anchor_generator_pb2.cpython-35.pyc │ ├── string_int_label_map_pb2.cpython-35.pyc │ ├── faster_rcnn_box_coder_pb2.cpython-35.pyc │ ├── grid_anchor_generator_pb2.cpython-35.pyc │ ├── mean_stddev_box_coder_pb2.cpython-35.pyc │ └── region_similarity_calculator_pb2.cpython-35.pyc ├── bipartite_matcher.proto ├── mean_stddev_box_coder.proto ├── model.proto ├── matcher.proto ├── square_box_coder.proto ├── anchor_generator.proto ├── faster_rcnn_box_coder.proto ├── keypoint_box_coder.proto ├── pipeline.proto ├── box_coder.proto ├── region_similarity_calculator.proto ├── string_int_label_map.proto ├── argmax_matcher.proto ├── grid_anchor_generator.proto ├── image_resizer.proto ├── post_processing.proto ├── eval.proto ├── bipartite_matcher_pb2.py ├── mean_stddev_box_coder_pb2.py ├── input_reader.proto ├── ssd_anchor_generator.proto ├── train.proto ├── square_box_coder_pb2.py ├── optimizer.proto ├── ssd.proto ├── hyperparams.proto ├── faster_rcnn_box_coder_pb2.py ├── model_pb2.py ├── matcher_pb2.py ├── box_predictor.proto ├── keypoint_box_coder_pb2.py ├── argmax_matcher_pb2.py ├── anchor_generator_pb2.py ├── string_int_label_map_pb2.py ├── losses.proto ├── grid_anchor_generator_pb2.py ├── pipeline_pb2.py ├── box_coder_pb2.py ├── faster_rcnn.proto ├── eval_pb2.py └── post_processing_pb2.py ├── utils ├── __init__.py ├── __pycache__ │ ├── ops.cpython-35.pyc │ ├── metrics.cpython-35.pyc │ ├── __init__.cpython-35.pyc │ ├── np_box_ops.cpython-35.pyc │ ├── config_util.cpython-35.pyc │ ├── dataset_util.cpython-35.pyc │ ├── np_box_list.cpython-35.pyc │ ├── shape_utils.cpython-35.pyc │ ├── static_shape.cpython-35.pyc │ ├── label_map_util.cpython-35.pyc │ ├── np_box_list_ops.cpython-35.pyc │ ├── variables_helper.cpython-35.pyc │ ├── learning_schedules.cpython-35.pyc │ ├── per_image_evaluation.cpython-35.pyc │ ├── visualization_utils.cpython-35.pyc │ └── object_detection_evaluation.cpython-35.pyc ├── dataset_util_test.py ├── category_util_test.py ├── static_shape_test.py ├── static_shape.py ├── category_util.py ├── np_box_ops_test.py ├── test_utils_test.py ├── dataset_util.py ├── learning_schedules_test.py ├── np_box_ops.py ├── metrics_test.py ├── shape_utils.py ├── np_box_list.py ├── test_utils.py ├── variables_helper.py ├── shape_utils_test.py ├── np_box_list_test.py ├── metrics.py ├── label_map_util.py ├── label_map_util_test.py └── learning_schedules.py ├── images ├── GRDD2020.png ├── installation1.png ├── installation2.png └── RoadDamageTypeDef.png ├── crackLabelMap.txt ├── LICENSE └── smartphoneAPPS.md /protos/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/GRDD2020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/GRDD2020.png -------------------------------------------------------------------------------- /images/installation1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/installation1.png -------------------------------------------------------------------------------- /images/installation2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/installation2.png -------------------------------------------------------------------------------- /images/RoadDamageTypeDef.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/RoadDamageTypeDef.png -------------------------------------------------------------------------------- /utils/__pycache__/ops.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/ops.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/metrics.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/eval_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/eval_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/model_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/model_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/ssd_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/ssd_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/train_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/train_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/np_box_ops.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/np_box_ops.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/losses_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/losses_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/matcher_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/matcher_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/config_util.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/dataset_util.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/np_box_list.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/np_box_list.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/shape_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/shape_utils.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/static_shape.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/static_shape.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/box_coder_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/box_coder_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/optimizer_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/optimizer_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/pipeline_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/pipeline_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/label_map_util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/label_map_util.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/np_box_list_ops.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/np_box_list_ops.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/faster_rcnn_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/faster_rcnn_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/hyperparams_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/hyperparams_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/input_reader_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/input_reader_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/preprocessor_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/preprocessor_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/variables_helper.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/variables_helper.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/argmax_matcher_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/argmax_matcher_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/box_predictor_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/box_predictor_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/image_resizer_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/image_resizer_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/post_processing_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/post_processing_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/learning_schedules.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/learning_schedules.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/per_image_evaluation.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/per_image_evaluation.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualization_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/visualization_utils.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/anchor_generator_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/anchor_generator_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/bipartite_matcher_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/bipartite_matcher_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/square_box_coder_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/square_box_coder_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/keypoint_box_coder_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/keypoint_box_coder_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/ssd_anchor_generator_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/ssd_anchor_generator_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/string_int_label_map_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/string_int_label_map_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/faster_rcnn_box_coder_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/faster_rcnn_box_coder_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/grid_anchor_generator_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/grid_anchor_generator_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/mean_stddev_box_coder_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/mean_stddev_box_coder_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/object_detection_evaluation.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/object_detection_evaluation.cpython-35.pyc -------------------------------------------------------------------------------- /protos/__pycache__/region_similarity_calculator_pb2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/region_similarity_calculator_pb2.cpython-35.pyc -------------------------------------------------------------------------------- /protos/bipartite_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for bipartite matcher. See 6 | // matchers/bipartite_matcher.py for details. 7 | message BipartiteMatcher { 8 | } 9 | -------------------------------------------------------------------------------- /protos/mean_stddev_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for MeanStddevBoxCoder. See 6 | // box_coders/mean_stddev_box_coder.py for details. 7 | message MeanStddevBoxCoder { 8 | } 9 | -------------------------------------------------------------------------------- /protos/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn.proto"; 6 | import "object_detection/protos/ssd.proto"; 7 | 8 | // Top level configuration for DetectionModels. 9 | message DetectionModel { 10 | oneof model { 11 | FasterRcnn faster_rcnn = 1; 12 | Ssd ssd = 2; 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /crackLabelMap.txt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'D00' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'D01' 9 | } 10 | 11 | item { 12 | id: 3 13 | name: 'D10' 14 | } 15 | 16 | item { 17 | id: 4 18 | name: 'D11' 19 | } 20 | 21 | item { 22 | id: 5 23 | name: 'D20' 24 | } 25 | 26 | item { 27 | id: 6 28 | name: 'D40' 29 | } 30 | 31 | item { 32 | id: 7 33 | name: 'D43' 34 | } 35 | 36 | item { 37 | id: 8 38 | name: 'D44' 39 | } 40 | -------------------------------------------------------------------------------- /protos/matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/argmax_matcher.proto"; 6 | import "object_detection/protos/bipartite_matcher.proto"; 7 | 8 | // Configuration proto for the matcher to be used in the object detection 9 | // pipeline. See core/matcher.py for details. 10 | message Matcher { 11 | oneof matcher_oneof { 12 | ArgMaxMatcher argmax_matcher = 1; 13 | BipartiteMatcher bipartite_matcher = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /protos/square_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SquareBoxCoder. See 6 | // box_coders/square_box_coder.py for details. 7 | message SquareBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box length. 13 | optional float length_scale = 3 [default = 5.0]; 14 | } 15 | -------------------------------------------------------------------------------- /protos/anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/grid_anchor_generator.proto"; 6 | import "object_detection/protos/ssd_anchor_generator.proto"; 7 | 8 | // Configuration proto for the anchor generator to use in the object detection 9 | // pipeline. See core/anchor_generator.py for details. 10 | message AnchorGenerator { 11 | oneof anchor_generator_oneof { 12 | GridAnchorGenerator grid_anchor_generator = 1; 13 | SsdAnchorGenerator ssd_anchor_generator = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /protos/faster_rcnn_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for FasterRCNNBoxCoder. See 6 | // box_coders/faster_rcnn_box_coder.py for details. 7 | message FasterRcnnBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box height. 13 | optional float height_scale = 3 [default = 5.0]; 14 | 15 | // Scale factor for anchor encoded box width. 16 | optional float width_scale = 4 [default = 5.0]; 17 | } 18 | -------------------------------------------------------------------------------- /protos/keypoint_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for KeypointBoxCoder. See 6 | // box_coders/keypoint_box_coder.py for details. 7 | message KeypointBoxCoder { 8 | optional int32 num_keypoints = 1; 9 | 10 | // Scale factor for anchor encoded box center and keypoints. 11 | optional float y_scale = 2 [default = 10.0]; 12 | optional float x_scale = 3 [default = 10.0]; 13 | 14 | // Scale factor for anchor encoded box height. 15 | optional float height_scale = 4 [default = 5.0]; 16 | 17 | // Scale factor for anchor encoded box width. 18 | optional float width_scale = 5 [default = 5.0]; 19 | } 20 | -------------------------------------------------------------------------------- /protos/pipeline.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/eval.proto"; 6 | import "object_detection/protos/input_reader.proto"; 7 | import "object_detection/protos/model.proto"; 8 | import "object_detection/protos/train.proto"; 9 | 10 | // Convenience message for configuring a training and eval pipeline. Allows all 11 | // of the pipeline parameters to be configured from one file. 12 | message TrainEvalPipelineConfig { 13 | optional DetectionModel model = 1; 14 | optional TrainConfig train_config = 2; 15 | optional InputReader train_input_reader = 3; 16 | optional EvalConfig eval_config = 4; 17 | optional InputReader eval_input_reader = 5; 18 | } 19 | -------------------------------------------------------------------------------- /protos/box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn_box_coder.proto"; 6 | import "object_detection/protos/keypoint_box_coder.proto"; 7 | import "object_detection/protos/mean_stddev_box_coder.proto"; 8 | import "object_detection/protos/square_box_coder.proto"; 9 | 10 | // Configuration proto for the box coder to be used in the object detection 11 | // pipeline. See core/box_coder.py for details. 12 | message BoxCoder { 13 | oneof box_coder_oneof { 14 | FasterRcnnBoxCoder faster_rcnn_box_coder = 1; 15 | MeanStddevBoxCoder mean_stddev_box_coder = 2; 16 | SquareBoxCoder square_box_coder = 3; 17 | KeypointBoxCoder keypoint_box_coder = 4; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /protos/region_similarity_calculator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for region similarity calculators. See 6 | // core/region_similarity_calculator.py for details. 7 | message RegionSimilarityCalculator { 8 | oneof region_similarity { 9 | NegSqDistSimilarity neg_sq_dist_similarity = 1; 10 | IouSimilarity iou_similarity = 2; 11 | IoaSimilarity ioa_similarity = 3; 12 | } 13 | } 14 | 15 | // Configuration for negative squared distance similarity calculator. 16 | message NegSqDistSimilarity { 17 | } 18 | 19 | // Configuration for intersection-over-union (IOU) similarity calculator. 20 | message IouSimilarity { 21 | } 22 | 23 | // Configuration for intersection-over-area (IOA) similarity calculator. 24 | message IoaSimilarity { 25 | } 26 | -------------------------------------------------------------------------------- /protos/string_int_label_map.proto: -------------------------------------------------------------------------------- 1 | // Message to store the mapping from class label strings to class id. Datasets 2 | // use string labels to represent classes while the object detection framework 3 | // works with class ids. This message maps them so they can be converted back 4 | // and forth as needed. 5 | syntax = "proto2"; 6 | 7 | package object_detection.protos; 8 | 9 | message StringIntLabelMapItem { 10 | // String name. The most common practice is to set this to a MID or synsets 11 | // id. 12 | optional string name = 1; 13 | 14 | // Integer id that maps to the string name above. Label ids should start from 15 | // 1. 16 | optional int32 id = 2; 17 | 18 | // Human readable string label. 19 | optional string display_name = 3; 20 | }; 21 | 22 | message StringIntLabelMap { 23 | repeated StringIntLabelMapItem item = 1; 24 | }; 25 | -------------------------------------------------------------------------------- /protos/argmax_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for ArgMaxMatcher. See 6 | // matchers/argmax_matcher.py for details. 7 | message ArgMaxMatcher { 8 | // Threshold for positive matches. 9 | optional float matched_threshold = 1 [default = 0.5]; 10 | 11 | // Threshold for negative matches. 12 | optional float unmatched_threshold = 2 [default = 0.5]; 13 | 14 | // Whether to construct ArgMaxMatcher without thresholds. 15 | optional bool ignore_thresholds = 3 [default = false]; 16 | 17 | // If True then negative matches are the ones below the unmatched_threshold, 18 | // whereas ignored matches are in between the matched and umatched 19 | // threshold. If False, then negative matches are in between the matched 20 | // and unmatched threshold, and everything lower than unmatched is ignored. 21 | optional bool negatives_lower_than_unmatched = 4 [default = true]; 22 | 23 | // Whether to ensure each row is matched to at least one column. 24 | optional bool force_match_for_each_row = 5 [default = false]; 25 | } 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 sekilab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /protos/grid_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for GridAnchorGenerator. See 6 | // anchor_generators/grid_anchor_generator.py for details. 7 | message GridAnchorGenerator { 8 | // Anchor height in pixels. 9 | optional int32 height = 1 [default = 256]; 10 | 11 | // Anchor width in pixels. 12 | optional int32 width = 2 [default = 256]; 13 | 14 | // Anchor stride in height dimension in pixels. 15 | optional int32 height_stride = 3 [default = 16]; 16 | 17 | // Anchor stride in width dimension in pixels. 18 | optional int32 width_stride = 4 [default = 16]; 19 | 20 | // Anchor height offset in pixels. 21 | optional int32 height_offset = 5 [default = 0]; 22 | 23 | // Anchor width offset in pixels. 24 | optional int32 width_offset = 6 [default = 0]; 25 | 26 | // At any given location, len(scales) * len(aspect_ratios) anchors are 27 | // generated with all possible combinations of scales and aspect ratios. 28 | 29 | // List of scales for the anchors. 30 | repeated float scales = 7; 31 | 32 | // List of aspect ratios for the anchors. 33 | repeated float aspect_ratios = 8; 34 | } 35 | -------------------------------------------------------------------------------- /utils/dataset_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.dataset_util.""" 17 | 18 | import os 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import dataset_util 22 | 23 | 24 | class DatasetUtilTest(tf.test.TestCase): 25 | 26 | def test_read_examples_list(self): 27 | example_list_data = """example1 1\nexample2 2""" 28 | example_list_path = os.path.join(self.get_temp_dir(), 'examples.txt') 29 | with tf.gfile.Open(example_list_path, 'wb') as f: 30 | f.write(example_list_data) 31 | 32 | examples = dataset_util.read_examples_list(example_list_path) 33 | self.assertListEqual(['example1', 'example2'], examples) 34 | 35 | 36 | if __name__ == '__main__': 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /protos/image_resizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for image resizing operations. 6 | // See builders/image_resizer_builder.py for details. 7 | message ImageResizer { 8 | oneof image_resizer_oneof { 9 | KeepAspectRatioResizer keep_aspect_ratio_resizer = 1; 10 | FixedShapeResizer fixed_shape_resizer = 2; 11 | } 12 | } 13 | 14 | // Enumeration type for image resizing methods provided in TensorFlow. 15 | enum ResizeType { 16 | BILINEAR = 0; // Corresponds to tf.image.ResizeMethod.BILINEAR 17 | NEAREST_NEIGHBOR = 1; // Corresponds to tf.image.ResizeMethod.NEAREST_NEIGHBOR 18 | BICUBIC = 2; // Corresponds to tf.image.ResizeMethod.BICUBIC 19 | AREA = 3; // Corresponds to tf.image.ResizeMethod.AREA 20 | } 21 | 22 | // Configuration proto for image resizer that keeps aspect ratio. 23 | message KeepAspectRatioResizer { 24 | // Desired size of the smaller image dimension in pixels. 25 | optional int32 min_dimension = 1 [default = 600]; 26 | 27 | // Desired size of the larger image dimension in pixels. 28 | optional int32 max_dimension = 2 [default = 1024]; 29 | 30 | // Desired method when resizing image. 31 | optional ResizeType resize_method = 3 [default = BILINEAR]; 32 | } 33 | 34 | // Configuration proto for image resizer that resizes to a fixed shape. 35 | message FixedShapeResizer { 36 | // Desired height of image in pixels. 37 | optional int32 height = 1 [default = 300]; 38 | 39 | // Desired width of image in pixels. 40 | optional int32 width = 2 [default = 300]; 41 | 42 | // Desired method when resizing image. 43 | optional ResizeType resize_method = 3 [default = BILINEAR]; 44 | } 45 | -------------------------------------------------------------------------------- /protos/post_processing.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for non-max-suppression operation on a batch of 6 | // detections. 7 | message BatchNonMaxSuppression { 8 | // Scalar threshold for score (low scoring boxes are removed). 9 | optional float score_threshold = 1 [default = 0.0]; 10 | 11 | // Scalar threshold for IOU (boxes that have high IOU overlap 12 | // with previously selected boxes are removed). 13 | optional float iou_threshold = 2 [default = 0.6]; 14 | 15 | // Maximum number of detections to retain per class. 16 | optional int32 max_detections_per_class = 3 [default = 100]; 17 | 18 | // Maximum number of detections to retain across all classes. 19 | optional int32 max_total_detections = 5 [default = 100]; 20 | } 21 | 22 | // Configuration proto for post-processing predicted boxes and 23 | // scores. 24 | message PostProcessing { 25 | // Non max suppression parameters. 26 | optional BatchNonMaxSuppression batch_non_max_suppression = 1; 27 | 28 | // Enum to specify how to convert the detection scores. 29 | enum ScoreConverter { 30 | // Input scores equals output scores. 31 | IDENTITY = 0; 32 | 33 | // Applies a sigmoid on input scores. 34 | SIGMOID = 1; 35 | 36 | // Applies a softmax on input scores 37 | SOFTMAX = 2; 38 | } 39 | 40 | // Score converter to use. 41 | optional ScoreConverter score_converter = 2 [default = IDENTITY]; 42 | // Scale logit (input) value before conversion in post-processing step. 43 | // Typically used for softmax distillation, though can be used to scale for 44 | // other reasons. 45 | optional float logit_scale = 3 [default = 1.0]; 46 | } 47 | -------------------------------------------------------------------------------- /protos/eval.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for configuring DetectionModel evaluation jobs (eval.py). 6 | message EvalConfig { 7 | // Number of visualization images to generate. 8 | optional uint32 num_visualizations = 1 [default=10]; 9 | 10 | // Number of examples to process of evaluation. 11 | optional uint32 num_examples = 2 [default=5000]; 12 | 13 | // How often to run evaluation. 14 | optional uint32 eval_interval_secs = 3 [default=300]; 15 | 16 | // Maximum number of times to run evaluation. If set to 0, will run forever. 17 | optional uint32 max_evals = 4 [default=0]; 18 | 19 | // Whether the TensorFlow graph used for evaluation should be saved to disk. 20 | optional bool save_graph = 5 [default=false]; 21 | 22 | // Path to directory to store visualizations in. If empty, visualization 23 | // images are not exported (only shown on Tensorboard). 24 | optional string visualization_export_dir = 6 [default=""]; 25 | 26 | // BNS name of the TensorFlow master. 27 | optional string eval_master = 7 [default=""]; 28 | 29 | // Type of metrics to use for evaluation. Currently supports only Pascal VOC 30 | // detection metrics. 31 | optional string metrics_set = 8 [default="pascal_voc_metrics"]; 32 | 33 | // Path to export detections to COCO compatible JSON format. 34 | optional string export_path = 9 [default='']; 35 | 36 | // Option to not read groundtruth labels and only export detections to 37 | // COCO-compatible JSON file. 38 | optional bool ignore_groundtruth = 10 [default=false]; 39 | 40 | // Use exponential moving averages of variables for evaluation. 41 | optional bool use_moving_averages = 11 [default=false]; 42 | 43 | // Whether to evaluate instance masks. 44 | // Note that since there is no evaluation code currently for instance 45 | // segmenation this option is unused. 46 | optional bool eval_instance_masks = 12 [default=false]; 47 | } 48 | -------------------------------------------------------------------------------- /protos/bipartite_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/bipartite_matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/bipartite_matcher.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n/object_detection/protos/bipartite_matcher.proto\x12\x17object_detection.protos\"\x12\n\x10\x42ipartiteMatcher') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _BIPARTITEMATCHER = _descriptor.Descriptor( 29 | name='BipartiteMatcher', 30 | full_name='object_detection.protos.BipartiteMatcher', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | ], 36 | extensions=[ 37 | ], 38 | nested_types=[], 39 | enum_types=[ 40 | ], 41 | options=None, 42 | is_extendable=False, 43 | extension_ranges=[], 44 | oneofs=[ 45 | ], 46 | serialized_start=76, 47 | serialized_end=94, 48 | ) 49 | 50 | DESCRIPTOR.message_types_by_name['BipartiteMatcher'] = _BIPARTITEMATCHER 51 | 52 | BipartiteMatcher = _reflection.GeneratedProtocolMessageType('BipartiteMatcher', (_message.Message,), dict( 53 | DESCRIPTOR = _BIPARTITEMATCHER, 54 | __module__ = 'object_detection.protos.bipartite_matcher_pb2' 55 | # @@protoc_insertion_point(class_scope:object_detection.protos.BipartiteMatcher) 56 | )) 57 | _sym_db.RegisterMessage(BipartiteMatcher) 58 | 59 | 60 | # @@protoc_insertion_point(module_scope) 61 | -------------------------------------------------------------------------------- /protos/mean_stddev_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/mean_stddev_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/mean_stddev_box_coder.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n3object_detection/protos/mean_stddev_box_coder.proto\x12\x17object_detection.protos\"\x14\n\x12MeanStddevBoxCoder') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _MEANSTDDEVBOXCODER = _descriptor.Descriptor( 29 | name='MeanStddevBoxCoder', 30 | full_name='object_detection.protos.MeanStddevBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | ], 36 | extensions=[ 37 | ], 38 | nested_types=[], 39 | enum_types=[ 40 | ], 41 | options=None, 42 | is_extendable=False, 43 | extension_ranges=[], 44 | oneofs=[ 45 | ], 46 | serialized_start=80, 47 | serialized_end=100, 48 | ) 49 | 50 | DESCRIPTOR.message_types_by_name['MeanStddevBoxCoder'] = _MEANSTDDEVBOXCODER 51 | 52 | MeanStddevBoxCoder = _reflection.GeneratedProtocolMessageType('MeanStddevBoxCoder', (_message.Message,), dict( 53 | DESCRIPTOR = _MEANSTDDEVBOXCODER, 54 | __module__ = 'object_detection.protos.mean_stddev_box_coder_pb2' 55 | # @@protoc_insertion_point(class_scope:object_detection.protos.MeanStddevBoxCoder) 56 | )) 57 | _sym_db.RegisterMessage(MeanStddevBoxCoder) 58 | 59 | 60 | # @@protoc_insertion_point(module_scope) 61 | -------------------------------------------------------------------------------- /utils/category_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.category_util.""" 17 | import os 18 | 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import category_util 22 | 23 | 24 | class EvalUtilTest(tf.test.TestCase): 25 | 26 | def test_load_categories_from_csv_file(self): 27 | csv_data = """ 28 | 0,"cat" 29 | 1,"dog" 30 | 2,"bird" 31 | """.strip(' ') 32 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 33 | with tf.gfile.Open(csv_path, 'wb') as f: 34 | f.write(csv_data) 35 | 36 | categories = category_util.load_categories_from_csv_file(csv_path) 37 | self.assertTrue({'id': 0, 'name': 'cat'} in categories) 38 | self.assertTrue({'id': 1, 'name': 'dog'} in categories) 39 | self.assertTrue({'id': 2, 'name': 'bird'} in categories) 40 | 41 | def test_save_categories_to_csv_file(self): 42 | categories = [ 43 | {'id': 0, 'name': 'cat'}, 44 | {'id': 1, 'name': 'dog'}, 45 | {'id': 2, 'name': 'bird'}, 46 | ] 47 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 48 | category_util.save_categories_to_csv_file(categories, csv_path) 49 | saved_categories = category_util.load_categories_from_csv_file(csv_path) 50 | self.assertEqual(saved_categories, categories) 51 | 52 | 53 | if __name__ == '__main__': 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /utils/static_shape_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.static_shape.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.utils import static_shape 21 | 22 | 23 | class StaticShapeTest(tf.test.TestCase): 24 | 25 | def test_return_correct_batchSize(self): 26 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 27 | self.assertEqual(32, static_shape.get_batch_size(tensor_shape)) 28 | 29 | def test_return_correct_height(self): 30 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 31 | self.assertEqual(299, static_shape.get_height(tensor_shape)) 32 | 33 | def test_return_correct_width(self): 34 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 35 | self.assertEqual(384, static_shape.get_width(tensor_shape)) 36 | 37 | def test_return_correct_depth(self): 38 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 39 | self.assertEqual(3, static_shape.get_depth(tensor_shape)) 40 | 41 | def test_die_on_tensor_shape_with_rank_three(self): 42 | tensor_shape = tf.TensorShape(dims=[32, 299, 384]) 43 | with self.assertRaises(ValueError): 44 | static_shape.get_batch_size(tensor_shape) 45 | static_shape.get_height(tensor_shape) 46 | static_shape.get_width(tensor_shape) 47 | static_shape.get_depth(tensor_shape) 48 | 49 | if __name__ == '__main__': 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /utils/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions to access TensorShape values. 17 | 18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 19 | """ 20 | 21 | 22 | def get_batch_size(tensor_shape): 23 | """Returns batch size from the tensor shape. 24 | 25 | Args: 26 | tensor_shape: A rank 4 TensorShape. 27 | 28 | Returns: 29 | An integer representing the batch size of the tensor. 30 | """ 31 | tensor_shape.assert_has_rank(rank=4) 32 | return tensor_shape[0].value 33 | 34 | 35 | def get_height(tensor_shape): 36 | """Returns height from the tensor shape. 37 | 38 | Args: 39 | tensor_shape: A rank 4 TensorShape. 40 | 41 | Returns: 42 | An integer representing the height of the tensor. 43 | """ 44 | tensor_shape.assert_has_rank(rank=4) 45 | return tensor_shape[1].value 46 | 47 | 48 | def get_width(tensor_shape): 49 | """Returns width from the tensor shape. 50 | 51 | Args: 52 | tensor_shape: A rank 4 TensorShape. 53 | 54 | Returns: 55 | An integer representing the width of the tensor. 56 | """ 57 | tensor_shape.assert_has_rank(rank=4) 58 | return tensor_shape[2].value 59 | 60 | 61 | def get_depth(tensor_shape): 62 | """Returns depth from the tensor shape. 63 | 64 | Args: 65 | tensor_shape: A rank 4 TensorShape. 66 | 67 | Returns: 68 | An integer representing the depth of the tensor. 69 | """ 70 | tensor_shape.assert_has_rank(rank=4) 71 | return tensor_shape[3].value 72 | -------------------------------------------------------------------------------- /smartphoneAPPS.md: -------------------------------------------------------------------------------- 1 | # RoadCrackDetector 2 | 3 | ## What is RoadCrackDetector? 4 | RoadCrackDetector is a smartphone apps that detects damages on the road by utilizing deep neural network model. 5 | 6 | スマートフォン上で深層学習モデルを動かすことで、道路路面の損傷画像を検出するアプリケーションです。 7 | 8 | ## How to use RoadCrackDetector? 9 | ### Basic Functions 10 | - Detect Road Damages when the car is running 11 | - Stop processing automatically when the car is stopping 12 | - Record road images every 1 second 13 | 14 | ### Instalation location 15 | - Car dashboard 16 | 17 | img1 18 | img2 19 | 20 | ## How to train the Crack Detection model? 21 | (March 2017) 22 | We used [yolo detector](https://pjreddie.com/darknet/yolo/ "yolo web") for training the model. 23 | The training dataset is more than 30,000 road images including road damages.. 24 | This application can just detect "damages". 25 | 26 | (September 2017) 27 | We used SSD using MobileNet for training the model. 28 | Trained model with 9,053 road images with damages can be accessed [here](https://s3-ap-northeast-1.amazonaws.com/mycityreport/trainedModels.tar.gz). 29 | This application can just detect "damages" and classify "damage types". 30 | 31 | ### Android application 32 | Apps with MobileNet+SDD(Octorber 2018) 33 | [RoadCrackDetector.apk(26MB)](https://s3-ap-northeast-1.amazonaws.com/sekilab-students/maeda/kashiyama/mcr111_open.apk) 34 | (Android 7.1 or higher is required) 35 | 36 | The app saves 2 kinds of files in 'ExternalStorage/Android/data/org.utokyo.sekilab.mcr/files'. 37 | Location file contains GPS coordinate every 3 seconds. Damage file contains road damage data and image data. 38 | 39 | ### Experiments in some local governments in Japan(August 2017) 40 | We did road inspection with our apps in Toga villege, please check [our website](http://sekilab.iis.u-tokyo.ac.jp/archives/category/news#post-1882)! 41 | You can also watch the movie in the experiment as a demo([demo movie](https://youtu.be/P74Hl0vr1-Y)) 42 | 43 | 富山県の利賀村にて、本アプリケーションを用いて実際に路面点検を実施しました ([参考](http://sekilab.iis.u-tokyo.ac.jp/archives/category/news#post-1882))。
44 | 実験の様子を[デモ動画](https://youtu.be/P74Hl0vr1-Y)として公開していますので、ご覧ください。 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /protos/input_reader.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for defining input readers that generate Object Detection 6 | // Examples from input sources. Input readers are expected to generate a 7 | // dictionary of tensors, with the following fields populated: 8 | // 9 | // 'image': an [image_height, image_width, channels] image tensor that detection 10 | // will be run on. 11 | // 'groundtruth_classes': a [num_boxes] int32 tensor storing the class 12 | // labels of detected boxes in the image. 13 | // 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of 14 | // detected boxes in the image. 15 | // 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height, 16 | // image_width] float tensor storing binary mask of the objects in boxes. 17 | 18 | message InputReader { 19 | // Path to StringIntLabelMap pbtxt file specifying the mapping from string 20 | // labels to integer ids. 21 | optional string label_map_path = 1 [default=""]; 22 | 23 | // Whether data should be processed in the order they are read in, or 24 | // shuffled randomly. 25 | optional bool shuffle = 2 [default=true]; 26 | 27 | // Maximum number of records to keep in reader queue. 28 | optional uint32 queue_capacity = 3 [default=2000]; 29 | 30 | // Minimum number of records to keep in reader queue. A large value is needed 31 | // to generate a good random shuffle. 32 | optional uint32 min_after_dequeue = 4 [default=1000]; 33 | 34 | // The number of times a data source is read. If set to zero, the data source 35 | // will be reused indefinitely. 36 | optional uint32 num_epochs = 5 [default=0]; 37 | 38 | // Number of reader instances to create. 39 | optional uint32 num_readers = 6 [default=8]; 40 | 41 | // Whether to load groundtruth instance masks. 42 | optional bool load_instance_masks = 7 [default = false]; 43 | 44 | oneof input_reader { 45 | TFRecordInputReader tf_record_input_reader = 8; 46 | ExternalInputReader external_input_reader = 9; 47 | } 48 | } 49 | 50 | // An input reader that reads TF Example protos from local TFRecord files. 51 | message TFRecordInputReader { 52 | // Path(s) to `TFRecordFile`s. 53 | repeated string input_path = 1; 54 | } 55 | 56 | // An externally defined input reader. Users may define an extension to this 57 | // proto to interface their own input readers. 58 | message ExternalInputReader { 59 | extensions 1 to 999; 60 | } 61 | -------------------------------------------------------------------------------- /protos/ssd_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SSD anchor generator described in 6 | // https://arxiv.org/abs/1512.02325. See 7 | // anchor_generators/multiple_grid_anchor_generator.py for details. 8 | message SsdAnchorGenerator { 9 | // Number of grid layers to create anchors for. 10 | optional int32 num_layers = 1 [default = 6]; 11 | 12 | // Scale of anchors corresponding to finest resolution. 13 | optional float min_scale = 2 [default = 0.2]; 14 | 15 | // Scale of anchors corresponding to coarsest resolution 16 | optional float max_scale = 3 [default = 0.95]; 17 | 18 | // Can be used to override min_scale->max_scale, with an explicitly defined 19 | // set of scales. If empty, then min_scale->max_scale is used. 20 | repeated float scales = 12; 21 | 22 | // Aspect ratios for anchors at each grid point. 23 | repeated float aspect_ratios = 4; 24 | 25 | // When this aspect ratio is greater than 0, then an additional 26 | // anchor, with an interpolated scale is added with this aspect ratio. 27 | optional float interpolated_scale_aspect_ratio = 13 [default = 1.0]; 28 | 29 | // Whether to use the following aspect ratio and scale combination for the 30 | // layer with the finest resolution : (scale=0.1, aspect_ratio=1.0), 31 | // (scale=min_scale, aspect_ration=2.0), (scale=min_scale, aspect_ratio=0.5). 32 | optional bool reduce_boxes_in_lowest_layer = 5 [default = true]; 33 | 34 | // The base anchor size in height dimension. 35 | optional float base_anchor_height = 6 [default = 1.0]; 36 | 37 | // The base anchor size in width dimension. 38 | optional float base_anchor_width = 7 [default = 1.0]; 39 | 40 | // Anchor stride in height dimension in pixels for each layer. The length of 41 | // this field is expected to be equal to the value of num_layers. 42 | repeated int32 height_stride = 8; 43 | 44 | // Anchor stride in width dimension in pixels for each layer. The length of 45 | // this field is expected to be equal to the value of num_layers. 46 | repeated int32 width_stride = 9; 47 | 48 | // Anchor height offset in pixels for each layer. The length of this field is 49 | // expected to be equal to the value of num_layers. 50 | repeated int32 height_offset = 10; 51 | 52 | // Anchor width offset in pixels for each layer. The length of this field is 53 | // expected to be equal to the value of num_layers. 54 | repeated int32 width_offset = 11; 55 | } 56 | -------------------------------------------------------------------------------- /utils/category_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for importing/exporting Object Detection categories.""" 17 | import csv 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def load_categories_from_csv_file(csv_path): 23 | """Loads categories from a csv file. 24 | 25 | The CSV file should have one comma delimited numeric category id and string 26 | category name pair per line. For example: 27 | 28 | 0,"cat" 29 | 1,"dog" 30 | 2,"bird" 31 | ... 32 | 33 | Args: 34 | csv_path: Path to the csv file to be parsed into categories. 35 | Returns: 36 | categories: A list of dictionaries representing all possible categories. 37 | The categories will contain an integer 'id' field and a string 38 | 'name' field. 39 | Raises: 40 | ValueError: If the csv file is incorrectly formatted. 41 | """ 42 | categories = [] 43 | 44 | with tf.gfile.Open(csv_path, 'r') as csvfile: 45 | reader = csv.reader(csvfile, delimiter=',', quotechar='"') 46 | for row in reader: 47 | if not row: 48 | continue 49 | 50 | if len(row) != 2: 51 | raise ValueError('Expected 2 fields per row in csv: %s' % ','.join(row)) 52 | 53 | category_id = int(row[0]) 54 | category_name = row[1] 55 | categories.append({'id': category_id, 'name': category_name}) 56 | 57 | return categories 58 | 59 | 60 | def save_categories_to_csv_file(categories, csv_path): 61 | """Saves categories to a csv file. 62 | 63 | Args: 64 | categories: A list of dictionaries representing categories to save to file. 65 | Each category must contain an 'id' and 'name' field. 66 | csv_path: Path to the csv file to be parsed into categories. 67 | """ 68 | categories.sort(key=lambda x: x['id']) 69 | with tf.gfile.Open(csv_path, 'w') as csvfile: 70 | writer = csv.writer(csvfile, delimiter=',', quotechar='"') 71 | for category in categories: 72 | writer.writerow([category['id'], category['name']]) 73 | -------------------------------------------------------------------------------- /utils/np_box_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.np_box_ops.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import np_box_ops 22 | 23 | 24 | class BoxOpsTests(tf.test.TestCase): 25 | 26 | def setUp(self): 27 | boxes1 = np.array([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], 28 | dtype=float) 29 | boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 30 | [0.0, 0.0, 20.0, 20.0]], 31 | dtype=float) 32 | self.boxes1 = boxes1 33 | self.boxes2 = boxes2 34 | 35 | def testArea(self): 36 | areas = np_box_ops.area(self.boxes1) 37 | expected_areas = np.array([6.0, 5.0], dtype=float) 38 | self.assertAllClose(expected_areas, areas) 39 | 40 | def testIntersection(self): 41 | intersection = np_box_ops.intersection(self.boxes1, self.boxes2) 42 | expected_intersection = np.array([[2.0, 0.0, 6.0], [1.0, 0.0, 5.0]], 43 | dtype=float) 44 | self.assertAllClose(intersection, expected_intersection) 45 | 46 | def testIOU(self): 47 | iou = np_box_ops.iou(self.boxes1, self.boxes2) 48 | expected_iou = np.array([[2.0 / 16.0, 0.0, 6.0 / 400.0], 49 | [1.0 / 16.0, 0.0, 5.0 / 400.0]], 50 | dtype=float) 51 | self.assertAllClose(iou, expected_iou) 52 | 53 | def testIOA(self): 54 | boxes1 = np.array([[0.25, 0.25, 0.75, 0.75], 55 | [0.0, 0.0, 0.5, 0.75]], 56 | dtype=np.float32) 57 | boxes2 = np.array([[0.5, 0.25, 1.0, 1.0], 58 | [0.0, 0.0, 1.0, 1.0]], 59 | dtype=np.float32) 60 | ioa21 = np_box_ops.ioa(boxes2, boxes1) 61 | expected_ioa21 = np.array([[0.5, 0.0], 62 | [1.0, 1.0]], 63 | dtype=np.float32) 64 | self.assertAllClose(ioa21, expected_ioa21) 65 | 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /utils/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.test_utils.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import test_utils 22 | 23 | 24 | class TestUtilsTest(tf.test.TestCase): 25 | 26 | def test_diagonal_gradient_image(self): 27 | """Tests if a good pyramid image is created.""" 28 | pyramid_image = test_utils.create_diagonal_gradient_image(3, 4, 2) 29 | 30 | # Test which is easy to understand. 31 | expected_first_channel = np.array([[3, 2, 1, 0], 32 | [4, 3, 2, 1], 33 | [5, 4, 3, 2]], dtype=np.float32) 34 | self.assertAllEqual(np.squeeze(pyramid_image[:, :, 0]), 35 | expected_first_channel) 36 | 37 | # Actual test. 38 | expected_image = np.array([[[3, 30], 39 | [2, 20], 40 | [1, 10], 41 | [0, 0]], 42 | [[4, 40], 43 | [3, 30], 44 | [2, 20], 45 | [1, 10]], 46 | [[5, 50], 47 | [4, 40], 48 | [3, 30], 49 | [2, 20]]], dtype=np.float32) 50 | 51 | self.assertAllEqual(pyramid_image, expected_image) 52 | 53 | def test_random_boxes(self): 54 | """Tests if valid random boxes are created.""" 55 | num_boxes = 1000 56 | max_height = 3 57 | max_width = 5 58 | boxes = test_utils.create_random_boxes(num_boxes, 59 | max_height, 60 | max_width) 61 | 62 | true_column = np.ones(shape=(num_boxes)) == 1 63 | self.assertAllEqual(boxes[:, 0] < boxes[:, 2], true_column) 64 | self.assertAllEqual(boxes[:, 1] < boxes[:, 3], true_column) 65 | 66 | self.assertTrue(boxes[:, 0].min() >= 0) 67 | self.assertTrue(boxes[:, 1].min() >= 0) 68 | self.assertTrue(boxes[:, 2].max() <= max_height) 69 | self.assertTrue(boxes[:, 3].max() <= max_width) 70 | 71 | 72 | if __name__ == '__main__': 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def int64_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 23 | 24 | 25 | def int64_list_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 27 | 28 | 29 | def bytes_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | 33 | def bytes_list_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | def read_examples_list(path): 42 | """Read list of training or validation examples. 43 | 44 | The file is assumed to contain a single example per line where the first 45 | token in the line is an identifier that allows us to find the image and 46 | annotation xml for that example. 47 | 48 | For example, the line: 49 | xyz 3 50 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 51 | 52 | Args: 53 | path: absolute path to examples list file. 54 | 55 | Returns: 56 | list of example identifiers (strings). 57 | """ 58 | with tf.gfile.GFile(path) as fid: 59 | lines = fid.readlines() 60 | return [line.strip().split(' ')[0] for line in lines] 61 | 62 | 63 | def recursive_parse_xml_to_dict(xml): 64 | """Recursively parses XML contents to python dict. 65 | 66 | We assume that `object` tags are the only ones that can appear 67 | multiple times at the same level of a tree. 68 | 69 | Args: 70 | xml: xml tree obtained by parsing XML file contents using lxml.etree 71 | 72 | Returns: 73 | Python dictionary holding XML contents. 74 | """ 75 | if not xml: 76 | return {xml.tag: xml.text} 77 | result = {} 78 | for child in xml: 79 | child_result = recursive_parse_xml_to_dict(child) 80 | if child.tag != 'object': 81 | result[child.tag] = child_result[child.tag] 82 | else: 83 | if child.tag not in result: 84 | result[child.tag] = [] 85 | result[child.tag].append(child_result[child.tag]) 86 | return {xml.tag: result} 87 | -------------------------------------------------------------------------------- /protos/train.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/optimizer.proto"; 6 | import "object_detection/protos/preprocessor.proto"; 7 | 8 | // Message for configuring DetectionModel training jobs (train.py). 9 | message TrainConfig { 10 | // Input queue batch size. 11 | optional uint32 batch_size = 1 [default=32]; 12 | 13 | // Data augmentation options. 14 | repeated PreprocessingStep data_augmentation_options = 2; 15 | 16 | // Whether to synchronize replicas during training. 17 | optional bool sync_replicas = 3 [default=false]; 18 | 19 | // How frequently to keep checkpoints. 20 | optional uint32 keep_checkpoint_every_n_hours = 4 [default=1000]; 21 | 22 | // Optimizer used to train the DetectionModel. 23 | optional Optimizer optimizer = 5; 24 | 25 | // If greater than 0, clips gradients by this value. 26 | optional float gradient_clipping_by_norm = 6 [default=0.0]; 27 | 28 | // Checkpoint to restore variables from. Typically used to load feature 29 | // extractor variables trained outside of object detection. 30 | optional string fine_tune_checkpoint = 7 [default=""]; 31 | 32 | // Specifies if the finetune checkpoint is from an object detection model. 33 | // If from an object detection model, the model being trained should have 34 | // the same parameters with the exception of the num_classes parameter. 35 | // If false, it assumes the checkpoint was a object classification model. 36 | optional bool from_detection_checkpoint = 8 [default=false]; 37 | 38 | // Number of steps to train the DetectionModel for. If 0, will train the model 39 | // indefinitely. 40 | optional uint32 num_steps = 9 [default=0]; 41 | 42 | // Number of training steps between replica startup. 43 | // This flag must be set to 0 if sync_replicas is set to true. 44 | optional float startup_delay_steps = 10 [default=15]; 45 | 46 | // If greater than 0, multiplies the gradient of bias variables by this 47 | // amount. 48 | optional float bias_grad_multiplier = 11 [default=0]; 49 | 50 | // Variables that should not be updated during training. 51 | repeated string freeze_variables = 12; 52 | 53 | // Number of replicas to aggregate before making parameter updates. 54 | optional int32 replicas_to_aggregate = 13 [default=1]; 55 | 56 | // Maximum number of elements to store within a queue. 57 | optional int32 batch_queue_capacity = 14 [default=150]; 58 | 59 | // Number of threads to use for batching. 60 | optional int32 num_batch_queue_threads = 15 [default=8]; 61 | 62 | // Maximum capacity of the queue used to prefetch assembled batches. 63 | optional int32 prefetch_queue_capacity = 16 [default=5]; 64 | 65 | // If true, boxes with the same coordinates will be merged together. 66 | // This is useful when each box can have multiple labels. 67 | // Note that only Sigmoid classification losses should be used. 68 | optional bool merge_multiple_label_boxes = 17 [default=false]; 69 | } 70 | -------------------------------------------------------------------------------- /protos/square_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/square_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/square_box_coder.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n.object_detection/protos/square_box_coder.proto\x12\x17object_detection.protos\"S\n\x0eSquareBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0clength_scale\x18\x03 \x01(\x02:\x01\x35') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _SQUAREBOXCODER = _descriptor.Descriptor( 29 | name='SquareBoxCoder', 30 | full_name='object_detection.protos.SquareBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='y_scale', full_name='object_detection.protos.SquareBoxCoder.y_scale', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=10, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='x_scale', full_name='object_detection.protos.SquareBoxCoder.x_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=10, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='length_scale', full_name='object_detection.protos.SquareBoxCoder.length_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=5, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | extension_ranges=[], 65 | oneofs=[ 66 | ], 67 | serialized_start=75, 68 | serialized_end=158, 69 | ) 70 | 71 | DESCRIPTOR.message_types_by_name['SquareBoxCoder'] = _SQUAREBOXCODER 72 | 73 | SquareBoxCoder = _reflection.GeneratedProtocolMessageType('SquareBoxCoder', (_message.Message,), dict( 74 | DESCRIPTOR = _SQUAREBOXCODER, 75 | __module__ = 'object_detection.protos.square_box_coder_pb2' 76 | # @@protoc_insertion_point(class_scope:object_detection.protos.SquareBoxCoder) 77 | )) 78 | _sym_db.RegisterMessage(SquareBoxCoder) 79 | 80 | 81 | # @@protoc_insertion_point(module_scope) 82 | -------------------------------------------------------------------------------- /protos/optimizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Messages for configuring the optimizing strategy for training object 6 | // detection models. 7 | 8 | // Top level optimizer message. 9 | message Optimizer { 10 | oneof optimizer { 11 | RMSPropOptimizer rms_prop_optimizer = 1; 12 | MomentumOptimizer momentum_optimizer = 2; 13 | AdamOptimizer adam_optimizer = 3; 14 | } 15 | optional bool use_moving_average = 4 [default = true]; 16 | optional float moving_average_decay = 5 [default = 0.9999]; 17 | } 18 | 19 | // Configuration message for the RMSPropOptimizer 20 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 21 | message RMSPropOptimizer { 22 | optional LearningRate learning_rate = 1; 23 | optional float momentum_optimizer_value = 2 [default = 0.9]; 24 | optional float decay = 3 [default = 0.9]; 25 | optional float epsilon = 4 [default = 1.0]; 26 | } 27 | 28 | // Configuration message for the MomentumOptimizer 29 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer 30 | message MomentumOptimizer { 31 | optional LearningRate learning_rate = 1; 32 | optional float momentum_optimizer_value = 2 [default = 0.9]; 33 | } 34 | 35 | // Configuration message for the AdamOptimizer 36 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 37 | message AdamOptimizer { 38 | optional LearningRate learning_rate = 1; 39 | } 40 | 41 | // Configuration message for optimizer learning rate. 42 | message LearningRate { 43 | oneof learning_rate { 44 | ConstantLearningRate constant_learning_rate = 1; 45 | ExponentialDecayLearningRate exponential_decay_learning_rate = 2; 46 | ManualStepLearningRate manual_step_learning_rate = 3; 47 | CosineDecayLearningRate cosine_decay_learning_rate = 4; 48 | } 49 | } 50 | 51 | // Configuration message for a constant learning rate. 52 | message ConstantLearningRate { 53 | optional float learning_rate = 1 [default = 0.002]; 54 | } 55 | 56 | // Configuration message for an exponentially decaying learning rate. 57 | // See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ 58 | // decaying_the_learning_rate#exponential_decay 59 | message ExponentialDecayLearningRate { 60 | optional float initial_learning_rate = 1 [default = 0.002]; 61 | optional uint32 decay_steps = 2 [default = 4000000]; 62 | optional float decay_factor = 3 [default = 0.95]; 63 | optional bool staircase = 4 [default = true]; 64 | } 65 | 66 | // Configuration message for a manually defined learning rate schedule. 67 | message ManualStepLearningRate { 68 | optional float initial_learning_rate = 1 [default = 0.002]; 69 | message LearningRateSchedule { 70 | optional uint32 step = 1; 71 | optional float learning_rate = 2 [default = 0.002]; 72 | } 73 | repeated LearningRateSchedule schedule = 2; 74 | } 75 | 76 | // Configuration message for a cosine decaying learning rate as defined in 77 | // object_detection/utils/learning_schedules.py 78 | message CosineDecayLearningRate { 79 | optional float learning_rate_base = 1 [default = 0.002]; 80 | optional uint32 total_steps = 2 [default = 4000000]; 81 | optional float warmup_learning_rate = 3 [default = 0.0002]; 82 | optional uint32 warmup_steps = 4 [default = 10000]; 83 | } 84 | -------------------------------------------------------------------------------- /protos/ssd.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package object_detection.protos; 3 | 4 | import "object_detection/protos/anchor_generator.proto"; 5 | import "object_detection/protos/box_coder.proto"; 6 | import "object_detection/protos/box_predictor.proto"; 7 | import "object_detection/protos/hyperparams.proto"; 8 | import "object_detection/protos/image_resizer.proto"; 9 | import "object_detection/protos/matcher.proto"; 10 | import "object_detection/protos/losses.proto"; 11 | import "object_detection/protos/post_processing.proto"; 12 | import "object_detection/protos/region_similarity_calculator.proto"; 13 | 14 | // Configuration for Single Shot Detection (SSD) models. 15 | message Ssd { 16 | 17 | // Number of classes to predict. 18 | optional int32 num_classes = 1; 19 | 20 | // Image resizer for preprocessing the input image. 21 | optional ImageResizer image_resizer = 2; 22 | 23 | // Feature extractor config. 24 | optional SsdFeatureExtractor feature_extractor = 3; 25 | 26 | // Box coder to encode the boxes. 27 | optional BoxCoder box_coder = 4; 28 | 29 | // Matcher to match groundtruth with anchors. 30 | optional Matcher matcher = 5; 31 | 32 | // Region similarity calculator to compute similarity of boxes. 33 | optional RegionSimilarityCalculator similarity_calculator = 6; 34 | 35 | // Box predictor to attach to the features. 36 | optional BoxPredictor box_predictor = 7; 37 | 38 | // Anchor generator to compute anchors. 39 | optional AnchorGenerator anchor_generator = 8; 40 | 41 | // Post processing to apply on the predictions. 42 | optional PostProcessing post_processing = 9; 43 | 44 | // Whether to normalize the loss by number of groundtruth boxes that match to 45 | // the anchors. 46 | optional bool normalize_loss_by_num_matches = 10 [default=true]; 47 | 48 | // Loss configuration for training. 49 | optional Loss loss = 11; 50 | } 51 | 52 | 53 | message SsdFeatureExtractor { 54 | // Type of ssd feature extractor. 55 | optional string type = 1; 56 | 57 | // The factor to alter the depth of the channels in the feature extractor. 58 | optional float depth_multiplier = 2 [default=1.0]; 59 | 60 | // Minimum number of the channels in the feature extractor. 61 | optional int32 min_depth = 3 [default=16]; 62 | 63 | // Hyperparameters for the feature extractor. 64 | optional Hyperparams conv_hyperparams = 4; 65 | 66 | // The nearest multiple to zero-pad the input height and width dimensions to. 67 | // For example, if pad_to_multiple = 2, input dimensions are zero-padded 68 | // until the resulting dimensions are even. 69 | optional int32 pad_to_multiple = 5 [default = 1]; 70 | 71 | // Whether to update batch norm parameters during training or not. 72 | // When training with a relative small batch size (e.g. 1), it is 73 | // desirable to disable batch norm update and use pretrained batch norm 74 | // params. 75 | // 76 | // Note: Some feature extractors are used with canned arg_scopes 77 | // (e.g resnet arg scopes). In these cases training behavior of batch norm 78 | // variables may depend on both values of `batch_norm_trainable` and 79 | // `is_training`. 80 | // 81 | // When canned arg_scopes are used with feature extractors `conv_hyperparams` 82 | // will apply only to the additional layers that are added and are outside the 83 | // canned arg_scope. 84 | optional bool batch_norm_trainable = 6 [default=true]; 85 | } 86 | -------------------------------------------------------------------------------- /utils/learning_schedules_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.learning_schedules.""" 17 | import tensorflow as tf 18 | 19 | from object_detection.utils import learning_schedules 20 | 21 | 22 | class LearningSchedulesTest(tf.test.TestCase): 23 | 24 | def testExponentialDecayWithBurnin(self): 25 | global_step = tf.placeholder(tf.int32, []) 26 | learning_rate_base = 1.0 27 | learning_rate_decay_steps = 3 28 | learning_rate_decay_factor = .1 29 | burnin_learning_rate = .5 30 | burnin_steps = 2 31 | exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01] 32 | learning_rate = learning_schedules.exponential_decay_with_burnin( 33 | global_step, learning_rate_base, learning_rate_decay_steps, 34 | learning_rate_decay_factor, burnin_learning_rate, burnin_steps) 35 | with self.test_session() as sess: 36 | output_rates = [] 37 | for input_global_step in range(8): 38 | output_rate = sess.run(learning_rate, 39 | feed_dict={global_step: input_global_step}) 40 | output_rates.append(output_rate) 41 | self.assertAllClose(output_rates, exp_rates) 42 | 43 | def testCosineDecayWithWarmup(self): 44 | global_step = tf.placeholder(tf.int32, []) 45 | learning_rate_base = 1.0 46 | total_steps = 100 47 | warmup_learning_rate = 0.1 48 | warmup_steps = 9 49 | input_global_steps = [0, 4, 8, 9, 100] 50 | exp_rates = [0.1, 0.5, 0.9, 1.0, 0] 51 | learning_rate = learning_schedules.cosine_decay_with_warmup( 52 | global_step, learning_rate_base, total_steps, 53 | warmup_learning_rate, warmup_steps) 54 | with self.test_session() as sess: 55 | output_rates = [] 56 | for input_global_step in input_global_steps: 57 | output_rate = sess.run(learning_rate, 58 | feed_dict={global_step: input_global_step}) 59 | output_rates.append(output_rate) 60 | self.assertAllClose(output_rates, exp_rates) 61 | 62 | def testManualStepping(self): 63 | global_step = tf.placeholder(tf.int64, []) 64 | boundaries = [2, 3, 7] 65 | rates = [1.0, 2.0, 3.0, 4.0] 66 | exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0] 67 | learning_rate = learning_schedules.manual_stepping(global_step, boundaries, 68 | rates) 69 | with self.test_session() as sess: 70 | output_rates = [] 71 | for input_global_step in range(10): 72 | output_rate = sess.run(learning_rate, 73 | feed_dict={global_step: input_global_step}) 74 | output_rates.append(output_rate) 75 | self.assertAllClose(output_rates, exp_rates) 76 | 77 | if __name__ == '__main__': 78 | tf.test.main() 79 | -------------------------------------------------------------------------------- /protos/hyperparams.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for the convolution op hyperparameters to use in the 6 | // object detection pipeline. 7 | message Hyperparams { 8 | 9 | // Operations affected by hyperparameters. 10 | enum Op { 11 | // Convolution, Separable Convolution, Convolution transpose. 12 | CONV = 1; 13 | 14 | // Fully connected 15 | FC = 2; 16 | } 17 | optional Op op = 1 [default = CONV]; 18 | 19 | // Regularizer for the weights of the convolution op. 20 | optional Regularizer regularizer = 2; 21 | 22 | // Initializer for the weights of the convolution op. 23 | optional Initializer initializer = 3; 24 | 25 | // Type of activation to apply after convolution. 26 | enum Activation { 27 | // Use None (no activation) 28 | NONE = 0; 29 | 30 | // Use tf.nn.relu 31 | RELU = 1; 32 | 33 | // Use tf.nn.relu6 34 | RELU_6 = 2; 35 | } 36 | optional Activation activation = 4 [default = RELU]; 37 | 38 | // BatchNorm hyperparameters. If this parameter is NOT set then BatchNorm is 39 | // not applied! 40 | optional BatchNorm batch_norm = 5; 41 | } 42 | 43 | // Proto with one-of field for regularizers. 44 | message Regularizer { 45 | oneof regularizer_oneof { 46 | L1Regularizer l1_regularizer = 1; 47 | L2Regularizer l2_regularizer = 2; 48 | } 49 | } 50 | 51 | // Configuration proto for L1 Regularizer. 52 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l1_regularizer 53 | message L1Regularizer { 54 | optional float weight = 1 [default = 1.0]; 55 | } 56 | 57 | // Configuration proto for L2 Regularizer. 58 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l2_regularizer 59 | message L2Regularizer { 60 | optional float weight = 1 [default = 1.0]; 61 | } 62 | 63 | // Proto with one-of field for initializers. 64 | message Initializer { 65 | oneof initializer_oneof { 66 | TruncatedNormalInitializer truncated_normal_initializer = 1; 67 | VarianceScalingInitializer variance_scaling_initializer = 2; 68 | } 69 | } 70 | 71 | // Configuration proto for truncated normal initializer. See 72 | // https://www.tensorflow.org/api_docs/python/tf/truncated_normal_initializer 73 | message TruncatedNormalInitializer { 74 | optional float mean = 1 [default = 0.0]; 75 | optional float stddev = 2 [default = 1.0]; 76 | } 77 | 78 | // Configuration proto for variance scaling initializer. See 79 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/ 80 | // variance_scaling_initializer 81 | message VarianceScalingInitializer { 82 | optional float factor = 1 [default = 2.0]; 83 | optional bool uniform = 2 [default = false]; 84 | enum Mode { 85 | FAN_IN = 0; 86 | FAN_OUT = 1; 87 | FAN_AVG = 2; 88 | } 89 | optional Mode mode = 3 [default = FAN_IN]; 90 | } 91 | 92 | // Configuration proto for batch norm to apply after convolution op. See 93 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm 94 | message BatchNorm { 95 | optional float decay = 1 [default = 0.999]; 96 | optional bool center = 2 [default = true]; 97 | optional bool scale = 3 [default = false]; 98 | optional float epsilon = 4 [default = 0.001]; 99 | // Whether to train the batch norm variables. If this is set to false during 100 | // training, the current value of the batch_norm variables are used for 101 | // forward pass but they are never updated. 102 | optional bool train = 5 [default = true]; 103 | } 104 | -------------------------------------------------------------------------------- /utils/np_box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, 4] numpy arrays representing bounding boxes. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | import numpy as np 23 | 24 | 25 | def area(boxes): 26 | """Computes area of boxes. 27 | 28 | Args: 29 | boxes: Numpy array with shape [N, 4] holding N boxes 30 | 31 | Returns: 32 | a numpy array with shape [N*1] representing box areas 33 | """ 34 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 35 | 36 | 37 | def intersection(boxes1, boxes2): 38 | """Compute pairwise intersection areas between boxes. 39 | 40 | Args: 41 | boxes1: a numpy array with shape [N, 4] holding N boxes 42 | boxes2: a numpy array with shape [M, 4] holding M boxes 43 | 44 | Returns: 45 | a numpy array with shape [N*M] representing pairwise intersection area 46 | """ 47 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1) 48 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1) 49 | 50 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2)) 51 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2)) 52 | intersect_heights = np.maximum( 53 | np.zeros(all_pairs_max_ymin.shape), 54 | all_pairs_min_ymax - all_pairs_max_ymin) 55 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2)) 56 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2)) 57 | intersect_widths = np.maximum( 58 | np.zeros(all_pairs_max_xmin.shape), 59 | all_pairs_min_xmax - all_pairs_max_xmin) 60 | return intersect_heights * intersect_widths 61 | 62 | 63 | def iou(boxes1, boxes2): 64 | """Computes pairwise intersection-over-union between box collections. 65 | 66 | Args: 67 | boxes1: a numpy array with shape [N, 4] holding N boxes. 68 | boxes2: a numpy array with shape [M, 4] holding N boxes. 69 | 70 | Returns: 71 | a numpy array with shape [N, M] representing pairwise iou scores. 72 | """ 73 | intersect = intersection(boxes1, boxes2) 74 | area1 = area(boxes1) 75 | area2 = area(boxes2) 76 | union = np.expand_dims(area1, axis=1) + np.expand_dims( 77 | area2, axis=0) - intersect 78 | return intersect / union 79 | 80 | 81 | def ioa(boxes1, boxes2): 82 | """Computes pairwise intersection-over-area between box collections. 83 | 84 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as 85 | their intersection area over box2's area. Note that ioa is not symmetric, 86 | that is, IOA(box1, box2) != IOA(box2, box1). 87 | 88 | Args: 89 | boxes1: a numpy array with shape [N, 4] holding N boxes. 90 | boxes2: a numpy array with shape [M, 4] holding N boxes. 91 | 92 | Returns: 93 | a numpy array with shape [N, M] representing pairwise ioa scores. 94 | """ 95 | intersect = intersection(boxes1, boxes2) 96 | areas = np.expand_dims(area(boxes2), axis=0) 97 | return intersect / areas 98 | -------------------------------------------------------------------------------- /protos/faster_rcnn_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/faster_rcnn_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/faster_rcnn_box_coder.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n3object_detection/protos/faster_rcnn_box_coder.proto\x12\x17object_detection.protos\"o\n\x12\x46\x61sterRcnnBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0cheight_scale\x18\x03 \x01(\x02:\x01\x35\x12\x16\n\x0bwidth_scale\x18\x04 \x01(\x02:\x01\x35') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _FASTERRCNNBOXCODER = _descriptor.Descriptor( 29 | name='FasterRcnnBoxCoder', 30 | full_name='object_detection.protos.FasterRcnnBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='y_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.y_scale', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=10, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='x_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.x_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=10, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='height_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.height_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=5, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='width_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.width_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=True, default_value=5, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | ], 64 | extensions=[ 65 | ], 66 | nested_types=[], 67 | enum_types=[ 68 | ], 69 | options=None, 70 | is_extendable=False, 71 | extension_ranges=[], 72 | oneofs=[ 73 | ], 74 | serialized_start=80, 75 | serialized_end=191, 76 | ) 77 | 78 | DESCRIPTOR.message_types_by_name['FasterRcnnBoxCoder'] = _FASTERRCNNBOXCODER 79 | 80 | FasterRcnnBoxCoder = _reflection.GeneratedProtocolMessageType('FasterRcnnBoxCoder', (_message.Message,), dict( 81 | DESCRIPTOR = _FASTERRCNNBOXCODER, 82 | __module__ = 'object_detection.protos.faster_rcnn_box_coder_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.FasterRcnnBoxCoder) 84 | )) 85 | _sym_db.RegisterMessage(FasterRcnnBoxCoder) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /utils/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.metrics.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import metrics 22 | 23 | 24 | class MetricsTest(tf.test.TestCase): 25 | 26 | def test_compute_cor_loc(self): 27 | num_gt_imgs_per_class = np.array([100, 1, 5, 1, 1], dtype=int) 28 | num_images_correctly_detected_per_class = np.array([10, 0, 1, 0, 0], 29 | dtype=int) 30 | corloc = metrics.compute_cor_loc(num_gt_imgs_per_class, 31 | num_images_correctly_detected_per_class) 32 | expected_corloc = np.array([0.1, 0, 0.2, 0, 0], dtype=float) 33 | self.assertTrue(np.allclose(corloc, expected_corloc)) 34 | 35 | def test_compute_cor_loc_nans(self): 36 | num_gt_imgs_per_class = np.array([100, 0, 0, 1, 1], dtype=int) 37 | num_images_correctly_detected_per_class = np.array([10, 0, 1, 0, 0], 38 | dtype=int) 39 | corloc = metrics.compute_cor_loc(num_gt_imgs_per_class, 40 | num_images_correctly_detected_per_class) 41 | expected_corloc = np.array([0.1, np.nan, np.nan, 0, 0], dtype=float) 42 | self.assertAllClose(corloc, expected_corloc) 43 | 44 | def test_compute_precision_recall(self): 45 | num_gt = 10 46 | scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float) 47 | labels = np.array([0, 1, 1, 0, 0, 1], dtype=bool) 48 | accumulated_tp_count = np.array([0, 1, 1, 2, 2, 3], dtype=float) 49 | expected_precision = accumulated_tp_count / np.array([1, 2, 3, 4, 5, 6]) 50 | expected_recall = accumulated_tp_count / num_gt 51 | precision, recall = metrics.compute_precision_recall(scores, labels, num_gt) 52 | self.assertAllClose(precision, expected_precision) 53 | self.assertAllClose(recall, expected_recall) 54 | 55 | def test_compute_average_precision(self): 56 | precision = np.array([0.8, 0.76, 0.9, 0.65, 0.7, 0.5, 0.55, 0], dtype=float) 57 | recall = np.array([0.3, 0.3, 0.4, 0.4, 0.45, 0.45, 0.5, 0.5], dtype=float) 58 | processed_precision = np.array([0.9, 0.9, 0.9, 0.7, 0.7, 0.55, 0.55, 0], 59 | dtype=float) 60 | recall_interval = np.array([0.3, 0, 0.1, 0, 0.05, 0, 0.05, 0], dtype=float) 61 | expected_mean_ap = np.sum(recall_interval * processed_precision) 62 | mean_ap = metrics.compute_average_precision(precision, recall) 63 | self.assertAlmostEqual(expected_mean_ap, mean_ap) 64 | 65 | def test_compute_precision_recall_and_ap_no_groundtruth(self): 66 | num_gt = 0 67 | scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float) 68 | labels = np.array([0, 0, 0, 0, 0, 0], dtype=bool) 69 | expected_precision = None 70 | expected_recall = None 71 | precision, recall = metrics.compute_precision_recall(scores, labels, num_gt) 72 | self.assertEquals(precision, expected_precision) 73 | self.assertEquals(recall, expected_recall) 74 | ap = metrics.compute_average_precision(precision, recall) 75 | self.assertTrue(np.isnan(ap)) 76 | 77 | 78 | if __name__ == '__main__': 79 | tf.test.main() 80 | -------------------------------------------------------------------------------- /protos/model_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/model.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.faster_rcnn_pb2 17 | import object_detection.protos.ssd_pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/model.proto', 22 | package='object_detection.protos', 23 | serialized_pb=_b('\n#object_detection/protos/model.proto\x12\x17object_detection.protos\x1a)object_detection/protos/faster_rcnn.proto\x1a!object_detection/protos/ssd.proto\"\x82\x01\n\x0e\x44\x65tectionModel\x12:\n\x0b\x66\x61ster_rcnn\x18\x01 \x01(\x0b\x32#.object_detection.protos.FasterRcnnH\x00\x12+\n\x03ssd\x18\x02 \x01(\x0b\x32\x1c.object_detection.protos.SsdH\x00\x42\x07\n\x05model') 24 | , 25 | dependencies=[object_detection.protos.faster_rcnn_pb2.DESCRIPTOR,object_detection.protos.ssd_pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | _DETECTIONMODEL = _descriptor.Descriptor( 32 | name='DetectionModel', 33 | full_name='object_detection.protos.DetectionModel', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='faster_rcnn', full_name='object_detection.protos.DetectionModel.faster_rcnn', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='ssd', full_name='object_detection.protos.DetectionModel.ssd', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | extension_ranges=[], 61 | oneofs=[ 62 | _descriptor.OneofDescriptor( 63 | name='model', full_name='object_detection.protos.DetectionModel.model', 64 | index=0, containing_type=None, fields=[]), 65 | ], 66 | serialized_start=143, 67 | serialized_end=273, 68 | ) 69 | 70 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].message_type = object_detection.protos.faster_rcnn_pb2._FASTERRCNN 71 | _DETECTIONMODEL.fields_by_name['ssd'].message_type = object_detection.protos.ssd_pb2._SSD 72 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append( 73 | _DETECTIONMODEL.fields_by_name['faster_rcnn']) 74 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model'] 75 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append( 76 | _DETECTIONMODEL.fields_by_name['ssd']) 77 | _DETECTIONMODEL.fields_by_name['ssd'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model'] 78 | DESCRIPTOR.message_types_by_name['DetectionModel'] = _DETECTIONMODEL 79 | 80 | DetectionModel = _reflection.GeneratedProtocolMessageType('DetectionModel', (_message.Message,), dict( 81 | DESCRIPTOR = _DETECTIONMODEL, 82 | __module__ = 'object_detection.protos.model_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.DetectionModel) 84 | )) 85 | _sym_db.RegisterMessage(DetectionModel) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /protos/matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.argmax_matcher_pb2 17 | import object_detection.protos.bipartite_matcher_pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/matcher.proto', 22 | package='object_detection.protos', 23 | serialized_pb=_b('\n%object_detection/protos/matcher.proto\x12\x17object_detection.protos\x1a,object_detection/protos/argmax_matcher.proto\x1a/object_detection/protos/bipartite_matcher.proto\"\xa4\x01\n\x07Matcher\x12@\n\x0e\x61rgmax_matcher\x18\x01 \x01(\x0b\x32&.object_detection.protos.ArgMaxMatcherH\x00\x12\x46\n\x11\x62ipartite_matcher\x18\x02 \x01(\x0b\x32).object_detection.protos.BipartiteMatcherH\x00\x42\x0f\n\rmatcher_oneof') 24 | , 25 | dependencies=[object_detection.protos.argmax_matcher_pb2.DESCRIPTOR,object_detection.protos.bipartite_matcher_pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | _MATCHER = _descriptor.Descriptor( 32 | name='Matcher', 33 | full_name='object_detection.protos.Matcher', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='argmax_matcher', full_name='object_detection.protos.Matcher.argmax_matcher', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='bipartite_matcher', full_name='object_detection.protos.Matcher.bipartite_matcher', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | extension_ranges=[], 61 | oneofs=[ 62 | _descriptor.OneofDescriptor( 63 | name='matcher_oneof', full_name='object_detection.protos.Matcher.matcher_oneof', 64 | index=0, containing_type=None, fields=[]), 65 | ], 66 | serialized_start=162, 67 | serialized_end=326, 68 | ) 69 | 70 | _MATCHER.fields_by_name['argmax_matcher'].message_type = object_detection.protos.argmax_matcher_pb2._ARGMAXMATCHER 71 | _MATCHER.fields_by_name['bipartite_matcher'].message_type = object_detection.protos.bipartite_matcher_pb2._BIPARTITEMATCHER 72 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append( 73 | _MATCHER.fields_by_name['argmax_matcher']) 74 | _MATCHER.fields_by_name['argmax_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof'] 75 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append( 76 | _MATCHER.fields_by_name['bipartite_matcher']) 77 | _MATCHER.fields_by_name['bipartite_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof'] 78 | DESCRIPTOR.message_types_by_name['Matcher'] = _MATCHER 79 | 80 | Matcher = _reflection.GeneratedProtocolMessageType('Matcher', (_message.Message,), dict( 81 | DESCRIPTOR = _MATCHER, 82 | __module__ = 'object_detection.protos.matcher_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.Matcher) 84 | )) 85 | _sym_db.RegisterMessage(Matcher) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /protos/box_predictor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/hyperparams.proto"; 6 | 7 | 8 | // Configuration proto for box predictor. See core/box_predictor.py for details. 9 | message BoxPredictor { 10 | oneof box_predictor_oneof { 11 | ConvolutionalBoxPredictor convolutional_box_predictor = 1; 12 | MaskRCNNBoxPredictor mask_rcnn_box_predictor = 2; 13 | RfcnBoxPredictor rfcn_box_predictor = 3; 14 | } 15 | } 16 | 17 | // Configuration proto for Convolutional box predictor. 18 | message ConvolutionalBoxPredictor { 19 | // Hyperparameters for convolution ops used in the box predictor. 20 | optional Hyperparams conv_hyperparams = 1; 21 | 22 | // Minumum feature depth prior to predicting box encodings and class 23 | // predictions. 24 | optional int32 min_depth = 2 [default = 0]; 25 | 26 | // Maximum feature depth prior to predicting box encodings and class 27 | // predictions. If max_depth is set to 0, no additional feature map will be 28 | // inserted before location and class predictions. 29 | optional int32 max_depth = 3 [default = 0]; 30 | 31 | // Number of the additional conv layers before the predictor. 32 | optional int32 num_layers_before_predictor = 4 [default = 0]; 33 | 34 | // Whether to use dropout for class prediction. 35 | optional bool use_dropout = 5 [default = true]; 36 | 37 | // Keep probability for dropout 38 | optional float dropout_keep_probability = 6 [default = 0.8]; 39 | 40 | // Size of final convolution kernel. If the spatial resolution of the feature 41 | // map is smaller than the kernel size, then the kernel size is set to 42 | // min(feature_width, feature_height). 43 | optional int32 kernel_size = 7 [default = 1]; 44 | 45 | // Size of the encoding for boxes. 46 | optional int32 box_code_size = 8 [default = 4]; 47 | 48 | // Whether to apply sigmoid to the output of class predictions. 49 | // TODO: Do we need this since we have a post processing module.? 50 | optional bool apply_sigmoid_to_scores = 9 [default = false]; 51 | 52 | optional float class_prediction_bias_init = 10 [default = 0.0]; 53 | } 54 | 55 | message MaskRCNNBoxPredictor { 56 | // Hyperparameters for fully connected ops used in the box predictor. 57 | optional Hyperparams fc_hyperparams = 1; 58 | 59 | // Whether to use dropout op prior to the both box and class predictions. 60 | optional bool use_dropout = 2 [default= false]; 61 | 62 | // Keep probability for dropout. This is only used if use_dropout is true. 63 | optional float dropout_keep_probability = 3 [default = 0.5]; 64 | 65 | // Size of the encoding for the boxes. 66 | optional int32 box_code_size = 4 [default = 4]; 67 | 68 | // Hyperparameters for convolution ops used in the box predictor. 69 | optional Hyperparams conv_hyperparams = 5; 70 | 71 | // Whether to predict instance masks inside detection boxes. 72 | optional bool predict_instance_masks = 6 [default = false]; 73 | 74 | // The depth for the first conv2d_transpose op applied to the 75 | // image_features in the mask prediciton branch 76 | optional int32 mask_prediction_conv_depth = 7 [default = 256]; 77 | 78 | // Whether to predict keypoints inside detection boxes. 79 | optional bool predict_keypoints = 8 [default = false]; 80 | } 81 | 82 | message RfcnBoxPredictor { 83 | // Hyperparameters for convolution ops used in the box predictor. 84 | optional Hyperparams conv_hyperparams = 1; 85 | 86 | // Bin sizes for RFCN crops. 87 | optional int32 num_spatial_bins_height = 2 [default = 3]; 88 | 89 | optional int32 num_spatial_bins_width = 3 [default = 3]; 90 | 91 | // Target depth to reduce the input image features to. 92 | optional int32 depth = 4 [default=1024]; 93 | 94 | // Size of the encoding for the boxes. 95 | optional int32 box_code_size = 5 [default = 4]; 96 | 97 | // Size to resize the rfcn crops to. 98 | optional int32 crop_height = 6 [default= 12]; 99 | 100 | optional int32 crop_width = 7 [default=12]; 101 | } 102 | -------------------------------------------------------------------------------- /protos/keypoint_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/keypoint_box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/keypoint_box_coder.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n0object_detection/protos/keypoint_box_coder.proto\x12\x17object_detection.protos\"\x84\x01\n\x10KeypointBoxCoder\x12\x15\n\rnum_keypoints\x18\x01 \x01(\x05\x12\x13\n\x07y_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x03 \x01(\x02:\x02\x31\x30\x12\x17\n\x0cheight_scale\x18\x04 \x01(\x02:\x01\x35\x12\x16\n\x0bwidth_scale\x18\x05 \x01(\x02:\x01\x35') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _KEYPOINTBOXCODER = _descriptor.Descriptor( 29 | name='KeypointBoxCoder', 30 | full_name='object_detection.protos.KeypointBoxCoder', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='num_keypoints', full_name='object_detection.protos.KeypointBoxCoder.num_keypoints', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=False, default_value=0, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='y_scale', full_name='object_detection.protos.KeypointBoxCoder.y_scale', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=10, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='x_scale', full_name='object_detection.protos.KeypointBoxCoder.x_scale', index=2, 51 | number=3, type=2, cpp_type=6, label=1, 52 | has_default_value=True, default_value=10, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='height_scale', full_name='object_detection.protos.KeypointBoxCoder.height_scale', index=3, 58 | number=4, type=2, cpp_type=6, label=1, 59 | has_default_value=True, default_value=5, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='width_scale', full_name='object_detection.protos.KeypointBoxCoder.width_scale', index=4, 65 | number=5, type=2, cpp_type=6, label=1, 66 | has_default_value=True, default_value=5, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | extension_ranges=[], 79 | oneofs=[ 80 | ], 81 | serialized_start=78, 82 | serialized_end=210, 83 | ) 84 | 85 | DESCRIPTOR.message_types_by_name['KeypointBoxCoder'] = _KEYPOINTBOXCODER 86 | 87 | KeypointBoxCoder = _reflection.GeneratedProtocolMessageType('KeypointBoxCoder', (_message.Message,), dict( 88 | DESCRIPTOR = _KEYPOINTBOXCODER, 89 | __module__ = 'object_detection.protos.keypoint_box_coder_pb2' 90 | # @@protoc_insertion_point(class_scope:object_detection.protos.KeypointBoxCoder) 91 | )) 92 | _sym_db.RegisterMessage(KeypointBoxCoder) 93 | 94 | 95 | # @@protoc_insertion_point(module_scope) 96 | -------------------------------------------------------------------------------- /protos/argmax_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/argmax_matcher.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/argmax_matcher.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n,object_detection/protos/argmax_matcher.proto\x12\x17object_detection.protos\"\xca\x01\n\rArgMaxMatcher\x12\x1e\n\x11matched_threshold\x18\x01 \x01(\x02:\x03\x30.5\x12 \n\x13unmatched_threshold\x18\x02 \x01(\x02:\x03\x30.5\x12 \n\x11ignore_thresholds\x18\x03 \x01(\x08:\x05\x66\x61lse\x12,\n\x1enegatives_lower_than_unmatched\x18\x04 \x01(\x08:\x04true\x12\'\n\x18\x66orce_match_for_each_row\x18\x05 \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _ARGMAXMATCHER = _descriptor.Descriptor( 29 | name='ArgMaxMatcher', 30 | full_name='object_detection.protos.ArgMaxMatcher', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='matched_threshold', full_name='object_detection.protos.ArgMaxMatcher.matched_threshold', index=0, 37 | number=1, type=2, cpp_type=6, label=1, 38 | has_default_value=True, default_value=0.5, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='unmatched_threshold', full_name='object_detection.protos.ArgMaxMatcher.unmatched_threshold', index=1, 44 | number=2, type=2, cpp_type=6, label=1, 45 | has_default_value=True, default_value=0.5, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='ignore_thresholds', full_name='object_detection.protos.ArgMaxMatcher.ignore_thresholds', index=2, 51 | number=3, type=8, cpp_type=7, label=1, 52 | has_default_value=True, default_value=False, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='negatives_lower_than_unmatched', full_name='object_detection.protos.ArgMaxMatcher.negatives_lower_than_unmatched', index=3, 58 | number=4, type=8, cpp_type=7, label=1, 59 | has_default_value=True, default_value=True, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='force_match_for_each_row', full_name='object_detection.protos.ArgMaxMatcher.force_match_for_each_row', index=4, 65 | number=5, type=8, cpp_type=7, label=1, 66 | has_default_value=True, default_value=False, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | ], 71 | extensions=[ 72 | ], 73 | nested_types=[], 74 | enum_types=[ 75 | ], 76 | options=None, 77 | is_extendable=False, 78 | extension_ranges=[], 79 | oneofs=[ 80 | ], 81 | serialized_start=74, 82 | serialized_end=276, 83 | ) 84 | 85 | DESCRIPTOR.message_types_by_name['ArgMaxMatcher'] = _ARGMAXMATCHER 86 | 87 | ArgMaxMatcher = _reflection.GeneratedProtocolMessageType('ArgMaxMatcher', (_message.Message,), dict( 88 | DESCRIPTOR = _ARGMAXMATCHER, 89 | __module__ = 'object_detection.protos.argmax_matcher_pb2' 90 | # @@protoc_insertion_point(class_scope:object_detection.protos.ArgMaxMatcher) 91 | )) 92 | _sym_db.RegisterMessage(ArgMaxMatcher) 93 | 94 | 95 | # @@protoc_insertion_point(module_scope) 96 | -------------------------------------------------------------------------------- /protos/anchor_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/anchor_generator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.grid_anchor_generator_pb2 17 | import object_detection.protos.ssd_anchor_generator_pb2 18 | 19 | 20 | DESCRIPTOR = _descriptor.FileDescriptor( 21 | name='object_detection/protos/anchor_generator.proto', 22 | package='object_detection.protos', 23 | serialized_pb=_b('\n.object_detection/protos/anchor_generator.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/grid_anchor_generator.proto\x1a\x32object_detection/protos/ssd_anchor_generator.proto\"\xc7\x01\n\x0f\x41nchorGenerator\x12M\n\x15grid_anchor_generator\x18\x01 \x01(\x0b\x32,.object_detection.protos.GridAnchorGeneratorH\x00\x12K\n\x14ssd_anchor_generator\x18\x02 \x01(\x0b\x32+.object_detection.protos.SsdAnchorGeneratorH\x00\x42\x18\n\x16\x61nchor_generator_oneof') 24 | , 25 | dependencies=[object_detection.protos.grid_anchor_generator_pb2.DESCRIPTOR,object_detection.protos.ssd_anchor_generator_pb2.DESCRIPTOR,]) 26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 27 | 28 | 29 | 30 | 31 | _ANCHORGENERATOR = _descriptor.Descriptor( 32 | name='AnchorGenerator', 33 | full_name='object_detection.protos.AnchorGenerator', 34 | filename=None, 35 | file=DESCRIPTOR, 36 | containing_type=None, 37 | fields=[ 38 | _descriptor.FieldDescriptor( 39 | name='grid_anchor_generator', full_name='object_detection.protos.AnchorGenerator.grid_anchor_generator', index=0, 40 | number=1, type=11, cpp_type=10, label=1, 41 | has_default_value=False, default_value=None, 42 | message_type=None, enum_type=None, containing_type=None, 43 | is_extension=False, extension_scope=None, 44 | options=None), 45 | _descriptor.FieldDescriptor( 46 | name='ssd_anchor_generator', full_name='object_detection.protos.AnchorGenerator.ssd_anchor_generator', index=1, 47 | number=2, type=11, cpp_type=10, label=1, 48 | has_default_value=False, default_value=None, 49 | message_type=None, enum_type=None, containing_type=None, 50 | is_extension=False, extension_scope=None, 51 | options=None), 52 | ], 53 | extensions=[ 54 | ], 55 | nested_types=[], 56 | enum_types=[ 57 | ], 58 | options=None, 59 | is_extendable=False, 60 | extension_ranges=[], 61 | oneofs=[ 62 | _descriptor.OneofDescriptor( 63 | name='anchor_generator_oneof', full_name='object_detection.protos.AnchorGenerator.anchor_generator_oneof', 64 | index=0, containing_type=None, fields=[]), 65 | ], 66 | serialized_start=181, 67 | serialized_end=380, 68 | ) 69 | 70 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].message_type = object_detection.protos.grid_anchor_generator_pb2._GRIDANCHORGENERATOR 71 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].message_type = object_detection.protos.ssd_anchor_generator_pb2._SSDANCHORGENERATOR 72 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append( 73 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator']) 74 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'] 75 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append( 76 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator']) 77 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'] 78 | DESCRIPTOR.message_types_by_name['AnchorGenerator'] = _ANCHORGENERATOR 79 | 80 | AnchorGenerator = _reflection.GeneratedProtocolMessageType('AnchorGenerator', (_message.Message,), dict( 81 | DESCRIPTOR = _ANCHORGENERATOR, 82 | __module__ = 'object_detection.protos.anchor_generator_pb2' 83 | # @@protoc_insertion_point(class_scope:object_detection.protos.AnchorGenerator) 84 | )) 85 | _sym_db.RegisterMessage(AnchorGenerator) 86 | 87 | 88 | # @@protoc_insertion_point(module_scope) 89 | -------------------------------------------------------------------------------- /protos/string_int_label_map_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/string_int_label_map.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/string_int_label_map.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor( 29 | name='StringIntLabelMapItem', 30 | full_name='object_detection.protos.StringIntLabelMapItem', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0, 37 | number=1, type=9, cpp_type=9, label=1, 38 | has_default_value=False, default_value=_b("").decode('utf-8'), 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=False, default_value=0, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2, 51 | number=3, type=9, cpp_type=9, label=1, 52 | has_default_value=False, default_value=_b("").decode('utf-8'), 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | ], 57 | extensions=[ 58 | ], 59 | nested_types=[], 60 | enum_types=[ 61 | ], 62 | options=None, 63 | is_extendable=False, 64 | extension_ranges=[], 65 | oneofs=[ 66 | ], 67 | serialized_start=79, 68 | serialized_end=150, 69 | ) 70 | 71 | 72 | _STRINGINTLABELMAP = _descriptor.Descriptor( 73 | name='StringIntLabelMap', 74 | full_name='object_detection.protos.StringIntLabelMap', 75 | filename=None, 76 | file=DESCRIPTOR, 77 | containing_type=None, 78 | fields=[ 79 | _descriptor.FieldDescriptor( 80 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0, 81 | number=1, type=11, cpp_type=10, label=3, 82 | has_default_value=False, default_value=[], 83 | message_type=None, enum_type=None, containing_type=None, 84 | is_extension=False, extension_scope=None, 85 | options=None), 86 | ], 87 | extensions=[ 88 | ], 89 | nested_types=[], 90 | enum_types=[ 91 | ], 92 | options=None, 93 | is_extendable=False, 94 | extension_ranges=[], 95 | oneofs=[ 96 | ], 97 | serialized_start=152, 98 | serialized_end=233, 99 | ) 100 | 101 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM 102 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM 103 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP 104 | 105 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict( 106 | DESCRIPTOR = _STRINGINTLABELMAPITEM, 107 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 108 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem) 109 | )) 110 | _sym_db.RegisterMessage(StringIntLabelMapItem) 111 | 112 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict( 113 | DESCRIPTOR = _STRINGINTLABELMAP, 114 | __module__ = 'object_detection.protos.string_int_label_map_pb2' 115 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap) 116 | )) 117 | _sym_db.RegisterMessage(StringIntLabelMap) 118 | 119 | 120 | # @@protoc_insertion_point(module_scope) 121 | -------------------------------------------------------------------------------- /utils/shape_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils used to manipulate tensor shapes.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def _is_tensor(t): 22 | """Returns a boolean indicating whether the input is a tensor. 23 | 24 | Args: 25 | t: the input to be tested. 26 | 27 | Returns: 28 | a boolean that indicates whether t is a tensor. 29 | """ 30 | return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable)) 31 | 32 | 33 | def _set_dim_0(t, d0): 34 | """Sets the 0-th dimension of the input tensor. 35 | 36 | Args: 37 | t: the input tensor, assuming the rank is at least 1. 38 | d0: an integer indicating the 0-th dimension of the input tensor. 39 | 40 | Returns: 41 | the tensor t with the 0-th dimension set. 42 | """ 43 | t_shape = t.get_shape().as_list() 44 | t_shape[0] = d0 45 | t.set_shape(t_shape) 46 | return t 47 | 48 | 49 | def pad_tensor(t, length): 50 | """Pads the input tensor with 0s along the first dimension up to the length. 51 | 52 | Args: 53 | t: the input tensor, assuming the rank is at least 1. 54 | length: a tensor of shape [1] or an integer, indicating the first dimension 55 | of the input tensor t after padding, assuming length <= t.shape[0]. 56 | 57 | Returns: 58 | padded_t: the padded tensor, whose first dimension is length. If the length 59 | is an integer, the first dimension of padded_t is set to length 60 | statically. 61 | """ 62 | t_rank = tf.rank(t) 63 | t_shape = tf.shape(t) 64 | t_d0 = t_shape[0] 65 | pad_d0 = tf.expand_dims(length - t_d0, 0) 66 | pad_shape = tf.cond( 67 | tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0), 68 | lambda: tf.expand_dims(length - t_d0, 0)) 69 | padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0) 70 | if not _is_tensor(length): 71 | padded_t = _set_dim_0(padded_t, length) 72 | return padded_t 73 | 74 | 75 | def clip_tensor(t, length): 76 | """Clips the input tensor along the first dimension up to the length. 77 | 78 | Args: 79 | t: the input tensor, assuming the rank is at least 1. 80 | length: a tensor of shape [1] or an integer, indicating the first dimension 81 | of the input tensor t after clipping, assuming length <= t.shape[0]. 82 | 83 | Returns: 84 | clipped_t: the clipped tensor, whose first dimension is length. If the 85 | length is an integer, the first dimension of clipped_t is set to length 86 | statically. 87 | """ 88 | clipped_t = tf.gather(t, tf.range(length)) 89 | if not _is_tensor(length): 90 | clipped_t = _set_dim_0(clipped_t, length) 91 | return clipped_t 92 | 93 | 94 | def pad_or_clip_tensor(t, length): 95 | """Pad or clip the input tensor along the first dimension. 96 | 97 | Args: 98 | t: the input tensor, assuming the rank is at least 1. 99 | length: a tensor of shape [1] or an integer, indicating the first dimension 100 | of the input tensor t after processing. 101 | 102 | Returns: 103 | processed_t: the processed tensor, whose first dimension is length. If the 104 | length is an integer, the first dimension of the processed tensor is set 105 | to length statically. 106 | """ 107 | processed_t = tf.cond( 108 | tf.greater(tf.shape(t)[0], length), 109 | lambda: clip_tensor(t, length), 110 | lambda: pad_tensor(t, length)) 111 | if not _is_tensor(length): 112 | processed_t = _set_dim_0(processed_t, length) 113 | return processed_t 114 | 115 | 116 | def combined_static_and_dynamic_shape(tensor): 117 | """Returns a list containing static and dynamic values for the dimensions. 118 | 119 | Returns a list of static and dynamic values for shape dimensions. This is 120 | useful to preserve static shapes when available in reshape operation. 121 | 122 | Args: 123 | tensor: A tensor of any type. 124 | 125 | Returns: 126 | A list of size tensor.shape.ndims containing integers or a scalar tensor. 127 | """ 128 | static_shape = tensor.shape.as_list() 129 | dynamic_shape = tf.shape(tensor) 130 | combined_shape = [] 131 | for index, dim in enumerate(static_shape): 132 | if dim is not None: 133 | combined_shape.append(dim) 134 | else: 135 | combined_shape.append(dynamic_shape[index]) 136 | return combined_shape 137 | -------------------------------------------------------------------------------- /utils/np_box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxList classes and functions.""" 17 | 18 | import numpy as np 19 | 20 | 21 | class BoxList(object): 22 | """Box collection. 23 | 24 | BoxList represents a list of bounding boxes as numpy array, where each 25 | bounding box is represented as a row of 4 numbers, 26 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a 27 | given list correspond to a single image. 28 | 29 | Optionally, users can add additional related fields (such as 30 | objectness/classification scores). 31 | """ 32 | 33 | def __init__(self, data): 34 | """Constructs box collection. 35 | 36 | Args: 37 | data: a numpy array of shape [N, 4] representing box coordinates 38 | 39 | Raises: 40 | ValueError: if bbox data is not a numpy array 41 | ValueError: if invalid dimensions for bbox data 42 | """ 43 | if not isinstance(data, np.ndarray): 44 | raise ValueError('data must be a numpy array.') 45 | if len(data.shape) != 2 or data.shape[1] != 4: 46 | raise ValueError('Invalid dimensions for box data.') 47 | if data.dtype != np.float32 and data.dtype != np.float64: 48 | raise ValueError('Invalid data type for box data: float is required.') 49 | if not self._is_valid_boxes(data): 50 | raise ValueError('Invalid box data. data must be a numpy array of ' 51 | 'N*[y_min, x_min, y_max, x_max]') 52 | self.data = {'boxes': data} 53 | 54 | def num_boxes(self): 55 | """Return number of boxes held in collections.""" 56 | return self.data['boxes'].shape[0] 57 | 58 | def get_extra_fields(self): 59 | """Return all non-box fields.""" 60 | return [k for k in self.data.keys() if k != 'boxes'] 61 | 62 | def has_field(self, field): 63 | return field in self.data 64 | 65 | def add_field(self, field, field_data): 66 | """Add data to a specified field. 67 | 68 | Args: 69 | field: a string parameter used to speficy a related field to be accessed. 70 | field_data: a numpy array of [N, ...] representing the data associated 71 | with the field. 72 | Raises: 73 | ValueError: if the field is already exist or the dimension of the field 74 | data does not matches the number of boxes. 75 | """ 76 | if self.has_field(field): 77 | raise ValueError('Field ' + field + 'already exists') 78 | if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes(): 79 | raise ValueError('Invalid dimensions for field data') 80 | self.data[field] = field_data 81 | 82 | def get(self): 83 | """Convenience function for accesssing box coordinates. 84 | 85 | Returns: 86 | a numpy array of shape [N, 4] representing box corners 87 | """ 88 | return self.get_field('boxes') 89 | 90 | def get_field(self, field): 91 | """Accesses data associated with the specified field in the box collection. 92 | 93 | Args: 94 | field: a string parameter used to speficy a related field to be accessed. 95 | 96 | Returns: 97 | a numpy 1-d array representing data of an associated field 98 | 99 | Raises: 100 | ValueError: if invalid field 101 | """ 102 | if not self.has_field(field): 103 | raise ValueError('field {} does not exist'.format(field)) 104 | return self.data[field] 105 | 106 | def get_coordinates(self): 107 | """Get corner coordinates of boxes. 108 | 109 | Returns: 110 | a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max] 111 | """ 112 | box_coordinates = self.get() 113 | y_min = box_coordinates[:, 0] 114 | x_min = box_coordinates[:, 1] 115 | y_max = box_coordinates[:, 2] 116 | x_max = box_coordinates[:, 3] 117 | return [y_min, x_min, y_max, x_max] 118 | 119 | def _is_valid_boxes(self, data): 120 | """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin]. 121 | 122 | Args: 123 | data: a numpy array of shape [N, 4] representing box coordinates 124 | 125 | Returns: 126 | a boolean indicating whether all ymax of boxes are equal or greater than 127 | ymin, and all xmax of boxes are equal or greater than xmin. 128 | """ 129 | if data.shape[0] > 0: 130 | for i in range(data.shape[0]): 131 | if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]: 132 | return False 133 | return True 134 | -------------------------------------------------------------------------------- /utils/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains functions which are convenient for unit testing.""" 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from object_detection.core import anchor_generator 21 | from object_detection.core import box_coder 22 | from object_detection.core import box_list 23 | from object_detection.core import box_predictor 24 | from object_detection.core import matcher 25 | from object_detection.utils import shape_utils 26 | 27 | 28 | class MockBoxCoder(box_coder.BoxCoder): 29 | """Simple `difference` BoxCoder.""" 30 | 31 | @property 32 | def code_size(self): 33 | return 4 34 | 35 | def _encode(self, boxes, anchors): 36 | return boxes.get() - anchors.get() 37 | 38 | def _decode(self, rel_codes, anchors): 39 | return box_list.BoxList(rel_codes + anchors.get()) 40 | 41 | 42 | class MockBoxPredictor(box_predictor.BoxPredictor): 43 | """Simple box predictor that ignores inputs and outputs all zeros.""" 44 | 45 | def __init__(self, is_training, num_classes): 46 | super(MockBoxPredictor, self).__init__(is_training, num_classes) 47 | 48 | def _predict(self, image_features, num_predictions_per_location): 49 | combined_feature_shape = shape_utils.combined_static_and_dynamic_shape( 50 | image_features) 51 | batch_size = combined_feature_shape[0] 52 | num_anchors = (combined_feature_shape[1] * combined_feature_shape[2]) 53 | code_size = 4 54 | zero = tf.reduce_sum(0 * image_features) 55 | box_encodings = zero + tf.zeros( 56 | (batch_size, num_anchors, 1, code_size), dtype=tf.float32) 57 | class_predictions_with_background = zero + tf.zeros( 58 | (batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32) 59 | return {box_predictor.BOX_ENCODINGS: box_encodings, 60 | box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND: 61 | class_predictions_with_background} 62 | 63 | 64 | class MockAnchorGenerator(anchor_generator.AnchorGenerator): 65 | """Mock anchor generator.""" 66 | 67 | def name_scope(self): 68 | return 'MockAnchorGenerator' 69 | 70 | def num_anchors_per_location(self): 71 | return [1] 72 | 73 | def _generate(self, feature_map_shape_list): 74 | num_anchors = sum([shape[0] * shape[1] for shape in feature_map_shape_list]) 75 | return box_list.BoxList(tf.zeros((num_anchors, 4), dtype=tf.float32)) 76 | 77 | 78 | class MockMatcher(matcher.Matcher): 79 | """Simple matcher that matches first anchor to first groundtruth box.""" 80 | 81 | def _match(self, similarity_matrix): 82 | return tf.constant([0, -1, -1, -1], dtype=tf.int32) 83 | 84 | 85 | def create_diagonal_gradient_image(height, width, depth): 86 | """Creates pyramid image. Useful for testing. 87 | 88 | For example, pyramid_image(5, 6, 1) looks like: 89 | # [[[ 5. 4. 3. 2. 1. 0.] 90 | # [ 6. 5. 4. 3. 2. 1.] 91 | # [ 7. 6. 5. 4. 3. 2.] 92 | # [ 8. 7. 6. 5. 4. 3.] 93 | # [ 9. 8. 7. 6. 5. 4.]]] 94 | 95 | Args: 96 | height: height of image 97 | width: width of image 98 | depth: depth of image 99 | 100 | Returns: 101 | pyramid image 102 | """ 103 | row = np.arange(height) 104 | col = np.arange(width)[::-1] 105 | image_layer = np.expand_dims(row, 1) + col 106 | image_layer = np.expand_dims(image_layer, 2) 107 | 108 | image = image_layer 109 | for i in range(1, depth): 110 | image = np.concatenate((image, image_layer * pow(10, i)), 2) 111 | 112 | return image.astype(np.float32) 113 | 114 | 115 | def create_random_boxes(num_boxes, max_height, max_width): 116 | """Creates random bounding boxes of specific maximum height and width. 117 | 118 | Args: 119 | num_boxes: number of boxes. 120 | max_height: maximum height of boxes. 121 | max_width: maximum width of boxes. 122 | 123 | Returns: 124 | boxes: numpy array of shape [num_boxes, 4]. Each row is in form 125 | [y_min, x_min, y_max, x_max]. 126 | """ 127 | 128 | y_1 = np.random.uniform(size=(1, num_boxes)) * max_height 129 | y_2 = np.random.uniform(size=(1, num_boxes)) * max_height 130 | x_1 = np.random.uniform(size=(1, num_boxes)) * max_width 131 | x_2 = np.random.uniform(size=(1, num_boxes)) * max_width 132 | 133 | boxes = np.zeros(shape=(num_boxes, 4)) 134 | boxes[:, 0] = np.minimum(y_1, y_2) 135 | boxes[:, 1] = np.minimum(x_1, x_2) 136 | boxes[:, 2] = np.maximum(y_1, y_2) 137 | boxes[:, 3] = np.maximum(x_1, x_2) 138 | 139 | return boxes.astype(np.float32) 140 | -------------------------------------------------------------------------------- /protos/losses.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for configuring the localization loss, classification loss and hard 6 | // example miner used for training object detection models. See core/losses.py 7 | // for details 8 | message Loss { 9 | // Localization loss to use. 10 | optional LocalizationLoss localization_loss = 1; 11 | 12 | // Classification loss to use. 13 | optional ClassificationLoss classification_loss = 2; 14 | 15 | // If not left to default, applies hard example mining. 16 | optional HardExampleMiner hard_example_miner = 3; 17 | 18 | // Classification loss weight. 19 | optional float classification_weight = 4 [default=1.0]; 20 | 21 | // Localization loss weight. 22 | optional float localization_weight = 5 [default=1.0]; 23 | } 24 | 25 | // Configuration for bounding box localization loss function. 26 | message LocalizationLoss { 27 | oneof localization_loss { 28 | WeightedL2LocalizationLoss weighted_l2 = 1; 29 | WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2; 30 | WeightedIOULocalizationLoss weighted_iou = 3; 31 | } 32 | } 33 | 34 | // L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2 35 | message WeightedL2LocalizationLoss { 36 | // Output loss per anchor. 37 | optional bool anchorwise_output = 1 [default=false]; 38 | } 39 | 40 | // SmoothL1 (Huber) location loss: .5 * x ^ 2 if |x| < 1 else |x| - .5 41 | message WeightedSmoothL1LocalizationLoss { 42 | // Output loss per anchor. 43 | optional bool anchorwise_output = 1 [default=false]; 44 | } 45 | 46 | // Intersection over union location loss: 1 - IOU 47 | message WeightedIOULocalizationLoss { 48 | } 49 | 50 | // Configuration for class prediction loss function. 51 | message ClassificationLoss { 52 | oneof classification_loss { 53 | WeightedSigmoidClassificationLoss weighted_sigmoid = 1; 54 | WeightedSoftmaxClassificationLoss weighted_softmax = 2; 55 | BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3; 56 | SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4; 57 | } 58 | } 59 | 60 | // Classification loss using a sigmoid function over class predictions. 61 | message WeightedSigmoidClassificationLoss { 62 | // Output loss per anchor. 63 | optional bool anchorwise_output = 1 [default=false]; 64 | } 65 | 66 | // Sigmoid Focal cross entropy loss as described in 67 | // https://arxiv.org/abs/1708.02002 68 | message SigmoidFocalClassificationLoss { 69 | optional bool anchorwise_output = 1 [default = false]; 70 | // modulating factor for the loss. 71 | optional float gamma = 2 [default = 2.0]; 72 | // alpha weighting factor for the loss. 73 | optional float alpha = 3; 74 | } 75 | 76 | // Classification loss using a softmax function over class predictions. 77 | message WeightedSoftmaxClassificationLoss { 78 | // Output loss per anchor. 79 | optional bool anchorwise_output = 1 [default=false]; 80 | // Scale logit (input) value before calculating softmax classification loss. 81 | // Typically used for softmax distillation. 82 | optional float logit_scale = 2 [default = 1.0]; 83 | } 84 | 85 | // Classification loss using a sigmoid function over the class prediction with 86 | // the highest prediction score. 87 | message BootstrappedSigmoidClassificationLoss { 88 | // Interpolation weight between 0 and 1. 89 | optional float alpha = 1; 90 | 91 | // Whether hard boot strapping should be used or not. If true, will only use 92 | // one class favored by model. Othewise, will use all predicted class 93 | // probabilities. 94 | optional bool hard_bootstrap = 2 [default=false]; 95 | 96 | // Output loss per anchor. 97 | optional bool anchorwise_output = 3 [default=false]; 98 | } 99 | 100 | // Configuation for hard example miner. 101 | message HardExampleMiner { 102 | // Maximum number of hard examples to be selected per image (prior to 103 | // enforcing max negative to positive ratio constraint). If set to 0, 104 | // all examples obtained after NMS are considered. 105 | optional int32 num_hard_examples = 1 [default=64]; 106 | 107 | // Minimum intersection over union for an example to be discarded during NMS. 108 | optional float iou_threshold = 2 [default=0.7]; 109 | 110 | // Whether to use classification losses ('cls', default), localization losses 111 | // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and 112 | // loc_loss_weight are used to compute weighted sum of the two losses. 113 | enum LossType { 114 | BOTH = 0; 115 | CLASSIFICATION = 1; 116 | LOCALIZATION = 2; 117 | } 118 | optional LossType loss_type = 3 [default=BOTH]; 119 | 120 | // Maximum number of negatives to retain for each positive anchor. If 121 | // num_negatives_per_positive is 0 no prespecified negative:positive ratio is 122 | // enforced. 123 | optional int32 max_negatives_per_positive = 4 [default=0]; 124 | 125 | // Minimum number of negative anchors to sample for a given image. Setting 126 | // this to a positive number samples negatives in an image without any 127 | // positive anchors and thus not bias the model towards having at least one 128 | // detection per image. 129 | optional int32 min_negatives_per_image = 5 [default=0]; 130 | } 131 | -------------------------------------------------------------------------------- /utils/variables_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions for manipulating collections of variables during training. 17 | """ 18 | import logging 19 | import re 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | # TODO: Consider replacing with tf.contrib.filter_variables in 27 | # tensorflow/contrib/framework/python/ops/variables.py 28 | def filter_variables(variables, filter_regex_list, invert=False): 29 | """Filters out the variables matching the filter_regex. 30 | 31 | Filter out the variables whose name matches the any of the regular 32 | expressions in filter_regex_list and returns the remaining variables. 33 | Optionally, if invert=True, the complement set is returned. 34 | 35 | Args: 36 | variables: a list of tensorflow variables. 37 | filter_regex_list: a list of string regular expressions. 38 | invert: (boolean). If True, returns the complement of the filter set; that 39 | is, all variables matching filter_regex are kept and all others discarded. 40 | 41 | Returns: 42 | a list of filtered variables. 43 | """ 44 | kept_vars = [] 45 | variables_to_ignore_patterns = filter(None, filter_regex_list) 46 | for var in variables: 47 | add = True 48 | for pattern in variables_to_ignore_patterns: 49 | if re.match(pattern, var.op.name): 50 | add = False 51 | break 52 | if add != invert: 53 | kept_vars.append(var) 54 | return kept_vars 55 | 56 | 57 | def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier): 58 | """Multiply gradients whose variable names match a regular expression. 59 | 60 | Args: 61 | grads_and_vars: A list of gradient to variable pairs (tuples). 62 | regex_list: A list of string regular expressions. 63 | multiplier: A (float) multiplier to apply to each gradient matching the 64 | regular expression. 65 | 66 | Returns: 67 | grads_and_vars: A list of gradient to variable pairs (tuples). 68 | """ 69 | variables = [pair[1] for pair in grads_and_vars] 70 | matching_vars = filter_variables(variables, regex_list, invert=True) 71 | for var in matching_vars: 72 | logging.info('Applying multiplier %f to variable [%s]', 73 | multiplier, var.op.name) 74 | grad_multipliers = {var: float(multiplier) for var in matching_vars} 75 | return slim.learning.multiply_gradients(grads_and_vars, 76 | grad_multipliers) 77 | 78 | 79 | def freeze_gradients_matching_regex(grads_and_vars, regex_list): 80 | """Freeze gradients whose variable names match a regular expression. 81 | 82 | Args: 83 | grads_and_vars: A list of gradient to variable pairs (tuples). 84 | regex_list: A list of string regular expressions. 85 | 86 | Returns: 87 | grads_and_vars: A list of gradient to variable pairs (tuples) that do not 88 | contain the variables and gradients matching the regex. 89 | """ 90 | variables = [pair[1] for pair in grads_and_vars] 91 | matching_vars = filter_variables(variables, regex_list, invert=True) 92 | kept_grads_and_vars = [pair for pair in grads_and_vars 93 | if pair[1] not in matching_vars] 94 | for var in matching_vars: 95 | logging.info('Freezing variable [%s]', var.op.name) 96 | return kept_grads_and_vars 97 | 98 | 99 | def get_variables_available_in_checkpoint(variables, checkpoint_path): 100 | """Returns the subset of variables available in the checkpoint. 101 | 102 | Inspects given checkpoint and returns the subset of variables that are 103 | available in it. 104 | 105 | TODO: force input and output to be a dictionary. 106 | 107 | Args: 108 | variables: a list or dictionary of variables to find in checkpoint. 109 | checkpoint_path: path to the checkpoint to restore variables from. 110 | 111 | Returns: 112 | A list or dictionary of variables. 113 | Raises: 114 | ValueError: if `variables` is not a list or dict. 115 | """ 116 | if isinstance(variables, list): 117 | variable_names_map = {variable.op.name: variable for variable in variables} 118 | elif isinstance(variables, dict): 119 | variable_names_map = variables 120 | else: 121 | raise ValueError('`variables` is expected to be a list or dict.') 122 | ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path) 123 | ckpt_vars = ckpt_reader.get_variable_to_shape_map().keys() 124 | vars_in_ckpt = {} 125 | for variable_name, variable in sorted(variable_names_map.items()): 126 | if variable_name in ckpt_vars: 127 | vars_in_ckpt[variable_name] = variable 128 | else: 129 | logging.warning('Variable [%s] not available in checkpoint', 130 | variable_name) 131 | if isinstance(variables, list): 132 | return vars_in_ckpt.values() 133 | return vars_in_ckpt 134 | -------------------------------------------------------------------------------- /protos/grid_anchor_generator_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/grid_anchor_generator.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/grid_anchor_generator.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n3object_detection/protos/grid_anchor_generator.proto\x12\x17object_detection.protos\"\xcd\x01\n\x13GridAnchorGenerator\x12\x13\n\x06height\x18\x01 \x01(\x05:\x03\x32\x35\x36\x12\x12\n\x05width\x18\x02 \x01(\x05:\x03\x32\x35\x36\x12\x19\n\rheight_stride\x18\x03 \x01(\x05:\x02\x31\x36\x12\x18\n\x0cwidth_stride\x18\x04 \x01(\x05:\x02\x31\x36\x12\x18\n\rheight_offset\x18\x05 \x01(\x05:\x01\x30\x12\x17\n\x0cwidth_offset\x18\x06 \x01(\x05:\x01\x30\x12\x0e\n\x06scales\x18\x07 \x03(\x02\x12\x15\n\raspect_ratios\x18\x08 \x03(\x02') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _GRIDANCHORGENERATOR = _descriptor.Descriptor( 29 | name='GridAnchorGenerator', 30 | full_name='object_detection.protos.GridAnchorGenerator', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='height', full_name='object_detection.protos.GridAnchorGenerator.height', index=0, 37 | number=1, type=5, cpp_type=1, label=1, 38 | has_default_value=True, default_value=256, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='width', full_name='object_detection.protos.GridAnchorGenerator.width', index=1, 44 | number=2, type=5, cpp_type=1, label=1, 45 | has_default_value=True, default_value=256, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='height_stride', full_name='object_detection.protos.GridAnchorGenerator.height_stride', index=2, 51 | number=3, type=5, cpp_type=1, label=1, 52 | has_default_value=True, default_value=16, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='width_stride', full_name='object_detection.protos.GridAnchorGenerator.width_stride', index=3, 58 | number=4, type=5, cpp_type=1, label=1, 59 | has_default_value=True, default_value=16, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='height_offset', full_name='object_detection.protos.GridAnchorGenerator.height_offset', index=4, 65 | number=5, type=5, cpp_type=1, label=1, 66 | has_default_value=True, default_value=0, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='width_offset', full_name='object_detection.protos.GridAnchorGenerator.width_offset', index=5, 72 | number=6, type=5, cpp_type=1, label=1, 73 | has_default_value=True, default_value=0, 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='scales', full_name='object_detection.protos.GridAnchorGenerator.scales', index=6, 79 | number=7, type=2, cpp_type=6, label=3, 80 | has_default_value=False, default_value=[], 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | _descriptor.FieldDescriptor( 85 | name='aspect_ratios', full_name='object_detection.protos.GridAnchorGenerator.aspect_ratios', index=7, 86 | number=8, type=2, cpp_type=6, label=3, 87 | has_default_value=False, default_value=[], 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None), 91 | ], 92 | extensions=[ 93 | ], 94 | nested_types=[], 95 | enum_types=[ 96 | ], 97 | options=None, 98 | is_extendable=False, 99 | extension_ranges=[], 100 | oneofs=[ 101 | ], 102 | serialized_start=81, 103 | serialized_end=286, 104 | ) 105 | 106 | DESCRIPTOR.message_types_by_name['GridAnchorGenerator'] = _GRIDANCHORGENERATOR 107 | 108 | GridAnchorGenerator = _reflection.GeneratedProtocolMessageType('GridAnchorGenerator', (_message.Message,), dict( 109 | DESCRIPTOR = _GRIDANCHORGENERATOR, 110 | __module__ = 'object_detection.protos.grid_anchor_generator_pb2' 111 | # @@protoc_insertion_point(class_scope:object_detection.protos.GridAnchorGenerator) 112 | )) 113 | _sym_db.RegisterMessage(GridAnchorGenerator) 114 | 115 | 116 | # @@protoc_insertion_point(module_scope) 117 | -------------------------------------------------------------------------------- /protos/pipeline_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/pipeline.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.eval_pb2 17 | import object_detection.protos.input_reader_pb2 18 | import object_detection.protos.model_pb2 19 | import object_detection.protos.train_pb2 20 | 21 | 22 | DESCRIPTOR = _descriptor.FileDescriptor( 23 | name='object_detection/protos/pipeline.proto', 24 | package='object_detection.protos', 25 | serialized_pb=_b('\n&object_detection/protos/pipeline.proto\x12\x17object_detection.protos\x1a\"object_detection/protos/eval.proto\x1a*object_detection/protos/input_reader.proto\x1a#object_detection/protos/model.proto\x1a#object_detection/protos/train.proto\"\xca\x02\n\x17TrainEvalPipelineConfig\x12\x36\n\x05model\x18\x01 \x01(\x0b\x32\'.object_detection.protos.DetectionModel\x12:\n\x0ctrain_config\x18\x02 \x01(\x0b\x32$.object_detection.protos.TrainConfig\x12@\n\x12train_input_reader\x18\x03 \x01(\x0b\x32$.object_detection.protos.InputReader\x12\x38\n\x0b\x65val_config\x18\x04 \x01(\x0b\x32#.object_detection.protos.EvalConfig\x12?\n\x11\x65val_input_reader\x18\x05 \x01(\x0b\x32$.object_detection.protos.InputReader') 26 | , 27 | dependencies=[object_detection.protos.eval_pb2.DESCRIPTOR,object_detection.protos.input_reader_pb2.DESCRIPTOR,object_detection.protos.model_pb2.DESCRIPTOR,object_detection.protos.train_pb2.DESCRIPTOR,]) 28 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 29 | 30 | 31 | 32 | 33 | _TRAINEVALPIPELINECONFIG = _descriptor.Descriptor( 34 | name='TrainEvalPipelineConfig', 35 | full_name='object_detection.protos.TrainEvalPipelineConfig', 36 | filename=None, 37 | file=DESCRIPTOR, 38 | containing_type=None, 39 | fields=[ 40 | _descriptor.FieldDescriptor( 41 | name='model', full_name='object_detection.protos.TrainEvalPipelineConfig.model', index=0, 42 | number=1, type=11, cpp_type=10, label=1, 43 | has_default_value=False, default_value=None, 44 | message_type=None, enum_type=None, containing_type=None, 45 | is_extension=False, extension_scope=None, 46 | options=None), 47 | _descriptor.FieldDescriptor( 48 | name='train_config', full_name='object_detection.protos.TrainEvalPipelineConfig.train_config', index=1, 49 | number=2, type=11, cpp_type=10, label=1, 50 | has_default_value=False, default_value=None, 51 | message_type=None, enum_type=None, containing_type=None, 52 | is_extension=False, extension_scope=None, 53 | options=None), 54 | _descriptor.FieldDescriptor( 55 | name='train_input_reader', full_name='object_detection.protos.TrainEvalPipelineConfig.train_input_reader', index=2, 56 | number=3, type=11, cpp_type=10, label=1, 57 | has_default_value=False, default_value=None, 58 | message_type=None, enum_type=None, containing_type=None, 59 | is_extension=False, extension_scope=None, 60 | options=None), 61 | _descriptor.FieldDescriptor( 62 | name='eval_config', full_name='object_detection.protos.TrainEvalPipelineConfig.eval_config', index=3, 63 | number=4, type=11, cpp_type=10, label=1, 64 | has_default_value=False, default_value=None, 65 | message_type=None, enum_type=None, containing_type=None, 66 | is_extension=False, extension_scope=None, 67 | options=None), 68 | _descriptor.FieldDescriptor( 69 | name='eval_input_reader', full_name='object_detection.protos.TrainEvalPipelineConfig.eval_input_reader', index=4, 70 | number=5, type=11, cpp_type=10, label=1, 71 | has_default_value=False, default_value=None, 72 | message_type=None, enum_type=None, containing_type=None, 73 | is_extension=False, extension_scope=None, 74 | options=None), 75 | ], 76 | extensions=[ 77 | ], 78 | nested_types=[], 79 | enum_types=[ 80 | ], 81 | options=None, 82 | is_extendable=False, 83 | extension_ranges=[], 84 | oneofs=[ 85 | ], 86 | serialized_start=222, 87 | serialized_end=552, 88 | ) 89 | 90 | _TRAINEVALPIPELINECONFIG.fields_by_name['model'].message_type = object_detection.protos.model_pb2._DETECTIONMODEL 91 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_config'].message_type = object_detection.protos.train_pb2._TRAINCONFIG 92 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_input_reader'].message_type = object_detection.protos.input_reader_pb2._INPUTREADER 93 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_config'].message_type = object_detection.protos.eval_pb2._EVALCONFIG 94 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_input_reader'].message_type = object_detection.protos.input_reader_pb2._INPUTREADER 95 | DESCRIPTOR.message_types_by_name['TrainEvalPipelineConfig'] = _TRAINEVALPIPELINECONFIG 96 | 97 | TrainEvalPipelineConfig = _reflection.GeneratedProtocolMessageType('TrainEvalPipelineConfig', (_message.Message,), dict( 98 | DESCRIPTOR = _TRAINEVALPIPELINECONFIG, 99 | __module__ = 'object_detection.protos.pipeline_pb2' 100 | # @@protoc_insertion_point(class_scope:object_detection.protos.TrainEvalPipelineConfig) 101 | )) 102 | _sym_db.RegisterMessage(TrainEvalPipelineConfig) 103 | 104 | 105 | # @@protoc_insertion_point(module_scope) 106 | -------------------------------------------------------------------------------- /utils/shape_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.shape_utils.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.utils import shape_utils 21 | 22 | 23 | class UtilTest(tf.test.TestCase): 24 | 25 | def test_pad_tensor_using_integer_input(self): 26 | t1 = tf.constant([1], dtype=tf.int32) 27 | pad_t1 = shape_utils.pad_tensor(t1, 2) 28 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32) 29 | pad_t2 = shape_utils.pad_tensor(t2, 2) 30 | 31 | self.assertEqual(2, pad_t1.get_shape()[0]) 32 | self.assertEqual(2, pad_t2.get_shape()[0]) 33 | 34 | with self.test_session() as sess: 35 | pad_t1_result, pad_t2_result = sess.run([pad_t1, pad_t2]) 36 | self.assertAllEqual([1, 0], pad_t1_result) 37 | self.assertAllClose([[0.1, 0.2], [0, 0]], pad_t2_result) 38 | 39 | def test_pad_tensor_using_tensor_input(self): 40 | t1 = tf.constant([1], dtype=tf.int32) 41 | pad_t1 = shape_utils.pad_tensor(t1, tf.constant(2)) 42 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32) 43 | pad_t2 = shape_utils.pad_tensor(t2, tf.constant(2)) 44 | 45 | with self.test_session() as sess: 46 | pad_t1_result, pad_t2_result = sess.run([pad_t1, pad_t2]) 47 | self.assertAllEqual([1, 0], pad_t1_result) 48 | self.assertAllClose([[0.1, 0.2], [0, 0]], pad_t2_result) 49 | 50 | def test_clip_tensor_using_integer_input(self): 51 | t1 = tf.constant([1, 2, 3], dtype=tf.int32) 52 | clip_t1 = shape_utils.clip_tensor(t1, 2) 53 | t2 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32) 54 | clip_t2 = shape_utils.clip_tensor(t2, 2) 55 | 56 | self.assertEqual(2, clip_t1.get_shape()[0]) 57 | self.assertEqual(2, clip_t2.get_shape()[0]) 58 | 59 | with self.test_session() as sess: 60 | clip_t1_result, clip_t2_result = sess.run([clip_t1, clip_t2]) 61 | self.assertAllEqual([1, 2], clip_t1_result) 62 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], clip_t2_result) 63 | 64 | def test_clip_tensor_using_tensor_input(self): 65 | t1 = tf.constant([1, 2, 3], dtype=tf.int32) 66 | clip_t1 = shape_utils.clip_tensor(t1, tf.constant(2)) 67 | t2 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32) 68 | clip_t2 = shape_utils.clip_tensor(t2, tf.constant(2)) 69 | 70 | with self.test_session() as sess: 71 | clip_t1_result, clip_t2_result = sess.run([clip_t1, clip_t2]) 72 | self.assertAllEqual([1, 2], clip_t1_result) 73 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], clip_t2_result) 74 | 75 | def test_pad_or_clip_tensor_using_integer_input(self): 76 | t1 = tf.constant([1], dtype=tf.int32) 77 | tt1 = shape_utils.pad_or_clip_tensor(t1, 2) 78 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32) 79 | tt2 = shape_utils.pad_or_clip_tensor(t2, 2) 80 | 81 | t3 = tf.constant([1, 2, 3], dtype=tf.int32) 82 | tt3 = shape_utils.clip_tensor(t3, 2) 83 | t4 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32) 84 | tt4 = shape_utils.clip_tensor(t4, 2) 85 | 86 | self.assertEqual(2, tt1.get_shape()[0]) 87 | self.assertEqual(2, tt2.get_shape()[0]) 88 | self.assertEqual(2, tt3.get_shape()[0]) 89 | self.assertEqual(2, tt4.get_shape()[0]) 90 | 91 | with self.test_session() as sess: 92 | tt1_result, tt2_result, tt3_result, tt4_result = sess.run( 93 | [tt1, tt2, tt3, tt4]) 94 | self.assertAllEqual([1, 0], tt1_result) 95 | self.assertAllClose([[0.1, 0.2], [0, 0]], tt2_result) 96 | self.assertAllEqual([1, 2], tt3_result) 97 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], tt4_result) 98 | 99 | def test_pad_or_clip_tensor_using_tensor_input(self): 100 | t1 = tf.constant([1], dtype=tf.int32) 101 | tt1 = shape_utils.pad_or_clip_tensor(t1, tf.constant(2)) 102 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32) 103 | tt2 = shape_utils.pad_or_clip_tensor(t2, tf.constant(2)) 104 | 105 | t3 = tf.constant([1, 2, 3], dtype=tf.int32) 106 | tt3 = shape_utils.clip_tensor(t3, tf.constant(2)) 107 | t4 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32) 108 | tt4 = shape_utils.clip_tensor(t4, tf.constant(2)) 109 | 110 | with self.test_session() as sess: 111 | tt1_result, tt2_result, tt3_result, tt4_result = sess.run( 112 | [tt1, tt2, tt3, tt4]) 113 | self.assertAllEqual([1, 0], tt1_result) 114 | self.assertAllClose([[0.1, 0.2], [0, 0]], tt2_result) 115 | self.assertAllEqual([1, 2], tt3_result) 116 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], tt4_result) 117 | 118 | def test_combines_static_dynamic_shape(self): 119 | tensor = tf.placeholder(tf.float32, shape=(None, 2, 3)) 120 | combined_shape = shape_utils.combined_static_and_dynamic_shape( 121 | tensor) 122 | self.assertTrue(tf.contrib.framework.is_tensor(combined_shape[0])) 123 | self.assertListEqual(combined_shape[1:], [2, 3]) 124 | 125 | 126 | if __name__ == '__main__': 127 | tf.test.main() 128 | -------------------------------------------------------------------------------- /utils/np_box_list_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.np_box_list_test.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import np_box_list 22 | 23 | 24 | class BoxListTest(tf.test.TestCase): 25 | 26 | def test_invalid_box_data(self): 27 | with self.assertRaises(ValueError): 28 | np_box_list.BoxList([0, 0, 1, 1]) 29 | 30 | with self.assertRaises(ValueError): 31 | np_box_list.BoxList(np.array([[0, 0, 1, 1]], dtype=int)) 32 | 33 | with self.assertRaises(ValueError): 34 | np_box_list.BoxList(np.array([0, 1, 1, 3, 4], dtype=float)) 35 | 36 | with self.assertRaises(ValueError): 37 | np_box_list.BoxList(np.array([[0, 1, 1, 3], [3, 1, 1, 5]], dtype=float)) 38 | 39 | def test_has_field_with_existed_field(self): 40 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 41 | [0.0, 0.0, 20.0, 20.0]], 42 | dtype=float) 43 | boxlist = np_box_list.BoxList(boxes) 44 | self.assertTrue(boxlist.has_field('boxes')) 45 | 46 | def test_has_field_with_nonexisted_field(self): 47 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 48 | [0.0, 0.0, 20.0, 20.0]], 49 | dtype=float) 50 | boxlist = np_box_list.BoxList(boxes) 51 | self.assertFalse(boxlist.has_field('scores')) 52 | 53 | def test_get_field_with_existed_field(self): 54 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 55 | [0.0, 0.0, 20.0, 20.0]], 56 | dtype=float) 57 | boxlist = np_box_list.BoxList(boxes) 58 | self.assertTrue(np.allclose(boxlist.get_field('boxes'), boxes)) 59 | 60 | def test_get_field_with_nonexited_field(self): 61 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 62 | [0.0, 0.0, 20.0, 20.0]], 63 | dtype=float) 64 | boxlist = np_box_list.BoxList(boxes) 65 | with self.assertRaises(ValueError): 66 | boxlist.get_field('scores') 67 | 68 | 69 | class AddExtraFieldTest(tf.test.TestCase): 70 | 71 | def setUp(self): 72 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 73 | [0.0, 0.0, 20.0, 20.0]], 74 | dtype=float) 75 | self.boxlist = np_box_list.BoxList(boxes) 76 | 77 | def test_add_already_existed_field(self): 78 | with self.assertRaises(ValueError): 79 | self.boxlist.add_field('boxes', np.array([[0, 0, 0, 1, 0]], dtype=float)) 80 | 81 | def test_add_invalid_field_data(self): 82 | with self.assertRaises(ValueError): 83 | self.boxlist.add_field('scores', np.array([0.5, 0.7], dtype=float)) 84 | with self.assertRaises(ValueError): 85 | self.boxlist.add_field('scores', 86 | np.array([0.5, 0.7, 0.9, 0.1], dtype=float)) 87 | 88 | def test_add_single_dimensional_field_data(self): 89 | boxlist = self.boxlist 90 | scores = np.array([0.5, 0.7, 0.9], dtype=float) 91 | boxlist.add_field('scores', scores) 92 | self.assertTrue(np.allclose(scores, self.boxlist.get_field('scores'))) 93 | 94 | def test_add_multi_dimensional_field_data(self): 95 | boxlist = self.boxlist 96 | labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]], 97 | dtype=int) 98 | boxlist.add_field('labels', labels) 99 | self.assertTrue(np.allclose(labels, self.boxlist.get_field('labels'))) 100 | 101 | def test_get_extra_fields(self): 102 | boxlist = self.boxlist 103 | self.assertSameElements(boxlist.get_extra_fields(), []) 104 | 105 | scores = np.array([0.5, 0.7, 0.9], dtype=float) 106 | boxlist.add_field('scores', scores) 107 | self.assertSameElements(boxlist.get_extra_fields(), ['scores']) 108 | 109 | labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]], 110 | dtype=int) 111 | boxlist.add_field('labels', labels) 112 | self.assertSameElements(boxlist.get_extra_fields(), ['scores', 'labels']) 113 | 114 | def test_get_coordinates(self): 115 | y_min, x_min, y_max, x_max = self.boxlist.get_coordinates() 116 | 117 | expected_y_min = np.array([3.0, 14.0, 0.0], dtype=float) 118 | expected_x_min = np.array([4.0, 14.0, 0.0], dtype=float) 119 | expected_y_max = np.array([6.0, 15.0, 20.0], dtype=float) 120 | expected_x_max = np.array([8.0, 15.0, 20.0], dtype=float) 121 | 122 | self.assertTrue(np.allclose(y_min, expected_y_min)) 123 | self.assertTrue(np.allclose(x_min, expected_x_min)) 124 | self.assertTrue(np.allclose(y_max, expected_y_max)) 125 | self.assertTrue(np.allclose(x_max, expected_x_max)) 126 | 127 | def test_num_boxes(self): 128 | boxes = np.array([[0., 0., 100., 100.], [10., 30., 50., 70.]], dtype=float) 129 | boxlist = np_box_list.BoxList(boxes) 130 | expected_num_boxes = 2 131 | self.assertEquals(boxlist.num_boxes(), expected_num_boxes) 132 | 133 | 134 | if __name__ == '__main__': 135 | tf.test.main() 136 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for computing metrics like precision, recall, CorLoc and etc.""" 17 | from __future__ import division 18 | 19 | import numpy as np 20 | 21 | 22 | def compute_precision_recall(scores, labels, num_gt): 23 | """Compute precision and recall. 24 | 25 | Args: 26 | scores: A float numpy array representing detection score 27 | labels: A boolean numpy array representing true/false positive labels 28 | num_gt: Number of ground truth instances 29 | 30 | Raises: 31 | ValueError: if the input is not of the correct format 32 | 33 | Returns: 34 | precision: Fraction of positive instances over detected ones. This value is 35 | None if no ground truth labels are present. 36 | recall: Fraction of detected positive instance over all positive instances. 37 | This value is None if no ground truth labels are present. 38 | 39 | """ 40 | if not isinstance( 41 | labels, np.ndarray) or labels.dtype != np.bool or len(labels.shape) != 1: 42 | raise ValueError("labels must be single dimension bool numpy array") 43 | 44 | if not isinstance( 45 | scores, np.ndarray) or len(scores.shape) != 1: 46 | raise ValueError("scores must be single dimension numpy array") 47 | 48 | if num_gt < np.sum(labels): 49 | raise ValueError("Number of true positives must be smaller than num_gt.") 50 | 51 | if len(scores) != len(labels): 52 | raise ValueError("scores and labels must be of the same size.") 53 | 54 | if num_gt == 0: 55 | return None, None 56 | 57 | sorted_indices = np.argsort(scores) 58 | sorted_indices = sorted_indices[::-1] 59 | labels = labels.astype(int) 60 | true_positive_labels = labels[sorted_indices] 61 | false_positive_labels = 1 - true_positive_labels 62 | cum_true_positives = np.cumsum(true_positive_labels) 63 | cum_false_positives = np.cumsum(false_positive_labels) 64 | precision = cum_true_positives.astype(float) / ( 65 | cum_true_positives + cum_false_positives) 66 | recall = cum_true_positives.astype(float) / num_gt 67 | return precision, recall 68 | 69 | 70 | def compute_average_precision(precision, recall): 71 | """Compute Average Precision according to the definition in VOCdevkit. 72 | 73 | Precision is modified to ensure that it does not decrease as recall 74 | decrease. 75 | 76 | Args: 77 | precision: A float [N, 1] numpy array of precisions 78 | recall: A float [N, 1] numpy array of recalls 79 | 80 | Raises: 81 | ValueError: if the input is not of the correct format 82 | 83 | Returns: 84 | average_precison: The area under the precision recall curve. NaN if 85 | precision and recall are None. 86 | 87 | """ 88 | if precision is None: 89 | if recall is not None: 90 | raise ValueError("If precision is None, recall must also be None") 91 | return np.NAN 92 | 93 | if not isinstance(precision, np.ndarray) or not isinstance(recall, 94 | np.ndarray): 95 | raise ValueError("precision and recall must be numpy array") 96 | if precision.dtype != np.float or recall.dtype != np.float: 97 | raise ValueError("input must be float numpy array.") 98 | if len(precision) != len(recall): 99 | raise ValueError("precision and recall must be of the same size.") 100 | if not precision.size: 101 | return 0.0 102 | if np.amin(precision) < 0 or np.amax(precision) > 1: 103 | raise ValueError("Precision must be in the range of [0, 1].") 104 | if np.amin(recall) < 0 or np.amax(recall) > 1: 105 | raise ValueError("recall must be in the range of [0, 1].") 106 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): 107 | raise ValueError("recall must be a non-decreasing array") 108 | 109 | recall = np.concatenate([[0], recall, [1]]) 110 | precision = np.concatenate([[0], precision, [0]]) 111 | 112 | # Preprocess precision to be a non-decreasing array 113 | for i in range(len(precision) - 2, -1, -1): 114 | precision[i] = np.maximum(precision[i], precision[i + 1]) 115 | 116 | indices = np.where(recall[1:] != recall[:-1])[0] + 1 117 | average_precision = np.sum( 118 | (recall[indices] - recall[indices - 1]) * precision[indices]) 119 | return average_precision 120 | 121 | 122 | def compute_cor_loc(num_gt_imgs_per_class, 123 | num_images_correctly_detected_per_class): 124 | """Compute CorLoc according to the definition in the following paper. 125 | 126 | https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf 127 | 128 | Returns nans if there are no ground truth images for a class. 129 | 130 | Args: 131 | num_gt_imgs_per_class: 1D array, representing number of images containing 132 | at least one object instance of a particular class 133 | num_images_correctly_detected_per_class: 1D array, representing number of 134 | images that are correctly detected at least one object instance of a 135 | particular class 136 | 137 | Returns: 138 | corloc_per_class: A float numpy array represents the corloc score of each 139 | class 140 | """ 141 | return np.where( 142 | num_gt_imgs_per_class == 0, 143 | np.nan, 144 | num_images_correctly_detected_per_class / num_gt_imgs_per_class) 145 | -------------------------------------------------------------------------------- /protos/box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/box_coder.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | import object_detection.protos.faster_rcnn_box_coder_pb2 17 | import object_detection.protos.keypoint_box_coder_pb2 18 | import object_detection.protos.mean_stddev_box_coder_pb2 19 | import object_detection.protos.square_box_coder_pb2 20 | 21 | 22 | DESCRIPTOR = _descriptor.FileDescriptor( 23 | name='object_detection/protos/box_coder.proto', 24 | package='object_detection.protos', 25 | serialized_pb=_b('\n\'object_detection/protos/box_coder.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/faster_rcnn_box_coder.proto\x1a\x30object_detection/protos/keypoint_box_coder.proto\x1a\x33object_detection/protos/mean_stddev_box_coder.proto\x1a.object_detection/protos/square_box_coder.proto\"\xc7\x02\n\x08\x42oxCoder\x12L\n\x15\x66\x61ster_rcnn_box_coder\x18\x01 \x01(\x0b\x32+.object_detection.protos.FasterRcnnBoxCoderH\x00\x12L\n\x15mean_stddev_box_coder\x18\x02 \x01(\x0b\x32+.object_detection.protos.MeanStddevBoxCoderH\x00\x12\x43\n\x10square_box_coder\x18\x03 \x01(\x0b\x32\'.object_detection.protos.SquareBoxCoderH\x00\x12G\n\x12keypoint_box_coder\x18\x04 \x01(\x0b\x32).object_detection.protos.KeypointBoxCoderH\x00\x42\x11\n\x0f\x62ox_coder_oneof') 26 | , 27 | dependencies=[object_detection.protos.faster_rcnn_box_coder_pb2.DESCRIPTOR,object_detection.protos.keypoint_box_coder_pb2.DESCRIPTOR,object_detection.protos.mean_stddev_box_coder_pb2.DESCRIPTOR,object_detection.protos.square_box_coder_pb2.DESCRIPTOR,]) 28 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 29 | 30 | 31 | 32 | 33 | _BOXCODER = _descriptor.Descriptor( 34 | name='BoxCoder', 35 | full_name='object_detection.protos.BoxCoder', 36 | filename=None, 37 | file=DESCRIPTOR, 38 | containing_type=None, 39 | fields=[ 40 | _descriptor.FieldDescriptor( 41 | name='faster_rcnn_box_coder', full_name='object_detection.protos.BoxCoder.faster_rcnn_box_coder', index=0, 42 | number=1, type=11, cpp_type=10, label=1, 43 | has_default_value=False, default_value=None, 44 | message_type=None, enum_type=None, containing_type=None, 45 | is_extension=False, extension_scope=None, 46 | options=None), 47 | _descriptor.FieldDescriptor( 48 | name='mean_stddev_box_coder', full_name='object_detection.protos.BoxCoder.mean_stddev_box_coder', index=1, 49 | number=2, type=11, cpp_type=10, label=1, 50 | has_default_value=False, default_value=None, 51 | message_type=None, enum_type=None, containing_type=None, 52 | is_extension=False, extension_scope=None, 53 | options=None), 54 | _descriptor.FieldDescriptor( 55 | name='square_box_coder', full_name='object_detection.protos.BoxCoder.square_box_coder', index=2, 56 | number=3, type=11, cpp_type=10, label=1, 57 | has_default_value=False, default_value=None, 58 | message_type=None, enum_type=None, containing_type=None, 59 | is_extension=False, extension_scope=None, 60 | options=None), 61 | _descriptor.FieldDescriptor( 62 | name='keypoint_box_coder', full_name='object_detection.protos.BoxCoder.keypoint_box_coder', index=3, 63 | number=4, type=11, cpp_type=10, label=1, 64 | has_default_value=False, default_value=None, 65 | message_type=None, enum_type=None, containing_type=None, 66 | is_extension=False, extension_scope=None, 67 | options=None), 68 | ], 69 | extensions=[ 70 | ], 71 | nested_types=[], 72 | enum_types=[ 73 | ], 74 | options=None, 75 | is_extendable=False, 76 | extension_ranges=[], 77 | oneofs=[ 78 | _descriptor.OneofDescriptor( 79 | name='box_coder_oneof', full_name='object_detection.protos.BoxCoder.box_coder_oneof', 80 | index=0, containing_type=None, fields=[]), 81 | ], 82 | serialized_start=273, 83 | serialized_end=600, 84 | ) 85 | 86 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].message_type = object_detection.protos.faster_rcnn_box_coder_pb2._FASTERRCNNBOXCODER 87 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].message_type = object_detection.protos.mean_stddev_box_coder_pb2._MEANSTDDEVBOXCODER 88 | _BOXCODER.fields_by_name['square_box_coder'].message_type = object_detection.protos.square_box_coder_pb2._SQUAREBOXCODER 89 | _BOXCODER.fields_by_name['keypoint_box_coder'].message_type = object_detection.protos.keypoint_box_coder_pb2._KEYPOINTBOXCODER 90 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 91 | _BOXCODER.fields_by_name['faster_rcnn_box_coder']) 92 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 93 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 94 | _BOXCODER.fields_by_name['mean_stddev_box_coder']) 95 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 96 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 97 | _BOXCODER.fields_by_name['square_box_coder']) 98 | _BOXCODER.fields_by_name['square_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 99 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append( 100 | _BOXCODER.fields_by_name['keypoint_box_coder']) 101 | _BOXCODER.fields_by_name['keypoint_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof'] 102 | DESCRIPTOR.message_types_by_name['BoxCoder'] = _BOXCODER 103 | 104 | BoxCoder = _reflection.GeneratedProtocolMessageType('BoxCoder', (_message.Message,), dict( 105 | DESCRIPTOR = _BOXCODER, 106 | __module__ = 'object_detection.protos.box_coder_pb2' 107 | # @@protoc_insertion_point(class_scope:object_detection.protos.BoxCoder) 108 | )) 109 | _sym_db.RegisterMessage(BoxCoder) 110 | 111 | 112 | # @@protoc_insertion_point(module_scope) 113 | -------------------------------------------------------------------------------- /utils/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Label map utility functions.""" 17 | 18 | import logging 19 | 20 | import tensorflow as tf 21 | from google.protobuf import text_format 22 | #from object_detection.protos import string_int_label_map_pb2 23 | from protos import string_int_label_map_pb2 24 | 25 | 26 | def _validate_label_map(label_map): 27 | """Checks if a label map is valid. 28 | 29 | Args: 30 | label_map: StringIntLabelMap to validate. 31 | 32 | Raises: 33 | ValueError: if label map is invalid. 34 | """ 35 | for item in label_map.item: 36 | if item.id < 1: 37 | raise ValueError('Label map ids should be >= 1.') 38 | 39 | 40 | def create_category_index(categories): 41 | """Creates dictionary of COCO compatible categories keyed by category id. 42 | 43 | Args: 44 | categories: a list of dicts, each of which has the following keys: 45 | 'id': (required) an integer id uniquely identifying this category. 46 | 'name': (required) string representing category name 47 | e.g., 'cat', 'dog', 'pizza'. 48 | 49 | Returns: 50 | category_index: a dict containing the same entries as categories, but keyed 51 | by the 'id' field of each category. 52 | """ 53 | category_index = {} 54 | for cat in categories: 55 | category_index[cat['id']] = cat 56 | return category_index 57 | 58 | 59 | def convert_label_map_to_categories(label_map, 60 | max_num_classes, 61 | use_display_name=True): 62 | """Loads label map proto and returns categories list compatible with eval. 63 | 64 | This function loads a label map and returns a list of dicts, each of which 65 | has the following keys: 66 | 'id': (required) an integer id uniquely identifying this category. 67 | 'name': (required) string representing category name 68 | e.g., 'cat', 'dog', 'pizza'. 69 | We only allow class into the list if its id-label_id_offset is 70 | between 0 (inclusive) and max_num_classes (exclusive). 71 | If there are several items mapping to the same id in the label map, 72 | we will only keep the first one in the categories list. 73 | 74 | Args: 75 | label_map: a StringIntLabelMapProto or None. If None, a default categories 76 | list is created with max_num_classes categories. 77 | max_num_classes: maximum number of (consecutive) label indices to include. 78 | use_display_name: (boolean) choose whether to load 'display_name' field 79 | as category name. If False or if the display_name field does not exist, 80 | uses 'name' field as category names instead. 81 | Returns: 82 | categories: a list of dictionaries representing all possible categories. 83 | """ 84 | categories = [] 85 | list_of_ids_already_added = [] 86 | if not label_map: 87 | label_id_offset = 1 88 | for class_id in range(max_num_classes): 89 | categories.append({ 90 | 'id': class_id + label_id_offset, 91 | 'name': 'category_{}'.format(class_id + label_id_offset) 92 | }) 93 | return categories 94 | for item in label_map.item: 95 | if not 0 < item.id <= max_num_classes: 96 | logging.info('Ignore item %d since it falls outside of requested ' 97 | 'label range.', item.id) 98 | continue 99 | if use_display_name and item.HasField('display_name'): 100 | name = item.display_name 101 | else: 102 | name = item.name 103 | if item.id not in list_of_ids_already_added: 104 | list_of_ids_already_added.append(item.id) 105 | categories.append({'id': item.id, 'name': name}) 106 | return categories 107 | 108 | 109 | def load_labelmap(path): 110 | """Loads label map proto. 111 | 112 | Args: 113 | path: path to StringIntLabelMap proto text file. 114 | Returns: 115 | a StringIntLabelMapProto 116 | """ 117 | with tf.gfile.GFile(path, 'r') as fid: 118 | label_map_string = fid.read() 119 | label_map = string_int_label_map_pb2.StringIntLabelMap() 120 | try: 121 | text_format.Merge(label_map_string, label_map) 122 | except text_format.ParseError: 123 | label_map.ParseFromString(label_map_string) 124 | _validate_label_map(label_map) 125 | return label_map 126 | 127 | 128 | def get_label_map_dict(label_map_path, use_display_name=False): 129 | """Reads a label map and returns a dictionary of label names to id. 130 | 131 | Args: 132 | label_map_path: path to label_map. 133 | use_display_name: whether to use the label map items' display names as keys. 134 | 135 | Returns: 136 | A dictionary mapping label names to id. 137 | """ 138 | label_map = load_labelmap(label_map_path) 139 | label_map_dict = {} 140 | for item in label_map.item: 141 | if use_display_name: 142 | label_map_dict[item.display_name] = item.id 143 | else: 144 | label_map_dict[item.name] = item.id 145 | return label_map_dict 146 | 147 | 148 | def create_category_index_from_labelmap(label_map_path): 149 | """Reads a label map and returns a category index. 150 | 151 | Args: 152 | label_map_path: Path to `StringIntLabelMap` proto text file. 153 | 154 | Returns: 155 | A category index, which is a dictionary that maps integer ids to dicts 156 | containing categories, e.g. 157 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 158 | """ 159 | label_map = load_labelmap(label_map_path) 160 | max_num_classes = max(item.id for item in label_map.item) 161 | categories = convert_label_map_to_categories(label_map, max_num_classes) 162 | return create_category_index(categories) 163 | 164 | 165 | def create_class_agnostic_category_index(): 166 | """Creates a category index with a single `object` class.""" 167 | return {1: {'id': 1, 'name': 'object'}} 168 | -------------------------------------------------------------------------------- /protos/faster_rcnn.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/anchor_generator.proto"; 6 | import "object_detection/protos/box_predictor.proto"; 7 | import "object_detection/protos/hyperparams.proto"; 8 | import "object_detection/protos/image_resizer.proto"; 9 | import "object_detection/protos/losses.proto"; 10 | import "object_detection/protos/post_processing.proto"; 11 | 12 | // Configuration for Faster R-CNN models. 13 | // See meta_architectures/faster_rcnn_meta_arch.py and models/model_builder.py 14 | // 15 | // Naming conventions: 16 | // Faster R-CNN models have two stages: a first stage region proposal network 17 | // (or RPN) and a second stage box classifier. We thus use the prefixes 18 | // `first_stage_` and `second_stage_` to indicate the stage to which each 19 | // parameter pertains when relevant. 20 | message FasterRcnn { 21 | 22 | // Whether to construct only the Region Proposal Network (RPN). 23 | optional bool first_stage_only = 1 [default=false]; 24 | 25 | // Number of classes to predict. 26 | optional int32 num_classes = 3; 27 | 28 | // Image resizer for preprocessing the input image. 29 | optional ImageResizer image_resizer = 4; 30 | 31 | // Feature extractor config. 32 | optional FasterRcnnFeatureExtractor feature_extractor = 5; 33 | 34 | 35 | // (First stage) region proposal network (RPN) parameters. 36 | 37 | // Anchor generator to compute RPN anchors. 38 | optional AnchorGenerator first_stage_anchor_generator = 6; 39 | 40 | // Atrous rate for the convolution op applied to the 41 | // `first_stage_features_to_crop` tensor to obtain box predictions. 42 | optional int32 first_stage_atrous_rate = 7 [default=1]; 43 | 44 | // Hyperparameters for the convolutional RPN box predictor. 45 | optional Hyperparams first_stage_box_predictor_conv_hyperparams = 8; 46 | 47 | // Kernel size to use for the convolution op just prior to RPN box 48 | // predictions. 49 | optional int32 first_stage_box_predictor_kernel_size = 9 [default=3]; 50 | 51 | // Output depth for the convolution op just prior to RPN box predictions. 52 | optional int32 first_stage_box_predictor_depth = 10 [default=512]; 53 | 54 | // The batch size to use for computing the first stage objectness and 55 | // location losses. 56 | optional int32 first_stage_minibatch_size = 11 [default=256]; 57 | 58 | // Fraction of positive examples per image for the RPN. 59 | optional float first_stage_positive_balance_fraction = 12 [default=0.5]; 60 | 61 | // Non max suppression score threshold applied to first stage RPN proposals. 62 | optional float first_stage_nms_score_threshold = 13 [default=0.0]; 63 | 64 | // Non max suppression IOU threshold applied to first stage RPN proposals. 65 | optional float first_stage_nms_iou_threshold = 14 [default=0.7]; 66 | 67 | // Maximum number of RPN proposals retained after first stage postprocessing. 68 | optional int32 first_stage_max_proposals = 15 [default=300]; 69 | 70 | // First stage RPN localization loss weight. 71 | optional float first_stage_localization_loss_weight = 16 [default=1.0]; 72 | 73 | // First stage RPN objectness loss weight. 74 | optional float first_stage_objectness_loss_weight = 17 [default=1.0]; 75 | 76 | 77 | // Per-region cropping parameters. 78 | // Note that if a R-FCN model is constructed the per region cropping 79 | // parameters below are ignored. 80 | 81 | // Output size (width and height are set to be the same) of the initial 82 | // bilinear interpolation based cropping during ROI pooling. 83 | optional int32 initial_crop_size = 18; 84 | 85 | // Kernel size of the max pool op on the cropped feature map during 86 | // ROI pooling. 87 | optional int32 maxpool_kernel_size = 19; 88 | 89 | // Stride of the max pool op on the cropped feature map during ROI pooling. 90 | optional int32 maxpool_stride = 20; 91 | 92 | 93 | // (Second stage) box classifier parameters 94 | 95 | // Hyperparameters for the second stage box predictor. If box predictor type 96 | // is set to rfcn_box_predictor, a R-FCN model is constructed, otherwise a 97 | // Faster R-CNN model is constructed. 98 | optional BoxPredictor second_stage_box_predictor = 21; 99 | 100 | // The batch size per image used for computing the classification and refined 101 | // location loss of the box classifier. 102 | // Note that this field is ignored if `hard_example_miner` is configured. 103 | optional int32 second_stage_batch_size = 22 [default=64]; 104 | 105 | // Fraction of positive examples to use per image for the box classifier. 106 | optional float second_stage_balance_fraction = 23 [default=0.25]; 107 | 108 | // Post processing to apply on the second stage box classifier predictions. 109 | // Note: the `score_converter` provided to the FasterRCNNMetaArch constructor 110 | // is taken from this `second_stage_post_processing` proto. 111 | optional PostProcessing second_stage_post_processing = 24; 112 | 113 | // Second stage refined localization loss weight. 114 | optional float second_stage_localization_loss_weight = 25 [default=1.0]; 115 | 116 | // Second stage classification loss weight 117 | optional float second_stage_classification_loss_weight = 26 [default=1.0]; 118 | 119 | // Second stage instance mask loss weight. Note that this is only applicable 120 | // when `MaskRCNNBoxPredictor` is selected for second stage and configured to 121 | // predict instance masks. 122 | optional float second_stage_mask_prediction_loss_weight = 27 [default=1.0]; 123 | 124 | // If not left to default, applies hard example mining only to classification 125 | // and localization loss.. 126 | optional HardExampleMiner hard_example_miner = 28; 127 | 128 | // Loss for second stage box classifers, supports Softmax and Sigmoid. 129 | // Note that score converter must be consistent with loss type. 130 | // When there are multiple labels assigned to the same boxes, recommend 131 | // to use sigmoid loss and enable merge_multiple_label_boxes. 132 | // If not specified, Softmax loss is used as default. 133 | optional ClassificationLoss second_stage_classification_loss = 29; 134 | } 135 | 136 | 137 | message FasterRcnnFeatureExtractor { 138 | // Type of Faster R-CNN model (e.g., 'faster_rcnn_resnet101'; 139 | // See builders/model_builder.py for expected types). 140 | optional string type = 1; 141 | 142 | // Output stride of extracted RPN feature map. 143 | optional int32 first_stage_features_stride = 2 [default=16]; 144 | 145 | // Whether to update batch norm parameters during training or not. 146 | // When training with a relative large batch size (e.g. 8), it could be 147 | // desirable to enable batch norm update. 148 | optional bool batch_norm_trainable = 3 [default=false]; 149 | } 150 | -------------------------------------------------------------------------------- /utils/label_map_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.label_map_util.""" 17 | 18 | import os 19 | import tensorflow as tf 20 | 21 | from google.protobuf import text_format 22 | from object_detection.protos import string_int_label_map_pb2 23 | from object_detection.utils import label_map_util 24 | 25 | 26 | class LabelMapUtilTest(tf.test.TestCase): 27 | 28 | def _generate_label_map(self, num_classes): 29 | label_map_proto = string_int_label_map_pb2.StringIntLabelMap() 30 | for i in range(1, num_classes + 1): 31 | item = label_map_proto.item.add() 32 | item.id = i 33 | item.name = 'label_' + str(i) 34 | item.display_name = str(i) 35 | return label_map_proto 36 | 37 | def test_get_label_map_dict(self): 38 | label_map_string = """ 39 | item { 40 | id:2 41 | name:'cat' 42 | } 43 | item { 44 | id:1 45 | name:'dog' 46 | } 47 | """ 48 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') 49 | with tf.gfile.Open(label_map_path, 'wb') as f: 50 | f.write(label_map_string) 51 | 52 | label_map_dict = label_map_util.get_label_map_dict(label_map_path) 53 | self.assertEqual(label_map_dict['dog'], 1) 54 | self.assertEqual(label_map_dict['cat'], 2) 55 | 56 | def test_get_label_map_dict_display(self): 57 | label_map_string = """ 58 | item { 59 | id:2 60 | display_name:'cat' 61 | } 62 | item { 63 | id:1 64 | display_name:'dog' 65 | } 66 | """ 67 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') 68 | with tf.gfile.Open(label_map_path, 'wb') as f: 69 | f.write(label_map_string) 70 | 71 | label_map_dict = label_map_util.get_label_map_dict( 72 | label_map_path, use_display_name=True) 73 | self.assertEqual(label_map_dict['dog'], 1) 74 | self.assertEqual(label_map_dict['cat'], 2) 75 | 76 | def test_load_bad_label_map(self): 77 | label_map_string = """ 78 | item { 79 | id:0 80 | name:'class that should not be indexed at zero' 81 | } 82 | item { 83 | id:2 84 | name:'cat' 85 | } 86 | item { 87 | id:1 88 | name:'dog' 89 | } 90 | """ 91 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') 92 | with tf.gfile.Open(label_map_path, 'wb') as f: 93 | f.write(label_map_string) 94 | 95 | with self.assertRaises(ValueError): 96 | label_map_util.load_labelmap(label_map_path) 97 | 98 | def test_keep_categories_with_unique_id(self): 99 | label_map_proto = string_int_label_map_pb2.StringIntLabelMap() 100 | label_map_string = """ 101 | item { 102 | id:2 103 | name:'cat' 104 | } 105 | item { 106 | id:1 107 | name:'child' 108 | } 109 | item { 110 | id:1 111 | name:'person' 112 | } 113 | item { 114 | id:1 115 | name:'n00007846' 116 | } 117 | """ 118 | text_format.Merge(label_map_string, label_map_proto) 119 | categories = label_map_util.convert_label_map_to_categories( 120 | label_map_proto, max_num_classes=3) 121 | self.assertListEqual([{ 122 | 'id': 2, 123 | 'name': u'cat' 124 | }, { 125 | 'id': 1, 126 | 'name': u'child' 127 | }], categories) 128 | 129 | def test_convert_label_map_to_categories_no_label_map(self): 130 | categories = label_map_util.convert_label_map_to_categories( 131 | None, max_num_classes=3) 132 | expected_categories_list = [{ 133 | 'name': u'category_1', 134 | 'id': 1 135 | }, { 136 | 'name': u'category_2', 137 | 'id': 2 138 | }, { 139 | 'name': u'category_3', 140 | 'id': 3 141 | }] 142 | self.assertListEqual(expected_categories_list, categories) 143 | 144 | def test_convert_label_map_to_coco_categories(self): 145 | label_map_proto = self._generate_label_map(num_classes=4) 146 | categories = label_map_util.convert_label_map_to_categories( 147 | label_map_proto, max_num_classes=3) 148 | expected_categories_list = [{ 149 | 'name': u'1', 150 | 'id': 1 151 | }, { 152 | 'name': u'2', 153 | 'id': 2 154 | }, { 155 | 'name': u'3', 156 | 'id': 3 157 | }] 158 | self.assertListEqual(expected_categories_list, categories) 159 | 160 | def test_convert_label_map_to_coco_categories_with_few_classes(self): 161 | label_map_proto = self._generate_label_map(num_classes=4) 162 | cat_no_offset = label_map_util.convert_label_map_to_categories( 163 | label_map_proto, max_num_classes=2) 164 | expected_categories_list = [{ 165 | 'name': u'1', 166 | 'id': 1 167 | }, { 168 | 'name': u'2', 169 | 'id': 2 170 | }] 171 | self.assertListEqual(expected_categories_list, cat_no_offset) 172 | 173 | def test_create_category_index(self): 174 | categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}] 175 | category_index = label_map_util.create_category_index(categories) 176 | self.assertDictEqual({ 177 | 1: { 178 | 'name': u'1', 179 | 'id': 1 180 | }, 181 | 2: { 182 | 'name': u'2', 183 | 'id': 2 184 | } 185 | }, category_index) 186 | 187 | def test_create_category_index_from_labelmap(self): 188 | label_map_string = """ 189 | item { 190 | id:2 191 | name:'cat' 192 | } 193 | item { 194 | id:1 195 | name:'dog' 196 | } 197 | """ 198 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') 199 | with tf.gfile.Open(label_map_path, 'wb') as f: 200 | f.write(label_map_string) 201 | 202 | category_index = label_map_util.create_category_index_from_labelmap( 203 | label_map_path) 204 | self.assertDictEqual({ 205 | 1: { 206 | 'name': u'dog', 207 | 'id': 1 208 | }, 209 | 2: { 210 | 'name': u'cat', 211 | 'id': 2 212 | } 213 | }, category_index) 214 | 215 | 216 | if __name__ == '__main__': 217 | tf.test.main() 218 | -------------------------------------------------------------------------------- /utils/learning_schedules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Library of common learning rate schedules.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | 21 | def exponential_decay_with_burnin(global_step, 22 | learning_rate_base, 23 | learning_rate_decay_steps, 24 | learning_rate_decay_factor, 25 | burnin_learning_rate=0.0, 26 | burnin_steps=0): 27 | """Exponential decay schedule with burn-in period. 28 | 29 | In this schedule, learning rate is fixed at burnin_learning_rate 30 | for a fixed period, before transitioning to a regular exponential 31 | decay schedule. 32 | 33 | Args: 34 | global_step: int tensor representing global step. 35 | learning_rate_base: base learning rate. 36 | learning_rate_decay_steps: steps to take between decaying the learning rate. 37 | Note that this includes the number of burn-in steps. 38 | learning_rate_decay_factor: multiplicative factor by which to decay 39 | learning rate. 40 | burnin_learning_rate: initial learning rate during burn-in period. If 41 | 0.0 (which is the default), then the burn-in learning rate is simply 42 | set to learning_rate_base. 43 | burnin_steps: number of steps to use burnin learning rate. 44 | 45 | Returns: 46 | a (scalar) float tensor representing learning rate 47 | """ 48 | if burnin_learning_rate == 0: 49 | burnin_learning_rate = learning_rate_base 50 | post_burnin_learning_rate = tf.train.exponential_decay( 51 | learning_rate_base, 52 | global_step, 53 | learning_rate_decay_steps, 54 | learning_rate_decay_factor, 55 | staircase=True) 56 | return tf.cond( 57 | tf.less(global_step, burnin_steps), 58 | lambda: tf.convert_to_tensor(burnin_learning_rate), 59 | lambda: post_burnin_learning_rate) 60 | 61 | 62 | def cosine_decay_with_warmup(global_step, 63 | learning_rate_base, 64 | total_steps, 65 | warmup_learning_rate=0.0, 66 | warmup_steps=0): 67 | """Cosine decay schedule with warm up period. 68 | 69 | Cosine annealing learning rate as described in: 70 | Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts. 71 | ICLR 2017. https://arxiv.org/abs/1608.03983 72 | In this schedule, the learning rate grows linearly from warmup_learning_rate 73 | to learning_rate_base for warmup_steps, then transitions to a cosine decay 74 | schedule. 75 | 76 | Args: 77 | global_step: int64 (scalar) tensor representing global step. 78 | learning_rate_base: base learning rate. 79 | total_steps: total number of training steps. 80 | warmup_learning_rate: initial learning rate for warm up. 81 | warmup_steps: number of warmup steps. 82 | 83 | Returns: 84 | a (scalar) float tensor representing learning rate. 85 | 86 | Raises: 87 | ValueError: if warmup_learning_rate is larger than learning_rate_base, 88 | or if warmup_steps is larger than total_steps. 89 | """ 90 | if learning_rate_base < warmup_learning_rate: 91 | raise ValueError('learning_rate_base must be larger ' 92 | 'or equal to warmup_learning_rate.') 93 | if total_steps < warmup_steps: 94 | raise ValueError('total_steps must be larger or equal to ' 95 | 'warmup_steps.') 96 | learning_rate = 0.5 * learning_rate_base * ( 97 | 1 + tf.cos(np.pi * tf.cast( 98 | global_step - warmup_steps, tf.float32 99 | ) / float(total_steps - warmup_steps))) 100 | if warmup_steps > 0: 101 | slope = (learning_rate_base - warmup_learning_rate) / warmup_steps 102 | pre_cosine_learning_rate = slope * tf.cast( 103 | global_step, tf.float32) + warmup_learning_rate 104 | learning_rate = tf.cond( 105 | tf.less(global_step, warmup_steps), lambda: pre_cosine_learning_rate, 106 | lambda: learning_rate) 107 | return learning_rate 108 | 109 | 110 | def manual_stepping(global_step, boundaries, rates): 111 | """Manually stepped learning rate schedule. 112 | 113 | This function provides fine grained control over learning rates. One must 114 | specify a sequence of learning rates as well as a set of integer steps 115 | at which the current learning rate must transition to the next. For example, 116 | if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning 117 | rate returned by this function is .1 for global_step=0,...,4, .01 for 118 | global_step=5...9, and .001 for global_step=10 and onward. 119 | 120 | Args: 121 | global_step: int64 (scalar) tensor representing global step. 122 | boundaries: a list of global steps at which to switch learning 123 | rates. This list is assumed to consist of increasing positive integers. 124 | rates: a list of (float) learning rates corresponding to intervals between 125 | the boundaries. The length of this list must be exactly 126 | len(boundaries) + 1. 127 | 128 | Returns: 129 | a (scalar) float tensor representing learning rate 130 | Raises: 131 | ValueError: if one of the following checks fails: 132 | 1. boundaries is a strictly increasing list of positive integers 133 | 2. len(rates) == len(boundaries) + 1 134 | """ 135 | if any([b < 0 for b in boundaries]) or any( 136 | [not isinstance(b, int) for b in boundaries]): 137 | raise ValueError('boundaries must be a list of positive integers') 138 | if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]): 139 | raise ValueError('Entries in boundaries must be strictly increasing.') 140 | if any([not isinstance(r, float) for r in rates]): 141 | raise ValueError('Learning rates must be floats') 142 | if len(rates) != len(boundaries) + 1: 143 | raise ValueError('Number of provided learning rates must exceed ' 144 | 'number of boundary points by exactly 1.') 145 | step_boundaries = tf.constant(boundaries, tf.int64) 146 | learning_rates = tf.constant(rates, tf.float32) 147 | unreached_boundaries = tf.reshape( 148 | tf.where(tf.greater(step_boundaries, global_step)), [-1]) 149 | unreached_boundaries = tf.concat([unreached_boundaries, [len(boundaries)]], 0) 150 | index = tf.reshape(tf.reduce_min(unreached_boundaries), [1]) 151 | return tf.reshape(tf.slice(learning_rates, index, [1]), []) 152 | -------------------------------------------------------------------------------- /protos/eval_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/eval.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/eval.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n\"object_detection/protos/eval.proto\x12\x17object_detection.protos\"\x80\x03\n\nEvalConfig\x12\x1e\n\x12num_visualizations\x18\x01 \x01(\r:\x02\x31\x30\x12\x1a\n\x0cnum_examples\x18\x02 \x01(\r:\x04\x35\x30\x30\x30\x12\x1f\n\x12\x65val_interval_secs\x18\x03 \x01(\r:\x03\x33\x30\x30\x12\x14\n\tmax_evals\x18\x04 \x01(\r:\x01\x30\x12\x19\n\nsave_graph\x18\x05 \x01(\x08:\x05\x66\x61lse\x12\"\n\x18visualization_export_dir\x18\x06 \x01(\t:\x00\x12\x15\n\x0b\x65val_master\x18\x07 \x01(\t:\x00\x12\'\n\x0bmetrics_set\x18\x08 \x01(\t:\x12pascal_voc_metrics\x12\x15\n\x0b\x65xport_path\x18\t \x01(\t:\x00\x12!\n\x12ignore_groundtruth\x18\n \x01(\x08:\x05\x66\x61lse\x12\"\n\x13use_moving_averages\x18\x0b \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x65val_instance_masks\x18\x0c \x01(\x08:\x05\x66\x61lse') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | 28 | _EVALCONFIG = _descriptor.Descriptor( 29 | name='EvalConfig', 30 | full_name='object_detection.protos.EvalConfig', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | containing_type=None, 34 | fields=[ 35 | _descriptor.FieldDescriptor( 36 | name='num_visualizations', full_name='object_detection.protos.EvalConfig.num_visualizations', index=0, 37 | number=1, type=13, cpp_type=3, label=1, 38 | has_default_value=True, default_value=10, 39 | message_type=None, enum_type=None, containing_type=None, 40 | is_extension=False, extension_scope=None, 41 | options=None), 42 | _descriptor.FieldDescriptor( 43 | name='num_examples', full_name='object_detection.protos.EvalConfig.num_examples', index=1, 44 | number=2, type=13, cpp_type=3, label=1, 45 | has_default_value=True, default_value=5000, 46 | message_type=None, enum_type=None, containing_type=None, 47 | is_extension=False, extension_scope=None, 48 | options=None), 49 | _descriptor.FieldDescriptor( 50 | name='eval_interval_secs', full_name='object_detection.protos.EvalConfig.eval_interval_secs', index=2, 51 | number=3, type=13, cpp_type=3, label=1, 52 | has_default_value=True, default_value=300, 53 | message_type=None, enum_type=None, containing_type=None, 54 | is_extension=False, extension_scope=None, 55 | options=None), 56 | _descriptor.FieldDescriptor( 57 | name='max_evals', full_name='object_detection.protos.EvalConfig.max_evals', index=3, 58 | number=4, type=13, cpp_type=3, label=1, 59 | has_default_value=True, default_value=0, 60 | message_type=None, enum_type=None, containing_type=None, 61 | is_extension=False, extension_scope=None, 62 | options=None), 63 | _descriptor.FieldDescriptor( 64 | name='save_graph', full_name='object_detection.protos.EvalConfig.save_graph', index=4, 65 | number=5, type=8, cpp_type=7, label=1, 66 | has_default_value=True, default_value=False, 67 | message_type=None, enum_type=None, containing_type=None, 68 | is_extension=False, extension_scope=None, 69 | options=None), 70 | _descriptor.FieldDescriptor( 71 | name='visualization_export_dir', full_name='object_detection.protos.EvalConfig.visualization_export_dir', index=5, 72 | number=6, type=9, cpp_type=9, label=1, 73 | has_default_value=True, default_value=_b("").decode('utf-8'), 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='eval_master', full_name='object_detection.protos.EvalConfig.eval_master', index=6, 79 | number=7, type=9, cpp_type=9, label=1, 80 | has_default_value=True, default_value=_b("").decode('utf-8'), 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | _descriptor.FieldDescriptor( 85 | name='metrics_set', full_name='object_detection.protos.EvalConfig.metrics_set', index=7, 86 | number=8, type=9, cpp_type=9, label=1, 87 | has_default_value=True, default_value=_b("pascal_voc_metrics").decode('utf-8'), 88 | message_type=None, enum_type=None, containing_type=None, 89 | is_extension=False, extension_scope=None, 90 | options=None), 91 | _descriptor.FieldDescriptor( 92 | name='export_path', full_name='object_detection.protos.EvalConfig.export_path', index=8, 93 | number=9, type=9, cpp_type=9, label=1, 94 | has_default_value=True, default_value=_b("").decode('utf-8'), 95 | message_type=None, enum_type=None, containing_type=None, 96 | is_extension=False, extension_scope=None, 97 | options=None), 98 | _descriptor.FieldDescriptor( 99 | name='ignore_groundtruth', full_name='object_detection.protos.EvalConfig.ignore_groundtruth', index=9, 100 | number=10, type=8, cpp_type=7, label=1, 101 | has_default_value=True, default_value=False, 102 | message_type=None, enum_type=None, containing_type=None, 103 | is_extension=False, extension_scope=None, 104 | options=None), 105 | _descriptor.FieldDescriptor( 106 | name='use_moving_averages', full_name='object_detection.protos.EvalConfig.use_moving_averages', index=10, 107 | number=11, type=8, cpp_type=7, label=1, 108 | has_default_value=True, default_value=False, 109 | message_type=None, enum_type=None, containing_type=None, 110 | is_extension=False, extension_scope=None, 111 | options=None), 112 | _descriptor.FieldDescriptor( 113 | name='eval_instance_masks', full_name='object_detection.protos.EvalConfig.eval_instance_masks', index=11, 114 | number=12, type=8, cpp_type=7, label=1, 115 | has_default_value=True, default_value=False, 116 | message_type=None, enum_type=None, containing_type=None, 117 | is_extension=False, extension_scope=None, 118 | options=None), 119 | ], 120 | extensions=[ 121 | ], 122 | nested_types=[], 123 | enum_types=[ 124 | ], 125 | options=None, 126 | is_extendable=False, 127 | extension_ranges=[], 128 | oneofs=[ 129 | ], 130 | serialized_start=64, 131 | serialized_end=448, 132 | ) 133 | 134 | DESCRIPTOR.message_types_by_name['EvalConfig'] = _EVALCONFIG 135 | 136 | EvalConfig = _reflection.GeneratedProtocolMessageType('EvalConfig', (_message.Message,), dict( 137 | DESCRIPTOR = _EVALCONFIG, 138 | __module__ = 'object_detection.protos.eval_pb2' 139 | # @@protoc_insertion_point(class_scope:object_detection.protos.EvalConfig) 140 | )) 141 | _sym_db.RegisterMessage(EvalConfig) 142 | 143 | 144 | # @@protoc_insertion_point(module_scope) 145 | -------------------------------------------------------------------------------- /protos/post_processing_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: object_detection/protos/post_processing.proto 3 | 4 | import sys 5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | from google.protobuf import descriptor_pb2 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='object_detection/protos/post_processing.proto', 20 | package='object_detection.protos', 21 | serialized_pb=_b('\n-object_detection/protos/post_processing.proto\x12\x17object_detection.protos\"\x9a\x01\n\x16\x42\x61tchNonMaxSuppression\x12\x1a\n\x0fscore_threshold\x18\x01 \x01(\x02:\x01\x30\x12\x1a\n\riou_threshold\x18\x02 \x01(\x02:\x03\x30.6\x12%\n\x18max_detections_per_class\x18\x03 \x01(\x05:\x03\x31\x30\x30\x12!\n\x14max_total_detections\x18\x05 \x01(\x05:\x03\x31\x30\x30\"\x91\x02\n\x0ePostProcessing\x12R\n\x19\x62\x61tch_non_max_suppression\x18\x01 \x01(\x0b\x32/.object_detection.protos.BatchNonMaxSuppression\x12Y\n\x0fscore_converter\x18\x02 \x01(\x0e\x32\x36.object_detection.protos.PostProcessing.ScoreConverter:\x08IDENTITY\x12\x16\n\x0blogit_scale\x18\x03 \x01(\x02:\x01\x31\"8\n\x0eScoreConverter\x12\x0c\n\x08IDENTITY\x10\x00\x12\x0b\n\x07SIGMOID\x10\x01\x12\x0b\n\x07SOFTMAX\x10\x02') 22 | ) 23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 24 | 25 | 26 | 27 | _POSTPROCESSING_SCORECONVERTER = _descriptor.EnumDescriptor( 28 | name='ScoreConverter', 29 | full_name='object_detection.protos.PostProcessing.ScoreConverter', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | values=[ 33 | _descriptor.EnumValueDescriptor( 34 | name='IDENTITY', index=0, number=0, 35 | options=None, 36 | type=None), 37 | _descriptor.EnumValueDescriptor( 38 | name='SIGMOID', index=1, number=1, 39 | options=None, 40 | type=None), 41 | _descriptor.EnumValueDescriptor( 42 | name='SOFTMAX', index=2, number=2, 43 | options=None, 44 | type=None), 45 | ], 46 | containing_type=None, 47 | options=None, 48 | serialized_start=449, 49 | serialized_end=505, 50 | ) 51 | _sym_db.RegisterEnumDescriptor(_POSTPROCESSING_SCORECONVERTER) 52 | 53 | 54 | _BATCHNONMAXSUPPRESSION = _descriptor.Descriptor( 55 | name='BatchNonMaxSuppression', 56 | full_name='object_detection.protos.BatchNonMaxSuppression', 57 | filename=None, 58 | file=DESCRIPTOR, 59 | containing_type=None, 60 | fields=[ 61 | _descriptor.FieldDescriptor( 62 | name='score_threshold', full_name='object_detection.protos.BatchNonMaxSuppression.score_threshold', index=0, 63 | number=1, type=2, cpp_type=6, label=1, 64 | has_default_value=True, default_value=0, 65 | message_type=None, enum_type=None, containing_type=None, 66 | is_extension=False, extension_scope=None, 67 | options=None), 68 | _descriptor.FieldDescriptor( 69 | name='iou_threshold', full_name='object_detection.protos.BatchNonMaxSuppression.iou_threshold', index=1, 70 | number=2, type=2, cpp_type=6, label=1, 71 | has_default_value=True, default_value=0.6, 72 | message_type=None, enum_type=None, containing_type=None, 73 | is_extension=False, extension_scope=None, 74 | options=None), 75 | _descriptor.FieldDescriptor( 76 | name='max_detections_per_class', full_name='object_detection.protos.BatchNonMaxSuppression.max_detections_per_class', index=2, 77 | number=3, type=5, cpp_type=1, label=1, 78 | has_default_value=True, default_value=100, 79 | message_type=None, enum_type=None, containing_type=None, 80 | is_extension=False, extension_scope=None, 81 | options=None), 82 | _descriptor.FieldDescriptor( 83 | name='max_total_detections', full_name='object_detection.protos.BatchNonMaxSuppression.max_total_detections', index=3, 84 | number=5, type=5, cpp_type=1, label=1, 85 | has_default_value=True, default_value=100, 86 | message_type=None, enum_type=None, containing_type=None, 87 | is_extension=False, extension_scope=None, 88 | options=None), 89 | ], 90 | extensions=[ 91 | ], 92 | nested_types=[], 93 | enum_types=[ 94 | ], 95 | options=None, 96 | is_extendable=False, 97 | extension_ranges=[], 98 | oneofs=[ 99 | ], 100 | serialized_start=75, 101 | serialized_end=229, 102 | ) 103 | 104 | 105 | _POSTPROCESSING = _descriptor.Descriptor( 106 | name='PostProcessing', 107 | full_name='object_detection.protos.PostProcessing', 108 | filename=None, 109 | file=DESCRIPTOR, 110 | containing_type=None, 111 | fields=[ 112 | _descriptor.FieldDescriptor( 113 | name='batch_non_max_suppression', full_name='object_detection.protos.PostProcessing.batch_non_max_suppression', index=0, 114 | number=1, type=11, cpp_type=10, label=1, 115 | has_default_value=False, default_value=None, 116 | message_type=None, enum_type=None, containing_type=None, 117 | is_extension=False, extension_scope=None, 118 | options=None), 119 | _descriptor.FieldDescriptor( 120 | name='score_converter', full_name='object_detection.protos.PostProcessing.score_converter', index=1, 121 | number=2, type=14, cpp_type=8, label=1, 122 | has_default_value=True, default_value=0, 123 | message_type=None, enum_type=None, containing_type=None, 124 | is_extension=False, extension_scope=None, 125 | options=None), 126 | _descriptor.FieldDescriptor( 127 | name='logit_scale', full_name='object_detection.protos.PostProcessing.logit_scale', index=2, 128 | number=3, type=2, cpp_type=6, label=1, 129 | has_default_value=True, default_value=1, 130 | message_type=None, enum_type=None, containing_type=None, 131 | is_extension=False, extension_scope=None, 132 | options=None), 133 | ], 134 | extensions=[ 135 | ], 136 | nested_types=[], 137 | enum_types=[ 138 | _POSTPROCESSING_SCORECONVERTER, 139 | ], 140 | options=None, 141 | is_extendable=False, 142 | extension_ranges=[], 143 | oneofs=[ 144 | ], 145 | serialized_start=232, 146 | serialized_end=505, 147 | ) 148 | 149 | _POSTPROCESSING.fields_by_name['batch_non_max_suppression'].message_type = _BATCHNONMAXSUPPRESSION 150 | _POSTPROCESSING.fields_by_name['score_converter'].enum_type = _POSTPROCESSING_SCORECONVERTER 151 | _POSTPROCESSING_SCORECONVERTER.containing_type = _POSTPROCESSING 152 | DESCRIPTOR.message_types_by_name['BatchNonMaxSuppression'] = _BATCHNONMAXSUPPRESSION 153 | DESCRIPTOR.message_types_by_name['PostProcessing'] = _POSTPROCESSING 154 | 155 | BatchNonMaxSuppression = _reflection.GeneratedProtocolMessageType('BatchNonMaxSuppression', (_message.Message,), dict( 156 | DESCRIPTOR = _BATCHNONMAXSUPPRESSION, 157 | __module__ = 'object_detection.protos.post_processing_pb2' 158 | # @@protoc_insertion_point(class_scope:object_detection.protos.BatchNonMaxSuppression) 159 | )) 160 | _sym_db.RegisterMessage(BatchNonMaxSuppression) 161 | 162 | PostProcessing = _reflection.GeneratedProtocolMessageType('PostProcessing', (_message.Message,), dict( 163 | DESCRIPTOR = _POSTPROCESSING, 164 | __module__ = 'object_detection.protos.post_processing_pb2' 165 | # @@protoc_insertion_point(class_scope:object_detection.protos.PostProcessing) 166 | )) 167 | _sym_db.RegisterMessage(PostProcessing) 168 | 169 | 170 | # @@protoc_insertion_point(module_scope) 171 | --------------------------------------------------------------------------------