├── protos
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── eval_pb2.cpython-35.pyc
│ ├── model_pb2.cpython-35.pyc
│ ├── ssd_pb2.cpython-35.pyc
│ ├── train_pb2.cpython-35.pyc
│ ├── losses_pb2.cpython-35.pyc
│ ├── matcher_pb2.cpython-35.pyc
│ ├── box_coder_pb2.cpython-35.pyc
│ ├── optimizer_pb2.cpython-35.pyc
│ ├── pipeline_pb2.cpython-35.pyc
│ ├── faster_rcnn_pb2.cpython-35.pyc
│ ├── hyperparams_pb2.cpython-35.pyc
│ ├── input_reader_pb2.cpython-35.pyc
│ ├── preprocessor_pb2.cpython-35.pyc
│ ├── argmax_matcher_pb2.cpython-35.pyc
│ ├── box_predictor_pb2.cpython-35.pyc
│ ├── image_resizer_pb2.cpython-35.pyc
│ ├── post_processing_pb2.cpython-35.pyc
│ ├── anchor_generator_pb2.cpython-35.pyc
│ ├── bipartite_matcher_pb2.cpython-35.pyc
│ ├── square_box_coder_pb2.cpython-35.pyc
│ ├── keypoint_box_coder_pb2.cpython-35.pyc
│ ├── ssd_anchor_generator_pb2.cpython-35.pyc
│ ├── string_int_label_map_pb2.cpython-35.pyc
│ ├── faster_rcnn_box_coder_pb2.cpython-35.pyc
│ ├── grid_anchor_generator_pb2.cpython-35.pyc
│ ├── mean_stddev_box_coder_pb2.cpython-35.pyc
│ └── region_similarity_calculator_pb2.cpython-35.pyc
├── bipartite_matcher.proto
├── mean_stddev_box_coder.proto
├── model.proto
├── matcher.proto
├── square_box_coder.proto
├── anchor_generator.proto
├── faster_rcnn_box_coder.proto
├── keypoint_box_coder.proto
├── pipeline.proto
├── box_coder.proto
├── region_similarity_calculator.proto
├── string_int_label_map.proto
├── argmax_matcher.proto
├── grid_anchor_generator.proto
├── image_resizer.proto
├── post_processing.proto
├── eval.proto
├── bipartite_matcher_pb2.py
├── mean_stddev_box_coder_pb2.py
├── input_reader.proto
├── ssd_anchor_generator.proto
├── train.proto
├── square_box_coder_pb2.py
├── optimizer.proto
├── ssd.proto
├── hyperparams.proto
├── faster_rcnn_box_coder_pb2.py
├── model_pb2.py
├── matcher_pb2.py
├── box_predictor.proto
├── keypoint_box_coder_pb2.py
├── argmax_matcher_pb2.py
├── anchor_generator_pb2.py
├── string_int_label_map_pb2.py
├── losses.proto
├── grid_anchor_generator_pb2.py
├── pipeline_pb2.py
├── box_coder_pb2.py
├── faster_rcnn.proto
├── eval_pb2.py
└── post_processing_pb2.py
├── utils
├── __init__.py
├── __pycache__
│ ├── ops.cpython-35.pyc
│ ├── metrics.cpython-35.pyc
│ ├── __init__.cpython-35.pyc
│ ├── np_box_ops.cpython-35.pyc
│ ├── config_util.cpython-35.pyc
│ ├── dataset_util.cpython-35.pyc
│ ├── np_box_list.cpython-35.pyc
│ ├── shape_utils.cpython-35.pyc
│ ├── static_shape.cpython-35.pyc
│ ├── label_map_util.cpython-35.pyc
│ ├── np_box_list_ops.cpython-35.pyc
│ ├── variables_helper.cpython-35.pyc
│ ├── learning_schedules.cpython-35.pyc
│ ├── per_image_evaluation.cpython-35.pyc
│ ├── visualization_utils.cpython-35.pyc
│ └── object_detection_evaluation.cpython-35.pyc
├── dataset_util_test.py
├── category_util_test.py
├── static_shape_test.py
├── static_shape.py
├── category_util.py
├── np_box_ops_test.py
├── test_utils_test.py
├── dataset_util.py
├── learning_schedules_test.py
├── np_box_ops.py
├── metrics_test.py
├── shape_utils.py
├── np_box_list.py
├── test_utils.py
├── variables_helper.py
├── shape_utils_test.py
├── np_box_list_test.py
├── metrics.py
├── label_map_util.py
├── label_map_util_test.py
└── learning_schedules.py
├── images
├── GRDD2020.png
├── installation1.png
├── installation2.png
└── RoadDamageTypeDef.png
├── crackLabelMap.txt
├── LICENSE
└── smartphoneAPPS.md
/protos/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/images/GRDD2020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/GRDD2020.png
--------------------------------------------------------------------------------
/images/installation1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/installation1.png
--------------------------------------------------------------------------------
/images/installation2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/installation2.png
--------------------------------------------------------------------------------
/images/RoadDamageTypeDef.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/images/RoadDamageTypeDef.png
--------------------------------------------------------------------------------
/utils/__pycache__/ops.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/ops.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/metrics.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/eval_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/eval_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/model_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/model_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/ssd_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/ssd_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/train_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/train_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/np_box_ops.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/np_box_ops.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/losses_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/losses_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/matcher_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/matcher_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config_util.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/config_util.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_util.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/dataset_util.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/np_box_list.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/np_box_list.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/shape_utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/shape_utils.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/static_shape.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/static_shape.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/box_coder_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/box_coder_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/optimizer_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/optimizer_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/pipeline_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/pipeline_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/label_map_util.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/label_map_util.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/np_box_list_ops.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/np_box_list_ops.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/faster_rcnn_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/faster_rcnn_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/hyperparams_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/hyperparams_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/input_reader_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/input_reader_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/preprocessor_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/preprocessor_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/variables_helper.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/variables_helper.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/argmax_matcher_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/argmax_matcher_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/box_predictor_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/box_predictor_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/image_resizer_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/image_resizer_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/post_processing_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/post_processing_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/learning_schedules.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/learning_schedules.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/per_image_evaluation.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/per_image_evaluation.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/visualization_utils.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/visualization_utils.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/anchor_generator_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/anchor_generator_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/bipartite_matcher_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/bipartite_matcher_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/square_box_coder_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/square_box_coder_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/keypoint_box_coder_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/keypoint_box_coder_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/ssd_anchor_generator_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/ssd_anchor_generator_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/string_int_label_map_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/string_int_label_map_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/faster_rcnn_box_coder_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/faster_rcnn_box_coder_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/grid_anchor_generator_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/grid_anchor_generator_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/mean_stddev_box_coder_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/mean_stddev_box_coder_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/object_detection_evaluation.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/utils/__pycache__/object_detection_evaluation.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/__pycache__/region_similarity_calculator_pb2.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sekilab/RoadDamageDetector/HEAD/protos/__pycache__/region_similarity_calculator_pb2.cpython-35.pyc
--------------------------------------------------------------------------------
/protos/bipartite_matcher.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for bipartite matcher. See
6 | // matchers/bipartite_matcher.py for details.
7 | message BipartiteMatcher {
8 | }
9 |
--------------------------------------------------------------------------------
/protos/mean_stddev_box_coder.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for MeanStddevBoxCoder. See
6 | // box_coders/mean_stddev_box_coder.py for details.
7 | message MeanStddevBoxCoder {
8 | }
9 |
--------------------------------------------------------------------------------
/protos/model.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/faster_rcnn.proto";
6 | import "object_detection/protos/ssd.proto";
7 |
8 | // Top level configuration for DetectionModels.
9 | message DetectionModel {
10 | oneof model {
11 | FasterRcnn faster_rcnn = 1;
12 | Ssd ssd = 2;
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/crackLabelMap.txt:
--------------------------------------------------------------------------------
1 | item {
2 | id: 1
3 | name: 'D00'
4 | }
5 |
6 | item {
7 | id: 2
8 | name: 'D01'
9 | }
10 |
11 | item {
12 | id: 3
13 | name: 'D10'
14 | }
15 |
16 | item {
17 | id: 4
18 | name: 'D11'
19 | }
20 |
21 | item {
22 | id: 5
23 | name: 'D20'
24 | }
25 |
26 | item {
27 | id: 6
28 | name: 'D40'
29 | }
30 |
31 | item {
32 | id: 7
33 | name: 'D43'
34 | }
35 |
36 | item {
37 | id: 8
38 | name: 'D44'
39 | }
40 |
--------------------------------------------------------------------------------
/protos/matcher.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/argmax_matcher.proto";
6 | import "object_detection/protos/bipartite_matcher.proto";
7 |
8 | // Configuration proto for the matcher to be used in the object detection
9 | // pipeline. See core/matcher.py for details.
10 | message Matcher {
11 | oneof matcher_oneof {
12 | ArgMaxMatcher argmax_matcher = 1;
13 | BipartiteMatcher bipartite_matcher = 2;
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/protos/square_box_coder.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for SquareBoxCoder. See
6 | // box_coders/square_box_coder.py for details.
7 | message SquareBoxCoder {
8 | // Scale factor for anchor encoded box center.
9 | optional float y_scale = 1 [default = 10.0];
10 | optional float x_scale = 2 [default = 10.0];
11 |
12 | // Scale factor for anchor encoded box length.
13 | optional float length_scale = 3 [default = 5.0];
14 | }
15 |
--------------------------------------------------------------------------------
/protos/anchor_generator.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/grid_anchor_generator.proto";
6 | import "object_detection/protos/ssd_anchor_generator.proto";
7 |
8 | // Configuration proto for the anchor generator to use in the object detection
9 | // pipeline. See core/anchor_generator.py for details.
10 | message AnchorGenerator {
11 | oneof anchor_generator_oneof {
12 | GridAnchorGenerator grid_anchor_generator = 1;
13 | SsdAnchorGenerator ssd_anchor_generator = 2;
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/protos/faster_rcnn_box_coder.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for FasterRCNNBoxCoder. See
6 | // box_coders/faster_rcnn_box_coder.py for details.
7 | message FasterRcnnBoxCoder {
8 | // Scale factor for anchor encoded box center.
9 | optional float y_scale = 1 [default = 10.0];
10 | optional float x_scale = 2 [default = 10.0];
11 |
12 | // Scale factor for anchor encoded box height.
13 | optional float height_scale = 3 [default = 5.0];
14 |
15 | // Scale factor for anchor encoded box width.
16 | optional float width_scale = 4 [default = 5.0];
17 | }
18 |
--------------------------------------------------------------------------------
/protos/keypoint_box_coder.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for KeypointBoxCoder. See
6 | // box_coders/keypoint_box_coder.py for details.
7 | message KeypointBoxCoder {
8 | optional int32 num_keypoints = 1;
9 |
10 | // Scale factor for anchor encoded box center and keypoints.
11 | optional float y_scale = 2 [default = 10.0];
12 | optional float x_scale = 3 [default = 10.0];
13 |
14 | // Scale factor for anchor encoded box height.
15 | optional float height_scale = 4 [default = 5.0];
16 |
17 | // Scale factor for anchor encoded box width.
18 | optional float width_scale = 5 [default = 5.0];
19 | }
20 |
--------------------------------------------------------------------------------
/protos/pipeline.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/eval.proto";
6 | import "object_detection/protos/input_reader.proto";
7 | import "object_detection/protos/model.proto";
8 | import "object_detection/protos/train.proto";
9 |
10 | // Convenience message for configuring a training and eval pipeline. Allows all
11 | // of the pipeline parameters to be configured from one file.
12 | message TrainEvalPipelineConfig {
13 | optional DetectionModel model = 1;
14 | optional TrainConfig train_config = 2;
15 | optional InputReader train_input_reader = 3;
16 | optional EvalConfig eval_config = 4;
17 | optional InputReader eval_input_reader = 5;
18 | }
19 |
--------------------------------------------------------------------------------
/protos/box_coder.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/faster_rcnn_box_coder.proto";
6 | import "object_detection/protos/keypoint_box_coder.proto";
7 | import "object_detection/protos/mean_stddev_box_coder.proto";
8 | import "object_detection/protos/square_box_coder.proto";
9 |
10 | // Configuration proto for the box coder to be used in the object detection
11 | // pipeline. See core/box_coder.py for details.
12 | message BoxCoder {
13 | oneof box_coder_oneof {
14 | FasterRcnnBoxCoder faster_rcnn_box_coder = 1;
15 | MeanStddevBoxCoder mean_stddev_box_coder = 2;
16 | SquareBoxCoder square_box_coder = 3;
17 | KeypointBoxCoder keypoint_box_coder = 4;
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/protos/region_similarity_calculator.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for region similarity calculators. See
6 | // core/region_similarity_calculator.py for details.
7 | message RegionSimilarityCalculator {
8 | oneof region_similarity {
9 | NegSqDistSimilarity neg_sq_dist_similarity = 1;
10 | IouSimilarity iou_similarity = 2;
11 | IoaSimilarity ioa_similarity = 3;
12 | }
13 | }
14 |
15 | // Configuration for negative squared distance similarity calculator.
16 | message NegSqDistSimilarity {
17 | }
18 |
19 | // Configuration for intersection-over-union (IOU) similarity calculator.
20 | message IouSimilarity {
21 | }
22 |
23 | // Configuration for intersection-over-area (IOA) similarity calculator.
24 | message IoaSimilarity {
25 | }
26 |
--------------------------------------------------------------------------------
/protos/string_int_label_map.proto:
--------------------------------------------------------------------------------
1 | // Message to store the mapping from class label strings to class id. Datasets
2 | // use string labels to represent classes while the object detection framework
3 | // works with class ids. This message maps them so they can be converted back
4 | // and forth as needed.
5 | syntax = "proto2";
6 |
7 | package object_detection.protos;
8 |
9 | message StringIntLabelMapItem {
10 | // String name. The most common practice is to set this to a MID or synsets
11 | // id.
12 | optional string name = 1;
13 |
14 | // Integer id that maps to the string name above. Label ids should start from
15 | // 1.
16 | optional int32 id = 2;
17 |
18 | // Human readable string label.
19 | optional string display_name = 3;
20 | };
21 |
22 | message StringIntLabelMap {
23 | repeated StringIntLabelMapItem item = 1;
24 | };
25 |
--------------------------------------------------------------------------------
/protos/argmax_matcher.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for ArgMaxMatcher. See
6 | // matchers/argmax_matcher.py for details.
7 | message ArgMaxMatcher {
8 | // Threshold for positive matches.
9 | optional float matched_threshold = 1 [default = 0.5];
10 |
11 | // Threshold for negative matches.
12 | optional float unmatched_threshold = 2 [default = 0.5];
13 |
14 | // Whether to construct ArgMaxMatcher without thresholds.
15 | optional bool ignore_thresholds = 3 [default = false];
16 |
17 | // If True then negative matches are the ones below the unmatched_threshold,
18 | // whereas ignored matches are in between the matched and umatched
19 | // threshold. If False, then negative matches are in between the matched
20 | // and unmatched threshold, and everything lower than unmatched is ignored.
21 | optional bool negatives_lower_than_unmatched = 4 [default = true];
22 |
23 | // Whether to ensure each row is matched to at least one column.
24 | optional bool force_match_for_each_row = 5 [default = false];
25 | }
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 sekilab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/protos/grid_anchor_generator.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for GridAnchorGenerator. See
6 | // anchor_generators/grid_anchor_generator.py for details.
7 | message GridAnchorGenerator {
8 | // Anchor height in pixels.
9 | optional int32 height = 1 [default = 256];
10 |
11 | // Anchor width in pixels.
12 | optional int32 width = 2 [default = 256];
13 |
14 | // Anchor stride in height dimension in pixels.
15 | optional int32 height_stride = 3 [default = 16];
16 |
17 | // Anchor stride in width dimension in pixels.
18 | optional int32 width_stride = 4 [default = 16];
19 |
20 | // Anchor height offset in pixels.
21 | optional int32 height_offset = 5 [default = 0];
22 |
23 | // Anchor width offset in pixels.
24 | optional int32 width_offset = 6 [default = 0];
25 |
26 | // At any given location, len(scales) * len(aspect_ratios) anchors are
27 | // generated with all possible combinations of scales and aspect ratios.
28 |
29 | // List of scales for the anchors.
30 | repeated float scales = 7;
31 |
32 | // List of aspect ratios for the anchors.
33 | repeated float aspect_ratios = 8;
34 | }
35 |
--------------------------------------------------------------------------------
/utils/dataset_util_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.dataset_util."""
17 |
18 | import os
19 | import tensorflow as tf
20 |
21 | from object_detection.utils import dataset_util
22 |
23 |
24 | class DatasetUtilTest(tf.test.TestCase):
25 |
26 | def test_read_examples_list(self):
27 | example_list_data = """example1 1\nexample2 2"""
28 | example_list_path = os.path.join(self.get_temp_dir(), 'examples.txt')
29 | with tf.gfile.Open(example_list_path, 'wb') as f:
30 | f.write(example_list_data)
31 |
32 | examples = dataset_util.read_examples_list(example_list_path)
33 | self.assertListEqual(['example1', 'example2'], examples)
34 |
35 |
36 | if __name__ == '__main__':
37 | tf.test.main()
38 |
--------------------------------------------------------------------------------
/protos/image_resizer.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for image resizing operations.
6 | // See builders/image_resizer_builder.py for details.
7 | message ImageResizer {
8 | oneof image_resizer_oneof {
9 | KeepAspectRatioResizer keep_aspect_ratio_resizer = 1;
10 | FixedShapeResizer fixed_shape_resizer = 2;
11 | }
12 | }
13 |
14 | // Enumeration type for image resizing methods provided in TensorFlow.
15 | enum ResizeType {
16 | BILINEAR = 0; // Corresponds to tf.image.ResizeMethod.BILINEAR
17 | NEAREST_NEIGHBOR = 1; // Corresponds to tf.image.ResizeMethod.NEAREST_NEIGHBOR
18 | BICUBIC = 2; // Corresponds to tf.image.ResizeMethod.BICUBIC
19 | AREA = 3; // Corresponds to tf.image.ResizeMethod.AREA
20 | }
21 |
22 | // Configuration proto for image resizer that keeps aspect ratio.
23 | message KeepAspectRatioResizer {
24 | // Desired size of the smaller image dimension in pixels.
25 | optional int32 min_dimension = 1 [default = 600];
26 |
27 | // Desired size of the larger image dimension in pixels.
28 | optional int32 max_dimension = 2 [default = 1024];
29 |
30 | // Desired method when resizing image.
31 | optional ResizeType resize_method = 3 [default = BILINEAR];
32 | }
33 |
34 | // Configuration proto for image resizer that resizes to a fixed shape.
35 | message FixedShapeResizer {
36 | // Desired height of image in pixels.
37 | optional int32 height = 1 [default = 300];
38 |
39 | // Desired width of image in pixels.
40 | optional int32 width = 2 [default = 300];
41 |
42 | // Desired method when resizing image.
43 | optional ResizeType resize_method = 3 [default = BILINEAR];
44 | }
45 |
--------------------------------------------------------------------------------
/protos/post_processing.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for non-max-suppression operation on a batch of
6 | // detections.
7 | message BatchNonMaxSuppression {
8 | // Scalar threshold for score (low scoring boxes are removed).
9 | optional float score_threshold = 1 [default = 0.0];
10 |
11 | // Scalar threshold for IOU (boxes that have high IOU overlap
12 | // with previously selected boxes are removed).
13 | optional float iou_threshold = 2 [default = 0.6];
14 |
15 | // Maximum number of detections to retain per class.
16 | optional int32 max_detections_per_class = 3 [default = 100];
17 |
18 | // Maximum number of detections to retain across all classes.
19 | optional int32 max_total_detections = 5 [default = 100];
20 | }
21 |
22 | // Configuration proto for post-processing predicted boxes and
23 | // scores.
24 | message PostProcessing {
25 | // Non max suppression parameters.
26 | optional BatchNonMaxSuppression batch_non_max_suppression = 1;
27 |
28 | // Enum to specify how to convert the detection scores.
29 | enum ScoreConverter {
30 | // Input scores equals output scores.
31 | IDENTITY = 0;
32 |
33 | // Applies a sigmoid on input scores.
34 | SIGMOID = 1;
35 |
36 | // Applies a softmax on input scores
37 | SOFTMAX = 2;
38 | }
39 |
40 | // Score converter to use.
41 | optional ScoreConverter score_converter = 2 [default = IDENTITY];
42 | // Scale logit (input) value before conversion in post-processing step.
43 | // Typically used for softmax distillation, though can be used to scale for
44 | // other reasons.
45 | optional float logit_scale = 3 [default = 1.0];
46 | }
47 |
--------------------------------------------------------------------------------
/protos/eval.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Message for configuring DetectionModel evaluation jobs (eval.py).
6 | message EvalConfig {
7 | // Number of visualization images to generate.
8 | optional uint32 num_visualizations = 1 [default=10];
9 |
10 | // Number of examples to process of evaluation.
11 | optional uint32 num_examples = 2 [default=5000];
12 |
13 | // How often to run evaluation.
14 | optional uint32 eval_interval_secs = 3 [default=300];
15 |
16 | // Maximum number of times to run evaluation. If set to 0, will run forever.
17 | optional uint32 max_evals = 4 [default=0];
18 |
19 | // Whether the TensorFlow graph used for evaluation should be saved to disk.
20 | optional bool save_graph = 5 [default=false];
21 |
22 | // Path to directory to store visualizations in. If empty, visualization
23 | // images are not exported (only shown on Tensorboard).
24 | optional string visualization_export_dir = 6 [default=""];
25 |
26 | // BNS name of the TensorFlow master.
27 | optional string eval_master = 7 [default=""];
28 |
29 | // Type of metrics to use for evaluation. Currently supports only Pascal VOC
30 | // detection metrics.
31 | optional string metrics_set = 8 [default="pascal_voc_metrics"];
32 |
33 | // Path to export detections to COCO compatible JSON format.
34 | optional string export_path = 9 [default=''];
35 |
36 | // Option to not read groundtruth labels and only export detections to
37 | // COCO-compatible JSON file.
38 | optional bool ignore_groundtruth = 10 [default=false];
39 |
40 | // Use exponential moving averages of variables for evaluation.
41 | optional bool use_moving_averages = 11 [default=false];
42 |
43 | // Whether to evaluate instance masks.
44 | // Note that since there is no evaluation code currently for instance
45 | // segmenation this option is unused.
46 | optional bool eval_instance_masks = 12 [default=false];
47 | }
48 |
--------------------------------------------------------------------------------
/protos/bipartite_matcher_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/bipartite_matcher.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/bipartite_matcher.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n/object_detection/protos/bipartite_matcher.proto\x12\x17object_detection.protos\"\x12\n\x10\x42ipartiteMatcher')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _BIPARTITEMATCHER = _descriptor.Descriptor(
29 | name='BipartiteMatcher',
30 | full_name='object_detection.protos.BipartiteMatcher',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | ],
36 | extensions=[
37 | ],
38 | nested_types=[],
39 | enum_types=[
40 | ],
41 | options=None,
42 | is_extendable=False,
43 | extension_ranges=[],
44 | oneofs=[
45 | ],
46 | serialized_start=76,
47 | serialized_end=94,
48 | )
49 |
50 | DESCRIPTOR.message_types_by_name['BipartiteMatcher'] = _BIPARTITEMATCHER
51 |
52 | BipartiteMatcher = _reflection.GeneratedProtocolMessageType('BipartiteMatcher', (_message.Message,), dict(
53 | DESCRIPTOR = _BIPARTITEMATCHER,
54 | __module__ = 'object_detection.protos.bipartite_matcher_pb2'
55 | # @@protoc_insertion_point(class_scope:object_detection.protos.BipartiteMatcher)
56 | ))
57 | _sym_db.RegisterMessage(BipartiteMatcher)
58 |
59 |
60 | # @@protoc_insertion_point(module_scope)
61 |
--------------------------------------------------------------------------------
/protos/mean_stddev_box_coder_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/mean_stddev_box_coder.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/mean_stddev_box_coder.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n3object_detection/protos/mean_stddev_box_coder.proto\x12\x17object_detection.protos\"\x14\n\x12MeanStddevBoxCoder')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _MEANSTDDEVBOXCODER = _descriptor.Descriptor(
29 | name='MeanStddevBoxCoder',
30 | full_name='object_detection.protos.MeanStddevBoxCoder',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | ],
36 | extensions=[
37 | ],
38 | nested_types=[],
39 | enum_types=[
40 | ],
41 | options=None,
42 | is_extendable=False,
43 | extension_ranges=[],
44 | oneofs=[
45 | ],
46 | serialized_start=80,
47 | serialized_end=100,
48 | )
49 |
50 | DESCRIPTOR.message_types_by_name['MeanStddevBoxCoder'] = _MEANSTDDEVBOXCODER
51 |
52 | MeanStddevBoxCoder = _reflection.GeneratedProtocolMessageType('MeanStddevBoxCoder', (_message.Message,), dict(
53 | DESCRIPTOR = _MEANSTDDEVBOXCODER,
54 | __module__ = 'object_detection.protos.mean_stddev_box_coder_pb2'
55 | # @@protoc_insertion_point(class_scope:object_detection.protos.MeanStddevBoxCoder)
56 | ))
57 | _sym_db.RegisterMessage(MeanStddevBoxCoder)
58 |
59 |
60 | # @@protoc_insertion_point(module_scope)
61 |
--------------------------------------------------------------------------------
/utils/category_util_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.category_util."""
17 | import os
18 |
19 | import tensorflow as tf
20 |
21 | from object_detection.utils import category_util
22 |
23 |
24 | class EvalUtilTest(tf.test.TestCase):
25 |
26 | def test_load_categories_from_csv_file(self):
27 | csv_data = """
28 | 0,"cat"
29 | 1,"dog"
30 | 2,"bird"
31 | """.strip(' ')
32 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv')
33 | with tf.gfile.Open(csv_path, 'wb') as f:
34 | f.write(csv_data)
35 |
36 | categories = category_util.load_categories_from_csv_file(csv_path)
37 | self.assertTrue({'id': 0, 'name': 'cat'} in categories)
38 | self.assertTrue({'id': 1, 'name': 'dog'} in categories)
39 | self.assertTrue({'id': 2, 'name': 'bird'} in categories)
40 |
41 | def test_save_categories_to_csv_file(self):
42 | categories = [
43 | {'id': 0, 'name': 'cat'},
44 | {'id': 1, 'name': 'dog'},
45 | {'id': 2, 'name': 'bird'},
46 | ]
47 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv')
48 | category_util.save_categories_to_csv_file(categories, csv_path)
49 | saved_categories = category_util.load_categories_from_csv_file(csv_path)
50 | self.assertEqual(saved_categories, categories)
51 |
52 |
53 | if __name__ == '__main__':
54 | tf.test.main()
55 |
--------------------------------------------------------------------------------
/utils/static_shape_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.static_shape."""
17 |
18 | import tensorflow as tf
19 |
20 | from object_detection.utils import static_shape
21 |
22 |
23 | class StaticShapeTest(tf.test.TestCase):
24 |
25 | def test_return_correct_batchSize(self):
26 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3])
27 | self.assertEqual(32, static_shape.get_batch_size(tensor_shape))
28 |
29 | def test_return_correct_height(self):
30 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3])
31 | self.assertEqual(299, static_shape.get_height(tensor_shape))
32 |
33 | def test_return_correct_width(self):
34 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3])
35 | self.assertEqual(384, static_shape.get_width(tensor_shape))
36 |
37 | def test_return_correct_depth(self):
38 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3])
39 | self.assertEqual(3, static_shape.get_depth(tensor_shape))
40 |
41 | def test_die_on_tensor_shape_with_rank_three(self):
42 | tensor_shape = tf.TensorShape(dims=[32, 299, 384])
43 | with self.assertRaises(ValueError):
44 | static_shape.get_batch_size(tensor_shape)
45 | static_shape.get_height(tensor_shape)
46 | static_shape.get_width(tensor_shape)
47 | static_shape.get_depth(tensor_shape)
48 |
49 | if __name__ == '__main__':
50 | tf.test.main()
51 |
--------------------------------------------------------------------------------
/utils/static_shape.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Helper functions to access TensorShape values.
17 |
18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth].
19 | """
20 |
21 |
22 | def get_batch_size(tensor_shape):
23 | """Returns batch size from the tensor shape.
24 |
25 | Args:
26 | tensor_shape: A rank 4 TensorShape.
27 |
28 | Returns:
29 | An integer representing the batch size of the tensor.
30 | """
31 | tensor_shape.assert_has_rank(rank=4)
32 | return tensor_shape[0].value
33 |
34 |
35 | def get_height(tensor_shape):
36 | """Returns height from the tensor shape.
37 |
38 | Args:
39 | tensor_shape: A rank 4 TensorShape.
40 |
41 | Returns:
42 | An integer representing the height of the tensor.
43 | """
44 | tensor_shape.assert_has_rank(rank=4)
45 | return tensor_shape[1].value
46 |
47 |
48 | def get_width(tensor_shape):
49 | """Returns width from the tensor shape.
50 |
51 | Args:
52 | tensor_shape: A rank 4 TensorShape.
53 |
54 | Returns:
55 | An integer representing the width of the tensor.
56 | """
57 | tensor_shape.assert_has_rank(rank=4)
58 | return tensor_shape[2].value
59 |
60 |
61 | def get_depth(tensor_shape):
62 | """Returns depth from the tensor shape.
63 |
64 | Args:
65 | tensor_shape: A rank 4 TensorShape.
66 |
67 | Returns:
68 | An integer representing the depth of the tensor.
69 | """
70 | tensor_shape.assert_has_rank(rank=4)
71 | return tensor_shape[3].value
72 |
--------------------------------------------------------------------------------
/smartphoneAPPS.md:
--------------------------------------------------------------------------------
1 | # RoadCrackDetector
2 |
3 | ## What is RoadCrackDetector?
4 | RoadCrackDetector is a smartphone apps that detects damages on the road by utilizing deep neural network model.
5 |
6 | スマートフォン上で深層学習モデルを動かすことで、道路路面の損傷画像を検出するアプリケーションです。
7 |
8 | ## How to use RoadCrackDetector?
9 | ### Basic Functions
10 | - Detect Road Damages when the car is running
11 | - Stop processing automatically when the car is stopping
12 | - Record road images every 1 second
13 |
14 | ### Instalation location
15 | - Car dashboard
16 |
17 |
18 |
19 |
20 | ## How to train the Crack Detection model?
21 | (March 2017)
22 | We used [yolo detector](https://pjreddie.com/darknet/yolo/ "yolo web") for training the model.
23 | The training dataset is more than 30,000 road images including road damages..
24 | This application can just detect "damages".
25 |
26 | (September 2017)
27 | We used SSD using MobileNet for training the model.
28 | Trained model with 9,053 road images with damages can be accessed [here](https://s3-ap-northeast-1.amazonaws.com/mycityreport/trainedModels.tar.gz).
29 | This application can just detect "damages" and classify "damage types".
30 |
31 | ### Android application
32 | Apps with MobileNet+SDD(Octorber 2018)
33 | [RoadCrackDetector.apk(26MB)](https://s3-ap-northeast-1.amazonaws.com/sekilab-students/maeda/kashiyama/mcr111_open.apk)
34 | (Android 7.1 or higher is required)
35 |
36 | The app saves 2 kinds of files in 'ExternalStorage/Android/data/org.utokyo.sekilab.mcr/files'.
37 | Location file contains GPS coordinate every 3 seconds. Damage file contains road damage data and image data.
38 |
39 | ### Experiments in some local governments in Japan(August 2017)
40 | We did road inspection with our apps in Toga villege, please check [our website](http://sekilab.iis.u-tokyo.ac.jp/archives/category/news#post-1882)!
41 | You can also watch the movie in the experiment as a demo([demo movie](https://youtu.be/P74Hl0vr1-Y))
42 |
43 | 富山県の利賀村にて、本アプリケーションを用いて実際に路面点検を実施しました ([参考](http://sekilab.iis.u-tokyo.ac.jp/archives/category/news#post-1882))。
44 | 実験の様子を[デモ動画](https://youtu.be/P74Hl0vr1-Y)として公開していますので、ご覧ください。
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/protos/input_reader.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for defining input readers that generate Object Detection
6 | // Examples from input sources. Input readers are expected to generate a
7 | // dictionary of tensors, with the following fields populated:
8 | //
9 | // 'image': an [image_height, image_width, channels] image tensor that detection
10 | // will be run on.
11 | // 'groundtruth_classes': a [num_boxes] int32 tensor storing the class
12 | // labels of detected boxes in the image.
13 | // 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of
14 | // detected boxes in the image.
15 | // 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height,
16 | // image_width] float tensor storing binary mask of the objects in boxes.
17 |
18 | message InputReader {
19 | // Path to StringIntLabelMap pbtxt file specifying the mapping from string
20 | // labels to integer ids.
21 | optional string label_map_path = 1 [default=""];
22 |
23 | // Whether data should be processed in the order they are read in, or
24 | // shuffled randomly.
25 | optional bool shuffle = 2 [default=true];
26 |
27 | // Maximum number of records to keep in reader queue.
28 | optional uint32 queue_capacity = 3 [default=2000];
29 |
30 | // Minimum number of records to keep in reader queue. A large value is needed
31 | // to generate a good random shuffle.
32 | optional uint32 min_after_dequeue = 4 [default=1000];
33 |
34 | // The number of times a data source is read. If set to zero, the data source
35 | // will be reused indefinitely.
36 | optional uint32 num_epochs = 5 [default=0];
37 |
38 | // Number of reader instances to create.
39 | optional uint32 num_readers = 6 [default=8];
40 |
41 | // Whether to load groundtruth instance masks.
42 | optional bool load_instance_masks = 7 [default = false];
43 |
44 | oneof input_reader {
45 | TFRecordInputReader tf_record_input_reader = 8;
46 | ExternalInputReader external_input_reader = 9;
47 | }
48 | }
49 |
50 | // An input reader that reads TF Example protos from local TFRecord files.
51 | message TFRecordInputReader {
52 | // Path(s) to `TFRecordFile`s.
53 | repeated string input_path = 1;
54 | }
55 |
56 | // An externally defined input reader. Users may define an extension to this
57 | // proto to interface their own input readers.
58 | message ExternalInputReader {
59 | extensions 1 to 999;
60 | }
61 |
--------------------------------------------------------------------------------
/protos/ssd_anchor_generator.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for SSD anchor generator described in
6 | // https://arxiv.org/abs/1512.02325. See
7 | // anchor_generators/multiple_grid_anchor_generator.py for details.
8 | message SsdAnchorGenerator {
9 | // Number of grid layers to create anchors for.
10 | optional int32 num_layers = 1 [default = 6];
11 |
12 | // Scale of anchors corresponding to finest resolution.
13 | optional float min_scale = 2 [default = 0.2];
14 |
15 | // Scale of anchors corresponding to coarsest resolution
16 | optional float max_scale = 3 [default = 0.95];
17 |
18 | // Can be used to override min_scale->max_scale, with an explicitly defined
19 | // set of scales. If empty, then min_scale->max_scale is used.
20 | repeated float scales = 12;
21 |
22 | // Aspect ratios for anchors at each grid point.
23 | repeated float aspect_ratios = 4;
24 |
25 | // When this aspect ratio is greater than 0, then an additional
26 | // anchor, with an interpolated scale is added with this aspect ratio.
27 | optional float interpolated_scale_aspect_ratio = 13 [default = 1.0];
28 |
29 | // Whether to use the following aspect ratio and scale combination for the
30 | // layer with the finest resolution : (scale=0.1, aspect_ratio=1.0),
31 | // (scale=min_scale, aspect_ration=2.0), (scale=min_scale, aspect_ratio=0.5).
32 | optional bool reduce_boxes_in_lowest_layer = 5 [default = true];
33 |
34 | // The base anchor size in height dimension.
35 | optional float base_anchor_height = 6 [default = 1.0];
36 |
37 | // The base anchor size in width dimension.
38 | optional float base_anchor_width = 7 [default = 1.0];
39 |
40 | // Anchor stride in height dimension in pixels for each layer. The length of
41 | // this field is expected to be equal to the value of num_layers.
42 | repeated int32 height_stride = 8;
43 |
44 | // Anchor stride in width dimension in pixels for each layer. The length of
45 | // this field is expected to be equal to the value of num_layers.
46 | repeated int32 width_stride = 9;
47 |
48 | // Anchor height offset in pixels for each layer. The length of this field is
49 | // expected to be equal to the value of num_layers.
50 | repeated int32 height_offset = 10;
51 |
52 | // Anchor width offset in pixels for each layer. The length of this field is
53 | // expected to be equal to the value of num_layers.
54 | repeated int32 width_offset = 11;
55 | }
56 |
--------------------------------------------------------------------------------
/utils/category_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Functions for importing/exporting Object Detection categories."""
17 | import csv
18 |
19 | import tensorflow as tf
20 |
21 |
22 | def load_categories_from_csv_file(csv_path):
23 | """Loads categories from a csv file.
24 |
25 | The CSV file should have one comma delimited numeric category id and string
26 | category name pair per line. For example:
27 |
28 | 0,"cat"
29 | 1,"dog"
30 | 2,"bird"
31 | ...
32 |
33 | Args:
34 | csv_path: Path to the csv file to be parsed into categories.
35 | Returns:
36 | categories: A list of dictionaries representing all possible categories.
37 | The categories will contain an integer 'id' field and a string
38 | 'name' field.
39 | Raises:
40 | ValueError: If the csv file is incorrectly formatted.
41 | """
42 | categories = []
43 |
44 | with tf.gfile.Open(csv_path, 'r') as csvfile:
45 | reader = csv.reader(csvfile, delimiter=',', quotechar='"')
46 | for row in reader:
47 | if not row:
48 | continue
49 |
50 | if len(row) != 2:
51 | raise ValueError('Expected 2 fields per row in csv: %s' % ','.join(row))
52 |
53 | category_id = int(row[0])
54 | category_name = row[1]
55 | categories.append({'id': category_id, 'name': category_name})
56 |
57 | return categories
58 |
59 |
60 | def save_categories_to_csv_file(categories, csv_path):
61 | """Saves categories to a csv file.
62 |
63 | Args:
64 | categories: A list of dictionaries representing categories to save to file.
65 | Each category must contain an 'id' and 'name' field.
66 | csv_path: Path to the csv file to be parsed into categories.
67 | """
68 | categories.sort(key=lambda x: x['id'])
69 | with tf.gfile.Open(csv_path, 'w') as csvfile:
70 | writer = csv.writer(csvfile, delimiter=',', quotechar='"')
71 | for category in categories:
72 | writer.writerow([category['id'], category['name']])
73 |
--------------------------------------------------------------------------------
/utils/np_box_ops_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.np_box_ops."""
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 | from object_detection.utils import np_box_ops
22 |
23 |
24 | class BoxOpsTests(tf.test.TestCase):
25 |
26 | def setUp(self):
27 | boxes1 = np.array([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]],
28 | dtype=float)
29 | boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
30 | [0.0, 0.0, 20.0, 20.0]],
31 | dtype=float)
32 | self.boxes1 = boxes1
33 | self.boxes2 = boxes2
34 |
35 | def testArea(self):
36 | areas = np_box_ops.area(self.boxes1)
37 | expected_areas = np.array([6.0, 5.0], dtype=float)
38 | self.assertAllClose(expected_areas, areas)
39 |
40 | def testIntersection(self):
41 | intersection = np_box_ops.intersection(self.boxes1, self.boxes2)
42 | expected_intersection = np.array([[2.0, 0.0, 6.0], [1.0, 0.0, 5.0]],
43 | dtype=float)
44 | self.assertAllClose(intersection, expected_intersection)
45 |
46 | def testIOU(self):
47 | iou = np_box_ops.iou(self.boxes1, self.boxes2)
48 | expected_iou = np.array([[2.0 / 16.0, 0.0, 6.0 / 400.0],
49 | [1.0 / 16.0, 0.0, 5.0 / 400.0]],
50 | dtype=float)
51 | self.assertAllClose(iou, expected_iou)
52 |
53 | def testIOA(self):
54 | boxes1 = np.array([[0.25, 0.25, 0.75, 0.75],
55 | [0.0, 0.0, 0.5, 0.75]],
56 | dtype=np.float32)
57 | boxes2 = np.array([[0.5, 0.25, 1.0, 1.0],
58 | [0.0, 0.0, 1.0, 1.0]],
59 | dtype=np.float32)
60 | ioa21 = np_box_ops.ioa(boxes2, boxes1)
61 | expected_ioa21 = np.array([[0.5, 0.0],
62 | [1.0, 1.0]],
63 | dtype=np.float32)
64 | self.assertAllClose(ioa21, expected_ioa21)
65 |
66 |
67 | if __name__ == '__main__':
68 | tf.test.main()
69 |
--------------------------------------------------------------------------------
/utils/test_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.test_utils."""
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 | from object_detection.utils import test_utils
22 |
23 |
24 | class TestUtilsTest(tf.test.TestCase):
25 |
26 | def test_diagonal_gradient_image(self):
27 | """Tests if a good pyramid image is created."""
28 | pyramid_image = test_utils.create_diagonal_gradient_image(3, 4, 2)
29 |
30 | # Test which is easy to understand.
31 | expected_first_channel = np.array([[3, 2, 1, 0],
32 | [4, 3, 2, 1],
33 | [5, 4, 3, 2]], dtype=np.float32)
34 | self.assertAllEqual(np.squeeze(pyramid_image[:, :, 0]),
35 | expected_first_channel)
36 |
37 | # Actual test.
38 | expected_image = np.array([[[3, 30],
39 | [2, 20],
40 | [1, 10],
41 | [0, 0]],
42 | [[4, 40],
43 | [3, 30],
44 | [2, 20],
45 | [1, 10]],
46 | [[5, 50],
47 | [4, 40],
48 | [3, 30],
49 | [2, 20]]], dtype=np.float32)
50 |
51 | self.assertAllEqual(pyramid_image, expected_image)
52 |
53 | def test_random_boxes(self):
54 | """Tests if valid random boxes are created."""
55 | num_boxes = 1000
56 | max_height = 3
57 | max_width = 5
58 | boxes = test_utils.create_random_boxes(num_boxes,
59 | max_height,
60 | max_width)
61 |
62 | true_column = np.ones(shape=(num_boxes)) == 1
63 | self.assertAllEqual(boxes[:, 0] < boxes[:, 2], true_column)
64 | self.assertAllEqual(boxes[:, 1] < boxes[:, 3], true_column)
65 |
66 | self.assertTrue(boxes[:, 0].min() >= 0)
67 | self.assertTrue(boxes[:, 1].min() >= 0)
68 | self.assertTrue(boxes[:, 2].max() <= max_height)
69 | self.assertTrue(boxes[:, 3].max() <= max_width)
70 |
71 |
72 | if __name__ == '__main__':
73 | tf.test.main()
74 |
--------------------------------------------------------------------------------
/utils/dataset_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility functions for creating TFRecord data sets."""
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def int64_feature(value):
22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
23 |
24 |
25 | def int64_list_feature(value):
26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
27 |
28 |
29 | def bytes_feature(value):
30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
31 |
32 |
33 | def bytes_list_feature(value):
34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
35 |
36 |
37 | def float_list_feature(value):
38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
39 |
40 |
41 | def read_examples_list(path):
42 | """Read list of training or validation examples.
43 |
44 | The file is assumed to contain a single example per line where the first
45 | token in the line is an identifier that allows us to find the image and
46 | annotation xml for that example.
47 |
48 | For example, the line:
49 | xyz 3
50 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored).
51 |
52 | Args:
53 | path: absolute path to examples list file.
54 |
55 | Returns:
56 | list of example identifiers (strings).
57 | """
58 | with tf.gfile.GFile(path) as fid:
59 | lines = fid.readlines()
60 | return [line.strip().split(' ')[0] for line in lines]
61 |
62 |
63 | def recursive_parse_xml_to_dict(xml):
64 | """Recursively parses XML contents to python dict.
65 |
66 | We assume that `object` tags are the only ones that can appear
67 | multiple times at the same level of a tree.
68 |
69 | Args:
70 | xml: xml tree obtained by parsing XML file contents using lxml.etree
71 |
72 | Returns:
73 | Python dictionary holding XML contents.
74 | """
75 | if not xml:
76 | return {xml.tag: xml.text}
77 | result = {}
78 | for child in xml:
79 | child_result = recursive_parse_xml_to_dict(child)
80 | if child.tag != 'object':
81 | result[child.tag] = child_result[child.tag]
82 | else:
83 | if child.tag not in result:
84 | result[child.tag] = []
85 | result[child.tag].append(child_result[child.tag])
86 | return {xml.tag: result}
87 |
--------------------------------------------------------------------------------
/protos/train.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/optimizer.proto";
6 | import "object_detection/protos/preprocessor.proto";
7 |
8 | // Message for configuring DetectionModel training jobs (train.py).
9 | message TrainConfig {
10 | // Input queue batch size.
11 | optional uint32 batch_size = 1 [default=32];
12 |
13 | // Data augmentation options.
14 | repeated PreprocessingStep data_augmentation_options = 2;
15 |
16 | // Whether to synchronize replicas during training.
17 | optional bool sync_replicas = 3 [default=false];
18 |
19 | // How frequently to keep checkpoints.
20 | optional uint32 keep_checkpoint_every_n_hours = 4 [default=1000];
21 |
22 | // Optimizer used to train the DetectionModel.
23 | optional Optimizer optimizer = 5;
24 |
25 | // If greater than 0, clips gradients by this value.
26 | optional float gradient_clipping_by_norm = 6 [default=0.0];
27 |
28 | // Checkpoint to restore variables from. Typically used to load feature
29 | // extractor variables trained outside of object detection.
30 | optional string fine_tune_checkpoint = 7 [default=""];
31 |
32 | // Specifies if the finetune checkpoint is from an object detection model.
33 | // If from an object detection model, the model being trained should have
34 | // the same parameters with the exception of the num_classes parameter.
35 | // If false, it assumes the checkpoint was a object classification model.
36 | optional bool from_detection_checkpoint = 8 [default=false];
37 |
38 | // Number of steps to train the DetectionModel for. If 0, will train the model
39 | // indefinitely.
40 | optional uint32 num_steps = 9 [default=0];
41 |
42 | // Number of training steps between replica startup.
43 | // This flag must be set to 0 if sync_replicas is set to true.
44 | optional float startup_delay_steps = 10 [default=15];
45 |
46 | // If greater than 0, multiplies the gradient of bias variables by this
47 | // amount.
48 | optional float bias_grad_multiplier = 11 [default=0];
49 |
50 | // Variables that should not be updated during training.
51 | repeated string freeze_variables = 12;
52 |
53 | // Number of replicas to aggregate before making parameter updates.
54 | optional int32 replicas_to_aggregate = 13 [default=1];
55 |
56 | // Maximum number of elements to store within a queue.
57 | optional int32 batch_queue_capacity = 14 [default=150];
58 |
59 | // Number of threads to use for batching.
60 | optional int32 num_batch_queue_threads = 15 [default=8];
61 |
62 | // Maximum capacity of the queue used to prefetch assembled batches.
63 | optional int32 prefetch_queue_capacity = 16 [default=5];
64 |
65 | // If true, boxes with the same coordinates will be merged together.
66 | // This is useful when each box can have multiple labels.
67 | // Note that only Sigmoid classification losses should be used.
68 | optional bool merge_multiple_label_boxes = 17 [default=false];
69 | }
70 |
--------------------------------------------------------------------------------
/protos/square_box_coder_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/square_box_coder.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/square_box_coder.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n.object_detection/protos/square_box_coder.proto\x12\x17object_detection.protos\"S\n\x0eSquareBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0clength_scale\x18\x03 \x01(\x02:\x01\x35')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _SQUAREBOXCODER = _descriptor.Descriptor(
29 | name='SquareBoxCoder',
30 | full_name='object_detection.protos.SquareBoxCoder',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='y_scale', full_name='object_detection.protos.SquareBoxCoder.y_scale', index=0,
37 | number=1, type=2, cpp_type=6, label=1,
38 | has_default_value=True, default_value=10,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='x_scale', full_name='object_detection.protos.SquareBoxCoder.x_scale', index=1,
44 | number=2, type=2, cpp_type=6, label=1,
45 | has_default_value=True, default_value=10,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='length_scale', full_name='object_detection.protos.SquareBoxCoder.length_scale', index=2,
51 | number=3, type=2, cpp_type=6, label=1,
52 | has_default_value=True, default_value=5,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | ],
57 | extensions=[
58 | ],
59 | nested_types=[],
60 | enum_types=[
61 | ],
62 | options=None,
63 | is_extendable=False,
64 | extension_ranges=[],
65 | oneofs=[
66 | ],
67 | serialized_start=75,
68 | serialized_end=158,
69 | )
70 |
71 | DESCRIPTOR.message_types_by_name['SquareBoxCoder'] = _SQUAREBOXCODER
72 |
73 | SquareBoxCoder = _reflection.GeneratedProtocolMessageType('SquareBoxCoder', (_message.Message,), dict(
74 | DESCRIPTOR = _SQUAREBOXCODER,
75 | __module__ = 'object_detection.protos.square_box_coder_pb2'
76 | # @@protoc_insertion_point(class_scope:object_detection.protos.SquareBoxCoder)
77 | ))
78 | _sym_db.RegisterMessage(SquareBoxCoder)
79 |
80 |
81 | # @@protoc_insertion_point(module_scope)
82 |
--------------------------------------------------------------------------------
/protos/optimizer.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Messages for configuring the optimizing strategy for training object
6 | // detection models.
7 |
8 | // Top level optimizer message.
9 | message Optimizer {
10 | oneof optimizer {
11 | RMSPropOptimizer rms_prop_optimizer = 1;
12 | MomentumOptimizer momentum_optimizer = 2;
13 | AdamOptimizer adam_optimizer = 3;
14 | }
15 | optional bool use_moving_average = 4 [default = true];
16 | optional float moving_average_decay = 5 [default = 0.9999];
17 | }
18 |
19 | // Configuration message for the RMSPropOptimizer
20 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
21 | message RMSPropOptimizer {
22 | optional LearningRate learning_rate = 1;
23 | optional float momentum_optimizer_value = 2 [default = 0.9];
24 | optional float decay = 3 [default = 0.9];
25 | optional float epsilon = 4 [default = 1.0];
26 | }
27 |
28 | // Configuration message for the MomentumOptimizer
29 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
30 | message MomentumOptimizer {
31 | optional LearningRate learning_rate = 1;
32 | optional float momentum_optimizer_value = 2 [default = 0.9];
33 | }
34 |
35 | // Configuration message for the AdamOptimizer
36 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
37 | message AdamOptimizer {
38 | optional LearningRate learning_rate = 1;
39 | }
40 |
41 | // Configuration message for optimizer learning rate.
42 | message LearningRate {
43 | oneof learning_rate {
44 | ConstantLearningRate constant_learning_rate = 1;
45 | ExponentialDecayLearningRate exponential_decay_learning_rate = 2;
46 | ManualStepLearningRate manual_step_learning_rate = 3;
47 | CosineDecayLearningRate cosine_decay_learning_rate = 4;
48 | }
49 | }
50 |
51 | // Configuration message for a constant learning rate.
52 | message ConstantLearningRate {
53 | optional float learning_rate = 1 [default = 0.002];
54 | }
55 |
56 | // Configuration message for an exponentially decaying learning rate.
57 | // See https://www.tensorflow.org/versions/master/api_docs/python/train/ \
58 | // decaying_the_learning_rate#exponential_decay
59 | message ExponentialDecayLearningRate {
60 | optional float initial_learning_rate = 1 [default = 0.002];
61 | optional uint32 decay_steps = 2 [default = 4000000];
62 | optional float decay_factor = 3 [default = 0.95];
63 | optional bool staircase = 4 [default = true];
64 | }
65 |
66 | // Configuration message for a manually defined learning rate schedule.
67 | message ManualStepLearningRate {
68 | optional float initial_learning_rate = 1 [default = 0.002];
69 | message LearningRateSchedule {
70 | optional uint32 step = 1;
71 | optional float learning_rate = 2 [default = 0.002];
72 | }
73 | repeated LearningRateSchedule schedule = 2;
74 | }
75 |
76 | // Configuration message for a cosine decaying learning rate as defined in
77 | // object_detection/utils/learning_schedules.py
78 | message CosineDecayLearningRate {
79 | optional float learning_rate_base = 1 [default = 0.002];
80 | optional uint32 total_steps = 2 [default = 4000000];
81 | optional float warmup_learning_rate = 3 [default = 0.0002];
82 | optional uint32 warmup_steps = 4 [default = 10000];
83 | }
84 |
--------------------------------------------------------------------------------
/protos/ssd.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 | package object_detection.protos;
3 |
4 | import "object_detection/protos/anchor_generator.proto";
5 | import "object_detection/protos/box_coder.proto";
6 | import "object_detection/protos/box_predictor.proto";
7 | import "object_detection/protos/hyperparams.proto";
8 | import "object_detection/protos/image_resizer.proto";
9 | import "object_detection/protos/matcher.proto";
10 | import "object_detection/protos/losses.proto";
11 | import "object_detection/protos/post_processing.proto";
12 | import "object_detection/protos/region_similarity_calculator.proto";
13 |
14 | // Configuration for Single Shot Detection (SSD) models.
15 | message Ssd {
16 |
17 | // Number of classes to predict.
18 | optional int32 num_classes = 1;
19 |
20 | // Image resizer for preprocessing the input image.
21 | optional ImageResizer image_resizer = 2;
22 |
23 | // Feature extractor config.
24 | optional SsdFeatureExtractor feature_extractor = 3;
25 |
26 | // Box coder to encode the boxes.
27 | optional BoxCoder box_coder = 4;
28 |
29 | // Matcher to match groundtruth with anchors.
30 | optional Matcher matcher = 5;
31 |
32 | // Region similarity calculator to compute similarity of boxes.
33 | optional RegionSimilarityCalculator similarity_calculator = 6;
34 |
35 | // Box predictor to attach to the features.
36 | optional BoxPredictor box_predictor = 7;
37 |
38 | // Anchor generator to compute anchors.
39 | optional AnchorGenerator anchor_generator = 8;
40 |
41 | // Post processing to apply on the predictions.
42 | optional PostProcessing post_processing = 9;
43 |
44 | // Whether to normalize the loss by number of groundtruth boxes that match to
45 | // the anchors.
46 | optional bool normalize_loss_by_num_matches = 10 [default=true];
47 |
48 | // Loss configuration for training.
49 | optional Loss loss = 11;
50 | }
51 |
52 |
53 | message SsdFeatureExtractor {
54 | // Type of ssd feature extractor.
55 | optional string type = 1;
56 |
57 | // The factor to alter the depth of the channels in the feature extractor.
58 | optional float depth_multiplier = 2 [default=1.0];
59 |
60 | // Minimum number of the channels in the feature extractor.
61 | optional int32 min_depth = 3 [default=16];
62 |
63 | // Hyperparameters for the feature extractor.
64 | optional Hyperparams conv_hyperparams = 4;
65 |
66 | // The nearest multiple to zero-pad the input height and width dimensions to.
67 | // For example, if pad_to_multiple = 2, input dimensions are zero-padded
68 | // until the resulting dimensions are even.
69 | optional int32 pad_to_multiple = 5 [default = 1];
70 |
71 | // Whether to update batch norm parameters during training or not.
72 | // When training with a relative small batch size (e.g. 1), it is
73 | // desirable to disable batch norm update and use pretrained batch norm
74 | // params.
75 | //
76 | // Note: Some feature extractors are used with canned arg_scopes
77 | // (e.g resnet arg scopes). In these cases training behavior of batch norm
78 | // variables may depend on both values of `batch_norm_trainable` and
79 | // `is_training`.
80 | //
81 | // When canned arg_scopes are used with feature extractors `conv_hyperparams`
82 | // will apply only to the additional layers that are added and are outside the
83 | // canned arg_scope.
84 | optional bool batch_norm_trainable = 6 [default=true];
85 | }
86 |
--------------------------------------------------------------------------------
/utils/learning_schedules_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.learning_schedules."""
17 | import tensorflow as tf
18 |
19 | from object_detection.utils import learning_schedules
20 |
21 |
22 | class LearningSchedulesTest(tf.test.TestCase):
23 |
24 | def testExponentialDecayWithBurnin(self):
25 | global_step = tf.placeholder(tf.int32, [])
26 | learning_rate_base = 1.0
27 | learning_rate_decay_steps = 3
28 | learning_rate_decay_factor = .1
29 | burnin_learning_rate = .5
30 | burnin_steps = 2
31 | exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01]
32 | learning_rate = learning_schedules.exponential_decay_with_burnin(
33 | global_step, learning_rate_base, learning_rate_decay_steps,
34 | learning_rate_decay_factor, burnin_learning_rate, burnin_steps)
35 | with self.test_session() as sess:
36 | output_rates = []
37 | for input_global_step in range(8):
38 | output_rate = sess.run(learning_rate,
39 | feed_dict={global_step: input_global_step})
40 | output_rates.append(output_rate)
41 | self.assertAllClose(output_rates, exp_rates)
42 |
43 | def testCosineDecayWithWarmup(self):
44 | global_step = tf.placeholder(tf.int32, [])
45 | learning_rate_base = 1.0
46 | total_steps = 100
47 | warmup_learning_rate = 0.1
48 | warmup_steps = 9
49 | input_global_steps = [0, 4, 8, 9, 100]
50 | exp_rates = [0.1, 0.5, 0.9, 1.0, 0]
51 | learning_rate = learning_schedules.cosine_decay_with_warmup(
52 | global_step, learning_rate_base, total_steps,
53 | warmup_learning_rate, warmup_steps)
54 | with self.test_session() as sess:
55 | output_rates = []
56 | for input_global_step in input_global_steps:
57 | output_rate = sess.run(learning_rate,
58 | feed_dict={global_step: input_global_step})
59 | output_rates.append(output_rate)
60 | self.assertAllClose(output_rates, exp_rates)
61 |
62 | def testManualStepping(self):
63 | global_step = tf.placeholder(tf.int64, [])
64 | boundaries = [2, 3, 7]
65 | rates = [1.0, 2.0, 3.0, 4.0]
66 | exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0]
67 | learning_rate = learning_schedules.manual_stepping(global_step, boundaries,
68 | rates)
69 | with self.test_session() as sess:
70 | output_rates = []
71 | for input_global_step in range(10):
72 | output_rate = sess.run(learning_rate,
73 | feed_dict={global_step: input_global_step})
74 | output_rates.append(output_rate)
75 | self.assertAllClose(output_rates, exp_rates)
76 |
77 | if __name__ == '__main__':
78 | tf.test.main()
79 |
--------------------------------------------------------------------------------
/protos/hyperparams.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Configuration proto for the convolution op hyperparameters to use in the
6 | // object detection pipeline.
7 | message Hyperparams {
8 |
9 | // Operations affected by hyperparameters.
10 | enum Op {
11 | // Convolution, Separable Convolution, Convolution transpose.
12 | CONV = 1;
13 |
14 | // Fully connected
15 | FC = 2;
16 | }
17 | optional Op op = 1 [default = CONV];
18 |
19 | // Regularizer for the weights of the convolution op.
20 | optional Regularizer regularizer = 2;
21 |
22 | // Initializer for the weights of the convolution op.
23 | optional Initializer initializer = 3;
24 |
25 | // Type of activation to apply after convolution.
26 | enum Activation {
27 | // Use None (no activation)
28 | NONE = 0;
29 |
30 | // Use tf.nn.relu
31 | RELU = 1;
32 |
33 | // Use tf.nn.relu6
34 | RELU_6 = 2;
35 | }
36 | optional Activation activation = 4 [default = RELU];
37 |
38 | // BatchNorm hyperparameters. If this parameter is NOT set then BatchNorm is
39 | // not applied!
40 | optional BatchNorm batch_norm = 5;
41 | }
42 |
43 | // Proto with one-of field for regularizers.
44 | message Regularizer {
45 | oneof regularizer_oneof {
46 | L1Regularizer l1_regularizer = 1;
47 | L2Regularizer l2_regularizer = 2;
48 | }
49 | }
50 |
51 | // Configuration proto for L1 Regularizer.
52 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l1_regularizer
53 | message L1Regularizer {
54 | optional float weight = 1 [default = 1.0];
55 | }
56 |
57 | // Configuration proto for L2 Regularizer.
58 | // See https://www.tensorflow.org/api_docs/python/tf/contrib/layers/l2_regularizer
59 | message L2Regularizer {
60 | optional float weight = 1 [default = 1.0];
61 | }
62 |
63 | // Proto with one-of field for initializers.
64 | message Initializer {
65 | oneof initializer_oneof {
66 | TruncatedNormalInitializer truncated_normal_initializer = 1;
67 | VarianceScalingInitializer variance_scaling_initializer = 2;
68 | }
69 | }
70 |
71 | // Configuration proto for truncated normal initializer. See
72 | // https://www.tensorflow.org/api_docs/python/tf/truncated_normal_initializer
73 | message TruncatedNormalInitializer {
74 | optional float mean = 1 [default = 0.0];
75 | optional float stddev = 2 [default = 1.0];
76 | }
77 |
78 | // Configuration proto for variance scaling initializer. See
79 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/
80 | // variance_scaling_initializer
81 | message VarianceScalingInitializer {
82 | optional float factor = 1 [default = 2.0];
83 | optional bool uniform = 2 [default = false];
84 | enum Mode {
85 | FAN_IN = 0;
86 | FAN_OUT = 1;
87 | FAN_AVG = 2;
88 | }
89 | optional Mode mode = 3 [default = FAN_IN];
90 | }
91 |
92 | // Configuration proto for batch norm to apply after convolution op. See
93 | // https://www.tensorflow.org/api_docs/python/tf/contrib/layers/batch_norm
94 | message BatchNorm {
95 | optional float decay = 1 [default = 0.999];
96 | optional bool center = 2 [default = true];
97 | optional bool scale = 3 [default = false];
98 | optional float epsilon = 4 [default = 0.001];
99 | // Whether to train the batch norm variables. If this is set to false during
100 | // training, the current value of the batch_norm variables are used for
101 | // forward pass but they are never updated.
102 | optional bool train = 5 [default = true];
103 | }
104 |
--------------------------------------------------------------------------------
/utils/np_box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Operations for [N, 4] numpy arrays representing bounding boxes.
17 |
18 | Example box operations that are supported:
19 | * Areas: compute bounding box areas
20 | * IOU: pairwise intersection-over-union scores
21 | """
22 | import numpy as np
23 |
24 |
25 | def area(boxes):
26 | """Computes area of boxes.
27 |
28 | Args:
29 | boxes: Numpy array with shape [N, 4] holding N boxes
30 |
31 | Returns:
32 | a numpy array with shape [N*1] representing box areas
33 | """
34 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
35 |
36 |
37 | def intersection(boxes1, boxes2):
38 | """Compute pairwise intersection areas between boxes.
39 |
40 | Args:
41 | boxes1: a numpy array with shape [N, 4] holding N boxes
42 | boxes2: a numpy array with shape [M, 4] holding M boxes
43 |
44 | Returns:
45 | a numpy array with shape [N*M] representing pairwise intersection area
46 | """
47 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1)
48 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1)
49 |
50 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2))
51 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2))
52 | intersect_heights = np.maximum(
53 | np.zeros(all_pairs_max_ymin.shape),
54 | all_pairs_min_ymax - all_pairs_max_ymin)
55 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2))
56 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2))
57 | intersect_widths = np.maximum(
58 | np.zeros(all_pairs_max_xmin.shape),
59 | all_pairs_min_xmax - all_pairs_max_xmin)
60 | return intersect_heights * intersect_widths
61 |
62 |
63 | def iou(boxes1, boxes2):
64 | """Computes pairwise intersection-over-union between box collections.
65 |
66 | Args:
67 | boxes1: a numpy array with shape [N, 4] holding N boxes.
68 | boxes2: a numpy array with shape [M, 4] holding N boxes.
69 |
70 | Returns:
71 | a numpy array with shape [N, M] representing pairwise iou scores.
72 | """
73 | intersect = intersection(boxes1, boxes2)
74 | area1 = area(boxes1)
75 | area2 = area(boxes2)
76 | union = np.expand_dims(area1, axis=1) + np.expand_dims(
77 | area2, axis=0) - intersect
78 | return intersect / union
79 |
80 |
81 | def ioa(boxes1, boxes2):
82 | """Computes pairwise intersection-over-area between box collections.
83 |
84 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
85 | their intersection area over box2's area. Note that ioa is not symmetric,
86 | that is, IOA(box1, box2) != IOA(box2, box1).
87 |
88 | Args:
89 | boxes1: a numpy array with shape [N, 4] holding N boxes.
90 | boxes2: a numpy array with shape [M, 4] holding N boxes.
91 |
92 | Returns:
93 | a numpy array with shape [N, M] representing pairwise ioa scores.
94 | """
95 | intersect = intersection(boxes1, boxes2)
96 | areas = np.expand_dims(area(boxes2), axis=0)
97 | return intersect / areas
98 |
--------------------------------------------------------------------------------
/protos/faster_rcnn_box_coder_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/faster_rcnn_box_coder.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/faster_rcnn_box_coder.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n3object_detection/protos/faster_rcnn_box_coder.proto\x12\x17object_detection.protos\"o\n\x12\x46\x61sterRcnnBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0cheight_scale\x18\x03 \x01(\x02:\x01\x35\x12\x16\n\x0bwidth_scale\x18\x04 \x01(\x02:\x01\x35')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _FASTERRCNNBOXCODER = _descriptor.Descriptor(
29 | name='FasterRcnnBoxCoder',
30 | full_name='object_detection.protos.FasterRcnnBoxCoder',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='y_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.y_scale', index=0,
37 | number=1, type=2, cpp_type=6, label=1,
38 | has_default_value=True, default_value=10,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='x_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.x_scale', index=1,
44 | number=2, type=2, cpp_type=6, label=1,
45 | has_default_value=True, default_value=10,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='height_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.height_scale', index=2,
51 | number=3, type=2, cpp_type=6, label=1,
52 | has_default_value=True, default_value=5,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='width_scale', full_name='object_detection.protos.FasterRcnnBoxCoder.width_scale', index=3,
58 | number=4, type=2, cpp_type=6, label=1,
59 | has_default_value=True, default_value=5,
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | ],
64 | extensions=[
65 | ],
66 | nested_types=[],
67 | enum_types=[
68 | ],
69 | options=None,
70 | is_extendable=False,
71 | extension_ranges=[],
72 | oneofs=[
73 | ],
74 | serialized_start=80,
75 | serialized_end=191,
76 | )
77 |
78 | DESCRIPTOR.message_types_by_name['FasterRcnnBoxCoder'] = _FASTERRCNNBOXCODER
79 |
80 | FasterRcnnBoxCoder = _reflection.GeneratedProtocolMessageType('FasterRcnnBoxCoder', (_message.Message,), dict(
81 | DESCRIPTOR = _FASTERRCNNBOXCODER,
82 | __module__ = 'object_detection.protos.faster_rcnn_box_coder_pb2'
83 | # @@protoc_insertion_point(class_scope:object_detection.protos.FasterRcnnBoxCoder)
84 | ))
85 | _sym_db.RegisterMessage(FasterRcnnBoxCoder)
86 |
87 |
88 | # @@protoc_insertion_point(module_scope)
89 |
--------------------------------------------------------------------------------
/utils/metrics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.metrics."""
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 | from object_detection.utils import metrics
22 |
23 |
24 | class MetricsTest(tf.test.TestCase):
25 |
26 | def test_compute_cor_loc(self):
27 | num_gt_imgs_per_class = np.array([100, 1, 5, 1, 1], dtype=int)
28 | num_images_correctly_detected_per_class = np.array([10, 0, 1, 0, 0],
29 | dtype=int)
30 | corloc = metrics.compute_cor_loc(num_gt_imgs_per_class,
31 | num_images_correctly_detected_per_class)
32 | expected_corloc = np.array([0.1, 0, 0.2, 0, 0], dtype=float)
33 | self.assertTrue(np.allclose(corloc, expected_corloc))
34 |
35 | def test_compute_cor_loc_nans(self):
36 | num_gt_imgs_per_class = np.array([100, 0, 0, 1, 1], dtype=int)
37 | num_images_correctly_detected_per_class = np.array([10, 0, 1, 0, 0],
38 | dtype=int)
39 | corloc = metrics.compute_cor_loc(num_gt_imgs_per_class,
40 | num_images_correctly_detected_per_class)
41 | expected_corloc = np.array([0.1, np.nan, np.nan, 0, 0], dtype=float)
42 | self.assertAllClose(corloc, expected_corloc)
43 |
44 | def test_compute_precision_recall(self):
45 | num_gt = 10
46 | scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float)
47 | labels = np.array([0, 1, 1, 0, 0, 1], dtype=bool)
48 | accumulated_tp_count = np.array([0, 1, 1, 2, 2, 3], dtype=float)
49 | expected_precision = accumulated_tp_count / np.array([1, 2, 3, 4, 5, 6])
50 | expected_recall = accumulated_tp_count / num_gt
51 | precision, recall = metrics.compute_precision_recall(scores, labels, num_gt)
52 | self.assertAllClose(precision, expected_precision)
53 | self.assertAllClose(recall, expected_recall)
54 |
55 | def test_compute_average_precision(self):
56 | precision = np.array([0.8, 0.76, 0.9, 0.65, 0.7, 0.5, 0.55, 0], dtype=float)
57 | recall = np.array([0.3, 0.3, 0.4, 0.4, 0.45, 0.45, 0.5, 0.5], dtype=float)
58 | processed_precision = np.array([0.9, 0.9, 0.9, 0.7, 0.7, 0.55, 0.55, 0],
59 | dtype=float)
60 | recall_interval = np.array([0.3, 0, 0.1, 0, 0.05, 0, 0.05, 0], dtype=float)
61 | expected_mean_ap = np.sum(recall_interval * processed_precision)
62 | mean_ap = metrics.compute_average_precision(precision, recall)
63 | self.assertAlmostEqual(expected_mean_ap, mean_ap)
64 |
65 | def test_compute_precision_recall_and_ap_no_groundtruth(self):
66 | num_gt = 0
67 | scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float)
68 | labels = np.array([0, 0, 0, 0, 0, 0], dtype=bool)
69 | expected_precision = None
70 | expected_recall = None
71 | precision, recall = metrics.compute_precision_recall(scores, labels, num_gt)
72 | self.assertEquals(precision, expected_precision)
73 | self.assertEquals(recall, expected_recall)
74 | ap = metrics.compute_average_precision(precision, recall)
75 | self.assertTrue(np.isnan(ap))
76 |
77 |
78 | if __name__ == '__main__':
79 | tf.test.main()
80 |
--------------------------------------------------------------------------------
/protos/model_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/model.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | import object_detection.protos.faster_rcnn_pb2
17 | import object_detection.protos.ssd_pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='object_detection/protos/model.proto',
22 | package='object_detection.protos',
23 | serialized_pb=_b('\n#object_detection/protos/model.proto\x12\x17object_detection.protos\x1a)object_detection/protos/faster_rcnn.proto\x1a!object_detection/protos/ssd.proto\"\x82\x01\n\x0e\x44\x65tectionModel\x12:\n\x0b\x66\x61ster_rcnn\x18\x01 \x01(\x0b\x32#.object_detection.protos.FasterRcnnH\x00\x12+\n\x03ssd\x18\x02 \x01(\x0b\x32\x1c.object_detection.protos.SsdH\x00\x42\x07\n\x05model')
24 | ,
25 | dependencies=[object_detection.protos.faster_rcnn_pb2.DESCRIPTOR,object_detection.protos.ssd_pb2.DESCRIPTOR,])
26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
27 |
28 |
29 |
30 |
31 | _DETECTIONMODEL = _descriptor.Descriptor(
32 | name='DetectionModel',
33 | full_name='object_detection.protos.DetectionModel',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='faster_rcnn', full_name='object_detection.protos.DetectionModel.faster_rcnn', index=0,
40 | number=1, type=11, cpp_type=10, label=1,
41 | has_default_value=False, default_value=None,
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | options=None),
45 | _descriptor.FieldDescriptor(
46 | name='ssd', full_name='object_detection.protos.DetectionModel.ssd', index=1,
47 | number=2, type=11, cpp_type=10, label=1,
48 | has_default_value=False, default_value=None,
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | options=None),
52 | ],
53 | extensions=[
54 | ],
55 | nested_types=[],
56 | enum_types=[
57 | ],
58 | options=None,
59 | is_extendable=False,
60 | extension_ranges=[],
61 | oneofs=[
62 | _descriptor.OneofDescriptor(
63 | name='model', full_name='object_detection.protos.DetectionModel.model',
64 | index=0, containing_type=None, fields=[]),
65 | ],
66 | serialized_start=143,
67 | serialized_end=273,
68 | )
69 |
70 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].message_type = object_detection.protos.faster_rcnn_pb2._FASTERRCNN
71 | _DETECTIONMODEL.fields_by_name['ssd'].message_type = object_detection.protos.ssd_pb2._SSD
72 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append(
73 | _DETECTIONMODEL.fields_by_name['faster_rcnn'])
74 | _DETECTIONMODEL.fields_by_name['faster_rcnn'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model']
75 | _DETECTIONMODEL.oneofs_by_name['model'].fields.append(
76 | _DETECTIONMODEL.fields_by_name['ssd'])
77 | _DETECTIONMODEL.fields_by_name['ssd'].containing_oneof = _DETECTIONMODEL.oneofs_by_name['model']
78 | DESCRIPTOR.message_types_by_name['DetectionModel'] = _DETECTIONMODEL
79 |
80 | DetectionModel = _reflection.GeneratedProtocolMessageType('DetectionModel', (_message.Message,), dict(
81 | DESCRIPTOR = _DETECTIONMODEL,
82 | __module__ = 'object_detection.protos.model_pb2'
83 | # @@protoc_insertion_point(class_scope:object_detection.protos.DetectionModel)
84 | ))
85 | _sym_db.RegisterMessage(DetectionModel)
86 |
87 |
88 | # @@protoc_insertion_point(module_scope)
89 |
--------------------------------------------------------------------------------
/protos/matcher_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/matcher.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | import object_detection.protos.argmax_matcher_pb2
17 | import object_detection.protos.bipartite_matcher_pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='object_detection/protos/matcher.proto',
22 | package='object_detection.protos',
23 | serialized_pb=_b('\n%object_detection/protos/matcher.proto\x12\x17object_detection.protos\x1a,object_detection/protos/argmax_matcher.proto\x1a/object_detection/protos/bipartite_matcher.proto\"\xa4\x01\n\x07Matcher\x12@\n\x0e\x61rgmax_matcher\x18\x01 \x01(\x0b\x32&.object_detection.protos.ArgMaxMatcherH\x00\x12\x46\n\x11\x62ipartite_matcher\x18\x02 \x01(\x0b\x32).object_detection.protos.BipartiteMatcherH\x00\x42\x0f\n\rmatcher_oneof')
24 | ,
25 | dependencies=[object_detection.protos.argmax_matcher_pb2.DESCRIPTOR,object_detection.protos.bipartite_matcher_pb2.DESCRIPTOR,])
26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
27 |
28 |
29 |
30 |
31 | _MATCHER = _descriptor.Descriptor(
32 | name='Matcher',
33 | full_name='object_detection.protos.Matcher',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='argmax_matcher', full_name='object_detection.protos.Matcher.argmax_matcher', index=0,
40 | number=1, type=11, cpp_type=10, label=1,
41 | has_default_value=False, default_value=None,
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | options=None),
45 | _descriptor.FieldDescriptor(
46 | name='bipartite_matcher', full_name='object_detection.protos.Matcher.bipartite_matcher', index=1,
47 | number=2, type=11, cpp_type=10, label=1,
48 | has_default_value=False, default_value=None,
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | options=None),
52 | ],
53 | extensions=[
54 | ],
55 | nested_types=[],
56 | enum_types=[
57 | ],
58 | options=None,
59 | is_extendable=False,
60 | extension_ranges=[],
61 | oneofs=[
62 | _descriptor.OneofDescriptor(
63 | name='matcher_oneof', full_name='object_detection.protos.Matcher.matcher_oneof',
64 | index=0, containing_type=None, fields=[]),
65 | ],
66 | serialized_start=162,
67 | serialized_end=326,
68 | )
69 |
70 | _MATCHER.fields_by_name['argmax_matcher'].message_type = object_detection.protos.argmax_matcher_pb2._ARGMAXMATCHER
71 | _MATCHER.fields_by_name['bipartite_matcher'].message_type = object_detection.protos.bipartite_matcher_pb2._BIPARTITEMATCHER
72 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append(
73 | _MATCHER.fields_by_name['argmax_matcher'])
74 | _MATCHER.fields_by_name['argmax_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof']
75 | _MATCHER.oneofs_by_name['matcher_oneof'].fields.append(
76 | _MATCHER.fields_by_name['bipartite_matcher'])
77 | _MATCHER.fields_by_name['bipartite_matcher'].containing_oneof = _MATCHER.oneofs_by_name['matcher_oneof']
78 | DESCRIPTOR.message_types_by_name['Matcher'] = _MATCHER
79 |
80 | Matcher = _reflection.GeneratedProtocolMessageType('Matcher', (_message.Message,), dict(
81 | DESCRIPTOR = _MATCHER,
82 | __module__ = 'object_detection.protos.matcher_pb2'
83 | # @@protoc_insertion_point(class_scope:object_detection.protos.Matcher)
84 | ))
85 | _sym_db.RegisterMessage(Matcher)
86 |
87 |
88 | # @@protoc_insertion_point(module_scope)
89 |
--------------------------------------------------------------------------------
/protos/box_predictor.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/hyperparams.proto";
6 |
7 |
8 | // Configuration proto for box predictor. See core/box_predictor.py for details.
9 | message BoxPredictor {
10 | oneof box_predictor_oneof {
11 | ConvolutionalBoxPredictor convolutional_box_predictor = 1;
12 | MaskRCNNBoxPredictor mask_rcnn_box_predictor = 2;
13 | RfcnBoxPredictor rfcn_box_predictor = 3;
14 | }
15 | }
16 |
17 | // Configuration proto for Convolutional box predictor.
18 | message ConvolutionalBoxPredictor {
19 | // Hyperparameters for convolution ops used in the box predictor.
20 | optional Hyperparams conv_hyperparams = 1;
21 |
22 | // Minumum feature depth prior to predicting box encodings and class
23 | // predictions.
24 | optional int32 min_depth = 2 [default = 0];
25 |
26 | // Maximum feature depth prior to predicting box encodings and class
27 | // predictions. If max_depth is set to 0, no additional feature map will be
28 | // inserted before location and class predictions.
29 | optional int32 max_depth = 3 [default = 0];
30 |
31 | // Number of the additional conv layers before the predictor.
32 | optional int32 num_layers_before_predictor = 4 [default = 0];
33 |
34 | // Whether to use dropout for class prediction.
35 | optional bool use_dropout = 5 [default = true];
36 |
37 | // Keep probability for dropout
38 | optional float dropout_keep_probability = 6 [default = 0.8];
39 |
40 | // Size of final convolution kernel. If the spatial resolution of the feature
41 | // map is smaller than the kernel size, then the kernel size is set to
42 | // min(feature_width, feature_height).
43 | optional int32 kernel_size = 7 [default = 1];
44 |
45 | // Size of the encoding for boxes.
46 | optional int32 box_code_size = 8 [default = 4];
47 |
48 | // Whether to apply sigmoid to the output of class predictions.
49 | // TODO: Do we need this since we have a post processing module.?
50 | optional bool apply_sigmoid_to_scores = 9 [default = false];
51 |
52 | optional float class_prediction_bias_init = 10 [default = 0.0];
53 | }
54 |
55 | message MaskRCNNBoxPredictor {
56 | // Hyperparameters for fully connected ops used in the box predictor.
57 | optional Hyperparams fc_hyperparams = 1;
58 |
59 | // Whether to use dropout op prior to the both box and class predictions.
60 | optional bool use_dropout = 2 [default= false];
61 |
62 | // Keep probability for dropout. This is only used if use_dropout is true.
63 | optional float dropout_keep_probability = 3 [default = 0.5];
64 |
65 | // Size of the encoding for the boxes.
66 | optional int32 box_code_size = 4 [default = 4];
67 |
68 | // Hyperparameters for convolution ops used in the box predictor.
69 | optional Hyperparams conv_hyperparams = 5;
70 |
71 | // Whether to predict instance masks inside detection boxes.
72 | optional bool predict_instance_masks = 6 [default = false];
73 |
74 | // The depth for the first conv2d_transpose op applied to the
75 | // image_features in the mask prediciton branch
76 | optional int32 mask_prediction_conv_depth = 7 [default = 256];
77 |
78 | // Whether to predict keypoints inside detection boxes.
79 | optional bool predict_keypoints = 8 [default = false];
80 | }
81 |
82 | message RfcnBoxPredictor {
83 | // Hyperparameters for convolution ops used in the box predictor.
84 | optional Hyperparams conv_hyperparams = 1;
85 |
86 | // Bin sizes for RFCN crops.
87 | optional int32 num_spatial_bins_height = 2 [default = 3];
88 |
89 | optional int32 num_spatial_bins_width = 3 [default = 3];
90 |
91 | // Target depth to reduce the input image features to.
92 | optional int32 depth = 4 [default=1024];
93 |
94 | // Size of the encoding for the boxes.
95 | optional int32 box_code_size = 5 [default = 4];
96 |
97 | // Size to resize the rfcn crops to.
98 | optional int32 crop_height = 6 [default= 12];
99 |
100 | optional int32 crop_width = 7 [default=12];
101 | }
102 |
--------------------------------------------------------------------------------
/protos/keypoint_box_coder_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/keypoint_box_coder.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/keypoint_box_coder.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n0object_detection/protos/keypoint_box_coder.proto\x12\x17object_detection.protos\"\x84\x01\n\x10KeypointBoxCoder\x12\x15\n\rnum_keypoints\x18\x01 \x01(\x05\x12\x13\n\x07y_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x03 \x01(\x02:\x02\x31\x30\x12\x17\n\x0cheight_scale\x18\x04 \x01(\x02:\x01\x35\x12\x16\n\x0bwidth_scale\x18\x05 \x01(\x02:\x01\x35')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _KEYPOINTBOXCODER = _descriptor.Descriptor(
29 | name='KeypointBoxCoder',
30 | full_name='object_detection.protos.KeypointBoxCoder',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='num_keypoints', full_name='object_detection.protos.KeypointBoxCoder.num_keypoints', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=False, default_value=0,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='y_scale', full_name='object_detection.protos.KeypointBoxCoder.y_scale', index=1,
44 | number=2, type=2, cpp_type=6, label=1,
45 | has_default_value=True, default_value=10,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='x_scale', full_name='object_detection.protos.KeypointBoxCoder.x_scale', index=2,
51 | number=3, type=2, cpp_type=6, label=1,
52 | has_default_value=True, default_value=10,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='height_scale', full_name='object_detection.protos.KeypointBoxCoder.height_scale', index=3,
58 | number=4, type=2, cpp_type=6, label=1,
59 | has_default_value=True, default_value=5,
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | _descriptor.FieldDescriptor(
64 | name='width_scale', full_name='object_detection.protos.KeypointBoxCoder.width_scale', index=4,
65 | number=5, type=2, cpp_type=6, label=1,
66 | has_default_value=True, default_value=5,
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None),
70 | ],
71 | extensions=[
72 | ],
73 | nested_types=[],
74 | enum_types=[
75 | ],
76 | options=None,
77 | is_extendable=False,
78 | extension_ranges=[],
79 | oneofs=[
80 | ],
81 | serialized_start=78,
82 | serialized_end=210,
83 | )
84 |
85 | DESCRIPTOR.message_types_by_name['KeypointBoxCoder'] = _KEYPOINTBOXCODER
86 |
87 | KeypointBoxCoder = _reflection.GeneratedProtocolMessageType('KeypointBoxCoder', (_message.Message,), dict(
88 | DESCRIPTOR = _KEYPOINTBOXCODER,
89 | __module__ = 'object_detection.protos.keypoint_box_coder_pb2'
90 | # @@protoc_insertion_point(class_scope:object_detection.protos.KeypointBoxCoder)
91 | ))
92 | _sym_db.RegisterMessage(KeypointBoxCoder)
93 |
94 |
95 | # @@protoc_insertion_point(module_scope)
96 |
--------------------------------------------------------------------------------
/protos/argmax_matcher_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/argmax_matcher.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/argmax_matcher.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n,object_detection/protos/argmax_matcher.proto\x12\x17object_detection.protos\"\xca\x01\n\rArgMaxMatcher\x12\x1e\n\x11matched_threshold\x18\x01 \x01(\x02:\x03\x30.5\x12 \n\x13unmatched_threshold\x18\x02 \x01(\x02:\x03\x30.5\x12 \n\x11ignore_thresholds\x18\x03 \x01(\x08:\x05\x66\x61lse\x12,\n\x1enegatives_lower_than_unmatched\x18\x04 \x01(\x08:\x04true\x12\'\n\x18\x66orce_match_for_each_row\x18\x05 \x01(\x08:\x05\x66\x61lse')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _ARGMAXMATCHER = _descriptor.Descriptor(
29 | name='ArgMaxMatcher',
30 | full_name='object_detection.protos.ArgMaxMatcher',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='matched_threshold', full_name='object_detection.protos.ArgMaxMatcher.matched_threshold', index=0,
37 | number=1, type=2, cpp_type=6, label=1,
38 | has_default_value=True, default_value=0.5,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='unmatched_threshold', full_name='object_detection.protos.ArgMaxMatcher.unmatched_threshold', index=1,
44 | number=2, type=2, cpp_type=6, label=1,
45 | has_default_value=True, default_value=0.5,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='ignore_thresholds', full_name='object_detection.protos.ArgMaxMatcher.ignore_thresholds', index=2,
51 | number=3, type=8, cpp_type=7, label=1,
52 | has_default_value=True, default_value=False,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='negatives_lower_than_unmatched', full_name='object_detection.protos.ArgMaxMatcher.negatives_lower_than_unmatched', index=3,
58 | number=4, type=8, cpp_type=7, label=1,
59 | has_default_value=True, default_value=True,
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | _descriptor.FieldDescriptor(
64 | name='force_match_for_each_row', full_name='object_detection.protos.ArgMaxMatcher.force_match_for_each_row', index=4,
65 | number=5, type=8, cpp_type=7, label=1,
66 | has_default_value=True, default_value=False,
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None),
70 | ],
71 | extensions=[
72 | ],
73 | nested_types=[],
74 | enum_types=[
75 | ],
76 | options=None,
77 | is_extendable=False,
78 | extension_ranges=[],
79 | oneofs=[
80 | ],
81 | serialized_start=74,
82 | serialized_end=276,
83 | )
84 |
85 | DESCRIPTOR.message_types_by_name['ArgMaxMatcher'] = _ARGMAXMATCHER
86 |
87 | ArgMaxMatcher = _reflection.GeneratedProtocolMessageType('ArgMaxMatcher', (_message.Message,), dict(
88 | DESCRIPTOR = _ARGMAXMATCHER,
89 | __module__ = 'object_detection.protos.argmax_matcher_pb2'
90 | # @@protoc_insertion_point(class_scope:object_detection.protos.ArgMaxMatcher)
91 | ))
92 | _sym_db.RegisterMessage(ArgMaxMatcher)
93 |
94 |
95 | # @@protoc_insertion_point(module_scope)
96 |
--------------------------------------------------------------------------------
/protos/anchor_generator_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/anchor_generator.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | import object_detection.protos.grid_anchor_generator_pb2
17 | import object_detection.protos.ssd_anchor_generator_pb2
18 |
19 |
20 | DESCRIPTOR = _descriptor.FileDescriptor(
21 | name='object_detection/protos/anchor_generator.proto',
22 | package='object_detection.protos',
23 | serialized_pb=_b('\n.object_detection/protos/anchor_generator.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/grid_anchor_generator.proto\x1a\x32object_detection/protos/ssd_anchor_generator.proto\"\xc7\x01\n\x0f\x41nchorGenerator\x12M\n\x15grid_anchor_generator\x18\x01 \x01(\x0b\x32,.object_detection.protos.GridAnchorGeneratorH\x00\x12K\n\x14ssd_anchor_generator\x18\x02 \x01(\x0b\x32+.object_detection.protos.SsdAnchorGeneratorH\x00\x42\x18\n\x16\x61nchor_generator_oneof')
24 | ,
25 | dependencies=[object_detection.protos.grid_anchor_generator_pb2.DESCRIPTOR,object_detection.protos.ssd_anchor_generator_pb2.DESCRIPTOR,])
26 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
27 |
28 |
29 |
30 |
31 | _ANCHORGENERATOR = _descriptor.Descriptor(
32 | name='AnchorGenerator',
33 | full_name='object_detection.protos.AnchorGenerator',
34 | filename=None,
35 | file=DESCRIPTOR,
36 | containing_type=None,
37 | fields=[
38 | _descriptor.FieldDescriptor(
39 | name='grid_anchor_generator', full_name='object_detection.protos.AnchorGenerator.grid_anchor_generator', index=0,
40 | number=1, type=11, cpp_type=10, label=1,
41 | has_default_value=False, default_value=None,
42 | message_type=None, enum_type=None, containing_type=None,
43 | is_extension=False, extension_scope=None,
44 | options=None),
45 | _descriptor.FieldDescriptor(
46 | name='ssd_anchor_generator', full_name='object_detection.protos.AnchorGenerator.ssd_anchor_generator', index=1,
47 | number=2, type=11, cpp_type=10, label=1,
48 | has_default_value=False, default_value=None,
49 | message_type=None, enum_type=None, containing_type=None,
50 | is_extension=False, extension_scope=None,
51 | options=None),
52 | ],
53 | extensions=[
54 | ],
55 | nested_types=[],
56 | enum_types=[
57 | ],
58 | options=None,
59 | is_extendable=False,
60 | extension_ranges=[],
61 | oneofs=[
62 | _descriptor.OneofDescriptor(
63 | name='anchor_generator_oneof', full_name='object_detection.protos.AnchorGenerator.anchor_generator_oneof',
64 | index=0, containing_type=None, fields=[]),
65 | ],
66 | serialized_start=181,
67 | serialized_end=380,
68 | )
69 |
70 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].message_type = object_detection.protos.grid_anchor_generator_pb2._GRIDANCHORGENERATOR
71 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].message_type = object_detection.protos.ssd_anchor_generator_pb2._SSDANCHORGENERATOR
72 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append(
73 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'])
74 | _ANCHORGENERATOR.fields_by_name['grid_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof']
75 | _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof'].fields.append(
76 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'])
77 | _ANCHORGENERATOR.fields_by_name['ssd_anchor_generator'].containing_oneof = _ANCHORGENERATOR.oneofs_by_name['anchor_generator_oneof']
78 | DESCRIPTOR.message_types_by_name['AnchorGenerator'] = _ANCHORGENERATOR
79 |
80 | AnchorGenerator = _reflection.GeneratedProtocolMessageType('AnchorGenerator', (_message.Message,), dict(
81 | DESCRIPTOR = _ANCHORGENERATOR,
82 | __module__ = 'object_detection.protos.anchor_generator_pb2'
83 | # @@protoc_insertion_point(class_scope:object_detection.protos.AnchorGenerator)
84 | ))
85 | _sym_db.RegisterMessage(AnchorGenerator)
86 |
87 |
88 | # @@protoc_insertion_point(module_scope)
89 |
--------------------------------------------------------------------------------
/protos/string_int_label_map_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/string_int_label_map.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/string_int_label_map.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n2object_detection/protos/string_int_label_map.proto\x12\x17object_detection.protos\"G\n\x15StringIntLabelMapItem\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x12\x14\n\x0c\x64isplay_name\x18\x03 \x01(\t\"Q\n\x11StringIntLabelMap\x12<\n\x04item\x18\x01 \x03(\x0b\x32..object_detection.protos.StringIntLabelMapItem')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _STRINGINTLABELMAPITEM = _descriptor.Descriptor(
29 | name='StringIntLabelMapItem',
30 | full_name='object_detection.protos.StringIntLabelMapItem',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='name', full_name='object_detection.protos.StringIntLabelMapItem.name', index=0,
37 | number=1, type=9, cpp_type=9, label=1,
38 | has_default_value=False, default_value=_b("").decode('utf-8'),
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='id', full_name='object_detection.protos.StringIntLabelMapItem.id', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=False, default_value=0,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='display_name', full_name='object_detection.protos.StringIntLabelMapItem.display_name', index=2,
51 | number=3, type=9, cpp_type=9, label=1,
52 | has_default_value=False, default_value=_b("").decode('utf-8'),
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | ],
57 | extensions=[
58 | ],
59 | nested_types=[],
60 | enum_types=[
61 | ],
62 | options=None,
63 | is_extendable=False,
64 | extension_ranges=[],
65 | oneofs=[
66 | ],
67 | serialized_start=79,
68 | serialized_end=150,
69 | )
70 |
71 |
72 | _STRINGINTLABELMAP = _descriptor.Descriptor(
73 | name='StringIntLabelMap',
74 | full_name='object_detection.protos.StringIntLabelMap',
75 | filename=None,
76 | file=DESCRIPTOR,
77 | containing_type=None,
78 | fields=[
79 | _descriptor.FieldDescriptor(
80 | name='item', full_name='object_detection.protos.StringIntLabelMap.item', index=0,
81 | number=1, type=11, cpp_type=10, label=3,
82 | has_default_value=False, default_value=[],
83 | message_type=None, enum_type=None, containing_type=None,
84 | is_extension=False, extension_scope=None,
85 | options=None),
86 | ],
87 | extensions=[
88 | ],
89 | nested_types=[],
90 | enum_types=[
91 | ],
92 | options=None,
93 | is_extendable=False,
94 | extension_ranges=[],
95 | oneofs=[
96 | ],
97 | serialized_start=152,
98 | serialized_end=233,
99 | )
100 |
101 | _STRINGINTLABELMAP.fields_by_name['item'].message_type = _STRINGINTLABELMAPITEM
102 | DESCRIPTOR.message_types_by_name['StringIntLabelMapItem'] = _STRINGINTLABELMAPITEM
103 | DESCRIPTOR.message_types_by_name['StringIntLabelMap'] = _STRINGINTLABELMAP
104 |
105 | StringIntLabelMapItem = _reflection.GeneratedProtocolMessageType('StringIntLabelMapItem', (_message.Message,), dict(
106 | DESCRIPTOR = _STRINGINTLABELMAPITEM,
107 | __module__ = 'object_detection.protos.string_int_label_map_pb2'
108 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMapItem)
109 | ))
110 | _sym_db.RegisterMessage(StringIntLabelMapItem)
111 |
112 | StringIntLabelMap = _reflection.GeneratedProtocolMessageType('StringIntLabelMap', (_message.Message,), dict(
113 | DESCRIPTOR = _STRINGINTLABELMAP,
114 | __module__ = 'object_detection.protos.string_int_label_map_pb2'
115 | # @@protoc_insertion_point(class_scope:object_detection.protos.StringIntLabelMap)
116 | ))
117 | _sym_db.RegisterMessage(StringIntLabelMap)
118 |
119 |
120 | # @@protoc_insertion_point(module_scope)
121 |
--------------------------------------------------------------------------------
/utils/shape_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utils used to manipulate tensor shapes."""
17 |
18 | import tensorflow as tf
19 |
20 |
21 | def _is_tensor(t):
22 | """Returns a boolean indicating whether the input is a tensor.
23 |
24 | Args:
25 | t: the input to be tested.
26 |
27 | Returns:
28 | a boolean that indicates whether t is a tensor.
29 | """
30 | return isinstance(t, (tf.Tensor, tf.SparseTensor, tf.Variable))
31 |
32 |
33 | def _set_dim_0(t, d0):
34 | """Sets the 0-th dimension of the input tensor.
35 |
36 | Args:
37 | t: the input tensor, assuming the rank is at least 1.
38 | d0: an integer indicating the 0-th dimension of the input tensor.
39 |
40 | Returns:
41 | the tensor t with the 0-th dimension set.
42 | """
43 | t_shape = t.get_shape().as_list()
44 | t_shape[0] = d0
45 | t.set_shape(t_shape)
46 | return t
47 |
48 |
49 | def pad_tensor(t, length):
50 | """Pads the input tensor with 0s along the first dimension up to the length.
51 |
52 | Args:
53 | t: the input tensor, assuming the rank is at least 1.
54 | length: a tensor of shape [1] or an integer, indicating the first dimension
55 | of the input tensor t after padding, assuming length <= t.shape[0].
56 |
57 | Returns:
58 | padded_t: the padded tensor, whose first dimension is length. If the length
59 | is an integer, the first dimension of padded_t is set to length
60 | statically.
61 | """
62 | t_rank = tf.rank(t)
63 | t_shape = tf.shape(t)
64 | t_d0 = t_shape[0]
65 | pad_d0 = tf.expand_dims(length - t_d0, 0)
66 | pad_shape = tf.cond(
67 | tf.greater(t_rank, 1), lambda: tf.concat([pad_d0, t_shape[1:]], 0),
68 | lambda: tf.expand_dims(length - t_d0, 0))
69 | padded_t = tf.concat([t, tf.zeros(pad_shape, dtype=t.dtype)], 0)
70 | if not _is_tensor(length):
71 | padded_t = _set_dim_0(padded_t, length)
72 | return padded_t
73 |
74 |
75 | def clip_tensor(t, length):
76 | """Clips the input tensor along the first dimension up to the length.
77 |
78 | Args:
79 | t: the input tensor, assuming the rank is at least 1.
80 | length: a tensor of shape [1] or an integer, indicating the first dimension
81 | of the input tensor t after clipping, assuming length <= t.shape[0].
82 |
83 | Returns:
84 | clipped_t: the clipped tensor, whose first dimension is length. If the
85 | length is an integer, the first dimension of clipped_t is set to length
86 | statically.
87 | """
88 | clipped_t = tf.gather(t, tf.range(length))
89 | if not _is_tensor(length):
90 | clipped_t = _set_dim_0(clipped_t, length)
91 | return clipped_t
92 |
93 |
94 | def pad_or_clip_tensor(t, length):
95 | """Pad or clip the input tensor along the first dimension.
96 |
97 | Args:
98 | t: the input tensor, assuming the rank is at least 1.
99 | length: a tensor of shape [1] or an integer, indicating the first dimension
100 | of the input tensor t after processing.
101 |
102 | Returns:
103 | processed_t: the processed tensor, whose first dimension is length. If the
104 | length is an integer, the first dimension of the processed tensor is set
105 | to length statically.
106 | """
107 | processed_t = tf.cond(
108 | tf.greater(tf.shape(t)[0], length),
109 | lambda: clip_tensor(t, length),
110 | lambda: pad_tensor(t, length))
111 | if not _is_tensor(length):
112 | processed_t = _set_dim_0(processed_t, length)
113 | return processed_t
114 |
115 |
116 | def combined_static_and_dynamic_shape(tensor):
117 | """Returns a list containing static and dynamic values for the dimensions.
118 |
119 | Returns a list of static and dynamic values for shape dimensions. This is
120 | useful to preserve static shapes when available in reshape operation.
121 |
122 | Args:
123 | tensor: A tensor of any type.
124 |
125 | Returns:
126 | A list of size tensor.shape.ndims containing integers or a scalar tensor.
127 | """
128 | static_shape = tensor.shape.as_list()
129 | dynamic_shape = tf.shape(tensor)
130 | combined_shape = []
131 | for index, dim in enumerate(static_shape):
132 | if dim is not None:
133 | combined_shape.append(dim)
134 | else:
135 | combined_shape.append(dynamic_shape[index])
136 | return combined_shape
137 |
--------------------------------------------------------------------------------
/utils/np_box_list.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Numpy BoxList classes and functions."""
17 |
18 | import numpy as np
19 |
20 |
21 | class BoxList(object):
22 | """Box collection.
23 |
24 | BoxList represents a list of bounding boxes as numpy array, where each
25 | bounding box is represented as a row of 4 numbers,
26 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a
27 | given list correspond to a single image.
28 |
29 | Optionally, users can add additional related fields (such as
30 | objectness/classification scores).
31 | """
32 |
33 | def __init__(self, data):
34 | """Constructs box collection.
35 |
36 | Args:
37 | data: a numpy array of shape [N, 4] representing box coordinates
38 |
39 | Raises:
40 | ValueError: if bbox data is not a numpy array
41 | ValueError: if invalid dimensions for bbox data
42 | """
43 | if not isinstance(data, np.ndarray):
44 | raise ValueError('data must be a numpy array.')
45 | if len(data.shape) != 2 or data.shape[1] != 4:
46 | raise ValueError('Invalid dimensions for box data.')
47 | if data.dtype != np.float32 and data.dtype != np.float64:
48 | raise ValueError('Invalid data type for box data: float is required.')
49 | if not self._is_valid_boxes(data):
50 | raise ValueError('Invalid box data. data must be a numpy array of '
51 | 'N*[y_min, x_min, y_max, x_max]')
52 | self.data = {'boxes': data}
53 |
54 | def num_boxes(self):
55 | """Return number of boxes held in collections."""
56 | return self.data['boxes'].shape[0]
57 |
58 | def get_extra_fields(self):
59 | """Return all non-box fields."""
60 | return [k for k in self.data.keys() if k != 'boxes']
61 |
62 | def has_field(self, field):
63 | return field in self.data
64 |
65 | def add_field(self, field, field_data):
66 | """Add data to a specified field.
67 |
68 | Args:
69 | field: a string parameter used to speficy a related field to be accessed.
70 | field_data: a numpy array of [N, ...] representing the data associated
71 | with the field.
72 | Raises:
73 | ValueError: if the field is already exist or the dimension of the field
74 | data does not matches the number of boxes.
75 | """
76 | if self.has_field(field):
77 | raise ValueError('Field ' + field + 'already exists')
78 | if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes():
79 | raise ValueError('Invalid dimensions for field data')
80 | self.data[field] = field_data
81 |
82 | def get(self):
83 | """Convenience function for accesssing box coordinates.
84 |
85 | Returns:
86 | a numpy array of shape [N, 4] representing box corners
87 | """
88 | return self.get_field('boxes')
89 |
90 | def get_field(self, field):
91 | """Accesses data associated with the specified field in the box collection.
92 |
93 | Args:
94 | field: a string parameter used to speficy a related field to be accessed.
95 |
96 | Returns:
97 | a numpy 1-d array representing data of an associated field
98 |
99 | Raises:
100 | ValueError: if invalid field
101 | """
102 | if not self.has_field(field):
103 | raise ValueError('field {} does not exist'.format(field))
104 | return self.data[field]
105 |
106 | def get_coordinates(self):
107 | """Get corner coordinates of boxes.
108 |
109 | Returns:
110 | a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max]
111 | """
112 | box_coordinates = self.get()
113 | y_min = box_coordinates[:, 0]
114 | x_min = box_coordinates[:, 1]
115 | y_max = box_coordinates[:, 2]
116 | x_max = box_coordinates[:, 3]
117 | return [y_min, x_min, y_max, x_max]
118 |
119 | def _is_valid_boxes(self, data):
120 | """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin].
121 |
122 | Args:
123 | data: a numpy array of shape [N, 4] representing box coordinates
124 |
125 | Returns:
126 | a boolean indicating whether all ymax of boxes are equal or greater than
127 | ymin, and all xmax of boxes are equal or greater than xmin.
128 | """
129 | if data.shape[0] > 0:
130 | for i in range(data.shape[0]):
131 | if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]:
132 | return False
133 | return True
134 |
--------------------------------------------------------------------------------
/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Contains functions which are convenient for unit testing."""
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 | from object_detection.core import anchor_generator
21 | from object_detection.core import box_coder
22 | from object_detection.core import box_list
23 | from object_detection.core import box_predictor
24 | from object_detection.core import matcher
25 | from object_detection.utils import shape_utils
26 |
27 |
28 | class MockBoxCoder(box_coder.BoxCoder):
29 | """Simple `difference` BoxCoder."""
30 |
31 | @property
32 | def code_size(self):
33 | return 4
34 |
35 | def _encode(self, boxes, anchors):
36 | return boxes.get() - anchors.get()
37 |
38 | def _decode(self, rel_codes, anchors):
39 | return box_list.BoxList(rel_codes + anchors.get())
40 |
41 |
42 | class MockBoxPredictor(box_predictor.BoxPredictor):
43 | """Simple box predictor that ignores inputs and outputs all zeros."""
44 |
45 | def __init__(self, is_training, num_classes):
46 | super(MockBoxPredictor, self).__init__(is_training, num_classes)
47 |
48 | def _predict(self, image_features, num_predictions_per_location):
49 | combined_feature_shape = shape_utils.combined_static_and_dynamic_shape(
50 | image_features)
51 | batch_size = combined_feature_shape[0]
52 | num_anchors = (combined_feature_shape[1] * combined_feature_shape[2])
53 | code_size = 4
54 | zero = tf.reduce_sum(0 * image_features)
55 | box_encodings = zero + tf.zeros(
56 | (batch_size, num_anchors, 1, code_size), dtype=tf.float32)
57 | class_predictions_with_background = zero + tf.zeros(
58 | (batch_size, num_anchors, self.num_classes + 1), dtype=tf.float32)
59 | return {box_predictor.BOX_ENCODINGS: box_encodings,
60 | box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND:
61 | class_predictions_with_background}
62 |
63 |
64 | class MockAnchorGenerator(anchor_generator.AnchorGenerator):
65 | """Mock anchor generator."""
66 |
67 | def name_scope(self):
68 | return 'MockAnchorGenerator'
69 |
70 | def num_anchors_per_location(self):
71 | return [1]
72 |
73 | def _generate(self, feature_map_shape_list):
74 | num_anchors = sum([shape[0] * shape[1] for shape in feature_map_shape_list])
75 | return box_list.BoxList(tf.zeros((num_anchors, 4), dtype=tf.float32))
76 |
77 |
78 | class MockMatcher(matcher.Matcher):
79 | """Simple matcher that matches first anchor to first groundtruth box."""
80 |
81 | def _match(self, similarity_matrix):
82 | return tf.constant([0, -1, -1, -1], dtype=tf.int32)
83 |
84 |
85 | def create_diagonal_gradient_image(height, width, depth):
86 | """Creates pyramid image. Useful for testing.
87 |
88 | For example, pyramid_image(5, 6, 1) looks like:
89 | # [[[ 5. 4. 3. 2. 1. 0.]
90 | # [ 6. 5. 4. 3. 2. 1.]
91 | # [ 7. 6. 5. 4. 3. 2.]
92 | # [ 8. 7. 6. 5. 4. 3.]
93 | # [ 9. 8. 7. 6. 5. 4.]]]
94 |
95 | Args:
96 | height: height of image
97 | width: width of image
98 | depth: depth of image
99 |
100 | Returns:
101 | pyramid image
102 | """
103 | row = np.arange(height)
104 | col = np.arange(width)[::-1]
105 | image_layer = np.expand_dims(row, 1) + col
106 | image_layer = np.expand_dims(image_layer, 2)
107 |
108 | image = image_layer
109 | for i in range(1, depth):
110 | image = np.concatenate((image, image_layer * pow(10, i)), 2)
111 |
112 | return image.astype(np.float32)
113 |
114 |
115 | def create_random_boxes(num_boxes, max_height, max_width):
116 | """Creates random bounding boxes of specific maximum height and width.
117 |
118 | Args:
119 | num_boxes: number of boxes.
120 | max_height: maximum height of boxes.
121 | max_width: maximum width of boxes.
122 |
123 | Returns:
124 | boxes: numpy array of shape [num_boxes, 4]. Each row is in form
125 | [y_min, x_min, y_max, x_max].
126 | """
127 |
128 | y_1 = np.random.uniform(size=(1, num_boxes)) * max_height
129 | y_2 = np.random.uniform(size=(1, num_boxes)) * max_height
130 | x_1 = np.random.uniform(size=(1, num_boxes)) * max_width
131 | x_2 = np.random.uniform(size=(1, num_boxes)) * max_width
132 |
133 | boxes = np.zeros(shape=(num_boxes, 4))
134 | boxes[:, 0] = np.minimum(y_1, y_2)
135 | boxes[:, 1] = np.minimum(x_1, x_2)
136 | boxes[:, 2] = np.maximum(y_1, y_2)
137 | boxes[:, 3] = np.maximum(x_1, x_2)
138 |
139 | return boxes.astype(np.float32)
140 |
--------------------------------------------------------------------------------
/protos/losses.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | // Message for configuring the localization loss, classification loss and hard
6 | // example miner used for training object detection models. See core/losses.py
7 | // for details
8 | message Loss {
9 | // Localization loss to use.
10 | optional LocalizationLoss localization_loss = 1;
11 |
12 | // Classification loss to use.
13 | optional ClassificationLoss classification_loss = 2;
14 |
15 | // If not left to default, applies hard example mining.
16 | optional HardExampleMiner hard_example_miner = 3;
17 |
18 | // Classification loss weight.
19 | optional float classification_weight = 4 [default=1.0];
20 |
21 | // Localization loss weight.
22 | optional float localization_weight = 5 [default=1.0];
23 | }
24 |
25 | // Configuration for bounding box localization loss function.
26 | message LocalizationLoss {
27 | oneof localization_loss {
28 | WeightedL2LocalizationLoss weighted_l2 = 1;
29 | WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2;
30 | WeightedIOULocalizationLoss weighted_iou = 3;
31 | }
32 | }
33 |
34 | // L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2
35 | message WeightedL2LocalizationLoss {
36 | // Output loss per anchor.
37 | optional bool anchorwise_output = 1 [default=false];
38 | }
39 |
40 | // SmoothL1 (Huber) location loss: .5 * x ^ 2 if |x| < 1 else |x| - .5
41 | message WeightedSmoothL1LocalizationLoss {
42 | // Output loss per anchor.
43 | optional bool anchorwise_output = 1 [default=false];
44 | }
45 |
46 | // Intersection over union location loss: 1 - IOU
47 | message WeightedIOULocalizationLoss {
48 | }
49 |
50 | // Configuration for class prediction loss function.
51 | message ClassificationLoss {
52 | oneof classification_loss {
53 | WeightedSigmoidClassificationLoss weighted_sigmoid = 1;
54 | WeightedSoftmaxClassificationLoss weighted_softmax = 2;
55 | BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
56 | SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
57 | }
58 | }
59 |
60 | // Classification loss using a sigmoid function over class predictions.
61 | message WeightedSigmoidClassificationLoss {
62 | // Output loss per anchor.
63 | optional bool anchorwise_output = 1 [default=false];
64 | }
65 |
66 | // Sigmoid Focal cross entropy loss as described in
67 | // https://arxiv.org/abs/1708.02002
68 | message SigmoidFocalClassificationLoss {
69 | optional bool anchorwise_output = 1 [default = false];
70 | // modulating factor for the loss.
71 | optional float gamma = 2 [default = 2.0];
72 | // alpha weighting factor for the loss.
73 | optional float alpha = 3;
74 | }
75 |
76 | // Classification loss using a softmax function over class predictions.
77 | message WeightedSoftmaxClassificationLoss {
78 | // Output loss per anchor.
79 | optional bool anchorwise_output = 1 [default=false];
80 | // Scale logit (input) value before calculating softmax classification loss.
81 | // Typically used for softmax distillation.
82 | optional float logit_scale = 2 [default = 1.0];
83 | }
84 |
85 | // Classification loss using a sigmoid function over the class prediction with
86 | // the highest prediction score.
87 | message BootstrappedSigmoidClassificationLoss {
88 | // Interpolation weight between 0 and 1.
89 | optional float alpha = 1;
90 |
91 | // Whether hard boot strapping should be used or not. If true, will only use
92 | // one class favored by model. Othewise, will use all predicted class
93 | // probabilities.
94 | optional bool hard_bootstrap = 2 [default=false];
95 |
96 | // Output loss per anchor.
97 | optional bool anchorwise_output = 3 [default=false];
98 | }
99 |
100 | // Configuation for hard example miner.
101 | message HardExampleMiner {
102 | // Maximum number of hard examples to be selected per image (prior to
103 | // enforcing max negative to positive ratio constraint). If set to 0,
104 | // all examples obtained after NMS are considered.
105 | optional int32 num_hard_examples = 1 [default=64];
106 |
107 | // Minimum intersection over union for an example to be discarded during NMS.
108 | optional float iou_threshold = 2 [default=0.7];
109 |
110 | // Whether to use classification losses ('cls', default), localization losses
111 | // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and
112 | // loc_loss_weight are used to compute weighted sum of the two losses.
113 | enum LossType {
114 | BOTH = 0;
115 | CLASSIFICATION = 1;
116 | LOCALIZATION = 2;
117 | }
118 | optional LossType loss_type = 3 [default=BOTH];
119 |
120 | // Maximum number of negatives to retain for each positive anchor. If
121 | // num_negatives_per_positive is 0 no prespecified negative:positive ratio is
122 | // enforced.
123 | optional int32 max_negatives_per_positive = 4 [default=0];
124 |
125 | // Minimum number of negative anchors to sample for a given image. Setting
126 | // this to a positive number samples negatives in an image without any
127 | // positive anchors and thus not bias the model towards having at least one
128 | // detection per image.
129 | optional int32 min_negatives_per_image = 5 [default=0];
130 | }
131 |
--------------------------------------------------------------------------------
/utils/variables_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Helper functions for manipulating collections of variables during training.
17 | """
18 | import logging
19 | import re
20 |
21 | import tensorflow as tf
22 |
23 | slim = tf.contrib.slim
24 |
25 |
26 | # TODO: Consider replacing with tf.contrib.filter_variables in
27 | # tensorflow/contrib/framework/python/ops/variables.py
28 | def filter_variables(variables, filter_regex_list, invert=False):
29 | """Filters out the variables matching the filter_regex.
30 |
31 | Filter out the variables whose name matches the any of the regular
32 | expressions in filter_regex_list and returns the remaining variables.
33 | Optionally, if invert=True, the complement set is returned.
34 |
35 | Args:
36 | variables: a list of tensorflow variables.
37 | filter_regex_list: a list of string regular expressions.
38 | invert: (boolean). If True, returns the complement of the filter set; that
39 | is, all variables matching filter_regex are kept and all others discarded.
40 |
41 | Returns:
42 | a list of filtered variables.
43 | """
44 | kept_vars = []
45 | variables_to_ignore_patterns = filter(None, filter_regex_list)
46 | for var in variables:
47 | add = True
48 | for pattern in variables_to_ignore_patterns:
49 | if re.match(pattern, var.op.name):
50 | add = False
51 | break
52 | if add != invert:
53 | kept_vars.append(var)
54 | return kept_vars
55 |
56 |
57 | def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
58 | """Multiply gradients whose variable names match a regular expression.
59 |
60 | Args:
61 | grads_and_vars: A list of gradient to variable pairs (tuples).
62 | regex_list: A list of string regular expressions.
63 | multiplier: A (float) multiplier to apply to each gradient matching the
64 | regular expression.
65 |
66 | Returns:
67 | grads_and_vars: A list of gradient to variable pairs (tuples).
68 | """
69 | variables = [pair[1] for pair in grads_and_vars]
70 | matching_vars = filter_variables(variables, regex_list, invert=True)
71 | for var in matching_vars:
72 | logging.info('Applying multiplier %f to variable [%s]',
73 | multiplier, var.op.name)
74 | grad_multipliers = {var: float(multiplier) for var in matching_vars}
75 | return slim.learning.multiply_gradients(grads_and_vars,
76 | grad_multipliers)
77 |
78 |
79 | def freeze_gradients_matching_regex(grads_and_vars, regex_list):
80 | """Freeze gradients whose variable names match a regular expression.
81 |
82 | Args:
83 | grads_and_vars: A list of gradient to variable pairs (tuples).
84 | regex_list: A list of string regular expressions.
85 |
86 | Returns:
87 | grads_and_vars: A list of gradient to variable pairs (tuples) that do not
88 | contain the variables and gradients matching the regex.
89 | """
90 | variables = [pair[1] for pair in grads_and_vars]
91 | matching_vars = filter_variables(variables, regex_list, invert=True)
92 | kept_grads_and_vars = [pair for pair in grads_and_vars
93 | if pair[1] not in matching_vars]
94 | for var in matching_vars:
95 | logging.info('Freezing variable [%s]', var.op.name)
96 | return kept_grads_and_vars
97 |
98 |
99 | def get_variables_available_in_checkpoint(variables, checkpoint_path):
100 | """Returns the subset of variables available in the checkpoint.
101 |
102 | Inspects given checkpoint and returns the subset of variables that are
103 | available in it.
104 |
105 | TODO: force input and output to be a dictionary.
106 |
107 | Args:
108 | variables: a list or dictionary of variables to find in checkpoint.
109 | checkpoint_path: path to the checkpoint to restore variables from.
110 |
111 | Returns:
112 | A list or dictionary of variables.
113 | Raises:
114 | ValueError: if `variables` is not a list or dict.
115 | """
116 | if isinstance(variables, list):
117 | variable_names_map = {variable.op.name: variable for variable in variables}
118 | elif isinstance(variables, dict):
119 | variable_names_map = variables
120 | else:
121 | raise ValueError('`variables` is expected to be a list or dict.')
122 | ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
123 | ckpt_vars = ckpt_reader.get_variable_to_shape_map().keys()
124 | vars_in_ckpt = {}
125 | for variable_name, variable in sorted(variable_names_map.items()):
126 | if variable_name in ckpt_vars:
127 | vars_in_ckpt[variable_name] = variable
128 | else:
129 | logging.warning('Variable [%s] not available in checkpoint',
130 | variable_name)
131 | if isinstance(variables, list):
132 | return vars_in_ckpt.values()
133 | return vars_in_ckpt
134 |
--------------------------------------------------------------------------------
/protos/grid_anchor_generator_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/grid_anchor_generator.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/grid_anchor_generator.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n3object_detection/protos/grid_anchor_generator.proto\x12\x17object_detection.protos\"\xcd\x01\n\x13GridAnchorGenerator\x12\x13\n\x06height\x18\x01 \x01(\x05:\x03\x32\x35\x36\x12\x12\n\x05width\x18\x02 \x01(\x05:\x03\x32\x35\x36\x12\x19\n\rheight_stride\x18\x03 \x01(\x05:\x02\x31\x36\x12\x18\n\x0cwidth_stride\x18\x04 \x01(\x05:\x02\x31\x36\x12\x18\n\rheight_offset\x18\x05 \x01(\x05:\x01\x30\x12\x17\n\x0cwidth_offset\x18\x06 \x01(\x05:\x01\x30\x12\x0e\n\x06scales\x18\x07 \x03(\x02\x12\x15\n\raspect_ratios\x18\x08 \x03(\x02')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _GRIDANCHORGENERATOR = _descriptor.Descriptor(
29 | name='GridAnchorGenerator',
30 | full_name='object_detection.protos.GridAnchorGenerator',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='height', full_name='object_detection.protos.GridAnchorGenerator.height', index=0,
37 | number=1, type=5, cpp_type=1, label=1,
38 | has_default_value=True, default_value=256,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='width', full_name='object_detection.protos.GridAnchorGenerator.width', index=1,
44 | number=2, type=5, cpp_type=1, label=1,
45 | has_default_value=True, default_value=256,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='height_stride', full_name='object_detection.protos.GridAnchorGenerator.height_stride', index=2,
51 | number=3, type=5, cpp_type=1, label=1,
52 | has_default_value=True, default_value=16,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='width_stride', full_name='object_detection.protos.GridAnchorGenerator.width_stride', index=3,
58 | number=4, type=5, cpp_type=1, label=1,
59 | has_default_value=True, default_value=16,
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | _descriptor.FieldDescriptor(
64 | name='height_offset', full_name='object_detection.protos.GridAnchorGenerator.height_offset', index=4,
65 | number=5, type=5, cpp_type=1, label=1,
66 | has_default_value=True, default_value=0,
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None),
70 | _descriptor.FieldDescriptor(
71 | name='width_offset', full_name='object_detection.protos.GridAnchorGenerator.width_offset', index=5,
72 | number=6, type=5, cpp_type=1, label=1,
73 | has_default_value=True, default_value=0,
74 | message_type=None, enum_type=None, containing_type=None,
75 | is_extension=False, extension_scope=None,
76 | options=None),
77 | _descriptor.FieldDescriptor(
78 | name='scales', full_name='object_detection.protos.GridAnchorGenerator.scales', index=6,
79 | number=7, type=2, cpp_type=6, label=3,
80 | has_default_value=False, default_value=[],
81 | message_type=None, enum_type=None, containing_type=None,
82 | is_extension=False, extension_scope=None,
83 | options=None),
84 | _descriptor.FieldDescriptor(
85 | name='aspect_ratios', full_name='object_detection.protos.GridAnchorGenerator.aspect_ratios', index=7,
86 | number=8, type=2, cpp_type=6, label=3,
87 | has_default_value=False, default_value=[],
88 | message_type=None, enum_type=None, containing_type=None,
89 | is_extension=False, extension_scope=None,
90 | options=None),
91 | ],
92 | extensions=[
93 | ],
94 | nested_types=[],
95 | enum_types=[
96 | ],
97 | options=None,
98 | is_extendable=False,
99 | extension_ranges=[],
100 | oneofs=[
101 | ],
102 | serialized_start=81,
103 | serialized_end=286,
104 | )
105 |
106 | DESCRIPTOR.message_types_by_name['GridAnchorGenerator'] = _GRIDANCHORGENERATOR
107 |
108 | GridAnchorGenerator = _reflection.GeneratedProtocolMessageType('GridAnchorGenerator', (_message.Message,), dict(
109 | DESCRIPTOR = _GRIDANCHORGENERATOR,
110 | __module__ = 'object_detection.protos.grid_anchor_generator_pb2'
111 | # @@protoc_insertion_point(class_scope:object_detection.protos.GridAnchorGenerator)
112 | ))
113 | _sym_db.RegisterMessage(GridAnchorGenerator)
114 |
115 |
116 | # @@protoc_insertion_point(module_scope)
117 |
--------------------------------------------------------------------------------
/protos/pipeline_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/pipeline.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | import object_detection.protos.eval_pb2
17 | import object_detection.protos.input_reader_pb2
18 | import object_detection.protos.model_pb2
19 | import object_detection.protos.train_pb2
20 |
21 |
22 | DESCRIPTOR = _descriptor.FileDescriptor(
23 | name='object_detection/protos/pipeline.proto',
24 | package='object_detection.protos',
25 | serialized_pb=_b('\n&object_detection/protos/pipeline.proto\x12\x17object_detection.protos\x1a\"object_detection/protos/eval.proto\x1a*object_detection/protos/input_reader.proto\x1a#object_detection/protos/model.proto\x1a#object_detection/protos/train.proto\"\xca\x02\n\x17TrainEvalPipelineConfig\x12\x36\n\x05model\x18\x01 \x01(\x0b\x32\'.object_detection.protos.DetectionModel\x12:\n\x0ctrain_config\x18\x02 \x01(\x0b\x32$.object_detection.protos.TrainConfig\x12@\n\x12train_input_reader\x18\x03 \x01(\x0b\x32$.object_detection.protos.InputReader\x12\x38\n\x0b\x65val_config\x18\x04 \x01(\x0b\x32#.object_detection.protos.EvalConfig\x12?\n\x11\x65val_input_reader\x18\x05 \x01(\x0b\x32$.object_detection.protos.InputReader')
26 | ,
27 | dependencies=[object_detection.protos.eval_pb2.DESCRIPTOR,object_detection.protos.input_reader_pb2.DESCRIPTOR,object_detection.protos.model_pb2.DESCRIPTOR,object_detection.protos.train_pb2.DESCRIPTOR,])
28 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
29 |
30 |
31 |
32 |
33 | _TRAINEVALPIPELINECONFIG = _descriptor.Descriptor(
34 | name='TrainEvalPipelineConfig',
35 | full_name='object_detection.protos.TrainEvalPipelineConfig',
36 | filename=None,
37 | file=DESCRIPTOR,
38 | containing_type=None,
39 | fields=[
40 | _descriptor.FieldDescriptor(
41 | name='model', full_name='object_detection.protos.TrainEvalPipelineConfig.model', index=0,
42 | number=1, type=11, cpp_type=10, label=1,
43 | has_default_value=False, default_value=None,
44 | message_type=None, enum_type=None, containing_type=None,
45 | is_extension=False, extension_scope=None,
46 | options=None),
47 | _descriptor.FieldDescriptor(
48 | name='train_config', full_name='object_detection.protos.TrainEvalPipelineConfig.train_config', index=1,
49 | number=2, type=11, cpp_type=10, label=1,
50 | has_default_value=False, default_value=None,
51 | message_type=None, enum_type=None, containing_type=None,
52 | is_extension=False, extension_scope=None,
53 | options=None),
54 | _descriptor.FieldDescriptor(
55 | name='train_input_reader', full_name='object_detection.protos.TrainEvalPipelineConfig.train_input_reader', index=2,
56 | number=3, type=11, cpp_type=10, label=1,
57 | has_default_value=False, default_value=None,
58 | message_type=None, enum_type=None, containing_type=None,
59 | is_extension=False, extension_scope=None,
60 | options=None),
61 | _descriptor.FieldDescriptor(
62 | name='eval_config', full_name='object_detection.protos.TrainEvalPipelineConfig.eval_config', index=3,
63 | number=4, type=11, cpp_type=10, label=1,
64 | has_default_value=False, default_value=None,
65 | message_type=None, enum_type=None, containing_type=None,
66 | is_extension=False, extension_scope=None,
67 | options=None),
68 | _descriptor.FieldDescriptor(
69 | name='eval_input_reader', full_name='object_detection.protos.TrainEvalPipelineConfig.eval_input_reader', index=4,
70 | number=5, type=11, cpp_type=10, label=1,
71 | has_default_value=False, default_value=None,
72 | message_type=None, enum_type=None, containing_type=None,
73 | is_extension=False, extension_scope=None,
74 | options=None),
75 | ],
76 | extensions=[
77 | ],
78 | nested_types=[],
79 | enum_types=[
80 | ],
81 | options=None,
82 | is_extendable=False,
83 | extension_ranges=[],
84 | oneofs=[
85 | ],
86 | serialized_start=222,
87 | serialized_end=552,
88 | )
89 |
90 | _TRAINEVALPIPELINECONFIG.fields_by_name['model'].message_type = object_detection.protos.model_pb2._DETECTIONMODEL
91 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_config'].message_type = object_detection.protos.train_pb2._TRAINCONFIG
92 | _TRAINEVALPIPELINECONFIG.fields_by_name['train_input_reader'].message_type = object_detection.protos.input_reader_pb2._INPUTREADER
93 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_config'].message_type = object_detection.protos.eval_pb2._EVALCONFIG
94 | _TRAINEVALPIPELINECONFIG.fields_by_name['eval_input_reader'].message_type = object_detection.protos.input_reader_pb2._INPUTREADER
95 | DESCRIPTOR.message_types_by_name['TrainEvalPipelineConfig'] = _TRAINEVALPIPELINECONFIG
96 |
97 | TrainEvalPipelineConfig = _reflection.GeneratedProtocolMessageType('TrainEvalPipelineConfig', (_message.Message,), dict(
98 | DESCRIPTOR = _TRAINEVALPIPELINECONFIG,
99 | __module__ = 'object_detection.protos.pipeline_pb2'
100 | # @@protoc_insertion_point(class_scope:object_detection.protos.TrainEvalPipelineConfig)
101 | ))
102 | _sym_db.RegisterMessage(TrainEvalPipelineConfig)
103 |
104 |
105 | # @@protoc_insertion_point(module_scope)
106 |
--------------------------------------------------------------------------------
/utils/shape_utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.shape_utils."""
17 |
18 | import tensorflow as tf
19 |
20 | from object_detection.utils import shape_utils
21 |
22 |
23 | class UtilTest(tf.test.TestCase):
24 |
25 | def test_pad_tensor_using_integer_input(self):
26 | t1 = tf.constant([1], dtype=tf.int32)
27 | pad_t1 = shape_utils.pad_tensor(t1, 2)
28 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32)
29 | pad_t2 = shape_utils.pad_tensor(t2, 2)
30 |
31 | self.assertEqual(2, pad_t1.get_shape()[0])
32 | self.assertEqual(2, pad_t2.get_shape()[0])
33 |
34 | with self.test_session() as sess:
35 | pad_t1_result, pad_t2_result = sess.run([pad_t1, pad_t2])
36 | self.assertAllEqual([1, 0], pad_t1_result)
37 | self.assertAllClose([[0.1, 0.2], [0, 0]], pad_t2_result)
38 |
39 | def test_pad_tensor_using_tensor_input(self):
40 | t1 = tf.constant([1], dtype=tf.int32)
41 | pad_t1 = shape_utils.pad_tensor(t1, tf.constant(2))
42 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32)
43 | pad_t2 = shape_utils.pad_tensor(t2, tf.constant(2))
44 |
45 | with self.test_session() as sess:
46 | pad_t1_result, pad_t2_result = sess.run([pad_t1, pad_t2])
47 | self.assertAllEqual([1, 0], pad_t1_result)
48 | self.assertAllClose([[0.1, 0.2], [0, 0]], pad_t2_result)
49 |
50 | def test_clip_tensor_using_integer_input(self):
51 | t1 = tf.constant([1, 2, 3], dtype=tf.int32)
52 | clip_t1 = shape_utils.clip_tensor(t1, 2)
53 | t2 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32)
54 | clip_t2 = shape_utils.clip_tensor(t2, 2)
55 |
56 | self.assertEqual(2, clip_t1.get_shape()[0])
57 | self.assertEqual(2, clip_t2.get_shape()[0])
58 |
59 | with self.test_session() as sess:
60 | clip_t1_result, clip_t2_result = sess.run([clip_t1, clip_t2])
61 | self.assertAllEqual([1, 2], clip_t1_result)
62 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], clip_t2_result)
63 |
64 | def test_clip_tensor_using_tensor_input(self):
65 | t1 = tf.constant([1, 2, 3], dtype=tf.int32)
66 | clip_t1 = shape_utils.clip_tensor(t1, tf.constant(2))
67 | t2 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32)
68 | clip_t2 = shape_utils.clip_tensor(t2, tf.constant(2))
69 |
70 | with self.test_session() as sess:
71 | clip_t1_result, clip_t2_result = sess.run([clip_t1, clip_t2])
72 | self.assertAllEqual([1, 2], clip_t1_result)
73 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], clip_t2_result)
74 |
75 | def test_pad_or_clip_tensor_using_integer_input(self):
76 | t1 = tf.constant([1], dtype=tf.int32)
77 | tt1 = shape_utils.pad_or_clip_tensor(t1, 2)
78 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32)
79 | tt2 = shape_utils.pad_or_clip_tensor(t2, 2)
80 |
81 | t3 = tf.constant([1, 2, 3], dtype=tf.int32)
82 | tt3 = shape_utils.clip_tensor(t3, 2)
83 | t4 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32)
84 | tt4 = shape_utils.clip_tensor(t4, 2)
85 |
86 | self.assertEqual(2, tt1.get_shape()[0])
87 | self.assertEqual(2, tt2.get_shape()[0])
88 | self.assertEqual(2, tt3.get_shape()[0])
89 | self.assertEqual(2, tt4.get_shape()[0])
90 |
91 | with self.test_session() as sess:
92 | tt1_result, tt2_result, tt3_result, tt4_result = sess.run(
93 | [tt1, tt2, tt3, tt4])
94 | self.assertAllEqual([1, 0], tt1_result)
95 | self.assertAllClose([[0.1, 0.2], [0, 0]], tt2_result)
96 | self.assertAllEqual([1, 2], tt3_result)
97 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], tt4_result)
98 |
99 | def test_pad_or_clip_tensor_using_tensor_input(self):
100 | t1 = tf.constant([1], dtype=tf.int32)
101 | tt1 = shape_utils.pad_or_clip_tensor(t1, tf.constant(2))
102 | t2 = tf.constant([[0.1, 0.2]], dtype=tf.float32)
103 | tt2 = shape_utils.pad_or_clip_tensor(t2, tf.constant(2))
104 |
105 | t3 = tf.constant([1, 2, 3], dtype=tf.int32)
106 | tt3 = shape_utils.clip_tensor(t3, tf.constant(2))
107 | t4 = tf.constant([[0.1, 0.2], [0.2, 0.4], [0.5, 0.8]], dtype=tf.float32)
108 | tt4 = shape_utils.clip_tensor(t4, tf.constant(2))
109 |
110 | with self.test_session() as sess:
111 | tt1_result, tt2_result, tt3_result, tt4_result = sess.run(
112 | [tt1, tt2, tt3, tt4])
113 | self.assertAllEqual([1, 0], tt1_result)
114 | self.assertAllClose([[0.1, 0.2], [0, 0]], tt2_result)
115 | self.assertAllEqual([1, 2], tt3_result)
116 | self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], tt4_result)
117 |
118 | def test_combines_static_dynamic_shape(self):
119 | tensor = tf.placeholder(tf.float32, shape=(None, 2, 3))
120 | combined_shape = shape_utils.combined_static_and_dynamic_shape(
121 | tensor)
122 | self.assertTrue(tf.contrib.framework.is_tensor(combined_shape[0]))
123 | self.assertListEqual(combined_shape[1:], [2, 3])
124 |
125 |
126 | if __name__ == '__main__':
127 | tf.test.main()
128 |
--------------------------------------------------------------------------------
/utils/np_box_list_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.np_box_list_test."""
17 |
18 | import numpy as np
19 | import tensorflow as tf
20 |
21 | from object_detection.utils import np_box_list
22 |
23 |
24 | class BoxListTest(tf.test.TestCase):
25 |
26 | def test_invalid_box_data(self):
27 | with self.assertRaises(ValueError):
28 | np_box_list.BoxList([0, 0, 1, 1])
29 |
30 | with self.assertRaises(ValueError):
31 | np_box_list.BoxList(np.array([[0, 0, 1, 1]], dtype=int))
32 |
33 | with self.assertRaises(ValueError):
34 | np_box_list.BoxList(np.array([0, 1, 1, 3, 4], dtype=float))
35 |
36 | with self.assertRaises(ValueError):
37 | np_box_list.BoxList(np.array([[0, 1, 1, 3], [3, 1, 1, 5]], dtype=float))
38 |
39 | def test_has_field_with_existed_field(self):
40 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
41 | [0.0, 0.0, 20.0, 20.0]],
42 | dtype=float)
43 | boxlist = np_box_list.BoxList(boxes)
44 | self.assertTrue(boxlist.has_field('boxes'))
45 |
46 | def test_has_field_with_nonexisted_field(self):
47 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
48 | [0.0, 0.0, 20.0, 20.0]],
49 | dtype=float)
50 | boxlist = np_box_list.BoxList(boxes)
51 | self.assertFalse(boxlist.has_field('scores'))
52 |
53 | def test_get_field_with_existed_field(self):
54 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
55 | [0.0, 0.0, 20.0, 20.0]],
56 | dtype=float)
57 | boxlist = np_box_list.BoxList(boxes)
58 | self.assertTrue(np.allclose(boxlist.get_field('boxes'), boxes))
59 |
60 | def test_get_field_with_nonexited_field(self):
61 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
62 | [0.0, 0.0, 20.0, 20.0]],
63 | dtype=float)
64 | boxlist = np_box_list.BoxList(boxes)
65 | with self.assertRaises(ValueError):
66 | boxlist.get_field('scores')
67 |
68 |
69 | class AddExtraFieldTest(tf.test.TestCase):
70 |
71 | def setUp(self):
72 | boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
73 | [0.0, 0.0, 20.0, 20.0]],
74 | dtype=float)
75 | self.boxlist = np_box_list.BoxList(boxes)
76 |
77 | def test_add_already_existed_field(self):
78 | with self.assertRaises(ValueError):
79 | self.boxlist.add_field('boxes', np.array([[0, 0, 0, 1, 0]], dtype=float))
80 |
81 | def test_add_invalid_field_data(self):
82 | with self.assertRaises(ValueError):
83 | self.boxlist.add_field('scores', np.array([0.5, 0.7], dtype=float))
84 | with self.assertRaises(ValueError):
85 | self.boxlist.add_field('scores',
86 | np.array([0.5, 0.7, 0.9, 0.1], dtype=float))
87 |
88 | def test_add_single_dimensional_field_data(self):
89 | boxlist = self.boxlist
90 | scores = np.array([0.5, 0.7, 0.9], dtype=float)
91 | boxlist.add_field('scores', scores)
92 | self.assertTrue(np.allclose(scores, self.boxlist.get_field('scores')))
93 |
94 | def test_add_multi_dimensional_field_data(self):
95 | boxlist = self.boxlist
96 | labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]],
97 | dtype=int)
98 | boxlist.add_field('labels', labels)
99 | self.assertTrue(np.allclose(labels, self.boxlist.get_field('labels')))
100 |
101 | def test_get_extra_fields(self):
102 | boxlist = self.boxlist
103 | self.assertSameElements(boxlist.get_extra_fields(), [])
104 |
105 | scores = np.array([0.5, 0.7, 0.9], dtype=float)
106 | boxlist.add_field('scores', scores)
107 | self.assertSameElements(boxlist.get_extra_fields(), ['scores'])
108 |
109 | labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]],
110 | dtype=int)
111 | boxlist.add_field('labels', labels)
112 | self.assertSameElements(boxlist.get_extra_fields(), ['scores', 'labels'])
113 |
114 | def test_get_coordinates(self):
115 | y_min, x_min, y_max, x_max = self.boxlist.get_coordinates()
116 |
117 | expected_y_min = np.array([3.0, 14.0, 0.0], dtype=float)
118 | expected_x_min = np.array([4.0, 14.0, 0.0], dtype=float)
119 | expected_y_max = np.array([6.0, 15.0, 20.0], dtype=float)
120 | expected_x_max = np.array([8.0, 15.0, 20.0], dtype=float)
121 |
122 | self.assertTrue(np.allclose(y_min, expected_y_min))
123 | self.assertTrue(np.allclose(x_min, expected_x_min))
124 | self.assertTrue(np.allclose(y_max, expected_y_max))
125 | self.assertTrue(np.allclose(x_max, expected_x_max))
126 |
127 | def test_num_boxes(self):
128 | boxes = np.array([[0., 0., 100., 100.], [10., 30., 50., 70.]], dtype=float)
129 | boxlist = np_box_list.BoxList(boxes)
130 | expected_num_boxes = 2
131 | self.assertEquals(boxlist.num_boxes(), expected_num_boxes)
132 |
133 |
134 | if __name__ == '__main__':
135 | tf.test.main()
136 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Functions for computing metrics like precision, recall, CorLoc and etc."""
17 | from __future__ import division
18 |
19 | import numpy as np
20 |
21 |
22 | def compute_precision_recall(scores, labels, num_gt):
23 | """Compute precision and recall.
24 |
25 | Args:
26 | scores: A float numpy array representing detection score
27 | labels: A boolean numpy array representing true/false positive labels
28 | num_gt: Number of ground truth instances
29 |
30 | Raises:
31 | ValueError: if the input is not of the correct format
32 |
33 | Returns:
34 | precision: Fraction of positive instances over detected ones. This value is
35 | None if no ground truth labels are present.
36 | recall: Fraction of detected positive instance over all positive instances.
37 | This value is None if no ground truth labels are present.
38 |
39 | """
40 | if not isinstance(
41 | labels, np.ndarray) or labels.dtype != np.bool or len(labels.shape) != 1:
42 | raise ValueError("labels must be single dimension bool numpy array")
43 |
44 | if not isinstance(
45 | scores, np.ndarray) or len(scores.shape) != 1:
46 | raise ValueError("scores must be single dimension numpy array")
47 |
48 | if num_gt < np.sum(labels):
49 | raise ValueError("Number of true positives must be smaller than num_gt.")
50 |
51 | if len(scores) != len(labels):
52 | raise ValueError("scores and labels must be of the same size.")
53 |
54 | if num_gt == 0:
55 | return None, None
56 |
57 | sorted_indices = np.argsort(scores)
58 | sorted_indices = sorted_indices[::-1]
59 | labels = labels.astype(int)
60 | true_positive_labels = labels[sorted_indices]
61 | false_positive_labels = 1 - true_positive_labels
62 | cum_true_positives = np.cumsum(true_positive_labels)
63 | cum_false_positives = np.cumsum(false_positive_labels)
64 | precision = cum_true_positives.astype(float) / (
65 | cum_true_positives + cum_false_positives)
66 | recall = cum_true_positives.astype(float) / num_gt
67 | return precision, recall
68 |
69 |
70 | def compute_average_precision(precision, recall):
71 | """Compute Average Precision according to the definition in VOCdevkit.
72 |
73 | Precision is modified to ensure that it does not decrease as recall
74 | decrease.
75 |
76 | Args:
77 | precision: A float [N, 1] numpy array of precisions
78 | recall: A float [N, 1] numpy array of recalls
79 |
80 | Raises:
81 | ValueError: if the input is not of the correct format
82 |
83 | Returns:
84 | average_precison: The area under the precision recall curve. NaN if
85 | precision and recall are None.
86 |
87 | """
88 | if precision is None:
89 | if recall is not None:
90 | raise ValueError("If precision is None, recall must also be None")
91 | return np.NAN
92 |
93 | if not isinstance(precision, np.ndarray) or not isinstance(recall,
94 | np.ndarray):
95 | raise ValueError("precision and recall must be numpy array")
96 | if precision.dtype != np.float or recall.dtype != np.float:
97 | raise ValueError("input must be float numpy array.")
98 | if len(precision) != len(recall):
99 | raise ValueError("precision and recall must be of the same size.")
100 | if not precision.size:
101 | return 0.0
102 | if np.amin(precision) < 0 or np.amax(precision) > 1:
103 | raise ValueError("Precision must be in the range of [0, 1].")
104 | if np.amin(recall) < 0 or np.amax(recall) > 1:
105 | raise ValueError("recall must be in the range of [0, 1].")
106 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)):
107 | raise ValueError("recall must be a non-decreasing array")
108 |
109 | recall = np.concatenate([[0], recall, [1]])
110 | precision = np.concatenate([[0], precision, [0]])
111 |
112 | # Preprocess precision to be a non-decreasing array
113 | for i in range(len(precision) - 2, -1, -1):
114 | precision[i] = np.maximum(precision[i], precision[i + 1])
115 |
116 | indices = np.where(recall[1:] != recall[:-1])[0] + 1
117 | average_precision = np.sum(
118 | (recall[indices] - recall[indices - 1]) * precision[indices])
119 | return average_precision
120 |
121 |
122 | def compute_cor_loc(num_gt_imgs_per_class,
123 | num_images_correctly_detected_per_class):
124 | """Compute CorLoc according to the definition in the following paper.
125 |
126 | https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf
127 |
128 | Returns nans if there are no ground truth images for a class.
129 |
130 | Args:
131 | num_gt_imgs_per_class: 1D array, representing number of images containing
132 | at least one object instance of a particular class
133 | num_images_correctly_detected_per_class: 1D array, representing number of
134 | images that are correctly detected at least one object instance of a
135 | particular class
136 |
137 | Returns:
138 | corloc_per_class: A float numpy array represents the corloc score of each
139 | class
140 | """
141 | return np.where(
142 | num_gt_imgs_per_class == 0,
143 | np.nan,
144 | num_images_correctly_detected_per_class / num_gt_imgs_per_class)
145 |
--------------------------------------------------------------------------------
/protos/box_coder_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/box_coder.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 | import object_detection.protos.faster_rcnn_box_coder_pb2
17 | import object_detection.protos.keypoint_box_coder_pb2
18 | import object_detection.protos.mean_stddev_box_coder_pb2
19 | import object_detection.protos.square_box_coder_pb2
20 |
21 |
22 | DESCRIPTOR = _descriptor.FileDescriptor(
23 | name='object_detection/protos/box_coder.proto',
24 | package='object_detection.protos',
25 | serialized_pb=_b('\n\'object_detection/protos/box_coder.proto\x12\x17object_detection.protos\x1a\x33object_detection/protos/faster_rcnn_box_coder.proto\x1a\x30object_detection/protos/keypoint_box_coder.proto\x1a\x33object_detection/protos/mean_stddev_box_coder.proto\x1a.object_detection/protos/square_box_coder.proto\"\xc7\x02\n\x08\x42oxCoder\x12L\n\x15\x66\x61ster_rcnn_box_coder\x18\x01 \x01(\x0b\x32+.object_detection.protos.FasterRcnnBoxCoderH\x00\x12L\n\x15mean_stddev_box_coder\x18\x02 \x01(\x0b\x32+.object_detection.protos.MeanStddevBoxCoderH\x00\x12\x43\n\x10square_box_coder\x18\x03 \x01(\x0b\x32\'.object_detection.protos.SquareBoxCoderH\x00\x12G\n\x12keypoint_box_coder\x18\x04 \x01(\x0b\x32).object_detection.protos.KeypointBoxCoderH\x00\x42\x11\n\x0f\x62ox_coder_oneof')
26 | ,
27 | dependencies=[object_detection.protos.faster_rcnn_box_coder_pb2.DESCRIPTOR,object_detection.protos.keypoint_box_coder_pb2.DESCRIPTOR,object_detection.protos.mean_stddev_box_coder_pb2.DESCRIPTOR,object_detection.protos.square_box_coder_pb2.DESCRIPTOR,])
28 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
29 |
30 |
31 |
32 |
33 | _BOXCODER = _descriptor.Descriptor(
34 | name='BoxCoder',
35 | full_name='object_detection.protos.BoxCoder',
36 | filename=None,
37 | file=DESCRIPTOR,
38 | containing_type=None,
39 | fields=[
40 | _descriptor.FieldDescriptor(
41 | name='faster_rcnn_box_coder', full_name='object_detection.protos.BoxCoder.faster_rcnn_box_coder', index=0,
42 | number=1, type=11, cpp_type=10, label=1,
43 | has_default_value=False, default_value=None,
44 | message_type=None, enum_type=None, containing_type=None,
45 | is_extension=False, extension_scope=None,
46 | options=None),
47 | _descriptor.FieldDescriptor(
48 | name='mean_stddev_box_coder', full_name='object_detection.protos.BoxCoder.mean_stddev_box_coder', index=1,
49 | number=2, type=11, cpp_type=10, label=1,
50 | has_default_value=False, default_value=None,
51 | message_type=None, enum_type=None, containing_type=None,
52 | is_extension=False, extension_scope=None,
53 | options=None),
54 | _descriptor.FieldDescriptor(
55 | name='square_box_coder', full_name='object_detection.protos.BoxCoder.square_box_coder', index=2,
56 | number=3, type=11, cpp_type=10, label=1,
57 | has_default_value=False, default_value=None,
58 | message_type=None, enum_type=None, containing_type=None,
59 | is_extension=False, extension_scope=None,
60 | options=None),
61 | _descriptor.FieldDescriptor(
62 | name='keypoint_box_coder', full_name='object_detection.protos.BoxCoder.keypoint_box_coder', index=3,
63 | number=4, type=11, cpp_type=10, label=1,
64 | has_default_value=False, default_value=None,
65 | message_type=None, enum_type=None, containing_type=None,
66 | is_extension=False, extension_scope=None,
67 | options=None),
68 | ],
69 | extensions=[
70 | ],
71 | nested_types=[],
72 | enum_types=[
73 | ],
74 | options=None,
75 | is_extendable=False,
76 | extension_ranges=[],
77 | oneofs=[
78 | _descriptor.OneofDescriptor(
79 | name='box_coder_oneof', full_name='object_detection.protos.BoxCoder.box_coder_oneof',
80 | index=0, containing_type=None, fields=[]),
81 | ],
82 | serialized_start=273,
83 | serialized_end=600,
84 | )
85 |
86 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].message_type = object_detection.protos.faster_rcnn_box_coder_pb2._FASTERRCNNBOXCODER
87 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].message_type = object_detection.protos.mean_stddev_box_coder_pb2._MEANSTDDEVBOXCODER
88 | _BOXCODER.fields_by_name['square_box_coder'].message_type = object_detection.protos.square_box_coder_pb2._SQUAREBOXCODER
89 | _BOXCODER.fields_by_name['keypoint_box_coder'].message_type = object_detection.protos.keypoint_box_coder_pb2._KEYPOINTBOXCODER
90 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append(
91 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'])
92 | _BOXCODER.fields_by_name['faster_rcnn_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof']
93 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append(
94 | _BOXCODER.fields_by_name['mean_stddev_box_coder'])
95 | _BOXCODER.fields_by_name['mean_stddev_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof']
96 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append(
97 | _BOXCODER.fields_by_name['square_box_coder'])
98 | _BOXCODER.fields_by_name['square_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof']
99 | _BOXCODER.oneofs_by_name['box_coder_oneof'].fields.append(
100 | _BOXCODER.fields_by_name['keypoint_box_coder'])
101 | _BOXCODER.fields_by_name['keypoint_box_coder'].containing_oneof = _BOXCODER.oneofs_by_name['box_coder_oneof']
102 | DESCRIPTOR.message_types_by_name['BoxCoder'] = _BOXCODER
103 |
104 | BoxCoder = _reflection.GeneratedProtocolMessageType('BoxCoder', (_message.Message,), dict(
105 | DESCRIPTOR = _BOXCODER,
106 | __module__ = 'object_detection.protos.box_coder_pb2'
107 | # @@protoc_insertion_point(class_scope:object_detection.protos.BoxCoder)
108 | ))
109 | _sym_db.RegisterMessage(BoxCoder)
110 |
111 |
112 | # @@protoc_insertion_point(module_scope)
113 |
--------------------------------------------------------------------------------
/utils/label_map_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Label map utility functions."""
17 |
18 | import logging
19 |
20 | import tensorflow as tf
21 | from google.protobuf import text_format
22 | #from object_detection.protos import string_int_label_map_pb2
23 | from protos import string_int_label_map_pb2
24 |
25 |
26 | def _validate_label_map(label_map):
27 | """Checks if a label map is valid.
28 |
29 | Args:
30 | label_map: StringIntLabelMap to validate.
31 |
32 | Raises:
33 | ValueError: if label map is invalid.
34 | """
35 | for item in label_map.item:
36 | if item.id < 1:
37 | raise ValueError('Label map ids should be >= 1.')
38 |
39 |
40 | def create_category_index(categories):
41 | """Creates dictionary of COCO compatible categories keyed by category id.
42 |
43 | Args:
44 | categories: a list of dicts, each of which has the following keys:
45 | 'id': (required) an integer id uniquely identifying this category.
46 | 'name': (required) string representing category name
47 | e.g., 'cat', 'dog', 'pizza'.
48 |
49 | Returns:
50 | category_index: a dict containing the same entries as categories, but keyed
51 | by the 'id' field of each category.
52 | """
53 | category_index = {}
54 | for cat in categories:
55 | category_index[cat['id']] = cat
56 | return category_index
57 |
58 |
59 | def convert_label_map_to_categories(label_map,
60 | max_num_classes,
61 | use_display_name=True):
62 | """Loads label map proto and returns categories list compatible with eval.
63 |
64 | This function loads a label map and returns a list of dicts, each of which
65 | has the following keys:
66 | 'id': (required) an integer id uniquely identifying this category.
67 | 'name': (required) string representing category name
68 | e.g., 'cat', 'dog', 'pizza'.
69 | We only allow class into the list if its id-label_id_offset is
70 | between 0 (inclusive) and max_num_classes (exclusive).
71 | If there are several items mapping to the same id in the label map,
72 | we will only keep the first one in the categories list.
73 |
74 | Args:
75 | label_map: a StringIntLabelMapProto or None. If None, a default categories
76 | list is created with max_num_classes categories.
77 | max_num_classes: maximum number of (consecutive) label indices to include.
78 | use_display_name: (boolean) choose whether to load 'display_name' field
79 | as category name. If False or if the display_name field does not exist,
80 | uses 'name' field as category names instead.
81 | Returns:
82 | categories: a list of dictionaries representing all possible categories.
83 | """
84 | categories = []
85 | list_of_ids_already_added = []
86 | if not label_map:
87 | label_id_offset = 1
88 | for class_id in range(max_num_classes):
89 | categories.append({
90 | 'id': class_id + label_id_offset,
91 | 'name': 'category_{}'.format(class_id + label_id_offset)
92 | })
93 | return categories
94 | for item in label_map.item:
95 | if not 0 < item.id <= max_num_classes:
96 | logging.info('Ignore item %d since it falls outside of requested '
97 | 'label range.', item.id)
98 | continue
99 | if use_display_name and item.HasField('display_name'):
100 | name = item.display_name
101 | else:
102 | name = item.name
103 | if item.id not in list_of_ids_already_added:
104 | list_of_ids_already_added.append(item.id)
105 | categories.append({'id': item.id, 'name': name})
106 | return categories
107 |
108 |
109 | def load_labelmap(path):
110 | """Loads label map proto.
111 |
112 | Args:
113 | path: path to StringIntLabelMap proto text file.
114 | Returns:
115 | a StringIntLabelMapProto
116 | """
117 | with tf.gfile.GFile(path, 'r') as fid:
118 | label_map_string = fid.read()
119 | label_map = string_int_label_map_pb2.StringIntLabelMap()
120 | try:
121 | text_format.Merge(label_map_string, label_map)
122 | except text_format.ParseError:
123 | label_map.ParseFromString(label_map_string)
124 | _validate_label_map(label_map)
125 | return label_map
126 |
127 |
128 | def get_label_map_dict(label_map_path, use_display_name=False):
129 | """Reads a label map and returns a dictionary of label names to id.
130 |
131 | Args:
132 | label_map_path: path to label_map.
133 | use_display_name: whether to use the label map items' display names as keys.
134 |
135 | Returns:
136 | A dictionary mapping label names to id.
137 | """
138 | label_map = load_labelmap(label_map_path)
139 | label_map_dict = {}
140 | for item in label_map.item:
141 | if use_display_name:
142 | label_map_dict[item.display_name] = item.id
143 | else:
144 | label_map_dict[item.name] = item.id
145 | return label_map_dict
146 |
147 |
148 | def create_category_index_from_labelmap(label_map_path):
149 | """Reads a label map and returns a category index.
150 |
151 | Args:
152 | label_map_path: Path to `StringIntLabelMap` proto text file.
153 |
154 | Returns:
155 | A category index, which is a dictionary that maps integer ids to dicts
156 | containing categories, e.g.
157 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...}
158 | """
159 | label_map = load_labelmap(label_map_path)
160 | max_num_classes = max(item.id for item in label_map.item)
161 | categories = convert_label_map_to_categories(label_map, max_num_classes)
162 | return create_category_index(categories)
163 |
164 |
165 | def create_class_agnostic_category_index():
166 | """Creates a category index with a single `object` class."""
167 | return {1: {'id': 1, 'name': 'object'}}
168 |
--------------------------------------------------------------------------------
/protos/faster_rcnn.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto2";
2 |
3 | package object_detection.protos;
4 |
5 | import "object_detection/protos/anchor_generator.proto";
6 | import "object_detection/protos/box_predictor.proto";
7 | import "object_detection/protos/hyperparams.proto";
8 | import "object_detection/protos/image_resizer.proto";
9 | import "object_detection/protos/losses.proto";
10 | import "object_detection/protos/post_processing.proto";
11 |
12 | // Configuration for Faster R-CNN models.
13 | // See meta_architectures/faster_rcnn_meta_arch.py and models/model_builder.py
14 | //
15 | // Naming conventions:
16 | // Faster R-CNN models have two stages: a first stage region proposal network
17 | // (or RPN) and a second stage box classifier. We thus use the prefixes
18 | // `first_stage_` and `second_stage_` to indicate the stage to which each
19 | // parameter pertains when relevant.
20 | message FasterRcnn {
21 |
22 | // Whether to construct only the Region Proposal Network (RPN).
23 | optional bool first_stage_only = 1 [default=false];
24 |
25 | // Number of classes to predict.
26 | optional int32 num_classes = 3;
27 |
28 | // Image resizer for preprocessing the input image.
29 | optional ImageResizer image_resizer = 4;
30 |
31 | // Feature extractor config.
32 | optional FasterRcnnFeatureExtractor feature_extractor = 5;
33 |
34 |
35 | // (First stage) region proposal network (RPN) parameters.
36 |
37 | // Anchor generator to compute RPN anchors.
38 | optional AnchorGenerator first_stage_anchor_generator = 6;
39 |
40 | // Atrous rate for the convolution op applied to the
41 | // `first_stage_features_to_crop` tensor to obtain box predictions.
42 | optional int32 first_stage_atrous_rate = 7 [default=1];
43 |
44 | // Hyperparameters for the convolutional RPN box predictor.
45 | optional Hyperparams first_stage_box_predictor_conv_hyperparams = 8;
46 |
47 | // Kernel size to use for the convolution op just prior to RPN box
48 | // predictions.
49 | optional int32 first_stage_box_predictor_kernel_size = 9 [default=3];
50 |
51 | // Output depth for the convolution op just prior to RPN box predictions.
52 | optional int32 first_stage_box_predictor_depth = 10 [default=512];
53 |
54 | // The batch size to use for computing the first stage objectness and
55 | // location losses.
56 | optional int32 first_stage_minibatch_size = 11 [default=256];
57 |
58 | // Fraction of positive examples per image for the RPN.
59 | optional float first_stage_positive_balance_fraction = 12 [default=0.5];
60 |
61 | // Non max suppression score threshold applied to first stage RPN proposals.
62 | optional float first_stage_nms_score_threshold = 13 [default=0.0];
63 |
64 | // Non max suppression IOU threshold applied to first stage RPN proposals.
65 | optional float first_stage_nms_iou_threshold = 14 [default=0.7];
66 |
67 | // Maximum number of RPN proposals retained after first stage postprocessing.
68 | optional int32 first_stage_max_proposals = 15 [default=300];
69 |
70 | // First stage RPN localization loss weight.
71 | optional float first_stage_localization_loss_weight = 16 [default=1.0];
72 |
73 | // First stage RPN objectness loss weight.
74 | optional float first_stage_objectness_loss_weight = 17 [default=1.0];
75 |
76 |
77 | // Per-region cropping parameters.
78 | // Note that if a R-FCN model is constructed the per region cropping
79 | // parameters below are ignored.
80 |
81 | // Output size (width and height are set to be the same) of the initial
82 | // bilinear interpolation based cropping during ROI pooling.
83 | optional int32 initial_crop_size = 18;
84 |
85 | // Kernel size of the max pool op on the cropped feature map during
86 | // ROI pooling.
87 | optional int32 maxpool_kernel_size = 19;
88 |
89 | // Stride of the max pool op on the cropped feature map during ROI pooling.
90 | optional int32 maxpool_stride = 20;
91 |
92 |
93 | // (Second stage) box classifier parameters
94 |
95 | // Hyperparameters for the second stage box predictor. If box predictor type
96 | // is set to rfcn_box_predictor, a R-FCN model is constructed, otherwise a
97 | // Faster R-CNN model is constructed.
98 | optional BoxPredictor second_stage_box_predictor = 21;
99 |
100 | // The batch size per image used for computing the classification and refined
101 | // location loss of the box classifier.
102 | // Note that this field is ignored if `hard_example_miner` is configured.
103 | optional int32 second_stage_batch_size = 22 [default=64];
104 |
105 | // Fraction of positive examples to use per image for the box classifier.
106 | optional float second_stage_balance_fraction = 23 [default=0.25];
107 |
108 | // Post processing to apply on the second stage box classifier predictions.
109 | // Note: the `score_converter` provided to the FasterRCNNMetaArch constructor
110 | // is taken from this `second_stage_post_processing` proto.
111 | optional PostProcessing second_stage_post_processing = 24;
112 |
113 | // Second stage refined localization loss weight.
114 | optional float second_stage_localization_loss_weight = 25 [default=1.0];
115 |
116 | // Second stage classification loss weight
117 | optional float second_stage_classification_loss_weight = 26 [default=1.0];
118 |
119 | // Second stage instance mask loss weight. Note that this is only applicable
120 | // when `MaskRCNNBoxPredictor` is selected for second stage and configured to
121 | // predict instance masks.
122 | optional float second_stage_mask_prediction_loss_weight = 27 [default=1.0];
123 |
124 | // If not left to default, applies hard example mining only to classification
125 | // and localization loss..
126 | optional HardExampleMiner hard_example_miner = 28;
127 |
128 | // Loss for second stage box classifers, supports Softmax and Sigmoid.
129 | // Note that score converter must be consistent with loss type.
130 | // When there are multiple labels assigned to the same boxes, recommend
131 | // to use sigmoid loss and enable merge_multiple_label_boxes.
132 | // If not specified, Softmax loss is used as default.
133 | optional ClassificationLoss second_stage_classification_loss = 29;
134 | }
135 |
136 |
137 | message FasterRcnnFeatureExtractor {
138 | // Type of Faster R-CNN model (e.g., 'faster_rcnn_resnet101';
139 | // See builders/model_builder.py for expected types).
140 | optional string type = 1;
141 |
142 | // Output stride of extracted RPN feature map.
143 | optional int32 first_stage_features_stride = 2 [default=16];
144 |
145 | // Whether to update batch norm parameters during training or not.
146 | // When training with a relative large batch size (e.g. 8), it could be
147 | // desirable to enable batch norm update.
148 | optional bool batch_norm_trainable = 3 [default=false];
149 | }
150 |
--------------------------------------------------------------------------------
/utils/label_map_util_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests for object_detection.utils.label_map_util."""
17 |
18 | import os
19 | import tensorflow as tf
20 |
21 | from google.protobuf import text_format
22 | from object_detection.protos import string_int_label_map_pb2
23 | from object_detection.utils import label_map_util
24 |
25 |
26 | class LabelMapUtilTest(tf.test.TestCase):
27 |
28 | def _generate_label_map(self, num_classes):
29 | label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
30 | for i in range(1, num_classes + 1):
31 | item = label_map_proto.item.add()
32 | item.id = i
33 | item.name = 'label_' + str(i)
34 | item.display_name = str(i)
35 | return label_map_proto
36 |
37 | def test_get_label_map_dict(self):
38 | label_map_string = """
39 | item {
40 | id:2
41 | name:'cat'
42 | }
43 | item {
44 | id:1
45 | name:'dog'
46 | }
47 | """
48 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
49 | with tf.gfile.Open(label_map_path, 'wb') as f:
50 | f.write(label_map_string)
51 |
52 | label_map_dict = label_map_util.get_label_map_dict(label_map_path)
53 | self.assertEqual(label_map_dict['dog'], 1)
54 | self.assertEqual(label_map_dict['cat'], 2)
55 |
56 | def test_get_label_map_dict_display(self):
57 | label_map_string = """
58 | item {
59 | id:2
60 | display_name:'cat'
61 | }
62 | item {
63 | id:1
64 | display_name:'dog'
65 | }
66 | """
67 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
68 | with tf.gfile.Open(label_map_path, 'wb') as f:
69 | f.write(label_map_string)
70 |
71 | label_map_dict = label_map_util.get_label_map_dict(
72 | label_map_path, use_display_name=True)
73 | self.assertEqual(label_map_dict['dog'], 1)
74 | self.assertEqual(label_map_dict['cat'], 2)
75 |
76 | def test_load_bad_label_map(self):
77 | label_map_string = """
78 | item {
79 | id:0
80 | name:'class that should not be indexed at zero'
81 | }
82 | item {
83 | id:2
84 | name:'cat'
85 | }
86 | item {
87 | id:1
88 | name:'dog'
89 | }
90 | """
91 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
92 | with tf.gfile.Open(label_map_path, 'wb') as f:
93 | f.write(label_map_string)
94 |
95 | with self.assertRaises(ValueError):
96 | label_map_util.load_labelmap(label_map_path)
97 |
98 | def test_keep_categories_with_unique_id(self):
99 | label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
100 | label_map_string = """
101 | item {
102 | id:2
103 | name:'cat'
104 | }
105 | item {
106 | id:1
107 | name:'child'
108 | }
109 | item {
110 | id:1
111 | name:'person'
112 | }
113 | item {
114 | id:1
115 | name:'n00007846'
116 | }
117 | """
118 | text_format.Merge(label_map_string, label_map_proto)
119 | categories = label_map_util.convert_label_map_to_categories(
120 | label_map_proto, max_num_classes=3)
121 | self.assertListEqual([{
122 | 'id': 2,
123 | 'name': u'cat'
124 | }, {
125 | 'id': 1,
126 | 'name': u'child'
127 | }], categories)
128 |
129 | def test_convert_label_map_to_categories_no_label_map(self):
130 | categories = label_map_util.convert_label_map_to_categories(
131 | None, max_num_classes=3)
132 | expected_categories_list = [{
133 | 'name': u'category_1',
134 | 'id': 1
135 | }, {
136 | 'name': u'category_2',
137 | 'id': 2
138 | }, {
139 | 'name': u'category_3',
140 | 'id': 3
141 | }]
142 | self.assertListEqual(expected_categories_list, categories)
143 |
144 | def test_convert_label_map_to_coco_categories(self):
145 | label_map_proto = self._generate_label_map(num_classes=4)
146 | categories = label_map_util.convert_label_map_to_categories(
147 | label_map_proto, max_num_classes=3)
148 | expected_categories_list = [{
149 | 'name': u'1',
150 | 'id': 1
151 | }, {
152 | 'name': u'2',
153 | 'id': 2
154 | }, {
155 | 'name': u'3',
156 | 'id': 3
157 | }]
158 | self.assertListEqual(expected_categories_list, categories)
159 |
160 | def test_convert_label_map_to_coco_categories_with_few_classes(self):
161 | label_map_proto = self._generate_label_map(num_classes=4)
162 | cat_no_offset = label_map_util.convert_label_map_to_categories(
163 | label_map_proto, max_num_classes=2)
164 | expected_categories_list = [{
165 | 'name': u'1',
166 | 'id': 1
167 | }, {
168 | 'name': u'2',
169 | 'id': 2
170 | }]
171 | self.assertListEqual(expected_categories_list, cat_no_offset)
172 |
173 | def test_create_category_index(self):
174 | categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}]
175 | category_index = label_map_util.create_category_index(categories)
176 | self.assertDictEqual({
177 | 1: {
178 | 'name': u'1',
179 | 'id': 1
180 | },
181 | 2: {
182 | 'name': u'2',
183 | 'id': 2
184 | }
185 | }, category_index)
186 |
187 | def test_create_category_index_from_labelmap(self):
188 | label_map_string = """
189 | item {
190 | id:2
191 | name:'cat'
192 | }
193 | item {
194 | id:1
195 | name:'dog'
196 | }
197 | """
198 | label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
199 | with tf.gfile.Open(label_map_path, 'wb') as f:
200 | f.write(label_map_string)
201 |
202 | category_index = label_map_util.create_category_index_from_labelmap(
203 | label_map_path)
204 | self.assertDictEqual({
205 | 1: {
206 | 'name': u'dog',
207 | 'id': 1
208 | },
209 | 2: {
210 | 'name': u'cat',
211 | 'id': 2
212 | }
213 | }, category_index)
214 |
215 |
216 | if __name__ == '__main__':
217 | tf.test.main()
218 |
--------------------------------------------------------------------------------
/utils/learning_schedules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Library of common learning rate schedules."""
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 |
20 |
21 | def exponential_decay_with_burnin(global_step,
22 | learning_rate_base,
23 | learning_rate_decay_steps,
24 | learning_rate_decay_factor,
25 | burnin_learning_rate=0.0,
26 | burnin_steps=0):
27 | """Exponential decay schedule with burn-in period.
28 |
29 | In this schedule, learning rate is fixed at burnin_learning_rate
30 | for a fixed period, before transitioning to a regular exponential
31 | decay schedule.
32 |
33 | Args:
34 | global_step: int tensor representing global step.
35 | learning_rate_base: base learning rate.
36 | learning_rate_decay_steps: steps to take between decaying the learning rate.
37 | Note that this includes the number of burn-in steps.
38 | learning_rate_decay_factor: multiplicative factor by which to decay
39 | learning rate.
40 | burnin_learning_rate: initial learning rate during burn-in period. If
41 | 0.0 (which is the default), then the burn-in learning rate is simply
42 | set to learning_rate_base.
43 | burnin_steps: number of steps to use burnin learning rate.
44 |
45 | Returns:
46 | a (scalar) float tensor representing learning rate
47 | """
48 | if burnin_learning_rate == 0:
49 | burnin_learning_rate = learning_rate_base
50 | post_burnin_learning_rate = tf.train.exponential_decay(
51 | learning_rate_base,
52 | global_step,
53 | learning_rate_decay_steps,
54 | learning_rate_decay_factor,
55 | staircase=True)
56 | return tf.cond(
57 | tf.less(global_step, burnin_steps),
58 | lambda: tf.convert_to_tensor(burnin_learning_rate),
59 | lambda: post_burnin_learning_rate)
60 |
61 |
62 | def cosine_decay_with_warmup(global_step,
63 | learning_rate_base,
64 | total_steps,
65 | warmup_learning_rate=0.0,
66 | warmup_steps=0):
67 | """Cosine decay schedule with warm up period.
68 |
69 | Cosine annealing learning rate as described in:
70 | Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
71 | ICLR 2017. https://arxiv.org/abs/1608.03983
72 | In this schedule, the learning rate grows linearly from warmup_learning_rate
73 | to learning_rate_base for warmup_steps, then transitions to a cosine decay
74 | schedule.
75 |
76 | Args:
77 | global_step: int64 (scalar) tensor representing global step.
78 | learning_rate_base: base learning rate.
79 | total_steps: total number of training steps.
80 | warmup_learning_rate: initial learning rate for warm up.
81 | warmup_steps: number of warmup steps.
82 |
83 | Returns:
84 | a (scalar) float tensor representing learning rate.
85 |
86 | Raises:
87 | ValueError: if warmup_learning_rate is larger than learning_rate_base,
88 | or if warmup_steps is larger than total_steps.
89 | """
90 | if learning_rate_base < warmup_learning_rate:
91 | raise ValueError('learning_rate_base must be larger '
92 | 'or equal to warmup_learning_rate.')
93 | if total_steps < warmup_steps:
94 | raise ValueError('total_steps must be larger or equal to '
95 | 'warmup_steps.')
96 | learning_rate = 0.5 * learning_rate_base * (
97 | 1 + tf.cos(np.pi * tf.cast(
98 | global_step - warmup_steps, tf.float32
99 | ) / float(total_steps - warmup_steps)))
100 | if warmup_steps > 0:
101 | slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
102 | pre_cosine_learning_rate = slope * tf.cast(
103 | global_step, tf.float32) + warmup_learning_rate
104 | learning_rate = tf.cond(
105 | tf.less(global_step, warmup_steps), lambda: pre_cosine_learning_rate,
106 | lambda: learning_rate)
107 | return learning_rate
108 |
109 |
110 | def manual_stepping(global_step, boundaries, rates):
111 | """Manually stepped learning rate schedule.
112 |
113 | This function provides fine grained control over learning rates. One must
114 | specify a sequence of learning rates as well as a set of integer steps
115 | at which the current learning rate must transition to the next. For example,
116 | if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
117 | rate returned by this function is .1 for global_step=0,...,4, .01 for
118 | global_step=5...9, and .001 for global_step=10 and onward.
119 |
120 | Args:
121 | global_step: int64 (scalar) tensor representing global step.
122 | boundaries: a list of global steps at which to switch learning
123 | rates. This list is assumed to consist of increasing positive integers.
124 | rates: a list of (float) learning rates corresponding to intervals between
125 | the boundaries. The length of this list must be exactly
126 | len(boundaries) + 1.
127 |
128 | Returns:
129 | a (scalar) float tensor representing learning rate
130 | Raises:
131 | ValueError: if one of the following checks fails:
132 | 1. boundaries is a strictly increasing list of positive integers
133 | 2. len(rates) == len(boundaries) + 1
134 | """
135 | if any([b < 0 for b in boundaries]) or any(
136 | [not isinstance(b, int) for b in boundaries]):
137 | raise ValueError('boundaries must be a list of positive integers')
138 | if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
139 | raise ValueError('Entries in boundaries must be strictly increasing.')
140 | if any([not isinstance(r, float) for r in rates]):
141 | raise ValueError('Learning rates must be floats')
142 | if len(rates) != len(boundaries) + 1:
143 | raise ValueError('Number of provided learning rates must exceed '
144 | 'number of boundary points by exactly 1.')
145 | step_boundaries = tf.constant(boundaries, tf.int64)
146 | learning_rates = tf.constant(rates, tf.float32)
147 | unreached_boundaries = tf.reshape(
148 | tf.where(tf.greater(step_boundaries, global_step)), [-1])
149 | unreached_boundaries = tf.concat([unreached_boundaries, [len(boundaries)]], 0)
150 | index = tf.reshape(tf.reduce_min(unreached_boundaries), [1])
151 | return tf.reshape(tf.slice(learning_rates, index, [1]), [])
152 |
--------------------------------------------------------------------------------
/protos/eval_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/eval.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/eval.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n\"object_detection/protos/eval.proto\x12\x17object_detection.protos\"\x80\x03\n\nEvalConfig\x12\x1e\n\x12num_visualizations\x18\x01 \x01(\r:\x02\x31\x30\x12\x1a\n\x0cnum_examples\x18\x02 \x01(\r:\x04\x35\x30\x30\x30\x12\x1f\n\x12\x65val_interval_secs\x18\x03 \x01(\r:\x03\x33\x30\x30\x12\x14\n\tmax_evals\x18\x04 \x01(\r:\x01\x30\x12\x19\n\nsave_graph\x18\x05 \x01(\x08:\x05\x66\x61lse\x12\"\n\x18visualization_export_dir\x18\x06 \x01(\t:\x00\x12\x15\n\x0b\x65val_master\x18\x07 \x01(\t:\x00\x12\'\n\x0bmetrics_set\x18\x08 \x01(\t:\x12pascal_voc_metrics\x12\x15\n\x0b\x65xport_path\x18\t \x01(\t:\x00\x12!\n\x12ignore_groundtruth\x18\n \x01(\x08:\x05\x66\x61lse\x12\"\n\x13use_moving_averages\x18\x0b \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x65val_instance_masks\x18\x0c \x01(\x08:\x05\x66\x61lse')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 |
28 | _EVALCONFIG = _descriptor.Descriptor(
29 | name='EvalConfig',
30 | full_name='object_detection.protos.EvalConfig',
31 | filename=None,
32 | file=DESCRIPTOR,
33 | containing_type=None,
34 | fields=[
35 | _descriptor.FieldDescriptor(
36 | name='num_visualizations', full_name='object_detection.protos.EvalConfig.num_visualizations', index=0,
37 | number=1, type=13, cpp_type=3, label=1,
38 | has_default_value=True, default_value=10,
39 | message_type=None, enum_type=None, containing_type=None,
40 | is_extension=False, extension_scope=None,
41 | options=None),
42 | _descriptor.FieldDescriptor(
43 | name='num_examples', full_name='object_detection.protos.EvalConfig.num_examples', index=1,
44 | number=2, type=13, cpp_type=3, label=1,
45 | has_default_value=True, default_value=5000,
46 | message_type=None, enum_type=None, containing_type=None,
47 | is_extension=False, extension_scope=None,
48 | options=None),
49 | _descriptor.FieldDescriptor(
50 | name='eval_interval_secs', full_name='object_detection.protos.EvalConfig.eval_interval_secs', index=2,
51 | number=3, type=13, cpp_type=3, label=1,
52 | has_default_value=True, default_value=300,
53 | message_type=None, enum_type=None, containing_type=None,
54 | is_extension=False, extension_scope=None,
55 | options=None),
56 | _descriptor.FieldDescriptor(
57 | name='max_evals', full_name='object_detection.protos.EvalConfig.max_evals', index=3,
58 | number=4, type=13, cpp_type=3, label=1,
59 | has_default_value=True, default_value=0,
60 | message_type=None, enum_type=None, containing_type=None,
61 | is_extension=False, extension_scope=None,
62 | options=None),
63 | _descriptor.FieldDescriptor(
64 | name='save_graph', full_name='object_detection.protos.EvalConfig.save_graph', index=4,
65 | number=5, type=8, cpp_type=7, label=1,
66 | has_default_value=True, default_value=False,
67 | message_type=None, enum_type=None, containing_type=None,
68 | is_extension=False, extension_scope=None,
69 | options=None),
70 | _descriptor.FieldDescriptor(
71 | name='visualization_export_dir', full_name='object_detection.protos.EvalConfig.visualization_export_dir', index=5,
72 | number=6, type=9, cpp_type=9, label=1,
73 | has_default_value=True, default_value=_b("").decode('utf-8'),
74 | message_type=None, enum_type=None, containing_type=None,
75 | is_extension=False, extension_scope=None,
76 | options=None),
77 | _descriptor.FieldDescriptor(
78 | name='eval_master', full_name='object_detection.protos.EvalConfig.eval_master', index=6,
79 | number=7, type=9, cpp_type=9, label=1,
80 | has_default_value=True, default_value=_b("").decode('utf-8'),
81 | message_type=None, enum_type=None, containing_type=None,
82 | is_extension=False, extension_scope=None,
83 | options=None),
84 | _descriptor.FieldDescriptor(
85 | name='metrics_set', full_name='object_detection.protos.EvalConfig.metrics_set', index=7,
86 | number=8, type=9, cpp_type=9, label=1,
87 | has_default_value=True, default_value=_b("pascal_voc_metrics").decode('utf-8'),
88 | message_type=None, enum_type=None, containing_type=None,
89 | is_extension=False, extension_scope=None,
90 | options=None),
91 | _descriptor.FieldDescriptor(
92 | name='export_path', full_name='object_detection.protos.EvalConfig.export_path', index=8,
93 | number=9, type=9, cpp_type=9, label=1,
94 | has_default_value=True, default_value=_b("").decode('utf-8'),
95 | message_type=None, enum_type=None, containing_type=None,
96 | is_extension=False, extension_scope=None,
97 | options=None),
98 | _descriptor.FieldDescriptor(
99 | name='ignore_groundtruth', full_name='object_detection.protos.EvalConfig.ignore_groundtruth', index=9,
100 | number=10, type=8, cpp_type=7, label=1,
101 | has_default_value=True, default_value=False,
102 | message_type=None, enum_type=None, containing_type=None,
103 | is_extension=False, extension_scope=None,
104 | options=None),
105 | _descriptor.FieldDescriptor(
106 | name='use_moving_averages', full_name='object_detection.protos.EvalConfig.use_moving_averages', index=10,
107 | number=11, type=8, cpp_type=7, label=1,
108 | has_default_value=True, default_value=False,
109 | message_type=None, enum_type=None, containing_type=None,
110 | is_extension=False, extension_scope=None,
111 | options=None),
112 | _descriptor.FieldDescriptor(
113 | name='eval_instance_masks', full_name='object_detection.protos.EvalConfig.eval_instance_masks', index=11,
114 | number=12, type=8, cpp_type=7, label=1,
115 | has_default_value=True, default_value=False,
116 | message_type=None, enum_type=None, containing_type=None,
117 | is_extension=False, extension_scope=None,
118 | options=None),
119 | ],
120 | extensions=[
121 | ],
122 | nested_types=[],
123 | enum_types=[
124 | ],
125 | options=None,
126 | is_extendable=False,
127 | extension_ranges=[],
128 | oneofs=[
129 | ],
130 | serialized_start=64,
131 | serialized_end=448,
132 | )
133 |
134 | DESCRIPTOR.message_types_by_name['EvalConfig'] = _EVALCONFIG
135 |
136 | EvalConfig = _reflection.GeneratedProtocolMessageType('EvalConfig', (_message.Message,), dict(
137 | DESCRIPTOR = _EVALCONFIG,
138 | __module__ = 'object_detection.protos.eval_pb2'
139 | # @@protoc_insertion_point(class_scope:object_detection.protos.EvalConfig)
140 | ))
141 | _sym_db.RegisterMessage(EvalConfig)
142 |
143 |
144 | # @@protoc_insertion_point(module_scope)
145 |
--------------------------------------------------------------------------------
/protos/post_processing_pb2.py:
--------------------------------------------------------------------------------
1 | # Generated by the protocol buffer compiler. DO NOT EDIT!
2 | # source: object_detection/protos/post_processing.proto
3 |
4 | import sys
5 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
6 | from google.protobuf import descriptor as _descriptor
7 | from google.protobuf import message as _message
8 | from google.protobuf import reflection as _reflection
9 | from google.protobuf import symbol_database as _symbol_database
10 | from google.protobuf import descriptor_pb2
11 | # @@protoc_insertion_point(imports)
12 |
13 | _sym_db = _symbol_database.Default()
14 |
15 |
16 |
17 |
18 | DESCRIPTOR = _descriptor.FileDescriptor(
19 | name='object_detection/protos/post_processing.proto',
20 | package='object_detection.protos',
21 | serialized_pb=_b('\n-object_detection/protos/post_processing.proto\x12\x17object_detection.protos\"\x9a\x01\n\x16\x42\x61tchNonMaxSuppression\x12\x1a\n\x0fscore_threshold\x18\x01 \x01(\x02:\x01\x30\x12\x1a\n\riou_threshold\x18\x02 \x01(\x02:\x03\x30.6\x12%\n\x18max_detections_per_class\x18\x03 \x01(\x05:\x03\x31\x30\x30\x12!\n\x14max_total_detections\x18\x05 \x01(\x05:\x03\x31\x30\x30\"\x91\x02\n\x0ePostProcessing\x12R\n\x19\x62\x61tch_non_max_suppression\x18\x01 \x01(\x0b\x32/.object_detection.protos.BatchNonMaxSuppression\x12Y\n\x0fscore_converter\x18\x02 \x01(\x0e\x32\x36.object_detection.protos.PostProcessing.ScoreConverter:\x08IDENTITY\x12\x16\n\x0blogit_scale\x18\x03 \x01(\x02:\x01\x31\"8\n\x0eScoreConverter\x12\x0c\n\x08IDENTITY\x10\x00\x12\x0b\n\x07SIGMOID\x10\x01\x12\x0b\n\x07SOFTMAX\x10\x02')
22 | )
23 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
24 |
25 |
26 |
27 | _POSTPROCESSING_SCORECONVERTER = _descriptor.EnumDescriptor(
28 | name='ScoreConverter',
29 | full_name='object_detection.protos.PostProcessing.ScoreConverter',
30 | filename=None,
31 | file=DESCRIPTOR,
32 | values=[
33 | _descriptor.EnumValueDescriptor(
34 | name='IDENTITY', index=0, number=0,
35 | options=None,
36 | type=None),
37 | _descriptor.EnumValueDescriptor(
38 | name='SIGMOID', index=1, number=1,
39 | options=None,
40 | type=None),
41 | _descriptor.EnumValueDescriptor(
42 | name='SOFTMAX', index=2, number=2,
43 | options=None,
44 | type=None),
45 | ],
46 | containing_type=None,
47 | options=None,
48 | serialized_start=449,
49 | serialized_end=505,
50 | )
51 | _sym_db.RegisterEnumDescriptor(_POSTPROCESSING_SCORECONVERTER)
52 |
53 |
54 | _BATCHNONMAXSUPPRESSION = _descriptor.Descriptor(
55 | name='BatchNonMaxSuppression',
56 | full_name='object_detection.protos.BatchNonMaxSuppression',
57 | filename=None,
58 | file=DESCRIPTOR,
59 | containing_type=None,
60 | fields=[
61 | _descriptor.FieldDescriptor(
62 | name='score_threshold', full_name='object_detection.protos.BatchNonMaxSuppression.score_threshold', index=0,
63 | number=1, type=2, cpp_type=6, label=1,
64 | has_default_value=True, default_value=0,
65 | message_type=None, enum_type=None, containing_type=None,
66 | is_extension=False, extension_scope=None,
67 | options=None),
68 | _descriptor.FieldDescriptor(
69 | name='iou_threshold', full_name='object_detection.protos.BatchNonMaxSuppression.iou_threshold', index=1,
70 | number=2, type=2, cpp_type=6, label=1,
71 | has_default_value=True, default_value=0.6,
72 | message_type=None, enum_type=None, containing_type=None,
73 | is_extension=False, extension_scope=None,
74 | options=None),
75 | _descriptor.FieldDescriptor(
76 | name='max_detections_per_class', full_name='object_detection.protos.BatchNonMaxSuppression.max_detections_per_class', index=2,
77 | number=3, type=5, cpp_type=1, label=1,
78 | has_default_value=True, default_value=100,
79 | message_type=None, enum_type=None, containing_type=None,
80 | is_extension=False, extension_scope=None,
81 | options=None),
82 | _descriptor.FieldDescriptor(
83 | name='max_total_detections', full_name='object_detection.protos.BatchNonMaxSuppression.max_total_detections', index=3,
84 | number=5, type=5, cpp_type=1, label=1,
85 | has_default_value=True, default_value=100,
86 | message_type=None, enum_type=None, containing_type=None,
87 | is_extension=False, extension_scope=None,
88 | options=None),
89 | ],
90 | extensions=[
91 | ],
92 | nested_types=[],
93 | enum_types=[
94 | ],
95 | options=None,
96 | is_extendable=False,
97 | extension_ranges=[],
98 | oneofs=[
99 | ],
100 | serialized_start=75,
101 | serialized_end=229,
102 | )
103 |
104 |
105 | _POSTPROCESSING = _descriptor.Descriptor(
106 | name='PostProcessing',
107 | full_name='object_detection.protos.PostProcessing',
108 | filename=None,
109 | file=DESCRIPTOR,
110 | containing_type=None,
111 | fields=[
112 | _descriptor.FieldDescriptor(
113 | name='batch_non_max_suppression', full_name='object_detection.protos.PostProcessing.batch_non_max_suppression', index=0,
114 | number=1, type=11, cpp_type=10, label=1,
115 | has_default_value=False, default_value=None,
116 | message_type=None, enum_type=None, containing_type=None,
117 | is_extension=False, extension_scope=None,
118 | options=None),
119 | _descriptor.FieldDescriptor(
120 | name='score_converter', full_name='object_detection.protos.PostProcessing.score_converter', index=1,
121 | number=2, type=14, cpp_type=8, label=1,
122 | has_default_value=True, default_value=0,
123 | message_type=None, enum_type=None, containing_type=None,
124 | is_extension=False, extension_scope=None,
125 | options=None),
126 | _descriptor.FieldDescriptor(
127 | name='logit_scale', full_name='object_detection.protos.PostProcessing.logit_scale', index=2,
128 | number=3, type=2, cpp_type=6, label=1,
129 | has_default_value=True, default_value=1,
130 | message_type=None, enum_type=None, containing_type=None,
131 | is_extension=False, extension_scope=None,
132 | options=None),
133 | ],
134 | extensions=[
135 | ],
136 | nested_types=[],
137 | enum_types=[
138 | _POSTPROCESSING_SCORECONVERTER,
139 | ],
140 | options=None,
141 | is_extendable=False,
142 | extension_ranges=[],
143 | oneofs=[
144 | ],
145 | serialized_start=232,
146 | serialized_end=505,
147 | )
148 |
149 | _POSTPROCESSING.fields_by_name['batch_non_max_suppression'].message_type = _BATCHNONMAXSUPPRESSION
150 | _POSTPROCESSING.fields_by_name['score_converter'].enum_type = _POSTPROCESSING_SCORECONVERTER
151 | _POSTPROCESSING_SCORECONVERTER.containing_type = _POSTPROCESSING
152 | DESCRIPTOR.message_types_by_name['BatchNonMaxSuppression'] = _BATCHNONMAXSUPPRESSION
153 | DESCRIPTOR.message_types_by_name['PostProcessing'] = _POSTPROCESSING
154 |
155 | BatchNonMaxSuppression = _reflection.GeneratedProtocolMessageType('BatchNonMaxSuppression', (_message.Message,), dict(
156 | DESCRIPTOR = _BATCHNONMAXSUPPRESSION,
157 | __module__ = 'object_detection.protos.post_processing_pb2'
158 | # @@protoc_insertion_point(class_scope:object_detection.protos.BatchNonMaxSuppression)
159 | ))
160 | _sym_db.RegisterMessage(BatchNonMaxSuppression)
161 |
162 | PostProcessing = _reflection.GeneratedProtocolMessageType('PostProcessing', (_message.Message,), dict(
163 | DESCRIPTOR = _POSTPROCESSING,
164 | __module__ = 'object_detection.protos.post_processing_pb2'
165 | # @@protoc_insertion_point(class_scope:object_detection.protos.PostProcessing)
166 | ))
167 | _sym_db.RegisterMessage(PostProcessing)
168 |
169 |
170 | # @@protoc_insertion_point(module_scope)
171 |
--------------------------------------------------------------------------------