├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── .gitkeep ├── pascal_label_map.pbtxt ├── raw_data │ └── .gitkeep ├── tfrecords │ └── .gitkeep ├── train │ ├── .gitkeep │ ├── jpg │ │ └── .gitkeep │ └── xml │ │ └── .gitkeep └── validate │ ├── .gitkeep │ ├── jpg │ └── .gitkeep │ └── xml │ └── .gitkeep ├── inference.py ├── model_file ├── .gitkeep ├── frozen_pb │ └── .gitkeep ├── pretrain │ └── .gitkeep └── train │ └── .gitkeep ├── model_train.py ├── object_detection ├── CONTRIBUTING.md ├── README.md ├── __init__.py ├── anchor_generators │ ├── __init__.py │ ├── flexible_grid_anchor_generator.py │ ├── flexible_grid_anchor_generator_test.py │ ├── grid_anchor_generator.py │ ├── grid_anchor_generator_test.py │ ├── multiple_grid_anchor_generator.py │ ├── multiple_grid_anchor_generator_test.py │ ├── multiscale_grid_anchor_generator.py │ └── multiscale_grid_anchor_generator_test.py ├── box_coders │ ├── __init__.py │ ├── faster_rcnn_box_coder.py │ ├── faster_rcnn_box_coder_test.py │ ├── keypoint_box_coder.py │ ├── keypoint_box_coder_test.py │ ├── mean_stddev_box_coder.py │ ├── mean_stddev_box_coder_test.py │ ├── square_box_coder.py │ └── square_box_coder_test.py ├── builders │ ├── __init__.py │ ├── anchor_generator_builder.py │ ├── anchor_generator_builder_test.py │ ├── box_coder_builder.py │ ├── box_coder_builder_test.py │ ├── box_predictor_builder.py │ ├── box_predictor_builder_test.py │ ├── calibration_builder.py │ ├── calibration_builder_test.py │ ├── dataset_builder.py │ ├── dataset_builder_test.py │ ├── graph_rewriter_builder.py │ ├── graph_rewriter_builder_test.py │ ├── hyperparams_builder.py │ ├── hyperparams_builder_test.py │ ├── image_resizer_builder.py │ ├── image_resizer_builder_test.py │ ├── input_reader_builder.py │ ├── input_reader_builder_test.py │ ├── losses_builder.py │ ├── losses_builder_test.py │ ├── matcher_builder.py │ ├── matcher_builder_test.py │ ├── model_builder.py │ ├── model_builder_test.py │ ├── optimizer_builder.py │ ├── optimizer_builder_test.py │ ├── post_processing_builder.py │ ├── post_processing_builder_test.py │ ├── preprocessor_builder.py │ ├── preprocessor_builder_test.py │ ├── region_similarity_calculator_builder.py │ ├── region_similarity_calculator_builder_test.py │ ├── target_assigner_builder.py │ └── target_assigner_builder_test.py ├── core │ ├── __init__.py │ ├── anchor_generator.py │ ├── balanced_positive_negative_sampler.py │ ├── balanced_positive_negative_sampler_test.py │ ├── batch_multiclass_nms_test.py │ ├── batcher.py │ ├── batcher_test.py │ ├── box_coder.py │ ├── box_coder_test.py │ ├── box_list.py │ ├── box_list_ops.py │ ├── box_list_ops_test.py │ ├── box_list_test.py │ ├── box_predictor.py │ ├── class_agnostic_nms_test.py │ ├── data_decoder.py │ ├── data_parser.py │ ├── freezable_batch_norm.py │ ├── freezable_batch_norm_test.py │ ├── keypoint_ops.py │ ├── keypoint_ops_test.py │ ├── losses.py │ ├── losses_test.py │ ├── matcher.py │ ├── matcher_test.py │ ├── minibatch_sampler.py │ ├── minibatch_sampler_test.py │ ├── model.py │ ├── multiclass_nms_test.py │ ├── post_processing.py │ ├── prefetcher.py │ ├── prefetcher_test.py │ ├── preprocessor.py │ ├── preprocessor_cache.py │ ├── preprocessor_test.py │ ├── region_similarity_calculator.py │ ├── region_similarity_calculator_test.py │ ├── standard_fields.py │ ├── target_assigner.py │ └── target_assigner_test.py ├── data │ ├── ava_label_map_v2.1.pbtxt │ ├── face_label_map.pbtxt │ ├── fgvc_2854_classes_label_map.pbtxt │ ├── kitti_label_map.pbtxt │ ├── mscoco_complete_label_map.pbtxt │ ├── mscoco_label_map.pbtxt │ ├── oid_bbox_trainable_label_map.pbtxt │ ├── oid_object_detection_challenge_500_label_map.pbtxt │ ├── oid_v4_label_map.pbtxt │ ├── pascal_label_map.pbtxt │ └── pet_label_map.pbtxt ├── data_decoders │ ├── __init__.py │ ├── tf_example_decoder.py │ └── tf_example_decoder_test.py ├── dataset_tools │ ├── __init__.py │ ├── create_coco_tf_record.py │ ├── create_coco_tf_record_test.py │ ├── create_kitti_tf_record.py │ ├── create_kitti_tf_record_test.py │ ├── create_oid_tf_record.py │ ├── create_pascal_tf_record.py │ ├── create_pascal_tf_record_test.py │ ├── create_pet_tf_record.py │ ├── create_pycocotools_package.sh │ ├── download_and_preprocess_mscoco.sh │ ├── oid_hierarchical_labels_expansion.py │ ├── oid_hierarchical_labels_expansion_test.py │ ├── oid_tfrecord_creation.py │ ├── oid_tfrecord_creation_test.py │ ├── tf_record_creation_util.py │ └── tf_record_creation_util_test.py ├── dockerfiles │ └── android │ │ ├── Dockerfile │ │ └── README.md ├── eval_util.py ├── eval_util_test.py ├── export_inference_graph.py ├── export_tflite_ssd_graph.py ├── export_tflite_ssd_graph_lib.py ├── export_tflite_ssd_graph_lib_test.py ├── exporter.py ├── exporter_test.py ├── g3doc │ ├── challenge_evaluation.md │ ├── configuring_jobs.md │ ├── defining_your_own_model.md │ ├── detection_model_zoo.md │ ├── evaluation_protocols.md │ ├── exporting_models.md │ ├── faq.md │ ├── img │ │ ├── dataset_explorer.png │ │ ├── groupof_case_eval.png │ │ ├── kites_with_segment_overlay.png │ │ ├── nongroupof_case_eval.png │ │ ├── oxford_pet.png │ │ ├── tensorboard.png │ │ ├── tensorboard2.png │ │ └── tf-od-api-logo.png │ ├── installation.md │ ├── instance_segmentation.md │ ├── oid_inference_and_evaluation.md │ ├── preparing_inputs.md │ ├── running_locally.md │ ├── running_notebook.md │ ├── running_on_cloud.md │ ├── running_on_mobile_tensorflowlite.md │ ├── running_pets.md │ ├── tpu_compatibility.md │ ├── tpu_exporters.md │ └── using_your_own_dataset.md ├── inference │ ├── __init__.py │ ├── detection_inference.py │ ├── detection_inference_test.py │ └── infer_detections.py ├── inputs.py ├── inputs_test.py ├── legacy │ ├── __init__.py │ ├── eval.py │ ├── evaluator.py │ ├── train.py │ ├── trainer.py │ └── trainer_test.py ├── matchers │ ├── __init__.py │ ├── argmax_matcher.py │ ├── argmax_matcher_test.py │ ├── bipartite_matcher.py │ └── bipartite_matcher_test.py ├── meta_architectures │ ├── __init__.py │ ├── faster_rcnn_meta_arch.py │ ├── faster_rcnn_meta_arch_test.py │ ├── faster_rcnn_meta_arch_test_lib.py │ ├── rfcn_meta_arch.py │ ├── rfcn_meta_arch_test.py │ ├── ssd_meta_arch.py │ ├── ssd_meta_arch_test.py │ └── ssd_meta_arch_test_lib.py ├── metrics │ ├── __init__.py │ ├── calibration_evaluation.py │ ├── calibration_evaluation_test.py │ ├── calibration_metrics.py │ ├── calibration_metrics_test.py │ ├── coco_evaluation.py │ ├── coco_evaluation_test.py │ ├── coco_tools.py │ ├── coco_tools_test.py │ ├── io_utils.py │ ├── offline_eval_map_corloc.py │ ├── offline_eval_map_corloc_test.py │ ├── oid_challenge_evaluation.py │ ├── oid_challenge_evaluation_utils.py │ ├── oid_challenge_evaluation_utils_test.py │ ├── oid_vrd_challenge_evaluation.py │ ├── oid_vrd_challenge_evaluation_utils.py │ ├── oid_vrd_challenge_evaluation_utils_test.py │ ├── tf_example_parser.py │ └── tf_example_parser_test.py ├── model_hparams.py ├── model_lib.py ├── model_lib_test.py ├── model_lib_v2.py ├── model_lib_v2_test.py ├── model_main.py ├── model_tpu_main.py ├── models │ ├── __init__.py │ ├── embedded_ssd_mobilenet_v1_feature_extractor.py │ ├── embedded_ssd_mobilenet_v1_feature_extractor_test.py │ ├── faster_rcnn_inception_resnet_v2_feature_extractor.py │ ├── faster_rcnn_inception_resnet_v2_feature_extractor_test.py │ ├── faster_rcnn_inception_resnet_v2_keras_feature_extractor.py │ ├── faster_rcnn_inception_resnet_v2_keras_feature_extractor_test.py │ ├── faster_rcnn_inception_v2_feature_extractor.py │ ├── faster_rcnn_inception_v2_feature_extractor_test.py │ ├── faster_rcnn_mobilenet_v1_feature_extractor.py │ ├── faster_rcnn_mobilenet_v1_feature_extractor_test.py │ ├── faster_rcnn_nas_feature_extractor.py │ ├── faster_rcnn_nas_feature_extractor_test.py │ ├── faster_rcnn_pnas_feature_extractor.py │ ├── faster_rcnn_pnas_feature_extractor_test.py │ ├── faster_rcnn_resnet_v1_feature_extractor.py │ ├── faster_rcnn_resnet_v1_feature_extractor_test.py │ ├── feature_map_generators.py │ ├── feature_map_generators_test.py │ ├── keras_models │ │ ├── __init__.py │ │ ├── base_models │ │ │ └── original_mobilenet_v2.py │ │ ├── inception_resnet_v2.py │ │ ├── inception_resnet_v2_test.py │ │ ├── mobilenet_v1.py │ │ ├── mobilenet_v1_test.py │ │ ├── mobilenet_v2.py │ │ ├── mobilenet_v2_test.py │ │ ├── model_utils.py │ │ ├── resnet_v1.py │ │ ├── resnet_v1_test.py │ │ └── test_utils.py │ ├── ssd_feature_extractor_test.py │ ├── ssd_inception_v2_feature_extractor.py │ ├── ssd_inception_v2_feature_extractor_test.py │ ├── ssd_inception_v3_feature_extractor.py │ ├── ssd_inception_v3_feature_extractor_test.py │ ├── ssd_mobilenet_edgetpu_feature_extractor.py │ ├── ssd_mobilenet_edgetpu_feature_extractor_test.py │ ├── ssd_mobilenet_edgetpu_feature_extractor_testbase.py │ ├── ssd_mobilenet_v1_feature_extractor.py │ ├── ssd_mobilenet_v1_feature_extractor_test.py │ ├── ssd_mobilenet_v1_fpn_feature_extractor.py │ ├── ssd_mobilenet_v1_fpn_feature_extractor_test.py │ ├── ssd_mobilenet_v1_fpn_keras_feature_extractor.py │ ├── ssd_mobilenet_v1_keras_feature_extractor.py │ ├── ssd_mobilenet_v1_ppn_feature_extractor.py │ ├── ssd_mobilenet_v1_ppn_feature_extractor_test.py │ ├── ssd_mobilenet_v2_feature_extractor.py │ ├── ssd_mobilenet_v2_feature_extractor_test.py │ ├── ssd_mobilenet_v2_fpn_feature_extractor.py │ ├── ssd_mobilenet_v2_fpn_feature_extractor_test.py │ ├── ssd_mobilenet_v2_fpn_keras_feature_extractor.py │ ├── ssd_mobilenet_v2_keras_feature_extractor.py │ ├── ssd_mobilenet_v3_feature_extractor.py │ ├── ssd_mobilenet_v3_feature_extractor_test.py │ ├── ssd_mobilenet_v3_feature_extractor_testbase.py │ ├── ssd_pnasnet_feature_extractor.py │ ├── ssd_pnasnet_feature_extractor_test.py │ ├── ssd_resnet_v1_fpn_feature_extractor.py │ ├── ssd_resnet_v1_fpn_feature_extractor_test.py │ ├── ssd_resnet_v1_fpn_feature_extractor_testbase.py │ ├── ssd_resnet_v1_fpn_keras_feature_extractor.py │ ├── ssd_resnet_v1_ppn_feature_extractor.py │ ├── ssd_resnet_v1_ppn_feature_extractor_test.py │ └── ssd_resnet_v1_ppn_feature_extractor_testbase.py ├── predictors │ ├── __init__.py │ ├── convolutional_box_predictor.py │ ├── convolutional_box_predictor_test.py │ ├── convolutional_keras_box_predictor.py │ ├── convolutional_keras_box_predictor_test.py │ ├── heads │ │ ├── __init__.py │ │ ├── box_head.py │ │ ├── box_head_test.py │ │ ├── class_head.py │ │ ├── class_head_test.py │ │ ├── head.py │ │ ├── keras_box_head.py │ │ ├── keras_box_head_test.py │ │ ├── keras_class_head.py │ │ ├── keras_class_head_test.py │ │ ├── keras_mask_head.py │ │ ├── keras_mask_head_test.py │ │ ├── keypoint_head.py │ │ ├── keypoint_head_test.py │ │ ├── mask_head.py │ │ └── mask_head_test.py │ ├── mask_rcnn_box_predictor.py │ ├── mask_rcnn_box_predictor_test.py │ ├── mask_rcnn_keras_box_predictor.py │ ├── mask_rcnn_keras_box_predictor_test.py │ ├── rfcn_box_predictor.py │ ├── rfcn_box_predictor_test.py │ ├── rfcn_keras_box_predictor.py │ └── rfcn_keras_box_predictor_test.py ├── protos │ ├── __init__.py │ ├── anchor_generator_pb2.py │ ├── argmax_matcher_pb2.py │ ├── bipartite_matcher_pb2.py │ ├── box_coder_pb2.py │ ├── box_predictor_pb2.py │ ├── calibration_pb2.py │ ├── eval_pb2.py │ ├── faster_rcnn_box_coder_pb2.py │ ├── faster_rcnn_pb2.py │ ├── flexible_grid_anchor_generator_pb2.py │ ├── graph_rewriter_pb2.py │ ├── grid_anchor_generator_pb2.py │ ├── hyperparams_pb2.py │ ├── image_resizer_pb2.py │ ├── input_reader_pb2.py │ ├── keypoint_box_coder_pb2.py │ ├── losses_pb2.py │ ├── matcher_pb2.py │ ├── mean_stddev_box_coder_pb2.py │ ├── model_pb2.py │ ├── multiscale_anchor_generator_pb2.py │ ├── optimizer_pb2.py │ ├── pipeline_pb2.py │ ├── post_processing_pb2.py │ ├── preprocessor_pb2.py │ ├── region_similarity_calculator_pb2.py │ ├── square_box_coder_pb2.py │ ├── ssd_anchor_generator_pb2.py │ ├── ssd_pb2.py │ ├── string_int_label_map_pb2.py │ ├── target_assigner_pb2.py │ └── train_pb2.py ├── samples │ ├── cloud │ │ └── cloud.yml │ └── configs │ │ ├── embedded_ssd_mobilenet_v1_coco.config │ │ ├── facessd_mobilenet_v2_quantized_320x320_open_image_v4.config │ │ ├── faster_rcnn_inception_resnet_v2_atrous_coco.config │ │ ├── faster_rcnn_inception_resnet_v2_atrous_cosine_lr_coco.config │ │ ├── faster_rcnn_inception_resnet_v2_atrous_oid.config │ │ ├── faster_rcnn_inception_resnet_v2_atrous_oid_v4.config │ │ ├── faster_rcnn_inception_resnet_v2_atrous_pets.config │ │ ├── faster_rcnn_inception_v2_coco.config │ │ ├── faster_rcnn_inception_v2_pets.config │ │ ├── faster_rcnn_nas_coco.config │ │ ├── faster_rcnn_resnet101_atrous_coco.config │ │ ├── faster_rcnn_resnet101_ava_v2.1.config │ │ ├── faster_rcnn_resnet101_coco.config │ │ ├── faster_rcnn_resnet101_fgvc.config │ │ ├── faster_rcnn_resnet101_kitti.config │ │ ├── faster_rcnn_resnet101_pets.config │ │ ├── faster_rcnn_resnet101_voc07.config │ │ ├── faster_rcnn_resnet152_coco.config │ │ ├── faster_rcnn_resnet152_pets.config │ │ ├── faster_rcnn_resnet50_coco.config │ │ ├── faster_rcnn_resnet50_fgvc.config │ │ ├── faster_rcnn_resnet50_pets.config │ │ ├── mask_rcnn_inception_resnet_v2_atrous_coco.config │ │ ├── mask_rcnn_inception_v2_coco.config │ │ ├── mask_rcnn_resnet101_atrous_coco.config │ │ ├── mask_rcnn_resnet101_pets.config │ │ ├── mask_rcnn_resnet50_atrous_coco.config │ │ ├── rfcn_resnet101_coco.config │ │ ├── rfcn_resnet101_pets.config │ │ ├── ssd_inception_v2_coco.config │ │ ├── ssd_inception_v2_pets.config │ │ ├── ssd_inception_v3_pets.config │ │ ├── ssd_mobilenet_v1_0.75_depth_300x300_coco14_sync.config │ │ ├── ssd_mobilenet_v1_0.75_depth_quantized_300x300_coco14_sync.config │ │ ├── ssd_mobilenet_v1_0.75_depth_quantized_300x300_pets_sync.config │ │ ├── ssd_mobilenet_v1_300x300_coco14_sync.config │ │ ├── ssd_mobilenet_v1_coco.config │ │ ├── ssd_mobilenet_v1_focal_loss_pets.config │ │ ├── ssd_mobilenet_v1_focal_loss_pets_inference.config │ │ ├── ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync.config │ │ ├── ssd_mobilenet_v1_pets.config │ │ ├── ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync.config │ │ ├── ssd_mobilenet_v1_quantized_300x300_coco14_sync.config │ │ ├── ssd_mobilenet_v2_coco.config │ │ ├── ssd_mobilenet_v2_fpnlite_quantized_shared_box_predictor_256x256_depthmultiplier_75_coco14_sync.config │ │ ├── ssd_mobilenet_v2_fullyconv_coco.config │ │ ├── ssd_mobilenet_v2_oid_v4.config │ │ ├── ssd_mobilenet_v2_pets_keras.config │ │ ├── ssd_mobilenet_v2_quantized_300x300_coco.config │ │ ├── ssd_resnet101_v1_fpn_shared_box_predictor_oid_512x512_sync.config │ │ ├── ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync.config │ │ ├── ssdlite_mobilenet_edgetpu_320x320_coco.config │ │ ├── ssdlite_mobilenet_edgetpu_320x320_coco_quant.config │ │ ├── ssdlite_mobilenet_v1_coco.config │ │ ├── ssdlite_mobilenet_v2_coco.config │ │ ├── ssdlite_mobilenet_v3_large_320x320_coco.config │ │ └── ssdlite_mobilenet_v3_small_320x320_coco.config ├── tpu_exporters │ ├── __init__.py │ ├── export_saved_model_tpu.py │ ├── export_saved_model_tpu_lib.py │ ├── export_saved_model_tpu_lib_test.py │ ├── faster_rcnn.py │ ├── ssd.py │ ├── testdata │ │ ├── __init__.py │ │ ├── faster_rcnn │ │ │ └── faster_rcnn_resnet101_atrous_coco.config │ │ └── ssd │ │ │ └── ssd_pipeline.config │ ├── utils.py │ └── utils_test.py └── utils │ ├── __init__.py │ ├── autoaugment_utils.py │ ├── category_util.py │ ├── category_util_test.py │ ├── config_util.py │ ├── config_util_test.py │ ├── context_manager.py │ ├── context_manager_test.py │ ├── dataset_util.py │ ├── dataset_util_test.py │ ├── json_utils.py │ ├── json_utils_test.py │ ├── label_map_util.py │ ├── label_map_util_test.py │ ├── learning_schedules.py │ ├── learning_schedules_test.py │ ├── metrics.py │ ├── metrics_test.py │ ├── model_util.py │ ├── model_util_test.py │ ├── np_box_list.py │ ├── np_box_list_ops.py │ ├── np_box_list_ops_test.py │ ├── np_box_list_test.py │ ├── np_box_mask_list.py │ ├── np_box_mask_list_ops.py │ ├── np_box_mask_list_ops_test.py │ ├── np_box_mask_list_test.py │ ├── np_box_ops.py │ ├── np_box_ops_test.py │ ├── np_mask_ops.py │ ├── np_mask_ops_test.py │ ├── object_detection_evaluation.py │ ├── object_detection_evaluation_test.py │ ├── ops.py │ ├── ops_test.py │ ├── patch_ops.py │ ├── patch_ops_test.py │ ├── per_image_evaluation.py │ ├── per_image_evaluation_test.py │ ├── per_image_vrd_evaluation.py │ ├── per_image_vrd_evaluation_test.py │ ├── shape_utils.py │ ├── shape_utils_test.py │ ├── spatial_transform_ops.py │ ├── spatial_transform_ops_test.py │ ├── static_shape.py │ ├── static_shape_test.py │ ├── test_case.py │ ├── test_utils.py │ ├── test_utils_test.py │ ├── variables_helper.py │ ├── variables_helper_test.py │ ├── visualization_utils.py │ ├── visualization_utils_test.py │ ├── vrd_evaluation.py │ └── vrd_evaluation_test.py ├── slim ├── BUILD ├── PRINCIPLES.md ├── README.md ├── WORKSPACE ├── __init__.py ├── data │ ├── README.md │ ├── __init__.py │ ├── data_decoder.py │ ├── data_provider.py │ ├── dataset.py │ ├── dataset_data_provider.py │ ├── dataset_data_provider_test.py │ ├── parallel_reader.py │ ├── parallel_reader_test.py │ ├── prefetch_queue.py │ ├── prefetch_queue_test.py │ ├── test_utils.py │ ├── tfexample_decoder.py │ └── tfexample_decoder_test.py ├── datasets │ ├── __init__.py │ ├── build_imagenet_data.py │ ├── cifar10.py │ ├── dataset_factory.py │ ├── dataset_utils.py │ ├── download_and_convert_cifar10.py │ ├── download_and_convert_flowers.py │ ├── download_and_convert_imagenet.sh │ ├── download_and_convert_mnist.py │ ├── download_and_convert_visualwakewords.py │ ├── download_and_convert_visualwakewords_lib.py │ ├── download_imagenet.sh │ ├── flowers.py │ ├── imagenet.py │ ├── mnist.py │ ├── preprocess_imagenet_validation_data.py │ ├── process_bounding_boxes.py │ └── visualwakewords.py ├── deployment │ ├── __init__.py │ ├── model_deploy.py │ └── model_deploy_test.py ├── download_and_convert_data.py ├── eval_image_classifier.py ├── evaluation.py ├── evaluation_test.py ├── export_inference_graph.py ├── export_inference_graph_test.py ├── layers │ ├── __init__.py │ ├── bucketization_op.py │ ├── initializers.py │ ├── initializers_test.py │ ├── layers.py │ ├── layers_test.py │ ├── normalization.py │ ├── normalization_test.py │ ├── optimizers.py │ ├── optimizers_test.py │ ├── regularizers.py │ ├── regularizers_test.py │ ├── rev_block_lib.py │ ├── rev_block_lib_test.py │ ├── sparse_ops.py │ ├── sparse_ops_test.py │ ├── summaries.py │ ├── summaries_test.py │ ├── utils.py │ └── utils_test.py ├── learning.py ├── learning_test.py ├── losses │ ├── __init__.py │ ├── loss_ops.py │ ├── loss_ops_test.py │ ├── metric_learning.py │ └── metric_learning_test.py ├── metrics │ ├── __init__.py │ ├── classification.py │ ├── classification_test.py │ ├── histogram_ops.py │ ├── histogram_ops_test.py │ ├── metric_ops.py │ ├── metric_ops_large_test.py │ └── metric_ops_test.py ├── model_analyzer.py ├── nets │ ├── __init__.py │ ├── alexnet.py │ ├── alexnet_test.py │ ├── cifarnet.py │ ├── cyclegan.py │ ├── cyclegan_test.py │ ├── dcgan.py │ ├── dcgan_test.py │ ├── i3d.py │ ├── i3d_test.py │ ├── i3d_utils.py │ ├── inception.py │ ├── inception_resnet_v2.py │ ├── inception_resnet_v2_test.py │ ├── inception_utils.py │ ├── inception_v1.py │ ├── inception_v1_test.py │ ├── inception_v2.py │ ├── inception_v2_test.py │ ├── inception_v3.py │ ├── inception_v3_test.py │ ├── inception_v4.py │ ├── inception_v4_test.py │ ├── lenet.py │ ├── mobilenet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── conv_blocks.py │ │ ├── g3doc │ │ │ ├── edgetpu_latency.png │ │ │ ├── latency_pixel1.png │ │ │ └── madds_top1_accuracy.png │ │ ├── mnet_v1_vs_v2_pixel1_latency.png │ │ ├── mobilenet.py │ │ ├── mobilenet_example.ipynb │ │ ├── mobilenet_v2.py │ │ ├── mobilenet_v2_test.py │ │ ├── mobilenet_v3.py │ │ └── mobilenet_v3_test.py │ ├── mobilenet_v1.md │ ├── mobilenet_v1.png │ ├── mobilenet_v1.py │ ├── mobilenet_v1_eval.py │ ├── mobilenet_v1_test.py │ ├── mobilenet_v1_train.py │ ├── nasnet │ │ ├── README.md │ │ ├── __init__.py │ │ ├── nasnet.py │ │ ├── nasnet_test.py │ │ ├── nasnet_utils.py │ │ ├── nasnet_utils_test.py │ │ ├── pnasnet.py │ │ └── pnasnet_test.py │ ├── nets_factory.py │ ├── nets_factory_test.py │ ├── overfeat.py │ ├── overfeat_test.py │ ├── pix2pix.py │ ├── pix2pix_test.py │ ├── post_training_quantization.py │ ├── resnet_utils.py │ ├── resnet_v1.py │ ├── resnet_v1_test.py │ ├── resnet_v2.py │ ├── resnet_v2_test.py │ ├── s3dg.py │ ├── s3dg_test.py │ ├── vgg.py │ └── vgg_test.py ├── ops │ ├── __init__.py │ ├── arg_scope.py │ ├── arg_scope_test.py │ ├── framework_ops.py │ ├── framework_ops_test.py │ ├── variables.py │ └── variables_test.py ├── preprocessing │ ├── __init__.py │ ├── cifarnet_preprocessing.py │ ├── inception_preprocessing.py │ ├── lenet_preprocessing.py │ ├── preprocessing_factory.py │ └── vgg_preprocessing.py ├── queues.py ├── scripts │ ├── export_mobilenet.sh │ ├── finetune_inception_resnet_v2_on_flowers.sh │ ├── finetune_inception_v1_on_flowers.sh │ ├── finetune_inception_v3_on_flowers.sh │ ├── finetune_resnet_v1_50_on_flowers.sh │ ├── train_cifarnet_on_cifar10.sh │ └── train_lenet_on_mnist.sh ├── setup.py ├── slim_walkthrough.ipynb ├── summaries.py ├── summaries_test.py ├── train_image_classifier.py └── training │ ├── __init__.py │ ├── evaluation.py │ ├── evaluation_test.py │ ├── training.py │ └── training_test.py ├── ssd_resnet50_v1_fpn.config ├── tf_datatools ├── __init__.py ├── create_pascal_tf_record.py ├── pascal_label_map.pbtxt └── utils │ ├── __init__.py │ ├── dataset_util.py │ ├── label_map_util.py │ └── protos │ ├── __init__.py │ ├── anchor_generator_pb2.py │ ├── argmax_matcher_pb2.py │ ├── bipartite_matcher_pb2.py │ ├── box_coder_pb2.py │ ├── box_predictor_pb2.py │ ├── calibration_pb2.py │ ├── eval_pb2.py │ ├── faster_rcnn_box_coder_pb2.py │ ├── faster_rcnn_pb2.py │ ├── flexible_grid_anchor_generator_pb2.py │ ├── graph_rewriter_pb2.py │ ├── grid_anchor_generator_pb2.py │ ├── hyperparams_pb2.py │ ├── image_resizer_pb2.py │ ├── input_reader_pb2.py │ ├── keypoint_box_coder_pb2.py │ ├── losses_pb2.py │ ├── matcher_pb2.py │ ├── mean_stddev_box_coder_pb2.py │ ├── model_pb2.py │ ├── multiscale_anchor_generator_pb2.py │ ├── optimizer_pb2.py │ ├── pipeline_pb2.py │ ├── post_processing_pb2.py │ ├── preprocessor_pb2.py │ ├── region_similarity_calculator_pb2.py │ ├── square_box_coder_pb2.py │ ├── ssd_anchor_generator_pb2.py │ ├── ssd_pb2.py │ ├── string_int_label_map_pb2.py │ ├── target_assigner_pb2.py │ └── train_pb2.py └── tfrecord_generator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | .dmypy.json 113 | dmypy.json 114 | 115 | # Pyre type checker 116 | .pyre/ 117 | 118 | # Pycharm 119 | /.idea/ 120 | 121 | # MacOS 122 | .DS_Store 123 | 124 | *.xml 125 | *.jpg 126 | *.proto 127 | *.record 128 | *.txt 129 | /object_detection/test_ckpt/ 130 | /object_detection/test_data/ 131 | /object_detection/test_images/ 132 | *.meta 133 | *.index 134 | model_file/pretrain/checkpoint 135 | *.pb 136 | *.data-00000-of-00001 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sonar_baseline_with_tensorflow 2 | ## 项目说明 3 | 本项目用于[和鲸平台比赛](https://www.kesci.com/home/competition/5e535a612537a0002ca864ac),实践了用户 [Pumpkin](https://www.kesci.com/home/user/profile/5da7e869048089002c7d2f58) 分享的 [baseline](https://www.kesci.com/home/project/5e6331644b7a30002c98895e) 4 | 5 | ## 参考项目链接 6 | 1. [Protocol Buffer](https://github.com/protocolbuffers/protobuf.git) 7 | 2. [Tensorflow 开源模型](https://github.com/tensorflow/models.git) 8 | 3. [TF-slim](https://github.com/google-research/tf-slim.git) 9 | 10 | ## 环境 11 | - CUDA 10.0 12 | - cuDNN 7.6.0.64 13 | - Tensorflow 1.15.2 14 | - lxml 4.5.0 15 | - Pillow 7.0.0 16 | - matplotlib 3.2.0 17 | - Cython 0.29.15 18 | - pycocotools 2.0 19 | 20 | 21 | ## 使用方法 22 | 1. 运行 `tfrecord_generator.py`, 采用 `-path` 参数传入大赛数据集的**压缩包** 23 | 2. 运行 `model_train.py`, 采用 `-path` 参数传入预训练模型的文件夹地址 24 | 3. 运行 `inference.py`, 采用 `-step`指定希望被用于推理的训练步数, `-path`指定被推理图片放置的文件夹路径 -------------------------------------------------------------------------------- /dataset/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/.gitkeep -------------------------------------------------------------------------------- /dataset/pascal_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: '目标物' 4 | } -------------------------------------------------------------------------------- /dataset/raw_data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/raw_data/.gitkeep -------------------------------------------------------------------------------- /dataset/tfrecords/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/tfrecords/.gitkeep -------------------------------------------------------------------------------- /dataset/train/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/train/.gitkeep -------------------------------------------------------------------------------- /dataset/train/jpg/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/train/jpg/.gitkeep -------------------------------------------------------------------------------- /dataset/train/xml/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/train/xml/.gitkeep -------------------------------------------------------------------------------- /dataset/validate/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/validate/.gitkeep -------------------------------------------------------------------------------- /dataset/validate/jpg/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/validate/jpg/.gitkeep -------------------------------------------------------------------------------- /dataset/validate/xml/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/dataset/validate/xml/.gitkeep -------------------------------------------------------------------------------- /model_file/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/model_file/.gitkeep -------------------------------------------------------------------------------- /model_file/frozen_pb/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/model_file/frozen_pb/.gitkeep -------------------------------------------------------------------------------- /model_file/pretrain/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/model_file/pretrain/.gitkeep -------------------------------------------------------------------------------- /model_file/train/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/model_file/train/.gitkeep -------------------------------------------------------------------------------- /model_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import sys 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("-path", help="Specify the absolute path to the folder of pre-train model") 8 | args = parser.parse_args() 9 | 10 | if not hasattr(args, 'path'): 11 | print("Please specify the absolute path with \"-path\"") 12 | else: 13 | os.chdir(sys.path[0]) 14 | 15 | if len(os.listdir('model_file/pretrain')) > 1: 16 | pass 17 | else: 18 | shutil.copy(args.path, 'model_file/pretrain') 19 | os.system('pip install -e slim') 20 | os.system('python object_detection/legacy/train.py --train_dir=model_file/train ' + 21 | '--pipeline_config_path=ssd_resnet50_v1_fpn.config') -------------------------------------------------------------------------------- /object_detection/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to the Tensorflow Object Detection API 2 | 3 | Patches to Tensorflow Object Detection API are welcome! 4 | 5 | We require contributors to fill out either the individual or corporate 6 | Contributor License Agreement (CLA). 7 | 8 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). 9 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 10 | 11 | Please follow the 12 | [Tensorflow contributing guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md) 13 | when submitting pull requests. 14 | -------------------------------------------------------------------------------- /object_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/__init__.py -------------------------------------------------------------------------------- /object_detection/anchor_generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/anchor_generators/__init__.py -------------------------------------------------------------------------------- /object_detection/box_coders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/box_coders/__init__.py -------------------------------------------------------------------------------- /object_detection/box_coders/mean_stddev_box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Mean stddev box coder. 17 | 18 | This box coder use the following coding schema to encode boxes: 19 | rel_code = (box_corner - anchor_corner_mean) / anchor_corner_stddev. 20 | """ 21 | from object_detection.core import box_coder 22 | from object_detection.core import box_list 23 | 24 | 25 | class MeanStddevBoxCoder(box_coder.BoxCoder): 26 | """Mean stddev box coder.""" 27 | 28 | def __init__(self, stddev=0.01): 29 | """Constructor for MeanStddevBoxCoder. 30 | 31 | Args: 32 | stddev: The standard deviation used to encode and decode boxes. 33 | """ 34 | self._stddev = stddev 35 | 36 | @property 37 | def code_size(self): 38 | return 4 39 | 40 | def _encode(self, boxes, anchors): 41 | """Encode a box collection with respect to anchor collection. 42 | 43 | Args: 44 | boxes: BoxList holding N boxes to be encoded. 45 | anchors: BoxList of N anchors. 46 | 47 | Returns: 48 | a tensor representing N anchor-encoded boxes 49 | 50 | Raises: 51 | ValueError: if the anchors still have deprecated stddev field. 52 | """ 53 | box_corners = boxes.get() 54 | if anchors.has_field('stddev'): 55 | raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and " 56 | "should not be specified in the box list.") 57 | means = anchors.get() 58 | return (box_corners - means) / self._stddev 59 | 60 | def _decode(self, rel_codes, anchors): 61 | """Decode. 62 | 63 | Args: 64 | rel_codes: a tensor representing N anchor-encoded boxes. 65 | anchors: BoxList of anchors. 66 | 67 | Returns: 68 | boxes: BoxList holding N bounding boxes 69 | 70 | Raises: 71 | ValueError: if the anchors still have deprecated stddev field and expects 72 | the decode method to use stddev value from that field. 73 | """ 74 | means = anchors.get() 75 | if anchors.has_field('stddev'): 76 | raise ValueError("'stddev' is a parameter of MeanStddevBoxCoder and " 77 | "should not be specified in the box list.") 78 | box_corners = rel_codes * self._stddev + means 79 | return box_list.BoxList(box_corners) 80 | -------------------------------------------------------------------------------- /object_detection/box_coders/mean_stddev_box_coder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.box_coder.mean_stddev_boxcoder.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.box_coders import mean_stddev_box_coder 21 | from object_detection.core import box_list 22 | 23 | 24 | class MeanStddevBoxCoderTest(tf.test.TestCase): 25 | 26 | def testGetCorrectRelativeCodesAfterEncoding(self): 27 | box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]] 28 | boxes = box_list.BoxList(tf.constant(box_corners)) 29 | expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]] 30 | prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]]) 31 | priors = box_list.BoxList(prior_means) 32 | 33 | coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1) 34 | rel_codes = coder.encode(boxes, priors) 35 | with self.test_session() as sess: 36 | rel_codes_out = sess.run(rel_codes) 37 | self.assertAllClose(rel_codes_out, expected_rel_codes) 38 | 39 | def testGetCorrectBoxesAfterDecoding(self): 40 | rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]) 41 | expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]] 42 | prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]]) 43 | priors = box_list.BoxList(prior_means) 44 | 45 | coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1) 46 | decoded_boxes = coder.decode(rel_codes, priors) 47 | decoded_box_corners = decoded_boxes.get() 48 | with self.test_session() as sess: 49 | decoded_out = sess.run(decoded_box_corners) 50 | self.assertAllClose(decoded_out, expected_box_corners) 51 | 52 | 53 | if __name__ == '__main__': 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /object_detection/builders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/builders/__init__.py -------------------------------------------------------------------------------- /object_detection/builders/box_coder_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A function to build an object detection box coder from configuration.""" 17 | from object_detection.box_coders import faster_rcnn_box_coder 18 | from object_detection.box_coders import keypoint_box_coder 19 | from object_detection.box_coders import mean_stddev_box_coder 20 | from object_detection.box_coders import square_box_coder 21 | from object_detection.protos import box_coder_pb2 22 | 23 | 24 | def build(box_coder_config): 25 | """Builds a box coder object based on the box coder config. 26 | 27 | Args: 28 | box_coder_config: A box_coder.proto object containing the config for the 29 | desired box coder. 30 | 31 | Returns: 32 | BoxCoder based on the config. 33 | 34 | Raises: 35 | ValueError: On empty box coder proto. 36 | """ 37 | if not isinstance(box_coder_config, box_coder_pb2.BoxCoder): 38 | raise ValueError('box_coder_config not of type box_coder_pb2.BoxCoder.') 39 | 40 | if box_coder_config.WhichOneof('box_coder_oneof') == 'faster_rcnn_box_coder': 41 | return faster_rcnn_box_coder.FasterRcnnBoxCoder(scale_factors=[ 42 | box_coder_config.faster_rcnn_box_coder.y_scale, 43 | box_coder_config.faster_rcnn_box_coder.x_scale, 44 | box_coder_config.faster_rcnn_box_coder.height_scale, 45 | box_coder_config.faster_rcnn_box_coder.width_scale 46 | ]) 47 | if box_coder_config.WhichOneof('box_coder_oneof') == 'keypoint_box_coder': 48 | return keypoint_box_coder.KeypointBoxCoder( 49 | box_coder_config.keypoint_box_coder.num_keypoints, 50 | scale_factors=[ 51 | box_coder_config.keypoint_box_coder.y_scale, 52 | box_coder_config.keypoint_box_coder.x_scale, 53 | box_coder_config.keypoint_box_coder.height_scale, 54 | box_coder_config.keypoint_box_coder.width_scale 55 | ]) 56 | if (box_coder_config.WhichOneof('box_coder_oneof') == 57 | 'mean_stddev_box_coder'): 58 | return mean_stddev_box_coder.MeanStddevBoxCoder( 59 | stddev=box_coder_config.mean_stddev_box_coder.stddev) 60 | if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder': 61 | return square_box_coder.SquareBoxCoder(scale_factors=[ 62 | box_coder_config.square_box_coder.y_scale, 63 | box_coder_config.square_box_coder.x_scale, 64 | box_coder_config.square_box_coder.length_scale 65 | ]) 66 | raise ValueError('Empty box coder.') 67 | -------------------------------------------------------------------------------- /object_detection/builders/graph_rewriter_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions for quantized training and evaluation.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def build(graph_rewriter_config, is_training): 21 | """Returns a function that modifies default graph based on options. 22 | 23 | Args: 24 | graph_rewriter_config: graph_rewriter_pb2.GraphRewriter proto. 25 | is_training: whether in training of eval mode. 26 | """ 27 | def graph_rewrite_fn(): 28 | """Function to quantize weights and activation of the default graph.""" 29 | if (graph_rewriter_config.quantization.weight_bits != 8 or 30 | graph_rewriter_config.quantization.activation_bits != 8): 31 | raise ValueError('Only 8bit quantization is supported') 32 | 33 | # Quantize the graph by inserting quantize ops for weights and activations 34 | if is_training: 35 | tf.contrib.quantize.experimental_create_training_graph( 36 | input_graph=tf.get_default_graph(), 37 | quant_delay=graph_rewriter_config.quantization.delay 38 | ) 39 | else: 40 | tf.contrib.quantize.experimental_create_eval_graph( 41 | input_graph=tf.get_default_graph() 42 | ) 43 | 44 | tf.contrib.layers.summarize_collection('quant_vars') 45 | return graph_rewrite_fn 46 | -------------------------------------------------------------------------------- /object_detection/builders/graph_rewriter_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for graph_rewriter_builder.""" 16 | import mock 17 | import tensorflow as tf 18 | from object_detection.builders import graph_rewriter_builder 19 | from object_detection.protos import graph_rewriter_pb2 20 | 21 | 22 | class QuantizationBuilderTest(tf.test.TestCase): 23 | 24 | def testQuantizationBuilderSetsUpCorrectTrainArguments(self): 25 | with mock.patch.object( 26 | tf.contrib.quantize, 27 | 'experimental_create_training_graph') as mock_quant_fn: 28 | with mock.patch.object(tf.contrib.layers, 29 | 'summarize_collection') as mock_summarize_col: 30 | graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() 31 | graph_rewriter_proto.quantization.delay = 10 32 | graph_rewriter_proto.quantization.weight_bits = 8 33 | graph_rewriter_proto.quantization.activation_bits = 8 34 | graph_rewrite_fn = graph_rewriter_builder.build( 35 | graph_rewriter_proto, is_training=True) 36 | graph_rewrite_fn() 37 | _, kwargs = mock_quant_fn.call_args 38 | self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) 39 | self.assertEqual(kwargs['quant_delay'], 10) 40 | mock_summarize_col.assert_called_with('quant_vars') 41 | 42 | def testQuantizationBuilderSetsUpCorrectEvalArguments(self): 43 | with mock.patch.object(tf.contrib.quantize, 44 | 'experimental_create_eval_graph') as mock_quant_fn: 45 | with mock.patch.object(tf.contrib.layers, 46 | 'summarize_collection') as mock_summarize_col: 47 | graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() 48 | graph_rewriter_proto.quantization.delay = 10 49 | graph_rewrite_fn = graph_rewriter_builder.build( 50 | graph_rewriter_proto, is_training=False) 51 | graph_rewrite_fn() 52 | _, kwargs = mock_quant_fn.call_args 53 | self.assertEqual(kwargs['input_graph'], tf.get_default_graph()) 54 | mock_summarize_col.assert_called_with('quant_vars') 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.test.main() 59 | -------------------------------------------------------------------------------- /object_detection/builders/input_reader_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Input reader builder. 17 | 18 | Creates data sources for DetectionModels from an InputReader config. See 19 | input_reader.proto for options. 20 | 21 | Note: If users wishes to also use their own InputReaders with the Object 22 | Detection configuration framework, they should define their own builder function 23 | that wraps the build function. 24 | """ 25 | 26 | import tensorflow as tf 27 | 28 | from object_detection.data_decoders import tf_example_decoder 29 | from object_detection.protos import input_reader_pb2 30 | 31 | parallel_reader = tf.contrib.slim.parallel_reader 32 | 33 | 34 | def build(input_reader_config): 35 | """Builds a tensor dictionary based on the InputReader config. 36 | 37 | Args: 38 | input_reader_config: A input_reader_pb2.InputReader object. 39 | 40 | Returns: 41 | A tensor dict based on the input_reader_config. 42 | 43 | Raises: 44 | ValueError: On invalid input reader proto. 45 | ValueError: If no input paths are specified. 46 | """ 47 | if not isinstance(input_reader_config, input_reader_pb2.InputReader): 48 | raise ValueError('input_reader_config not of type ' 49 | 'input_reader_pb2.InputReader.') 50 | 51 | if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': 52 | config = input_reader_config.tf_record_input_reader 53 | if not config.input_path: 54 | raise ValueError('At least one input path must be specified in ' 55 | '`input_reader_config`.') 56 | _, string_tensor = parallel_reader.parallel_read( 57 | config.input_path[:], # Convert `RepeatedScalarContainer` to list. 58 | reader_class=tf.TFRecordReader, 59 | num_epochs=(input_reader_config.num_epochs 60 | if input_reader_config.num_epochs else None), 61 | num_readers=input_reader_config.num_readers, 62 | shuffle=input_reader_config.shuffle, 63 | dtypes=[tf.string, tf.string], 64 | capacity=input_reader_config.queue_capacity, 65 | min_after_dequeue=input_reader_config.min_after_dequeue) 66 | 67 | label_map_proto_file = None 68 | if input_reader_config.HasField('label_map_path'): 69 | label_map_proto_file = input_reader_config.label_map_path 70 | decoder = tf_example_decoder.TfExampleDecoder( 71 | load_instance_masks=input_reader_config.load_instance_masks, 72 | instance_mask_type=input_reader_config.mask_type, 73 | label_map_proto_file=label_map_proto_file) 74 | return decoder.decode(string_tensor) 75 | 76 | raise ValueError('Unsupported input_reader_config.') 77 | -------------------------------------------------------------------------------- /object_detection/builders/matcher_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A function to build an object detection matcher from configuration.""" 17 | 18 | from object_detection.matchers import argmax_matcher 19 | from object_detection.matchers import bipartite_matcher 20 | from object_detection.protos import matcher_pb2 21 | 22 | 23 | def build(matcher_config): 24 | """Builds a matcher object based on the matcher config. 25 | 26 | Args: 27 | matcher_config: A matcher.proto object containing the config for the desired 28 | Matcher. 29 | 30 | Returns: 31 | Matcher based on the config. 32 | 33 | Raises: 34 | ValueError: On empty matcher proto. 35 | """ 36 | if not isinstance(matcher_config, matcher_pb2.Matcher): 37 | raise ValueError('matcher_config not of type matcher_pb2.Matcher.') 38 | if matcher_config.WhichOneof('matcher_oneof') == 'argmax_matcher': 39 | matcher = matcher_config.argmax_matcher 40 | matched_threshold = unmatched_threshold = None 41 | if not matcher.ignore_thresholds: 42 | matched_threshold = matcher.matched_threshold 43 | unmatched_threshold = matcher.unmatched_threshold 44 | return argmax_matcher.ArgMaxMatcher( 45 | matched_threshold=matched_threshold, 46 | unmatched_threshold=unmatched_threshold, 47 | negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched, 48 | force_match_for_each_row=matcher.force_match_for_each_row, 49 | use_matmul_gather=matcher.use_matmul_gather) 50 | if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher': 51 | matcher = matcher_config.bipartite_matcher 52 | return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather) 53 | raise ValueError('Empty matcher.') 54 | -------------------------------------------------------------------------------- /object_detection/builders/region_similarity_calculator_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builder for region similarity calculators.""" 17 | 18 | from object_detection.core import region_similarity_calculator 19 | from object_detection.protos import region_similarity_calculator_pb2 20 | 21 | 22 | def build(region_similarity_calculator_config): 23 | """Builds region similarity calculator based on the configuration. 24 | 25 | Builds one of [IouSimilarity, IoaSimilarity, NegSqDistSimilarity] objects. See 26 | core/region_similarity_calculator.proto for details. 27 | 28 | Args: 29 | region_similarity_calculator_config: RegionSimilarityCalculator 30 | configuration proto. 31 | 32 | Returns: 33 | region_similarity_calculator: RegionSimilarityCalculator object. 34 | 35 | Raises: 36 | ValueError: On unknown region similarity calculator. 37 | """ 38 | 39 | if not isinstance( 40 | region_similarity_calculator_config, 41 | region_similarity_calculator_pb2.RegionSimilarityCalculator): 42 | raise ValueError( 43 | 'region_similarity_calculator_config not of type ' 44 | 'region_similarity_calculator_pb2.RegionsSimilarityCalculator') 45 | 46 | similarity_calculator = region_similarity_calculator_config.WhichOneof( 47 | 'region_similarity') 48 | if similarity_calculator == 'iou_similarity': 49 | return region_similarity_calculator.IouSimilarity() 50 | if similarity_calculator == 'ioa_similarity': 51 | return region_similarity_calculator.IoaSimilarity() 52 | if similarity_calculator == 'neg_sq_dist_similarity': 53 | return region_similarity_calculator.NegSqDistSimilarity() 54 | if similarity_calculator == 'thresholded_iou_similarity': 55 | return region_similarity_calculator.ThresholdedIouSimilarity( 56 | region_similarity_calculator_config.thresholded_iou_similarity 57 | .iou_threshold) 58 | 59 | raise ValueError('Unknown region similarity calculator.') 60 | -------------------------------------------------------------------------------- /object_detection/builders/region_similarity_calculator_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for region_similarity_calculator_builder.""" 17 | 18 | import tensorflow as tf 19 | 20 | from google.protobuf import text_format 21 | from object_detection.builders import region_similarity_calculator_builder 22 | from object_detection.core import region_similarity_calculator 23 | from object_detection.protos import region_similarity_calculator_pb2 as sim_calc_pb2 24 | 25 | 26 | class RegionSimilarityCalculatorBuilderTest(tf.test.TestCase): 27 | 28 | def testBuildIoaSimilarityCalculator(self): 29 | similarity_calc_text_proto = """ 30 | ioa_similarity { 31 | } 32 | """ 33 | similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator() 34 | text_format.Merge(similarity_calc_text_proto, similarity_calc_proto) 35 | similarity_calc = region_similarity_calculator_builder.build( 36 | similarity_calc_proto) 37 | self.assertTrue(isinstance(similarity_calc, 38 | region_similarity_calculator.IoaSimilarity)) 39 | 40 | def testBuildIouSimilarityCalculator(self): 41 | similarity_calc_text_proto = """ 42 | iou_similarity { 43 | } 44 | """ 45 | similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator() 46 | text_format.Merge(similarity_calc_text_proto, similarity_calc_proto) 47 | similarity_calc = region_similarity_calculator_builder.build( 48 | similarity_calc_proto) 49 | self.assertTrue(isinstance(similarity_calc, 50 | region_similarity_calculator.IouSimilarity)) 51 | 52 | def testBuildNegSqDistSimilarityCalculator(self): 53 | similarity_calc_text_proto = """ 54 | neg_sq_dist_similarity { 55 | } 56 | """ 57 | similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator() 58 | text_format.Merge(similarity_calc_text_proto, similarity_calc_proto) 59 | similarity_calc = region_similarity_calculator_builder.build( 60 | similarity_calc_proto) 61 | self.assertTrue(isinstance(similarity_calc, 62 | region_similarity_calculator. 63 | NegSqDistSimilarity)) 64 | 65 | 66 | if __name__ == '__main__': 67 | tf.test.main() 68 | -------------------------------------------------------------------------------- /object_detection/builders/target_assigner_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A function to build an object detection box coder from configuration.""" 17 | from object_detection.builders import box_coder_builder 18 | from object_detection.builders import matcher_builder 19 | from object_detection.builders import region_similarity_calculator_builder 20 | from object_detection.core import target_assigner 21 | 22 | 23 | def build(target_assigner_config): 24 | """Builds a TargetAssigner object based on the config. 25 | 26 | Args: 27 | target_assigner_config: A target_assigner proto message containing config 28 | for the desired target assigner. 29 | 30 | Returns: 31 | TargetAssigner object based on the config. 32 | """ 33 | matcher_instance = matcher_builder.build(target_assigner_config.matcher) 34 | similarity_calc_instance = region_similarity_calculator_builder.build( 35 | target_assigner_config.similarity_calculator) 36 | box_coder = box_coder_builder.build(target_assigner_config.box_coder) 37 | return target_assigner.TargetAssigner( 38 | matcher=matcher_instance, 39 | similarity_calc=similarity_calc_instance, 40 | box_coder_instance=box_coder) 41 | -------------------------------------------------------------------------------- /object_detection/builders/target_assigner_builder_test.py: -------------------------------------------------------------------------------- 1 | """Tests for google3.third_party.tensorflow_models.object_detection.builders.target_assigner_builder.""" 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | import tensorflow as tf 18 | 19 | from google.protobuf import text_format 20 | 21 | 22 | from object_detection.builders import target_assigner_builder 23 | from object_detection.core import target_assigner 24 | from object_detection.protos import target_assigner_pb2 25 | 26 | 27 | class TargetAssignerBuilderTest(tf.test.TestCase): 28 | 29 | def test_build_a_target_assigner(self): 30 | target_assigner_text_proto = """ 31 | matcher { 32 | argmax_matcher {matched_threshold: 0.5} 33 | } 34 | similarity_calculator { 35 | iou_similarity {} 36 | } 37 | box_coder { 38 | faster_rcnn_box_coder {} 39 | } 40 | """ 41 | target_assigner_proto = target_assigner_pb2.TargetAssigner() 42 | text_format.Merge(target_assigner_text_proto, target_assigner_proto) 43 | target_assigner_instance = target_assigner_builder.build( 44 | target_assigner_proto) 45 | self.assertIsInstance(target_assigner_instance, 46 | target_assigner.TargetAssigner) 47 | 48 | 49 | if __name__ == '__main__': 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /object_detection/core/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /object_detection/core/box_coder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.core.box_coder.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.core import box_coder 21 | from object_detection.core import box_list 22 | 23 | 24 | class MockBoxCoder(box_coder.BoxCoder): 25 | """Test BoxCoder that encodes/decodes using the multiply-by-two function.""" 26 | 27 | def code_size(self): 28 | return 4 29 | 30 | def _encode(self, boxes, anchors): 31 | return 2.0 * boxes.get() 32 | 33 | def _decode(self, rel_codes, anchors): 34 | return box_list.BoxList(rel_codes / 2.0) 35 | 36 | 37 | class BoxCoderTest(tf.test.TestCase): 38 | 39 | def test_batch_decode(self): 40 | mock_anchor_corners = tf.constant( 41 | [[0, 0.1, 0.2, 0.3], [0.2, 0.4, 0.4, 0.6]], tf.float32) 42 | mock_anchors = box_list.BoxList(mock_anchor_corners) 43 | mock_box_coder = MockBoxCoder() 44 | 45 | expected_boxes = [[[0.0, 0.1, 0.5, 0.6], [0.5, 0.6, 0.7, 0.8]], 46 | [[0.1, 0.2, 0.3, 0.4], [0.7, 0.8, 0.9, 1.0]]] 47 | 48 | encoded_boxes_list = [mock_box_coder.encode( 49 | box_list.BoxList(tf.constant(boxes)), mock_anchors) 50 | for boxes in expected_boxes] 51 | encoded_boxes = tf.stack(encoded_boxes_list) 52 | decoded_boxes = box_coder.batch_decode( 53 | encoded_boxes, mock_box_coder, mock_anchors) 54 | 55 | with self.test_session() as sess: 56 | decoded_boxes_result = sess.run(decoded_boxes) 57 | self.assertAllClose(expected_boxes, decoded_boxes_result) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /object_detection/core/data_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Interface for data decoders. 17 | 18 | Data decoders decode the input data and return a dictionary of tensors keyed by 19 | the entries in core.reader.Fields. 20 | """ 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | from abc import ABCMeta 25 | from abc import abstractmethod 26 | import six 27 | 28 | 29 | class DataDecoder(six.with_metaclass(ABCMeta, object)): 30 | """Interface for data decoders.""" 31 | 32 | @abstractmethod 33 | def decode(self, data): 34 | """Return a single image and associated labels. 35 | 36 | Args: 37 | data: a string tensor holding a serialized protocol buffer corresponding 38 | to data for a single image. 39 | 40 | Returns: 41 | tensor_dict: a dictionary containing tensors. Possible keys are defined in 42 | reader.Fields. 43 | """ 44 | pass 45 | -------------------------------------------------------------------------------- /object_detection/core/data_parser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Interface for data parsers. 16 | 17 | Data parser parses input data and returns a dictionary of numpy arrays 18 | keyed by the entries in standard_fields.py. Since the parser parses records 19 | to numpy arrays (materialized tensors) directly, it is used to read data for 20 | evaluation/visualization; to parse the data during training, DataDecoder should 21 | be used. 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | from abc import ABCMeta 27 | from abc import abstractmethod 28 | import six 29 | 30 | 31 | class DataToNumpyParser(six.with_metaclass(ABCMeta, object)): 32 | """Abstract interface for data parser that produces numpy arrays.""" 33 | 34 | @abstractmethod 35 | def parse(self, input_data): 36 | """Parses input and returns a numpy array or a dictionary of numpy arrays. 37 | 38 | Args: 39 | input_data: an input data 40 | 41 | Returns: 42 | A numpy array or a dictionary of numpy arrays or None, if input 43 | cannot be parsed. 44 | """ 45 | pass 46 | -------------------------------------------------------------------------------- /object_detection/core/freezable_batch_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A freezable batch norm layer that uses Keras batch normalization.""" 17 | import tensorflow as tf 18 | 19 | 20 | class FreezableBatchNorm(tf.keras.layers.BatchNormalization): 21 | """Batch normalization layer (Ioffe and Szegedy, 2014). 22 | 23 | This is a `freezable` batch norm layer that supports setting the `training` 24 | parameter in the __init__ method rather than having to set it either via 25 | the Keras learning phase or via the `call` method parameter. This layer will 26 | forward all other parameters to the default Keras `BatchNormalization` 27 | layer 28 | 29 | This is class is necessary because Object Detection model training sometimes 30 | requires batch normalization layers to be `frozen` and used as if it was 31 | evaluation time, despite still training (and potentially using dropout layers) 32 | 33 | Like the default Keras BatchNormalization layer, this will normalize the 34 | activations of the previous layer at each batch, 35 | i.e. applies a transformation that maintains the mean activation 36 | close to 0 and the activation standard deviation close to 1. 37 | 38 | Arguments: 39 | training: If False, the layer will normalize using the moving average and 40 | std. dev, without updating the learned avg and std. dev. 41 | If None or True, the layer will follow the keras BatchNormalization layer 42 | strategy of checking the Keras learning phase at `call` time to decide 43 | what to do. 44 | **kwargs: The keyword arguments to forward to the keras BatchNormalization 45 | layer constructor. 46 | 47 | Input shape: 48 | Arbitrary. Use the keyword argument `input_shape` 49 | (tuple of integers, does not include the samples axis) 50 | when using this layer as the first layer in a model. 51 | 52 | Output shape: 53 | Same shape as input. 54 | 55 | References: 56 | - [Batch Normalization: Accelerating Deep Network Training by Reducing 57 | Internal Covariate Shift](https://arxiv.org/abs/1502.03167) 58 | """ 59 | 60 | def __init__(self, training=None, **kwargs): 61 | super(FreezableBatchNorm, self).__init__(**kwargs) 62 | self._training = training 63 | 64 | def call(self, inputs, training=None): 65 | # Override the call arg only if the batchnorm is frozen. (Ignore None) 66 | if self._training is False: # pylint: disable=g-bool-id-comparison 67 | training = self._training 68 | return super(FreezableBatchNorm, self).call(inputs, training=training) 69 | -------------------------------------------------------------------------------- /object_detection/core/prefetcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Provides functions to prefetch tensors to feed into models.""" 17 | import tensorflow as tf 18 | 19 | 20 | def prefetch(tensor_dict, capacity): 21 | """Creates a prefetch queue for tensors. 22 | 23 | Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a 24 | dequeue op that evaluates to a tensor_dict. This function is useful in 25 | prefetching preprocessed tensors so that the data is readily available for 26 | consumers. 27 | 28 | Example input pipeline when you don't need batching: 29 | ---------------------------------------------------- 30 | key, string_tensor = slim.parallel_reader.parallel_read(...) 31 | tensor_dict = decoder.decode(string_tensor) 32 | tensor_dict = preprocessor.preprocess(tensor_dict, ...) 33 | prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20) 34 | tensor_dict = prefetch_queue.dequeue() 35 | outputs = Model(tensor_dict) 36 | ... 37 | ---------------------------------------------------- 38 | 39 | For input pipelines with batching, refer to core/batcher.py 40 | 41 | Args: 42 | tensor_dict: a dictionary of tensors to prefetch. 43 | capacity: the size of the prefetch queue. 44 | 45 | Returns: 46 | a FIFO prefetcher queue 47 | """ 48 | names = list(tensor_dict.keys()) 49 | dtypes = [t.dtype for t in tensor_dict.values()] 50 | shapes = [t.get_shape() for t in tensor_dict.values()] 51 | prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes, 52 | shapes=shapes, 53 | names=names, 54 | name='prefetch_queue') 55 | enqueue_op = prefetch_queue.enqueue(tensor_dict) 56 | tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner( 57 | prefetch_queue, [enqueue_op])) 58 | tf.summary.scalar( 59 | 'queue/%s/fraction_of_%d_full' % (prefetch_queue.name, capacity), 60 | tf.cast(prefetch_queue.size(), dtype=tf.float32) * (1. / capacity)) 61 | return prefetch_queue 62 | -------------------------------------------------------------------------------- /object_detection/data/face_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "face" 3 | id: 1 4 | display_name: "face" 5 | } 6 | 7 | -------------------------------------------------------------------------------- /object_detection/data/kitti_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'car' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'pedestrian' 9 | } 10 | -------------------------------------------------------------------------------- /object_detection/data/pascal_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'aeroplane' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'bicycle' 9 | } 10 | 11 | item { 12 | id: 3 13 | name: 'bird' 14 | } 15 | 16 | item { 17 | id: 4 18 | name: 'boat' 19 | } 20 | 21 | item { 22 | id: 5 23 | name: 'bottle' 24 | } 25 | 26 | item { 27 | id: 6 28 | name: 'bus' 29 | } 30 | 31 | item { 32 | id: 7 33 | name: 'car' 34 | } 35 | 36 | item { 37 | id: 8 38 | name: 'cat' 39 | } 40 | 41 | item { 42 | id: 9 43 | name: 'chair' 44 | } 45 | 46 | item { 47 | id: 10 48 | name: 'cow' 49 | } 50 | 51 | item { 52 | id: 11 53 | name: 'diningtable' 54 | } 55 | 56 | item { 57 | id: 12 58 | name: 'dog' 59 | } 60 | 61 | item { 62 | id: 13 63 | name: 'horse' 64 | } 65 | 66 | item { 67 | id: 14 68 | name: 'motorbike' 69 | } 70 | 71 | item { 72 | id: 15 73 | name: 'person' 74 | } 75 | 76 | item { 77 | id: 16 78 | name: 'pottedplant' 79 | } 80 | 81 | item { 82 | id: 17 83 | name: 'sheep' 84 | } 85 | 86 | item { 87 | id: 18 88 | name: 'sofa' 89 | } 90 | 91 | item { 92 | id: 19 93 | name: 'train' 94 | } 95 | 96 | item { 97 | id: 20 98 | name: 'tvmonitor' 99 | } 100 | -------------------------------------------------------------------------------- /object_detection/data/pet_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'Abyssinian' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'american_bulldog' 9 | } 10 | 11 | item { 12 | id: 3 13 | name: 'american_pit_bull_terrier' 14 | } 15 | 16 | item { 17 | id: 4 18 | name: 'basset_hound' 19 | } 20 | 21 | item { 22 | id: 5 23 | name: 'beagle' 24 | } 25 | 26 | item { 27 | id: 6 28 | name: 'Bengal' 29 | } 30 | 31 | item { 32 | id: 7 33 | name: 'Birman' 34 | } 35 | 36 | item { 37 | id: 8 38 | name: 'Bombay' 39 | } 40 | 41 | item { 42 | id: 9 43 | name: 'boxer' 44 | } 45 | 46 | item { 47 | id: 10 48 | name: 'British_Shorthair' 49 | } 50 | 51 | item { 52 | id: 11 53 | name: 'chihuahua' 54 | } 55 | 56 | item { 57 | id: 12 58 | name: 'Egyptian_Mau' 59 | } 60 | 61 | item { 62 | id: 13 63 | name: 'english_cocker_spaniel' 64 | } 65 | 66 | item { 67 | id: 14 68 | name: 'english_setter' 69 | } 70 | 71 | item { 72 | id: 15 73 | name: 'german_shorthaired' 74 | } 75 | 76 | item { 77 | id: 16 78 | name: 'great_pyrenees' 79 | } 80 | 81 | item { 82 | id: 17 83 | name: 'havanese' 84 | } 85 | 86 | item { 87 | id: 18 88 | name: 'japanese_chin' 89 | } 90 | 91 | item { 92 | id: 19 93 | name: 'keeshond' 94 | } 95 | 96 | item { 97 | id: 20 98 | name: 'leonberger' 99 | } 100 | 101 | item { 102 | id: 21 103 | name: 'Maine_Coon' 104 | } 105 | 106 | item { 107 | id: 22 108 | name: 'miniature_pinscher' 109 | } 110 | 111 | item { 112 | id: 23 113 | name: 'newfoundland' 114 | } 115 | 116 | item { 117 | id: 24 118 | name: 'Persian' 119 | } 120 | 121 | item { 122 | id: 25 123 | name: 'pomeranian' 124 | } 125 | 126 | item { 127 | id: 26 128 | name: 'pug' 129 | } 130 | 131 | item { 132 | id: 27 133 | name: 'Ragdoll' 134 | } 135 | 136 | item { 137 | id: 28 138 | name: 'Russian_Blue' 139 | } 140 | 141 | item { 142 | id: 29 143 | name: 'saint_bernard' 144 | } 145 | 146 | item { 147 | id: 30 148 | name: 'samoyed' 149 | } 150 | 151 | item { 152 | id: 31 153 | name: 'scottish_terrier' 154 | } 155 | 156 | item { 157 | id: 32 158 | name: 'shiba_inu' 159 | } 160 | 161 | item { 162 | id: 33 163 | name: 'Siamese' 164 | } 165 | 166 | item { 167 | id: 34 168 | name: 'Sphynx' 169 | } 170 | 171 | item { 172 | id: 35 173 | name: 'staffordshire_bull_terrier' 174 | } 175 | 176 | item { 177 | id: 36 178 | name: 'wheaten_terrier' 179 | } 180 | 181 | item { 182 | id: 37 183 | name: 'yorkshire_terrier' 184 | } 185 | -------------------------------------------------------------------------------- /object_detection/data_decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/data_decoders/__init__.py -------------------------------------------------------------------------------- /object_detection/dataset_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/dataset_tools/__init__.py -------------------------------------------------------------------------------- /object_detection/dataset_tools/create_pycocotools_package.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download pycocotools and make package for CMLE jobs. 18 | # 19 | # usage: 20 | # bash object_detection/dataset_tools/create_pycocotools_package.sh \ 21 | # /tmp/pycocotools 22 | set -e 23 | 24 | if [ -z "$1" ]; then 25 | echo "usage create_pycocotools_package.sh [output dir]" 26 | exit 27 | fi 28 | 29 | # Create the output directory. 30 | OUTPUT_DIR="${1%/}" 31 | SCRATCH_DIR="${OUTPUT_DIR}/raw" 32 | mkdir -p "${OUTPUT_DIR}" 33 | mkdir -p "${SCRATCH_DIR}" 34 | 35 | cd ${SCRATCH_DIR} 36 | git clone https://github.com/cocodataset/cocoapi.git 37 | cd cocoapi/PythonAPI && mv ../common ./ 38 | 39 | sed "s/\.\.\/common/common/g" setup.py > setup.py.updated 40 | cp -f setup.py.updated setup.py 41 | rm setup.py.updated 42 | 43 | sed "s/\.\.\/common/common/g" pycocotools/_mask.pyx > _mask.pyx.updated 44 | cp -f _mask.pyx.updated pycocotools/_mask.pyx 45 | rm _mask.pyx.updated 46 | 47 | sed "s/import matplotlib\.pyplot as plt/import matplotlib;matplotlib\.use\(\'Agg\'\);import matplotlib\.pyplot as plt/g" pycocotools/coco.py > coco.py.updated 48 | cp -f coco.py.updated pycocotools/coco.py 49 | rm coco.py.updated 50 | 51 | cd "${OUTPUT_DIR}" 52 | tar -czf pycocotools-2.0.tar.gz -C "${SCRATCH_DIR}/cocoapi/" PythonAPI/ 53 | rm -rf ${SCRATCH_DIR} 54 | -------------------------------------------------------------------------------- /object_detection/dataset_tools/tf_record_creation_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Utilities for creating TFRecords of TF examples for the Open Images dataset. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def open_sharded_output_tfrecords(exit_stack, base_path, num_shards): 25 | """Opens all TFRecord shards for writing and adds them to an exit stack. 26 | 27 | Args: 28 | exit_stack: A context2.ExitStack used to automatically closed the TFRecords 29 | opened in this function. 30 | base_path: The base path for all shards 31 | num_shards: The number of shards 32 | 33 | Returns: 34 | The list of opened TFRecords. Position k in the list corresponds to shard k. 35 | """ 36 | tf_record_output_filenames = [ 37 | '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards) 38 | for idx in range(num_shards) 39 | ] 40 | 41 | tfrecords = [ 42 | exit_stack.enter_context(tf.python_io.TFRecordWriter(file_name)) 43 | for file_name in tf_record_output_filenames 44 | ] 45 | 46 | return tfrecords 47 | -------------------------------------------------------------------------------- /object_detection/dataset_tools/tf_record_creation_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tf_record_creation_util.py.""" 16 | 17 | import os 18 | import contextlib2 19 | import tensorflow as tf 20 | 21 | from object_detection.dataset_tools import tf_record_creation_util 22 | 23 | 24 | class OpenOutputTfrecordsTests(tf.test.TestCase): 25 | 26 | def test_sharded_tfrecord_writes(self): 27 | with contextlib2.ExitStack() as tf_record_close_stack: 28 | output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords( 29 | tf_record_close_stack, 30 | os.path.join(tf.test.get_temp_dir(), 'test.tfrec'), 10) 31 | for idx in range(10): 32 | output_tfrecords[idx].write('test_{}'.format(idx)) 33 | 34 | for idx in range(10): 35 | tf_record_path = '{}-{:05d}-of-00010'.format( 36 | os.path.join(tf.test.get_temp_dir(), 'test.tfrec'), idx) 37 | records = list(tf.python_io.tf_record_iterator(tf_record_path)) 38 | self.assertAllEqual(records, ['test_{}'.format(idx)]) 39 | 40 | 41 | if __name__ == '__main__': 42 | tf.test.main() 43 | -------------------------------------------------------------------------------- /object_detection/dockerfiles/android/README.md: -------------------------------------------------------------------------------- 1 | # Dockerfile for the TPU and TensorFlow Lite Object Detection tutorial 2 | 3 | This Docker image automates the setup involved with training 4 | object detection models on Google Cloud and building the Android TensorFlow Lite 5 | demo app. We recommend using this container if you decide to work through our 6 | tutorial on ["Training and serving a real-time mobile object detector in 7 | 30 minutes with Cloud TPUs"](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193), though of course it may be useful even if you would 8 | like to use the Object Detection API outside the context of the tutorial. 9 | 10 | A couple words of warning: 11 | 12 | 1. Docker containers do not have persistent storage. This means that any changes 13 | you make to files inside the container will not persist if you restart 14 | the container. When running through the tutorial, 15 | **do not close the container**. 16 | 2. To be able to deploy the [Android app]( 17 | https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android) 18 | (which you will build at the end of the tutorial), 19 | you will need to kill any instances of `adb` running on the host machine. You 20 | can accomplish this by closing all instances of Android Studio, and then 21 | running `adb kill-server`. 22 | 23 | You can install Docker by following the [instructions here]( 24 | https://docs.docker.com/install/). 25 | 26 | ## Running The Container 27 | 28 | From this directory, build the Dockerfile as follows (this takes a while): 29 | 30 | ``` 31 | docker build --tag detect-tf . 32 | ``` 33 | 34 | Run the container: 35 | 36 | ``` 37 | docker run --rm -it --privileged -p 6006:6006 detect-tf 38 | ``` 39 | 40 | When running the container, you will find yourself inside the `/tensorflow` 41 | directory, which is the path to the TensorFlow [source 42 | tree](https://github.com/tensorflow/tensorflow). 43 | 44 | ## Text Editing 45 | 46 | The tutorial also 47 | requires you to occasionally edit files inside the source tree. 48 | This Docker images comes with `vim`, `nano`, and `emacs` preinstalled for your 49 | convenience. 50 | 51 | ## What's In This Container 52 | 53 | This container is derived from the nightly build of TensorFlow, and contains the 54 | sources for TensorFlow at `/tensorflow`, as well as the 55 | [TensorFlow Models](https://github.com/tensorflow/models) which are available at 56 | `/tensorflow/models` (and contain the Object Detection API as a subdirectory 57 | at `/tensorflow/models/research/object_detection`). 58 | The Oxford-IIIT Pets dataset, the COCO pre-trained SSD + MobileNet (v1) 59 | checkpoint, and example 60 | trained model are all available in `/tmp` in their respective folders. 61 | 62 | This container also has the `gsutil` and `gcloud` utilities, the `bazel` build 63 | tool, and all dependencies necessary to use the Object Detection API, and 64 | compile and install the TensorFlow Lite Android demo app. 65 | 66 | At various points throughout the tutorial, you may see references to the 67 | *research directory*. This refers to the `research` folder within the 68 | models repository, located at 69 | `/tensorflow/models/resesarch`. 70 | -------------------------------------------------------------------------------- /object_detection/g3doc/exporting_models.md: -------------------------------------------------------------------------------- 1 | # Exporting a trained model for inference 2 | 3 | After your model has been trained, you should export it to a Tensorflow 4 | graph proto. A checkpoint will typically consist of three files: 5 | 6 | * model.ckpt-${CHECKPOINT_NUMBER}.data-00000-of-00001 7 | * model.ckpt-${CHECKPOINT_NUMBER}.index 8 | * model.ckpt-${CHECKPOINT_NUMBER}.meta 9 | 10 | After you've identified a candidate checkpoint to export, run the following 11 | command from tensorflow/models/research: 12 | 13 | ``` bash 14 | # From tensorflow/models/research/ 15 | INPUT_TYPE=image_tensor 16 | PIPELINE_CONFIG_PATH={path to pipeline config file} 17 | TRAINED_CKPT_PREFIX={path to model.ckpt} 18 | EXPORT_DIR={path to folder that will be used for export} 19 | python object_detection/export_inference_graph.py \ 20 | --input_type=${INPUT_TYPE} \ 21 | --pipeline_config_path=${PIPELINE_CONFIG_PATH} \ 22 | --trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \ 23 | --output_directory=${EXPORT_DIR} 24 | ``` 25 | 26 | NOTE: We are configuring our exported model to ingest 4-D image tensors. We can 27 | also configure the exported model to take encoded images or serialized 28 | `tf.Example`s. 29 | 30 | After export, you should see the directory ${EXPORT_DIR} containing the following: 31 | 32 | * saved_model/, a directory containing the saved model format of the exported model 33 | * frozen_inference_graph.pb, the frozen graph format of the exported model 34 | * model.ckpt.*, the model checkpoints used for exporting 35 | * checkpoint, a file specifying to restore included checkpoint files 36 | * pipeline.config, pipeline config file for the exported model 37 | -------------------------------------------------------------------------------- /object_detection/g3doc/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | 3 | ## Q: How can I ensure that all the groundtruth boxes are used during train and eval? 4 | A: For the object detecion framework to be TPU-complient, we must pad our input 5 | tensors to static shapes. This means that we must pad to a fixed number of 6 | bounding boxes, configured by `InputReader.max_number_of_boxes`. It is 7 | important to set this value to a number larger than the maximum number of 8 | groundtruth boxes in the dataset. If an image is encountered with more 9 | bounding boxes, the excess boxes will be clipped. 10 | 11 | ## Q: AttributeError: 'module' object has no attribute 'BackupHandler' 12 | A: This BackupHandler (tf.contrib.slim.tfexample_decoder.BackupHandler) was 13 | introduced in tensorflow 1.5.0 so runing with earlier versions may cause this 14 | issue. It now has been replaced by 15 | object_detection.data_decoders.tf_example_decoder.BackupHandler. Whoever sees 16 | this issue should be able to resolve it by syncing your fork to HEAD. 17 | Same for LookupTensor. 18 | 19 | ## Q: AttributeError: 'module' object has no attribute 'LookupTensor' 20 | A: Similar to BackupHandler, syncing your fork to HEAD should make it work. 21 | 22 | ## Q: Why can't I get the inference time as reported in model zoo? 23 | A: The inference time reported in model zoo is mean time of testing hundreds of 24 | images with an internal machine. As mentioned in 25 | [Tensorflow detection model zoo](detection_model_zoo.md), this speed depends 26 | highly on one's specific hardware configuration and should be treated more as 27 | relative timing. 28 | -------------------------------------------------------------------------------- /object_detection/g3doc/img/dataset_explorer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/dataset_explorer.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/groupof_case_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/groupof_case_eval.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/kites_with_segment_overlay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/kites_with_segment_overlay.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/nongroupof_case_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/nongroupof_case_eval.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/oxford_pet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/oxford_pet.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/tensorboard.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/tensorboard2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/tensorboard2.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/tf-od-api-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/g3doc/img/tf-od-api-logo.png -------------------------------------------------------------------------------- /object_detection/g3doc/preparing_inputs.md: -------------------------------------------------------------------------------- 1 | # Preparing Inputs 2 | 3 | Tensorflow Object Detection API reads data using the TFRecord file format. Two 4 | sample scripts (`create_pascal_tf_record.py` and `create_pet_tf_record.py`) are 5 | provided to convert from the PASCAL VOC dataset and Oxford-IIIT Pet dataset to 6 | TFRecords. 7 | 8 | ## Generating the PASCAL VOC TFRecord files. 9 | 10 | The raw 2012 PASCAL VOC data set is located 11 | [here](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar). 12 | To download, extract and convert it to TFRecords, run the following commands 13 | below: 14 | 15 | ```bash 16 | # From tensorflow/models/research/ 17 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 18 | tar -xvf VOCtrainval_11-May-2012.tar 19 | python object_detection/dataset_tools/create_pascal_tf_record.py \ 20 | --label_map_path=object_detection/data/pascal_label_map.pbtxt \ 21 | --data_dir=VOCdevkit --year=VOC2012 --set=train \ 22 | --output_path=pascal_train.record 23 | python object_detection/dataset_tools/create_pascal_tf_record.py \ 24 | --label_map_path=object_detection/data/pascal_label_map.pbtxt \ 25 | --data_dir=VOCdevkit --year=VOC2012 --set=val \ 26 | --output_path=pascal_val.record 27 | ``` 28 | 29 | You should end up with two TFRecord files named `pascal_train.record` and 30 | `pascal_val.record` in the `tensorflow/models/research/` directory. 31 | 32 | The label map for the PASCAL VOC data set can be found at 33 | `object_detection/data/pascal_label_map.pbtxt`. 34 | 35 | ## Generating the Oxford-IIIT Pet TFRecord files. 36 | 37 | The Oxford-IIIT Pet data set is located 38 | [here](http://www.robots.ox.ac.uk/~vgg/data/pets/). To download, extract and 39 | convert it to TFRecrods, run the following commands below: 40 | 41 | ```bash 42 | # From tensorflow/models/research/ 43 | wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz 44 | wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz 45 | tar -xvf annotations.tar.gz 46 | tar -xvf images.tar.gz 47 | python object_detection/dataset_tools/create_pet_tf_record.py \ 48 | --label_map_path=object_detection/data/pet_label_map.pbtxt \ 49 | --data_dir=`pwd` \ 50 | --output_dir=`pwd` 51 | ``` 52 | 53 | You should end up with two 10-sharded TFRecord files named 54 | `pet_faces_train.record-?????-of-00010` and 55 | `pet_faces_val.record-?????-of-00010` in the `tensorflow/models/research/` 56 | directory. 57 | 58 | The label map for the Pet dataset can be found at 59 | `object_detection/data/pet_label_map.pbtxt`. 60 | -------------------------------------------------------------------------------- /object_detection/g3doc/running_locally.md: -------------------------------------------------------------------------------- 1 | # Running Locally 2 | 3 | This page walks through the steps required to train an object detection model 4 | on a local machine. It assumes the reader has completed the 5 | following prerequisites: 6 | 7 | 1. The Tensorflow Object Detection API has been installed as documented in the 8 | [installation instructions](installation.md). This includes installing library 9 | dependencies, compiling the configuration protobufs and setting up the Python 10 | environment. 11 | 2. A valid data set has been created. See [this page](preparing_inputs.md) for 12 | instructions on how to generate a dataset for the PASCAL VOC challenge or the 13 | Oxford-IIIT Pet dataset. 14 | 3. A Object Detection pipeline configuration has been written. See 15 | [this page](configuring_jobs.md) for details on how to write a pipeline configuration. 16 | 17 | ## Recommended Directory Structure for Training and Evaluation 18 | 19 | ``` 20 | +data 21 | -label_map file 22 | -train TFRecord file 23 | -eval TFRecord file 24 | +models 25 | + model 26 | -pipeline config file 27 | +train 28 | +eval 29 | ``` 30 | 31 | ## Running the Training Job 32 | 33 | A local training job can be run with the following command: 34 | 35 | ```bash 36 | # From the tensorflow/models/research/ directory 37 | PIPELINE_CONFIG_PATH={path to pipeline config file} 38 | MODEL_DIR={path to model directory} 39 | NUM_TRAIN_STEPS=50000 40 | SAMPLE_1_OF_N_EVAL_EXAMPLES=1 41 | python object_detection/model_main.py \ 42 | --pipeline_config_path=${PIPELINE_CONFIG_PATH} \ 43 | --model_dir=${MODEL_DIR} \ 44 | --num_train_steps=${NUM_TRAIN_STEPS} \ 45 | --sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \ 46 | --alsologtostderr 47 | ``` 48 | 49 | where `${PIPELINE_CONFIG_PATH}` points to the pipeline config and 50 | `${MODEL_DIR}` points to the directory in which training checkpoints 51 | and events will be written to. Note that this binary will interleave both 52 | training and evaluation. 53 | 54 | ## Running Tensorboard 55 | 56 | Progress for training and eval jobs can be inspected using Tensorboard. If 57 | using the recommended directory structure, Tensorboard can be run using the 58 | following command: 59 | 60 | ```bash 61 | tensorboard --logdir=${MODEL_DIR} 62 | ``` 63 | 64 | where `${MODEL_DIR}` points to the directory that contains the 65 | train and eval directories. Please note it may take Tensorboard a couple minutes 66 | to populate with data. 67 | -------------------------------------------------------------------------------- /object_detection/g3doc/running_notebook.md: -------------------------------------------------------------------------------- 1 | # Quick Start: Jupyter notebook for off-the-shelf inference 2 | 3 | If you'd like to hit the ground running and run detection on a few example 4 | images right out of the box, we recommend trying out the Jupyter notebook demo. 5 | To run the Jupyter notebook, run the following command from 6 | `tensorflow/models/research/object_detection`: 7 | 8 | ``` 9 | # From tensorflow/models/research/object_detection 10 | jupyter notebook 11 | ``` 12 | 13 | The notebook should open in your favorite web browser. Click the 14 | [`object_detection_tutorial.ipynb`](../object_detection_tutorial.ipynb) link to 15 | open the demo. 16 | -------------------------------------------------------------------------------- /object_detection/g3doc/tpu_exporters.md: -------------------------------------------------------------------------------- 1 | # Object Detection TPU Inference Exporter 2 | 3 | This package contains SavedModel Exporter for TPU Inference of object detection 4 | models. 5 | 6 | ## Usage 7 | 8 | This Exporter is intended for users who have trained models with CPUs / GPUs, 9 | but would like to use them for inference on TPU without changing their code or 10 | re-training their models. 11 | 12 | Users are assumed to have: 13 | 14 | + `PIPELINE_CONFIG`: A pipeline_pb2.TrainEvalPipelineConfig config file; 15 | + `CHECKPOINT`: A model checkpoint trained on any device; 16 | 17 | and need to correctly set: 18 | 19 | + `EXPORT_DIR`: Path to export SavedModel; 20 | + `INPUT_PLACEHOLDER`: Name of input placeholder in model's signature_def_map; 21 | + `INPUT_TYPE`: Type of input node, which can be one of 'image_tensor', 22 | 'encoded_image_string_tensor', or 'tf_example'; 23 | + `USE_BFLOAT16`: Whether to use bfloat16 instead of float32 on TPU. 24 | 25 | The model can be exported with: 26 | 27 | ``` 28 | python object_detection/tpu_exporters/export_saved_model_tpu.py \ 29 | --pipeline_config_file= \ 30 | --ckpt_path= \ 31 | --export_dir= \ 32 | --input_placeholder_name= \ 33 | --input_type= \ 34 | --use_bfloat16= 35 | ``` 36 | -------------------------------------------------------------------------------- /object_detection/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/inference/__init__.py -------------------------------------------------------------------------------- /object_detection/legacy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/legacy/__init__.py -------------------------------------------------------------------------------- /object_detection/matchers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/matchers/__init__.py -------------------------------------------------------------------------------- /object_detection/matchers/bipartite_matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Bipartite matcher implementation.""" 17 | 18 | import tensorflow as tf 19 | 20 | from tensorflow.contrib.image.python.ops import image_ops 21 | from object_detection.core import matcher 22 | 23 | 24 | class GreedyBipartiteMatcher(matcher.Matcher): 25 | """Wraps a Tensorflow greedy bipartite matcher.""" 26 | 27 | def __init__(self, use_matmul_gather=False): 28 | """Constructs a Matcher. 29 | 30 | Args: 31 | use_matmul_gather: Force constructed match objects to use matrix 32 | multiplication based gather instead of standard tf.gather. 33 | (Default: False). 34 | """ 35 | super(GreedyBipartiteMatcher, self).__init__( 36 | use_matmul_gather=use_matmul_gather) 37 | 38 | def _match(self, similarity_matrix, valid_rows): 39 | """Bipartite matches a collection rows and columns. A greedy bi-partite. 40 | 41 | TODO(rathodv): Add num_valid_columns options to match only that many columns 42 | with all the rows. 43 | 44 | Args: 45 | similarity_matrix: Float tensor of shape [N, M] with pairwise similarity 46 | where higher values mean more similar. 47 | valid_rows: A boolean tensor of shape [N] indicating the rows that are 48 | valid. 49 | 50 | Returns: 51 | match_results: int32 tensor of shape [M] with match_results[i]=-1 52 | meaning that column i is not matched and otherwise that it is matched to 53 | row match_results[i]. 54 | """ 55 | valid_row_sim_matrix = tf.gather(similarity_matrix, 56 | tf.squeeze(tf.where(valid_rows), axis=-1)) 57 | invalid_row_sim_matrix = tf.gather( 58 | similarity_matrix, 59 | tf.squeeze(tf.where(tf.logical_not(valid_rows)), axis=-1)) 60 | similarity_matrix = tf.concat( 61 | [valid_row_sim_matrix, invalid_row_sim_matrix], axis=0) 62 | # Convert similarity matrix to distance matrix as tf.image.bipartite tries 63 | # to find minimum distance matches. 64 | distance_matrix = -1 * similarity_matrix 65 | num_valid_rows = tf.reduce_sum(tf.cast(valid_rows, dtype=tf.float32)) 66 | _, match_results = image_ops.bipartite_match( 67 | distance_matrix, num_valid_rows=num_valid_rows) 68 | match_results = tf.reshape(match_results, [-1]) 69 | match_results = tf.cast(match_results, tf.int32) 70 | return match_results 71 | -------------------------------------------------------------------------------- /object_detection/meta_architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/meta_architectures/__init__.py -------------------------------------------------------------------------------- /object_detection/meta_architectures/rfcn_meta_arch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.meta_architectures.rfcn_meta_arch.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib 21 | from object_detection.meta_architectures import rfcn_meta_arch 22 | 23 | 24 | class RFCNMetaArchTest( 25 | faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase): 26 | 27 | def _get_second_stage_box_predictor_text_proto( 28 | self, share_box_across_classes=False): 29 | del share_box_across_classes 30 | box_predictor_text_proto = """ 31 | rfcn_box_predictor { 32 | conv_hyperparams { 33 | op: CONV 34 | activation: NONE 35 | regularizer { 36 | l2_regularizer { 37 | weight: 0.0005 38 | } 39 | } 40 | initializer { 41 | variance_scaling_initializer { 42 | factor: 1.0 43 | uniform: true 44 | mode: FAN_AVG 45 | } 46 | } 47 | } 48 | } 49 | """ 50 | return box_predictor_text_proto 51 | 52 | def _get_model(self, box_predictor, **common_kwargs): 53 | return rfcn_meta_arch.RFCNMetaArch( 54 | second_stage_rfcn_box_predictor=box_predictor, **common_kwargs) 55 | 56 | def _get_box_classifier_features_shape(self, 57 | image_size, 58 | batch_size, 59 | max_num_proposals, 60 | initial_crop_size, 61 | maxpool_stride, 62 | num_features): 63 | return (batch_size, image_size, image_size, num_features) 64 | 65 | 66 | if __name__ == '__main__': 67 | tf.test.main() 68 | -------------------------------------------------------------------------------- /object_detection/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/metrics/__init__.py -------------------------------------------------------------------------------- /object_detection/metrics/io_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Common IO utils used in offline metric computation. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import csv 23 | 24 | 25 | def write_csv(fid, metrics): 26 | """Writes metrics key-value pairs to CSV file. 27 | 28 | Args: 29 | fid: File identifier of an opened file. 30 | metrics: A dictionary with metrics to be written. 31 | """ 32 | metrics_writer = csv.writer(fid, delimiter=',') 33 | for metric_name, metric_value in metrics.items(): 34 | metrics_writer.writerow([metric_name, str(metric_value)]) 35 | -------------------------------------------------------------------------------- /object_detection/metrics/offline_eval_map_corloc_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for utilities in offline_eval_map_corloc binary.""" 16 | 17 | import tensorflow as tf 18 | 19 | from object_detection.metrics import offline_eval_map_corloc as offline_eval 20 | 21 | 22 | class OfflineEvalMapCorlocTest(tf.test.TestCase): 23 | 24 | def test_generateShardedFilenames(self): 25 | test_filename = '/path/to/file' 26 | result = offline_eval._generate_sharded_filenames(test_filename) 27 | self.assertEqual(result, [test_filename]) 28 | 29 | test_filename = '/path/to/file-00000-of-00050' 30 | result = offline_eval._generate_sharded_filenames(test_filename) 31 | self.assertEqual(result, [test_filename]) 32 | 33 | result = offline_eval._generate_sharded_filenames('/path/to/@3.record') 34 | self.assertEqual(result, [ 35 | '/path/to/-00000-of-00003.record', '/path/to/-00001-of-00003.record', 36 | '/path/to/-00002-of-00003.record' 37 | ]) 38 | 39 | result = offline_eval._generate_sharded_filenames('/path/to/abc@3') 40 | self.assertEqual(result, [ 41 | '/path/to/abc-00000-of-00003', '/path/to/abc-00001-of-00003', 42 | '/path/to/abc-00002-of-00003' 43 | ]) 44 | 45 | result = offline_eval._generate_sharded_filenames('/path/to/@1') 46 | self.assertEqual(result, ['/path/to/-00000-of-00001']) 47 | 48 | def test_generateFilenames(self): 49 | test_filenames = ['/path/to/file', '/path/to/@3.record'] 50 | result = offline_eval._generate_filenames(test_filenames) 51 | self.assertEqual(result, [ 52 | '/path/to/file', '/path/to/-00000-of-00003.record', 53 | '/path/to/-00001-of-00003.record', '/path/to/-00002-of-00003.record' 54 | ]) 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.test.main() 59 | -------------------------------------------------------------------------------- /object_detection/model_hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Hyperparameters for the object detection model in TF.learn. 16 | 17 | This file consolidates and documents the hyperparameters used by the model. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import tensorflow as tf 25 | 26 | 27 | def create_hparams(hparams_overrides=None): 28 | """Returns hyperparameters, including any flag value overrides. 29 | 30 | Args: 31 | hparams_overrides: Optional hparams overrides, represented as a 32 | string containing comma-separated hparam_name=value pairs. 33 | 34 | Returns: 35 | The hyperparameters as a tf.HParams object. 36 | """ 37 | hparams = tf.contrib.training.HParams( 38 | # Whether a fine tuning checkpoint (provided in the pipeline config) 39 | # should be loaded for training. 40 | load_pretrained=True) 41 | # Override any of the preceding hyperparameter values. 42 | if hparams_overrides: 43 | hparams = hparams.parse(hparams_overrides) 44 | return hparams 45 | -------------------------------------------------------------------------------- /object_detection/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/models/__init__.py -------------------------------------------------------------------------------- /object_detection/models/keras_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/models/keras_models/__init__.py -------------------------------------------------------------------------------- /object_detection/models/keras_models/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils for Keras models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import collections 23 | import tensorflow as tf 24 | 25 | # This is to specify the custom config of model structures. For example, 26 | # ConvDefs(conv_name='conv_pw_12', filters=512) for Mobilenet V1 is to specify 27 | # the filters of the conv layer with name 'conv_pw_12' as 512.s 28 | ConvDefs = collections.namedtuple('ConvDefs', ['conv_name', 'filters']) 29 | 30 | 31 | def get_conv_def(conv_defs, layer_name): 32 | """Get the custom config for some layer of the model structure. 33 | 34 | Args: 35 | conv_defs: A named tuple to specify the custom config of the model 36 | network. See `ConvDefs` for details. 37 | layer_name: A string, the name of the layer to be customized. 38 | 39 | Returns: 40 | The number of filters for the layer, or `None` if there is no custom 41 | config for the requested layer. 42 | """ 43 | for conv_def in conv_defs: 44 | if layer_name == conv_def.conv_name: 45 | return conv_def.filters 46 | return None 47 | 48 | 49 | def input_layer(shape, placeholder_with_default): 50 | if tf.executing_eagerly(): 51 | return tf.keras.layers.Input(shape=shape) 52 | else: 53 | return tf.keras.layers.Input(tensor=placeholder_with_default) 54 | -------------------------------------------------------------------------------- /object_detection/models/ssd_mobilenet_edgetpu_feature_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """SSDFeatureExtractor for MobileNetEdgeTPU features.""" 16 | 17 | import tensorflow as tf 18 | 19 | from object_detection.models import ssd_mobilenet_v3_feature_extractor 20 | from nets.mobilenet import mobilenet_v3 21 | 22 | slim = tf.contrib.slim 23 | 24 | 25 | class SSDMobileNetEdgeTPUFeatureExtractor( 26 | ssd_mobilenet_v3_feature_extractor.SSDMobileNetV3FeatureExtractorBase): 27 | """MobileNetEdgeTPU feature extractor.""" 28 | 29 | def __init__(self, 30 | is_training, 31 | depth_multiplier, 32 | min_depth, 33 | pad_to_multiple, 34 | conv_hyperparams_fn, 35 | reuse_weights=None, 36 | use_explicit_padding=False, 37 | use_depthwise=False, 38 | override_base_feature_extractor_hyperparams=False, 39 | scope_name='MobilenetEdgeTPU'): 40 | super(SSDMobileNetEdgeTPUFeatureExtractor, self).__init__( 41 | conv_defs=mobilenet_v3.V3_EDGETPU, 42 | from_layer=['layer_18/expansion_output', 'layer_23'], 43 | is_training=is_training, 44 | depth_multiplier=depth_multiplier, 45 | min_depth=min_depth, 46 | pad_to_multiple=pad_to_multiple, 47 | conv_hyperparams_fn=conv_hyperparams_fn, 48 | reuse_weights=reuse_weights, 49 | use_explicit_padding=use_explicit_padding, 50 | use_depthwise=use_depthwise, 51 | override_base_feature_extractor_hyperparams=override_base_feature_extractor_hyperparams, 52 | scope_name=scope_name 53 | ) 54 | -------------------------------------------------------------------------------- /object_detection/models/ssd_mobilenet_edgetpu_feature_extractor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for ssd_mobilenet_edgetpu_feature_extractor.""" 16 | 17 | import tensorflow as tf 18 | 19 | from tensorflow.contrib import slim as contrib_slim 20 | from object_detection.models import ssd_mobilenet_edgetpu_feature_extractor 21 | from object_detection.models import ssd_mobilenet_edgetpu_feature_extractor_testbase 22 | 23 | slim = contrib_slim 24 | 25 | 26 | class SsdMobilenetEdgeTPUFeatureExtractorTest( 27 | ssd_mobilenet_edgetpu_feature_extractor_testbase 28 | ._SsdMobilenetEdgeTPUFeatureExtractorTestBase): 29 | 30 | def _get_input_sizes(self): 31 | """Return first two input feature map sizes.""" 32 | return [384, 192] 33 | 34 | def _create_feature_extractor(self, 35 | depth_multiplier, 36 | pad_to_multiple, 37 | use_explicit_padding=False, 38 | use_keras=False): 39 | """Constructs a new MobileNetEdgeTPU feature extractor. 40 | 41 | Args: 42 | depth_multiplier: float depth multiplier for feature extractor 43 | pad_to_multiple: the nearest multiple to zero pad the input height and 44 | width dimensions to. 45 | use_explicit_padding: use 'VALID' padding for convolutions, but prepad 46 | inputs so that the output dimensions are the same as if 'SAME' padding 47 | were used. 48 | use_keras: if True builds a keras-based feature extractor, if False builds 49 | a slim-based one. 50 | 51 | Returns: 52 | an ssd_meta_arch.SSDFeatureExtractor object. 53 | """ 54 | min_depth = 32 55 | return (ssd_mobilenet_edgetpu_feature_extractor 56 | .SSDMobileNetEdgeTPUFeatureExtractor( 57 | False, 58 | depth_multiplier, 59 | min_depth, 60 | pad_to_multiple, 61 | self.conv_hyperparams_fn, 62 | use_explicit_padding=use_explicit_padding)) 63 | 64 | 65 | if __name__ == '__main__': 66 | tf.test.main() 67 | -------------------------------------------------------------------------------- /object_detection/predictors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/predictors/__init__.py -------------------------------------------------------------------------------- /object_detection/predictors/heads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/predictors/heads/__init__.py -------------------------------------------------------------------------------- /object_detection/predictors/heads/head.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Base head class. 17 | 18 | All the different kinds of prediction heads in different models will inherit 19 | from this class. What is in common between all head classes is that they have a 20 | `predict` function that receives `features` as its first argument. 21 | 22 | How to add a new prediction head to an existing meta architecture? 23 | For example, how can we add a `3d shape` prediction head to Mask RCNN? 24 | 25 | We have to take the following steps to add a new prediction head to an 26 | existing meta arch: 27 | (a) Add a class for predicting the head. This class should inherit from the 28 | `Head` class below and have a `predict` function that receives the features 29 | and predicts the output. The output is always a tf.float32 tensor. 30 | (b) Add the head to the meta architecture. For example in case of Mask RCNN, 31 | go to box_predictor_builder and put in the logic for adding the new head to the 32 | Mask RCNN box predictor. 33 | (c) Add the logic for computing the loss for the new head. 34 | (d) Add the necessary metrics for the new head. 35 | (e) (optional) Add visualization for the new head. 36 | """ 37 | from abc import abstractmethod 38 | 39 | import tensorflow as tf 40 | 41 | 42 | class Head(object): 43 | """Mask RCNN head base class.""" 44 | 45 | def __init__(self): 46 | """Constructor.""" 47 | pass 48 | 49 | @abstractmethod 50 | def predict(self, features, num_predictions_per_location): 51 | """Returns the head's predictions. 52 | 53 | Args: 54 | features: A float tensor of features. 55 | num_predictions_per_location: Int containing number of predictions per 56 | location. 57 | 58 | Returns: 59 | A tf.float32 tensor. 60 | """ 61 | pass 62 | 63 | 64 | class KerasHead(tf.keras.Model): 65 | """Keras head base class.""" 66 | 67 | def call(self, features): 68 | """The Keras model call will delegate to the `_predict` method.""" 69 | return self._predict(features) 70 | 71 | @abstractmethod 72 | def _predict(self, features): 73 | """Returns the head's predictions. 74 | 75 | Args: 76 | features: A float tensor of features. 77 | 78 | Returns: 79 | A tf.float32 tensor. 80 | """ 81 | pass 82 | -------------------------------------------------------------------------------- /object_detection/predictors/heads/keypoint_head_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.predictors.heads.keypoint_head.""" 17 | import tensorflow as tf 18 | 19 | from google.protobuf import text_format 20 | from object_detection.builders import hyperparams_builder 21 | from object_detection.predictors.heads import keypoint_head 22 | from object_detection.protos import hyperparams_pb2 23 | from object_detection.utils import test_case 24 | 25 | 26 | class MaskRCNNKeypointHeadTest(test_case.TestCase): 27 | 28 | def _build_arg_scope_with_hyperparams(self, 29 | op_type=hyperparams_pb2.Hyperparams.FC): 30 | hyperparams = hyperparams_pb2.Hyperparams() 31 | hyperparams_text_proto = """ 32 | activation: NONE 33 | regularizer { 34 | l2_regularizer { 35 | } 36 | } 37 | initializer { 38 | truncated_normal_initializer { 39 | } 40 | } 41 | """ 42 | text_format.Merge(hyperparams_text_proto, hyperparams) 43 | hyperparams.op = op_type 44 | return hyperparams_builder.build(hyperparams, is_training=True) 45 | 46 | def test_prediction_size(self): 47 | keypoint_prediction_head = keypoint_head.MaskRCNNKeypointHead( 48 | conv_hyperparams_fn=self._build_arg_scope_with_hyperparams()) 49 | roi_pooled_features = tf.random_uniform( 50 | [64, 14, 14, 1024], minval=-2.0, maxval=2.0, dtype=tf.float32) 51 | prediction = keypoint_prediction_head.predict( 52 | features=roi_pooled_features, num_predictions_per_location=1) 53 | self.assertAllEqual([64, 1, 17, 56, 56], prediction.get_shape().as_list()) 54 | 55 | 56 | if __name__ == '__main__': 57 | tf.test.main() 58 | -------------------------------------------------------------------------------- /object_detection/predictors/rfcn_box_predictor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.predictors.rfcn_box_predictor.""" 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from google.protobuf import text_format 21 | from object_detection.builders import hyperparams_builder 22 | from object_detection.predictors import rfcn_box_predictor as box_predictor 23 | from object_detection.protos import hyperparams_pb2 24 | from object_detection.utils import test_case 25 | 26 | 27 | class RfcnBoxPredictorTest(test_case.TestCase): 28 | 29 | def _build_arg_scope_with_conv_hyperparams(self): 30 | conv_hyperparams = hyperparams_pb2.Hyperparams() 31 | conv_hyperparams_text_proto = """ 32 | regularizer { 33 | l2_regularizer { 34 | } 35 | } 36 | initializer { 37 | truncated_normal_initializer { 38 | } 39 | } 40 | """ 41 | text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) 42 | return hyperparams_builder.build(conv_hyperparams, is_training=True) 43 | 44 | def test_get_correct_box_encoding_and_class_prediction_shapes(self): 45 | 46 | def graph_fn(image_features, proposal_boxes): 47 | rfcn_box_predictor = box_predictor.RfcnBoxPredictor( 48 | is_training=False, 49 | num_classes=2, 50 | conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(), 51 | num_spatial_bins=[3, 3], 52 | depth=4, 53 | crop_size=[12, 12], 54 | box_code_size=4 55 | ) 56 | box_predictions = rfcn_box_predictor.predict( 57 | [image_features], num_predictions_per_location=[1], 58 | scope='BoxPredictor', 59 | proposal_boxes=proposal_boxes) 60 | box_encodings = tf.concat( 61 | box_predictions[box_predictor.BOX_ENCODINGS], axis=1) 62 | class_predictions_with_background = tf.concat( 63 | box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], 64 | axis=1) 65 | return (box_encodings, class_predictions_with_background) 66 | 67 | image_features = np.random.rand(4, 8, 8, 64).astype(np.float32) 68 | proposal_boxes = np.random.rand(4, 2, 4).astype(np.float32) 69 | (box_encodings, class_predictions_with_background) = self.execute( 70 | graph_fn, [image_features, proposal_boxes]) 71 | 72 | self.assertAllEqual(box_encodings.shape, [8, 1, 2, 4]) 73 | self.assertAllEqual(class_predictions_with_background.shape, [8, 1, 3]) 74 | 75 | 76 | if __name__ == '__main__': 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /object_detection/predictors/rfcn_keras_box_predictor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.predictors.rfcn_box_predictor.""" 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from google.protobuf import text_format 21 | from object_detection.builders import hyperparams_builder 22 | from object_detection.predictors import rfcn_keras_box_predictor as box_predictor 23 | from object_detection.protos import hyperparams_pb2 24 | from object_detection.utils import test_case 25 | 26 | 27 | class RfcnKerasBoxPredictorTest(test_case.TestCase): 28 | 29 | def _build_conv_hyperparams(self): 30 | conv_hyperparams = hyperparams_pb2.Hyperparams() 31 | conv_hyperparams_text_proto = """ 32 | regularizer { 33 | l2_regularizer { 34 | } 35 | } 36 | initializer { 37 | truncated_normal_initializer { 38 | } 39 | } 40 | """ 41 | text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams) 42 | return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams) 43 | 44 | def test_get_correct_box_encoding_and_class_prediction_shapes(self): 45 | 46 | def graph_fn(image_features, proposal_boxes): 47 | rfcn_box_predictor = box_predictor.RfcnKerasBoxPredictor( 48 | is_training=False, 49 | num_classes=2, 50 | conv_hyperparams=self._build_conv_hyperparams(), 51 | freeze_batchnorm=False, 52 | num_spatial_bins=[3, 3], 53 | depth=4, 54 | crop_size=[12, 12], 55 | box_code_size=4 56 | ) 57 | box_predictions = rfcn_box_predictor( 58 | [image_features], 59 | proposal_boxes=proposal_boxes) 60 | box_encodings = tf.concat( 61 | box_predictions[box_predictor.BOX_ENCODINGS], axis=1) 62 | class_predictions_with_background = tf.concat( 63 | box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], 64 | axis=1) 65 | return (box_encodings, class_predictions_with_background) 66 | 67 | image_features = np.random.rand(4, 8, 8, 64).astype(np.float32) 68 | proposal_boxes = np.random.rand(4, 2, 4).astype(np.float32) 69 | (box_encodings, class_predictions_with_background) = self.execute( 70 | graph_fn, [image_features, proposal_boxes]) 71 | 72 | self.assertAllEqual(box_encodings.shape, [8, 1, 2, 4]) 73 | self.assertAllEqual(class_predictions_with_background.shape, [8, 1, 3]) 74 | 75 | 76 | if __name__ == '__main__': 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /object_detection/protos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/protos/__init__.py -------------------------------------------------------------------------------- /object_detection/protos/bipartite_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: object_detection/protos/bipartite_matcher.proto 4 | 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor.FileDescriptor( 17 | name='object_detection/protos/bipartite_matcher.proto', 18 | package='object_detection.protos', 19 | syntax='proto2', 20 | serialized_options=None, 21 | serialized_pb=b'\n/object_detection/protos/bipartite_matcher.proto\x12\x17object_detection.protos\"4\n\x10\x42ipartiteMatcher\x12 \n\x11use_matmul_gather\x18\x06 \x01(\x08:\x05\x66\x61lse' 22 | ) 23 | 24 | 25 | 26 | 27 | _BIPARTITEMATCHER = _descriptor.Descriptor( 28 | name='BipartiteMatcher', 29 | full_name='object_detection.protos.BipartiteMatcher', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | containing_type=None, 33 | fields=[ 34 | _descriptor.FieldDescriptor( 35 | name='use_matmul_gather', full_name='object_detection.protos.BipartiteMatcher.use_matmul_gather', index=0, 36 | number=6, type=8, cpp_type=7, label=1, 37 | has_default_value=True, default_value=False, 38 | message_type=None, enum_type=None, containing_type=None, 39 | is_extension=False, extension_scope=None, 40 | serialized_options=None, file=DESCRIPTOR), 41 | ], 42 | extensions=[ 43 | ], 44 | nested_types=[], 45 | enum_types=[ 46 | ], 47 | serialized_options=None, 48 | is_extendable=False, 49 | syntax='proto2', 50 | extension_ranges=[], 51 | oneofs=[ 52 | ], 53 | serialized_start=76, 54 | serialized_end=128, 55 | ) 56 | 57 | DESCRIPTOR.message_types_by_name['BipartiteMatcher'] = _BIPARTITEMATCHER 58 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 59 | 60 | BipartiteMatcher = _reflection.GeneratedProtocolMessageType('BipartiteMatcher', (_message.Message,), { 61 | 'DESCRIPTOR' : _BIPARTITEMATCHER, 62 | '__module__' : 'object_detection.protos.bipartite_matcher_pb2' 63 | # @@protoc_insertion_point(class_scope:object_detection.protos.BipartiteMatcher) 64 | }) 65 | _sym_db.RegisterMessage(BipartiteMatcher) 66 | 67 | 68 | # @@protoc_insertion_point(module_scope) 69 | -------------------------------------------------------------------------------- /object_detection/protos/mean_stddev_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: object_detection/protos/mean_stddev_box_coder.proto 4 | 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor.FileDescriptor( 17 | name='object_detection/protos/mean_stddev_box_coder.proto', 18 | package='object_detection.protos', 19 | syntax='proto2', 20 | serialized_options=None, 21 | serialized_pb=b'\n3object_detection/protos/mean_stddev_box_coder.proto\x12\x17object_detection.protos\"*\n\x12MeanStddevBoxCoder\x12\x14\n\x06stddev\x18\x01 \x01(\x02:\x04\x30.01' 22 | ) 23 | 24 | 25 | 26 | 27 | _MEANSTDDEVBOXCODER = _descriptor.Descriptor( 28 | name='MeanStddevBoxCoder', 29 | full_name='object_detection.protos.MeanStddevBoxCoder', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | containing_type=None, 33 | fields=[ 34 | _descriptor.FieldDescriptor( 35 | name='stddev', full_name='object_detection.protos.MeanStddevBoxCoder.stddev', index=0, 36 | number=1, type=2, cpp_type=6, label=1, 37 | has_default_value=True, default_value=float(0.01), 38 | message_type=None, enum_type=None, containing_type=None, 39 | is_extension=False, extension_scope=None, 40 | serialized_options=None, file=DESCRIPTOR), 41 | ], 42 | extensions=[ 43 | ], 44 | nested_types=[], 45 | enum_types=[ 46 | ], 47 | serialized_options=None, 48 | is_extendable=False, 49 | syntax='proto2', 50 | extension_ranges=[], 51 | oneofs=[ 52 | ], 53 | serialized_start=80, 54 | serialized_end=122, 55 | ) 56 | 57 | DESCRIPTOR.message_types_by_name['MeanStddevBoxCoder'] = _MEANSTDDEVBOXCODER 58 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 59 | 60 | MeanStddevBoxCoder = _reflection.GeneratedProtocolMessageType('MeanStddevBoxCoder', (_message.Message,), { 61 | 'DESCRIPTOR' : _MEANSTDDEVBOXCODER, 62 | '__module__' : 'object_detection.protos.mean_stddev_box_coder_pb2' 63 | # @@protoc_insertion_point(class_scope:object_detection.protos.MeanStddevBoxCoder) 64 | }) 65 | _sym_db.RegisterMessage(MeanStddevBoxCoder) 66 | 67 | 68 | # @@protoc_insertion_point(module_scope) 69 | -------------------------------------------------------------------------------- /object_detection/protos/square_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: object_detection/protos/square_box_coder.proto 4 | 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor.FileDescriptor( 17 | name='object_detection/protos/square_box_coder.proto', 18 | package='object_detection.protos', 19 | syntax='proto2', 20 | serialized_options=None, 21 | serialized_pb=b'\n.object_detection/protos/square_box_coder.proto\x12\x17object_detection.protos\"S\n\x0eSquareBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0clength_scale\x18\x03 \x01(\x02:\x01\x35' 22 | ) 23 | 24 | 25 | 26 | 27 | _SQUAREBOXCODER = _descriptor.Descriptor( 28 | name='SquareBoxCoder', 29 | full_name='object_detection.protos.SquareBoxCoder', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | containing_type=None, 33 | fields=[ 34 | _descriptor.FieldDescriptor( 35 | name='y_scale', full_name='object_detection.protos.SquareBoxCoder.y_scale', index=0, 36 | number=1, type=2, cpp_type=6, label=1, 37 | has_default_value=True, default_value=float(10), 38 | message_type=None, enum_type=None, containing_type=None, 39 | is_extension=False, extension_scope=None, 40 | serialized_options=None, file=DESCRIPTOR), 41 | _descriptor.FieldDescriptor( 42 | name='x_scale', full_name='object_detection.protos.SquareBoxCoder.x_scale', index=1, 43 | number=2, type=2, cpp_type=6, label=1, 44 | has_default_value=True, default_value=float(10), 45 | message_type=None, enum_type=None, containing_type=None, 46 | is_extension=False, extension_scope=None, 47 | serialized_options=None, file=DESCRIPTOR), 48 | _descriptor.FieldDescriptor( 49 | name='length_scale', full_name='object_detection.protos.SquareBoxCoder.length_scale', index=2, 50 | number=3, type=2, cpp_type=6, label=1, 51 | has_default_value=True, default_value=float(5), 52 | message_type=None, enum_type=None, containing_type=None, 53 | is_extension=False, extension_scope=None, 54 | serialized_options=None, file=DESCRIPTOR), 55 | ], 56 | extensions=[ 57 | ], 58 | nested_types=[], 59 | enum_types=[ 60 | ], 61 | serialized_options=None, 62 | is_extendable=False, 63 | syntax='proto2', 64 | extension_ranges=[], 65 | oneofs=[ 66 | ], 67 | serialized_start=75, 68 | serialized_end=158, 69 | ) 70 | 71 | DESCRIPTOR.message_types_by_name['SquareBoxCoder'] = _SQUAREBOXCODER 72 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 73 | 74 | SquareBoxCoder = _reflection.GeneratedProtocolMessageType('SquareBoxCoder', (_message.Message,), { 75 | 'DESCRIPTOR' : _SQUAREBOXCODER, 76 | '__module__' : 'object_detection.protos.square_box_coder_pb2' 77 | # @@protoc_insertion_point(class_scope:object_detection.protos.SquareBoxCoder) 78 | }) 79 | _sym_db.RegisterMessage(SquareBoxCoder) 80 | 81 | 82 | # @@protoc_insertion_point(module_scope) 83 | -------------------------------------------------------------------------------- /object_detection/samples/cloud/cloud.yml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | runtimeVersion: "1.12" 3 | scaleTier: CUSTOM 4 | masterType: standard_gpu 5 | workerCount: 5 6 | workerType: standard_gpu 7 | parameterServerCount: 3 8 | parameterServerType: standard 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /object_detection/tpu_exporters/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /object_detection/tpu_exporters/export_saved_model_tpu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Python binary for exporting SavedModel, tailored for TPU inference.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | from object_detection.tpu_exporters import export_saved_model_tpu_lib 22 | 23 | flags = tf.app.flags 24 | FLAGS = flags.FLAGS 25 | 26 | flags.DEFINE_string('pipeline_config_file', None, 27 | 'A pipeline_pb2.TrainEvalPipelineConfig config file.') 28 | flags.DEFINE_string( 29 | 'ckpt_path', None, 'Path to trained checkpoint, typically of the form ' 30 | 'path/to/model.ckpt') 31 | flags.DEFINE_string('export_dir', None, 'Path to export SavedModel.') 32 | flags.DEFINE_string('input_placeholder_name', 'placeholder_tensor', 33 | 'Name of input placeholder in model\'s signature_def_map.') 34 | flags.DEFINE_string( 35 | 'input_type', 'tf_example', 'Type of input node. Can be ' 36 | 'one of [`image_tensor`, `encoded_image_string_tensor`, ' 37 | '`tf_example`]') 38 | flags.DEFINE_boolean('use_bfloat16', False, 'If true, use tf.bfloat16 on TPU.') 39 | 40 | 41 | def main(argv): 42 | if len(argv) > 1: 43 | raise tf.app.UsageError('Too many command-line arguments.') 44 | export_saved_model_tpu_lib.export(FLAGS.pipeline_config_file, FLAGS.ckpt_path, 45 | FLAGS.export_dir, 46 | FLAGS.input_placeholder_name, 47 | FLAGS.input_type, FLAGS.use_bfloat16) 48 | 49 | 50 | if __name__ == '__main__': 51 | tf.app.flags.mark_flag_as_required('pipeline_config_file') 52 | tf.app.flags.mark_flag_as_required('ckpt_path') 53 | tf.app.flags.mark_flag_as_required('export_dir') 54 | tf.app.run() 55 | -------------------------------------------------------------------------------- /object_detection/tpu_exporters/export_saved_model_tpu_lib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test for object detection's TPU exporter.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl.testing import parameterized 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | from object_detection.tpu_exporters import export_saved_model_tpu_lib 28 | 29 | flags = tf.app.flags 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | def get_path(path_suffix): 34 | return os.path.join(tf.resource_loader.get_data_files_path(), 'testdata', 35 | path_suffix) 36 | 37 | 38 | class ExportSavedModelTPUTest(tf.test.TestCase, parameterized.TestCase): 39 | 40 | @parameterized.named_parameters( 41 | ('ssd', get_path('ssd/ssd_pipeline.config'), 'image_tensor', True, 20), 42 | ('faster_rcnn', 43 | get_path('faster_rcnn/faster_rcnn_resnet101_atrous_coco.config'), 44 | 'image_tensor', True, 20)) 45 | def testExportAndLoad(self, 46 | pipeline_config_file, 47 | input_type='image_tensor', 48 | use_bfloat16=False, 49 | repeat=1): 50 | 51 | input_placeholder_name = 'placeholder_tensor' 52 | export_dir = os.path.join(FLAGS.test_tmpdir, 'tpu_saved_model') 53 | if tf.gfile.Exists(export_dir): 54 | tf.gfile.DeleteRecursively(export_dir) 55 | ckpt_path = None 56 | export_saved_model_tpu_lib.export(pipeline_config_file, ckpt_path, 57 | export_dir, input_placeholder_name, 58 | input_type, use_bfloat16) 59 | 60 | inputs = np.random.rand(256, 256, 3) 61 | tensor_dict_out = export_saved_model_tpu_lib.run_inference_from_saved_model( 62 | inputs, export_dir, input_placeholder_name, repeat) 63 | for k, v in tensor_dict_out.items(): 64 | tf.logging.info('{}: {}'.format(k, v)) 65 | 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /object_detection/tpu_exporters/testdata/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /object_detection/tpu_exporters/testdata/faster_rcnn/faster_rcnn_resnet101_atrous_coco.config: -------------------------------------------------------------------------------- 1 | # Faster R-CNN with Resnet-101 (v1), Atrous version 2 | # Trained on COCO, initialized from Imagenet classification checkpoint 3 | 4 | model { 5 | faster_rcnn { 6 | num_classes: 90 7 | image_resizer { 8 | keep_aspect_ratio_resizer { 9 | min_dimension: 600 10 | max_dimension: 1024 11 | } 12 | } 13 | feature_extractor { 14 | type: 'faster_rcnn_resnet101' 15 | first_stage_features_stride: 8 16 | } 17 | first_stage_anchor_generator { 18 | grid_anchor_generator { 19 | scales: [0.25, 0.5, 1.0, 2.0] 20 | aspect_ratios: [0.5, 1.0, 2.0] 21 | height_stride: 8 22 | width_stride: 8 23 | } 24 | } 25 | first_stage_atrous_rate: 2 26 | first_stage_box_predictor_conv_hyperparams { 27 | op: CONV 28 | regularizer { 29 | l2_regularizer { 30 | weight: 0.0 31 | } 32 | } 33 | initializer { 34 | truncated_normal_initializer { 35 | stddev: 0.01 36 | } 37 | } 38 | } 39 | first_stage_nms_score_threshold: 0.0 40 | first_stage_nms_iou_threshold: 0.7 41 | first_stage_max_proposals: 300 42 | first_stage_localization_loss_weight: 2.0 43 | first_stage_objectness_loss_weight: 1.0 44 | initial_crop_size: 14 45 | maxpool_kernel_size: 2 46 | maxpool_stride: 2 47 | second_stage_box_predictor { 48 | mask_rcnn_box_predictor { 49 | use_dropout: false 50 | dropout_keep_probability: 1.0 51 | fc_hyperparams { 52 | op: FC 53 | regularizer { 54 | l2_regularizer { 55 | weight: 0.0 56 | } 57 | } 58 | initializer { 59 | variance_scaling_initializer { 60 | factor: 1.0 61 | uniform: true 62 | mode: FAN_AVG 63 | } 64 | } 65 | } 66 | } 67 | } 68 | second_stage_post_processing { 69 | batch_non_max_suppression { 70 | score_threshold: 0.0 71 | iou_threshold: 0.6 72 | max_detections_per_class: 100 73 | max_total_detections: 300 74 | } 75 | score_converter: SOFTMAX 76 | } 77 | second_stage_localization_loss_weight: 2.0 78 | second_stage_classification_loss_weight: 1.0 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /object_detection/tpu_exporters/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for TPU inference.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | 23 | def bfloat16_to_float32(tensor): 24 | """Converts a tensor to tf.float32 only if it is tf.bfloat16.""" 25 | if tensor.dtype == tf.bfloat16: 26 | return tf.cast(tensor, dtype=tf.float32) 27 | else: 28 | return tensor 29 | 30 | 31 | def bfloat16_to_float32_nested(bfloat16_tensor_dict): 32 | """Converts bfloat16 tensors in a nested structure to float32. 33 | 34 | Other tensors not of dtype bfloat16 will be left as is. 35 | 36 | Args: 37 | bfloat16_tensor_dict: A Python dict, values being Tensor or Python 38 | list/tuple of Tensor. 39 | 40 | Returns: 41 | A Python dict with the same structure as `bfloat16_tensor_dict`, 42 | with all bfloat16 tensors converted to float32. 43 | """ 44 | float32_tensor_dict = {} 45 | for k, v in bfloat16_tensor_dict.items(): 46 | if isinstance(v, tf.Tensor): 47 | float32_tensor_dict[k] = bfloat16_to_float32(v) 48 | elif isinstance(v, (list, tuple)): 49 | float32_tensor_dict[k] = [bfloat16_to_float32(t) for t in v] 50 | return float32_tensor_dict 51 | -------------------------------------------------------------------------------- /object_detection/tpu_exporters/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test for Utility functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from object_detection.tpu_exporters import utils 24 | 25 | 26 | class UtilsTest(tf.test.TestCase): 27 | 28 | def testBfloat16ToFloat32(self): 29 | bfloat16_tensor = tf.random.uniform([2, 3], dtype=tf.bfloat16) 30 | float32_tensor = utils.bfloat16_to_float32(bfloat16_tensor) 31 | self.assertEqual(float32_tensor.dtype, tf.float32) 32 | 33 | def testOtherDtypesNotConverted(self): 34 | int32_tensor = tf.ones([2, 3], dtype=tf.int32) 35 | converted_tensor = utils.bfloat16_to_float32(int32_tensor) 36 | self.assertEqual(converted_tensor.dtype, tf.int32) 37 | 38 | def testBfloat16ToFloat32Nested(self): 39 | tensor_dict = { 40 | 'key1': tf.random.uniform([2, 3], dtype=tf.bfloat16), 41 | 'key2': [ 42 | tf.random.uniform([1, 2], dtype=tf.bfloat16) for _ in range(3) 43 | ], 44 | 'key3': tf.ones([2, 3], dtype=tf.int32), 45 | } 46 | tensor_dict = utils.bfloat16_to_float32_nested(tensor_dict) 47 | 48 | self.assertEqual(tensor_dict['key1'].dtype, tf.float32) 49 | for t in tensor_dict['key2']: 50 | self.assertEqual(t.dtype, tf.float32) 51 | self.assertEqual(tensor_dict['key3'].dtype, tf.int32) 52 | 53 | 54 | if __name__ == '__main__': 55 | tf.test.main() 56 | -------------------------------------------------------------------------------- /object_detection/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/object_detection/utils/__init__.py -------------------------------------------------------------------------------- /object_detection/utils/category_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for importing/exporting Object Detection categories.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import csv 23 | 24 | import tensorflow as tf 25 | 26 | 27 | def load_categories_from_csv_file(csv_path): 28 | """Loads categories from a csv file. 29 | 30 | The CSV file should have one comma delimited numeric category id and string 31 | category name pair per line. For example: 32 | 33 | 0,"cat" 34 | 1,"dog" 35 | 2,"bird" 36 | ... 37 | 38 | Args: 39 | csv_path: Path to the csv file to be parsed into categories. 40 | Returns: 41 | categories: A list of dictionaries representing all possible categories. 42 | The categories will contain an integer 'id' field and a string 43 | 'name' field. 44 | Raises: 45 | ValueError: If the csv file is incorrectly formatted. 46 | """ 47 | categories = [] 48 | 49 | with tf.gfile.Open(csv_path, 'r') as csvfile: 50 | reader = csv.reader(csvfile, delimiter=',', quotechar='"') 51 | for row in reader: 52 | if not row: 53 | continue 54 | 55 | if len(row) != 2: 56 | raise ValueError('Expected 2 fields per row in csv: %s' % ','.join(row)) 57 | 58 | category_id = int(row[0]) 59 | category_name = row[1] 60 | categories.append({'id': category_id, 'name': category_name}) 61 | 62 | return categories 63 | 64 | 65 | def save_categories_to_csv_file(categories, csv_path): 66 | """Saves categories to a csv file. 67 | 68 | Args: 69 | categories: A list of dictionaries representing categories to save to file. 70 | Each category must contain an 'id' and 'name' field. 71 | csv_path: Path to the csv file to be parsed into categories. 72 | """ 73 | categories.sort(key=lambda x: x['id']) 74 | with tf.gfile.Open(csv_path, 'w') as csvfile: 75 | writer = csv.writer(csvfile, delimiter=',', quotechar='"') 76 | for category in categories: 77 | writer.writerow([category['id'], category['name']]) 78 | -------------------------------------------------------------------------------- /object_detection/utils/category_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.category_util.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import tensorflow as tf 25 | 26 | from object_detection.utils import category_util 27 | 28 | 29 | class EvalUtilTest(tf.test.TestCase): 30 | 31 | def test_load_categories_from_csv_file(self): 32 | csv_data = """ 33 | 0,"cat" 34 | 1,"dog" 35 | 2,"bird" 36 | """.strip(' ') 37 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 38 | with tf.gfile.Open(csv_path, 'wb') as f: 39 | f.write(csv_data) 40 | 41 | categories = category_util.load_categories_from_csv_file(csv_path) 42 | self.assertTrue({'id': 0, 'name': 'cat'} in categories) 43 | self.assertTrue({'id': 1, 'name': 'dog'} in categories) 44 | self.assertTrue({'id': 2, 'name': 'bird'} in categories) 45 | 46 | def test_save_categories_to_csv_file(self): 47 | categories = [ 48 | {'id': 0, 'name': 'cat'}, 49 | {'id': 1, 'name': 'dog'}, 50 | {'id': 2, 'name': 'bird'}, 51 | ] 52 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 53 | category_util.save_categories_to_csv_file(categories, csv_path) 54 | saved_categories = category_util.load_categories_from_csv_file(csv_path) 55 | self.assertEqual(saved_categories, categories) 56 | 57 | 58 | if __name__ == '__main__': 59 | tf.test.main() 60 | -------------------------------------------------------------------------------- /object_detection/utils/context_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Python context management helper.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | class IdentityContextManager(object): 23 | """Returns an identity context manager that does nothing. 24 | 25 | This is helpful in setting up conditional `with` statement as below: 26 | 27 | with slim.arg_scope(x) if use_slim_scope else IdentityContextManager(): 28 | do_stuff() 29 | 30 | """ 31 | 32 | def __enter__(self): 33 | return None 34 | 35 | def __exit__(self, exec_type, exec_value, traceback): 36 | del exec_type 37 | del exec_value 38 | del traceback 39 | return False 40 | -------------------------------------------------------------------------------- /object_detection/utils/context_manager_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow_models.object_detection.utils.context_manager.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from object_detection.utils import context_manager 23 | 24 | 25 | class ContextManagerTest(tf.test.TestCase): 26 | 27 | def test_identity_context_manager(self): 28 | with context_manager.IdentityContextManager() as identity_context: 29 | self.assertIsNone(identity_context) 30 | 31 | 32 | if __name__ == '__main__': 33 | tf.test.main() 34 | -------------------------------------------------------------------------------- /object_detection/utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def int64_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 27 | 28 | 29 | def int64_list_feature(value): 30 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 31 | 32 | 33 | def bytes_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 35 | 36 | 37 | def bytes_list_feature(value): 38 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 39 | 40 | 41 | def float_list_feature(value): 42 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 43 | 44 | 45 | def read_examples_list(path): 46 | """Read list of training or validation examples. 47 | 48 | The file is assumed to contain a single example per line where the first 49 | token in the line is an identifier that allows us to find the image and 50 | annotation xml for that example. 51 | 52 | For example, the line: 53 | xyz 3 54 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 55 | 56 | Args: 57 | path: absolute path to examples list file. 58 | 59 | Returns: 60 | list of example identifiers (strings). 61 | """ 62 | with tf.gfile.GFile(path) as fid: 63 | lines = fid.readlines() 64 | return [line.strip().split(' ')[0] for line in lines] 65 | 66 | 67 | def recursive_parse_xml_to_dict(xml): 68 | """Recursively parses XML contents to python dict. 69 | 70 | We assume that `object` tags are the only ones that can appear 71 | multiple times at the same level of a tree. 72 | 73 | Args: 74 | xml: xml tree obtained by parsing XML file contents using lxml.etree 75 | 76 | Returns: 77 | Python dictionary holding XML contents. 78 | """ 79 | if not xml: 80 | return {xml.tag: xml.text} 81 | result = {} 82 | for child in xml: 83 | child_result = recursive_parse_xml_to_dict(child) 84 | if child.tag != 'object': 85 | result[child.tag] = child_result[child.tag] 86 | else: 87 | if child.tag not in result: 88 | result[child.tag] = [] 89 | result[child.tag].append(child_result[child.tag]) 90 | return {xml.tag: result} 91 | -------------------------------------------------------------------------------- /object_detection/utils/dataset_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.dataset_util.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import tensorflow as tf 24 | 25 | from object_detection.utils import dataset_util 26 | 27 | 28 | class DatasetUtilTest(tf.test.TestCase): 29 | 30 | def test_read_examples_list(self): 31 | example_list_data = """example1 1\nexample2 2""" 32 | example_list_path = os.path.join(self.get_temp_dir(), 'examples.txt') 33 | with tf.gfile.Open(example_list_path, 'wb') as f: 34 | f.write(example_list_data) 35 | 36 | examples = dataset_util.read_examples_list(example_list_path) 37 | self.assertListEqual(['example1', 'example2'], examples) 38 | 39 | 40 | if __name__ == '__main__': 41 | tf.test.main() 42 | -------------------------------------------------------------------------------- /object_detection/utils/json_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for dealing with writing json strings. 16 | 17 | json_utils wraps json.dump and json.dumps so that they can be used to safely 18 | control the precision of floats when writing to json strings or files. 19 | """ 20 | import json 21 | from json import encoder 22 | 23 | 24 | def Dump(obj, fid, float_digits=-1, **params): 25 | """Wrapper of json.dump that allows specifying the float precision used. 26 | 27 | Args: 28 | obj: The object to dump. 29 | fid: The file id to write to. 30 | float_digits: The number of digits of precision when writing floats out. 31 | **params: Additional parameters to pass to json.dumps. 32 | """ 33 | original_encoder = encoder.FLOAT_REPR 34 | if float_digits >= 0: 35 | encoder.FLOAT_REPR = lambda o: format(o, '.%df' % float_digits) 36 | try: 37 | json.dump(obj, fid, **params) 38 | finally: 39 | encoder.FLOAT_REPR = original_encoder 40 | 41 | 42 | def Dumps(obj, float_digits=-1, **params): 43 | """Wrapper of json.dumps that allows specifying the float precision used. 44 | 45 | Args: 46 | obj: The object to dump. 47 | float_digits: The number of digits of precision when writing floats out. 48 | **params: Additional parameters to pass to json.dumps. 49 | 50 | Returns: 51 | output: JSON string representation of obj. 52 | """ 53 | original_encoder = encoder.FLOAT_REPR 54 | original_c_make_encoder = encoder.c_make_encoder 55 | if float_digits >= 0: 56 | encoder.FLOAT_REPR = lambda o: format(o, '.%df' % float_digits) 57 | encoder.c_make_encoder = None 58 | try: 59 | output = json.dumps(obj, **params) 60 | finally: 61 | encoder.FLOAT_REPR = original_encoder 62 | encoder.c_make_encoder = original_c_make_encoder 63 | 64 | return output 65 | 66 | 67 | def PrettyParams(**params): 68 | """Returns parameters for use with Dump and Dumps to output pretty json. 69 | 70 | Example usage: 71 | ```json_str = json_utils.Dumps(obj, **json_utils.PrettyParams())``` 72 | ```json_str = json_utils.Dumps( 73 | obj, **json_utils.PrettyParams(allow_nans=False))``` 74 | 75 | Args: 76 | **params: Additional params to pass to json.dump or json.dumps. 77 | 78 | Returns: 79 | params: Parameters that are compatible with json_utils.Dump and 80 | json_utils.Dumps. 81 | """ 82 | params['float_digits'] = 4 83 | params['sort_keys'] = True 84 | params['indent'] = 2 85 | params['separators'] = (',', ': ') 86 | return params 87 | 88 | -------------------------------------------------------------------------------- /object_detection/utils/model_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Test utility functions for manipulating Keras models.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from object_detection.utils import model_util 25 | 26 | 27 | class ExtractSubmodelUtilTest(tf.test.TestCase): 28 | 29 | def test_simple_model(self): 30 | inputs = tf.keras.Input(shape=(256,)) # Returns a placeholder tensor 31 | 32 | # A layer instance is callable on a tensor, and returns a tensor. 33 | x = tf.keras.layers.Dense(128, activation='relu', name='a')(inputs) 34 | x = tf.keras.layers.Dense(64, activation='relu', name='b')(x) 35 | x = tf.keras.layers.Dense(32, activation='relu', name='c')(x) 36 | x = tf.keras.layers.Dense(16, activation='relu', name='d')(x) 37 | x = tf.keras.layers.Dense(8, activation='relu', name='e')(x) 38 | predictions = tf.keras.layers.Dense(10, activation='softmax')(x) 39 | 40 | model = tf.keras.Model(inputs=inputs, outputs=predictions) 41 | 42 | new_in = model.get_layer( 43 | name='b').input 44 | new_out = model.get_layer( 45 | name='d').output 46 | 47 | new_model = model_util.extract_submodel( 48 | model=model, 49 | inputs=new_in, 50 | outputs=new_out) 51 | 52 | batch_size = 3 53 | ones = tf.ones((batch_size, 128)) 54 | final_out = new_model(ones) 55 | self.assertAllEqual(final_out.shape, (batch_size, 16)) 56 | 57 | if __name__ == '__main__': 58 | tf.test.main() 59 | -------------------------------------------------------------------------------- /object_detection/utils/np_box_mask_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxMaskList classes and functions.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | from object_detection.utils import np_box_list 24 | 25 | 26 | class BoxMaskList(np_box_list.BoxList): 27 | """Convenience wrapper for BoxList with masks. 28 | 29 | BoxMaskList extends the np_box_list.BoxList to contain masks as well. 30 | In particular, its constructor receives both boxes and masks. Note that the 31 | masks correspond to the full image. 32 | """ 33 | 34 | def __init__(self, box_data, mask_data): 35 | """Constructs box collection. 36 | 37 | Args: 38 | box_data: a numpy array of shape [N, 4] representing box coordinates 39 | mask_data: a numpy array of shape [N, height, width] representing masks 40 | with values are in {0,1}. The masks correspond to the full 41 | image. The height and the width will be equal to image height and width. 42 | 43 | Raises: 44 | ValueError: if bbox data is not a numpy array 45 | ValueError: if invalid dimensions for bbox data 46 | ValueError: if mask data is not a numpy array 47 | ValueError: if invalid dimension for mask data 48 | """ 49 | super(BoxMaskList, self).__init__(box_data) 50 | if not isinstance(mask_data, np.ndarray): 51 | raise ValueError('Mask data must be a numpy array.') 52 | if len(mask_data.shape) != 3: 53 | raise ValueError('Invalid dimensions for mask data.') 54 | if mask_data.dtype != np.uint8: 55 | raise ValueError('Invalid data type for mask data: uint8 is required.') 56 | if mask_data.shape[0] != box_data.shape[0]: 57 | raise ValueError('There should be the same number of boxes and masks.') 58 | self.data['masks'] = mask_data 59 | 60 | def get_masks(self): 61 | """Convenience function for accessing masks. 62 | 63 | Returns: 64 | a numpy array of shape [N, height, width] representing masks 65 | """ 66 | return self.get_field('masks') 67 | -------------------------------------------------------------------------------- /object_detection/utils/np_box_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.np_box_ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | from object_detection.utils import np_box_ops 26 | 27 | 28 | class BoxOpsTests(tf.test.TestCase): 29 | 30 | def setUp(self): 31 | boxes1 = np.array([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], 32 | dtype=float) 33 | boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 34 | [0.0, 0.0, 20.0, 20.0]], 35 | dtype=float) 36 | self.boxes1 = boxes1 37 | self.boxes2 = boxes2 38 | 39 | def testArea(self): 40 | areas = np_box_ops.area(self.boxes1) 41 | expected_areas = np.array([6.0, 5.0], dtype=float) 42 | self.assertAllClose(expected_areas, areas) 43 | 44 | def testIntersection(self): 45 | intersection = np_box_ops.intersection(self.boxes1, self.boxes2) 46 | expected_intersection = np.array([[2.0, 0.0, 6.0], [1.0, 0.0, 5.0]], 47 | dtype=float) 48 | self.assertAllClose(intersection, expected_intersection) 49 | 50 | def testIOU(self): 51 | iou = np_box_ops.iou(self.boxes1, self.boxes2) 52 | expected_iou = np.array([[2.0 / 16.0, 0.0, 6.0 / 400.0], 53 | [1.0 / 16.0, 0.0, 5.0 / 400.0]], 54 | dtype=float) 55 | self.assertAllClose(iou, expected_iou) 56 | 57 | def testIOA(self): 58 | boxes1 = np.array([[0.25, 0.25, 0.75, 0.75], 59 | [0.0, 0.0, 0.5, 0.75]], 60 | dtype=np.float32) 61 | boxes2 = np.array([[0.5, 0.25, 1.0, 1.0], 62 | [0.0, 0.0, 1.0, 1.0]], 63 | dtype=np.float32) 64 | ioa21 = np_box_ops.ioa(boxes2, boxes1) 65 | expected_ioa21 = np.array([[0.5, 0.0], 66 | [1.0, 1.0]], 67 | dtype=np.float32) 68 | self.assertAllClose(ioa21, expected_ioa21) 69 | 70 | 71 | if __name__ == '__main__': 72 | tf.test.main() 73 | -------------------------------------------------------------------------------- /object_detection/utils/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions to access TensorShape values. 17 | 18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | 26 | def get_dim_as_int(dim): 27 | """Utility to get v1 or v2 TensorShape dim as an int. 28 | 29 | Args: 30 | dim: The TensorShape dimension to get as an int 31 | 32 | Returns: 33 | None or an int. 34 | """ 35 | try: 36 | return dim.value 37 | except AttributeError: 38 | return dim 39 | 40 | 41 | def get_batch_size(tensor_shape): 42 | """Returns batch size from the tensor shape. 43 | 44 | Args: 45 | tensor_shape: A rank 4 TensorShape. 46 | 47 | Returns: 48 | An integer representing the batch size of the tensor. 49 | """ 50 | tensor_shape.assert_has_rank(rank=4) 51 | return get_dim_as_int(tensor_shape[0]) 52 | 53 | 54 | def get_height(tensor_shape): 55 | """Returns height from the tensor shape. 56 | 57 | Args: 58 | tensor_shape: A rank 4 TensorShape. 59 | 60 | Returns: 61 | An integer representing the height of the tensor. 62 | """ 63 | tensor_shape.assert_has_rank(rank=4) 64 | return get_dim_as_int(tensor_shape[1]) 65 | 66 | 67 | def get_width(tensor_shape): 68 | """Returns width from the tensor shape. 69 | 70 | Args: 71 | tensor_shape: A rank 4 TensorShape. 72 | 73 | Returns: 74 | An integer representing the width of the tensor. 75 | """ 76 | tensor_shape.assert_has_rank(rank=4) 77 | return get_dim_as_int(tensor_shape[2]) 78 | 79 | 80 | def get_depth(tensor_shape): 81 | """Returns depth from the tensor shape. 82 | 83 | Args: 84 | tensor_shape: A rank 4 TensorShape. 85 | 86 | Returns: 87 | An integer representing the depth of the tensor. 88 | """ 89 | tensor_shape.assert_has_rank(rank=4) 90 | return get_dim_as_int(tensor_shape[3]) 91 | -------------------------------------------------------------------------------- /object_detection/utils/static_shape_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.static_shape.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | from object_detection.utils import static_shape 25 | 26 | 27 | class StaticShapeTest(tf.test.TestCase): 28 | 29 | def test_return_correct_batchSize(self): 30 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 31 | self.assertEqual(32, static_shape.get_batch_size(tensor_shape)) 32 | 33 | def test_return_correct_height(self): 34 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 35 | self.assertEqual(299, static_shape.get_height(tensor_shape)) 36 | 37 | def test_return_correct_width(self): 38 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 39 | self.assertEqual(384, static_shape.get_width(tensor_shape)) 40 | 41 | def test_return_correct_depth(self): 42 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 43 | self.assertEqual(3, static_shape.get_depth(tensor_shape)) 44 | 45 | def test_die_on_tensor_shape_with_rank_three(self): 46 | tensor_shape = tf.TensorShape(dims=[32, 299, 384]) 47 | with self.assertRaises(ValueError): 48 | static_shape.get_batch_size(tensor_shape) 49 | static_shape.get_height(tensor_shape) 50 | static_shape.get_width(tensor_shape) 51 | static_shape.get_depth(tensor_shape) 52 | 53 | if __name__ == '__main__': 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /slim/WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/WORKSPACE -------------------------------------------------------------------------------- /slim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/__init__.py -------------------------------------------------------------------------------- /slim/data/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2016 The TF-Slim Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | -------------------------------------------------------------------------------- /slim/data/data_decoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Contains helper functions and classes necessary for decoding data. 17 | 18 | While data providers read data from disk, sstables or other formats, data 19 | decoders decode the data (if necessary). A data decoder is provided with a 20 | serialized or encoded piece of data as well as a list of items and 21 | returns a set of tensors, each of which correspond to the requested list of 22 | items extracted from the data: 23 | 24 | def Decode(self, data, items): 25 | ... 26 | 27 | For example, if data is a compressed map, the implementation might be: 28 | 29 | def Decode(self, data, items): 30 | decompressed_map = _Decompress(data) 31 | outputs = [] 32 | for item in items: 33 | outputs.append(decompressed_map[item]) 34 | return outputs. 35 | """ 36 | 37 | from __future__ import absolute_import 38 | from __future__ import division 39 | from __future__ import print_function 40 | 41 | import abc 42 | import six 43 | 44 | 45 | @six.add_metaclass(abc.ABCMeta) 46 | class DataDecoder(object): 47 | """An abstract class which is used to decode data for a provider.""" 48 | 49 | @abc.abstractmethod 50 | def decode(self, data, items): 51 | """Decodes the data to returns the tensors specified by the list of items. 52 | 53 | Args: 54 | data: A possibly encoded data format. 55 | items: A list of strings, each of which indicate a particular data type. 56 | 57 | Returns: 58 | A list of `Tensors`, whose length matches the length of `items`, where 59 | each `Tensor` corresponds to each item. 60 | 61 | Raises: 62 | ValueError: If any of the items cannot be satisfied. 63 | """ 64 | pass 65 | 66 | @abc.abstractmethod 67 | def list_items(self): 68 | """Lists the names of the items that the decoder can decode. 69 | 70 | Returns: 71 | A list of string names. 72 | """ 73 | pass 74 | -------------------------------------------------------------------------------- /slim/data/dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Contains the definition of a Dataset. 17 | 18 | A Dataset is a collection of several components: (1) a list of data sources 19 | (2) a Reader class that can read those sources and returns possibly encoded 20 | samples of data (3) a decoder that decodes each sample of data provided by the 21 | reader (4) the total number of samples and (5) an optional dictionary mapping 22 | the list of items returns to a description of those items. 23 | 24 | Data can be loaded from a dataset specification using a dataset_data_provider: 25 | 26 | dataset = CreateMyDataset(...) 27 | provider = dataset_data_provider.DatasetDataProvider( 28 | dataset, shuffle=False) 29 | image, label = provider.get(['image', 'label']) 30 | 31 | See slim.data.dataset_data_provider for additional examples. 32 | """ 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | 39 | class Dataset(object): 40 | """Represents a Dataset specification.""" 41 | 42 | def __init__(self, data_sources, reader, decoder, num_samples, 43 | items_to_descriptions, **kwargs): 44 | """Initializes the dataset. 45 | 46 | Args: 47 | data_sources: A list of files that make up the dataset. 48 | reader: The reader class, a subclass of BaseReader such as TextLineReader 49 | or TFRecordReader. 50 | decoder: An instance of a data_decoder. 51 | num_samples: The number of samples in the dataset. 52 | items_to_descriptions: A map from the items that the dataset provides to 53 | the descriptions of those items. 54 | **kwargs: Any remaining dataset-specific fields. 55 | """ 56 | kwargs['data_sources'] = data_sources 57 | kwargs['reader'] = reader 58 | kwargs['decoder'] = decoder 59 | kwargs['num_samples'] = num_samples 60 | kwargs['items_to_descriptions'] = items_to_descriptions 61 | self.__dict__.update(kwargs) 62 | -------------------------------------------------------------------------------- /slim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | from datasets import visualwakewords 26 | 27 | datasets_map = { 28 | 'cifar10': cifar10, 29 | 'flowers': flowers, 30 | 'imagenet': imagenet, 31 | 'mnist': mnist, 32 | 'visualwakewords': visualwakewords, 33 | } 34 | 35 | 36 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 37 | """Given a dataset name and a split_name returns a Dataset. 38 | 39 | Args: 40 | name: String, the name of the dataset. 41 | split_name: A train/test split name. 42 | dataset_dir: The directory where the dataset files are stored. 43 | file_pattern: The file pattern to use for matching the dataset source files. 44 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 45 | reader defined by each dataset is used. 46 | 47 | Returns: 48 | A `Dataset` class. 49 | 50 | Raises: 51 | ValueError: If the dataset `name` is unknown. 52 | """ 53 | if name not in datasets_map: 54 | raise ValueError('Name of dataset unknown %s' % name) 55 | return datasets_map[name].get_split( 56 | split_name, 57 | dataset_dir, 58 | file_pattern, 59 | reader) 60 | -------------------------------------------------------------------------------- /slim/datasets/preprocess_imagenet_validation_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | r"""Process the ImageNet Challenge bounding boxes for TensorFlow model training. 17 | 18 | Associate the ImageNet 2012 Challenge validation data set with labels. 19 | 20 | The raw ImageNet validation data set is expected to reside in JPEG files 21 | located in the following directory structure. 22 | 23 | data_dir/ILSVRC2012_val_00000001.JPEG 24 | data_dir/ILSVRC2012_val_00000002.JPEG 25 | ... 26 | data_dir/ILSVRC2012_val_00050000.JPEG 27 | 28 | This script moves the files into a directory structure like such: 29 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 30 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 31 | ... 32 | where 'n01440764' is the unique synset label associated with 33 | these images. 34 | 35 | This directory reorganization requires a mapping from validation image 36 | number (i.e. suffix of the original file) to the associated label. This 37 | is provided in the ImageNet development kit via a Matlab file. 38 | 39 | In order to make life easier and divorce ourselves from Matlab, we instead 40 | supply a custom text file that provides this mapping for us. 41 | 42 | Sample usage: 43 | ./preprocess_imagenet_validation_data.py ILSVRC2012_img_val \ 44 | imagenet_2012_validation_synset_labels.txt 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | 51 | import os 52 | import sys 53 | 54 | from six.moves import xrange # pylint: disable=redefined-builtin 55 | 56 | 57 | if __name__ == '__main__': 58 | if len(sys.argv) < 3: 59 | print('Invalid usage\n' 60 | 'usage: preprocess_imagenet_validation_data.py ' 61 | ' ') 62 | sys.exit(-1) 63 | data_dir = sys.argv[1] 64 | validation_labels_file = sys.argv[2] 65 | 66 | # Read in the 50000 synsets associated with the validation data set. 67 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 68 | unique_labels = set(labels) 69 | 70 | # Make all sub-directories in the validation data dir. 71 | for label in unique_labels: 72 | labeled_data_dir = os.path.join(data_dir, label) 73 | os.makedirs(labeled_data_dir) 74 | 75 | # Move all of the image to the appropriate sub-directory. 76 | for i in xrange(len(labels)): 77 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 78 | original_filename = os.path.join(data_dir, basename) 79 | if not os.path.exists(original_filename): 80 | print('Failed to find: ', original_filename) 81 | sys.exit(-1) 82 | new_filename = os.path.join(data_dir, labels[i], basename) 83 | os.rename(original_filename, new_filename) 84 | -------------------------------------------------------------------------------- /slim/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/export_inference_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for export_inference_graph.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | 25 | import tensorflow as tf 26 | 27 | from tensorflow.python.platform import gfile 28 | import export_inference_graph 29 | 30 | 31 | class ExportInferenceGraphTest(tf.test.TestCase): 32 | 33 | def testExportInferenceGraph(self): 34 | tmpdir = self.get_temp_dir() 35 | output_file = os.path.join(tmpdir, 'inception_v3.pb') 36 | flags = tf.app.flags.FLAGS 37 | flags.output_file = output_file 38 | flags.model_name = 'inception_v3' 39 | flags.dataset_dir = tmpdir 40 | export_inference_graph.main(None) 41 | self.assertTrue(gfile.Exists(output_file)) 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /slim/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """layers module with higher level NN primitives.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # pylint: disable=wildcard-import 23 | from slim.layers.bucketization_op import * 24 | from slim.layers.initializers import * 25 | from slim.layers.layers import * 26 | from slim.layers.normalization import * 27 | from slim.layers.optimizers import * 28 | from slim.layers.regularizers import * 29 | from slim.layers.rev_block_lib import * 30 | from slim.layers.summaries import * 31 | 32 | # pylint: enable=wildcard-import 33 | -------------------------------------------------------------------------------- /slim/layers/bucketization_op.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Wrappers for bucketization operations.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | 22 | # pylint: disable=g-direct-tensorflow-import 23 | from tensorflow.python.ops import math_ops 24 | 25 | 26 | def bucketize(input_tensor, boundaries, name=None): 27 | """Bucketizes input_tensor by given boundaries. 28 | 29 | See bucketize_op.cc for more details. 30 | 31 | Args: 32 | input_tensor: A `Tensor` which will be bucketize. 33 | boundaries: A list of floats gives the boundaries. It has to be sorted. 34 | name: A name prefix for the returned tensors (optional). 35 | 36 | Returns: 37 | A `Tensor` with type int32 which indicates the corresponding bucket for 38 | each value in `input_tensor`. 39 | 40 | Raises: 41 | TypeError: If boundaries is not a list. 42 | """ 43 | return math_ops._bucketize( # pylint: disable=protected-access 44 | input_tensor, boundaries=boundaries, name=name) 45 | -------------------------------------------------------------------------------- /slim/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Ops for building neural network losses. 18 | 19 | See [Contrib Losses](https://tensorflow.org/api_guides/python/contrib.losses). 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | 27 | from tf_slim.losses import metric_learning 28 | # pylint: disable=wildcard-import 29 | from tf_slim.losses.loss_ops import * 30 | from tf_slim.losses.metric_learning import * 31 | 32 | # pylint: disable=g-direct-tensorflow-import 33 | from tensorflow.python.util.all_util import remove_undocumented 34 | 35 | _allowed_symbols = [ 36 | 'absolute_difference', 37 | 'add_loss', 38 | 'cluster_loss', 39 | 'compute_weighted_loss', 40 | 'contrastive_loss', 41 | 'cosine_distance', 42 | 'get_losses', 43 | 'get_regularization_losses', 44 | 'get_total_loss', 45 | 'hinge_loss', 46 | 'lifted_struct_loss', 47 | 'log_loss', 48 | 'mean_pairwise_squared_error', 49 | 'mean_squared_error', 50 | 'metric_learning', 51 | 'npairs_loss', 52 | 'npairs_loss_multilabel', 53 | 'sigmoid_cross_entropy', 54 | 'softmax_cross_entropy', 55 | 'sparse_softmax_cross_entropy', 56 | 'triplet_semihard_loss', 57 | ] 58 | remove_undocumented(__name__, _allowed_symbols) 59 | -------------------------------------------------------------------------------- /slim/metrics/metric_ops_large_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Large tests for metric_ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import numpy as np 23 | from six.moves import xrange # pylint: disable=redefined-builtin 24 | import tensorflow.compat.v1 as tf 25 | from tf_slim.metrics import metric_ops 26 | # pylint: disable=g-direct-tensorflow-import 27 | from tensorflow.python.framework import dtypes as dtypes_lib 28 | from tensorflow.python.ops import math_ops 29 | from tensorflow.python.ops import random_ops 30 | from tensorflow.python.ops import variables 31 | from tensorflow.python.platform import test 32 | 33 | 34 | def setUpModule(): 35 | tf.disable_eager_execution() 36 | 37 | 38 | class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase): 39 | 40 | def setUp(self): 41 | super(StreamingPrecisionRecallAtEqualThresholdsLargeTest, self).setUp() 42 | np.random.seed(1) 43 | 44 | def testLargeCase(self): 45 | shape = [32, 512, 256, 1] 46 | predictions = random_ops.random_uniform( 47 | shape, 0.0, 1.0, dtype=dtypes_lib.float32) 48 | labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) 49 | 50 | result, update_op = metric_ops.precision_recall_at_equal_thresholds( 51 | labels=labels, predictions=predictions, num_thresholds=201) 52 | # Run many updates, enough to cause highly inaccurate values if the 53 | # code used float32 for accumulation. 54 | num_updates = 71 55 | 56 | with self.cached_session() as sess: 57 | sess.run(variables.local_variables_initializer()) 58 | for _ in xrange(num_updates): 59 | sess.run(update_op) 60 | 61 | prdata = sess.run(result) 62 | 63 | # Since we use random values, we won't know the tp/fp/tn/fn values, but 64 | # tp and fp at threshold 0 should be the total number of positive and 65 | # negative labels, hence their sum should be total number of pixels. 66 | expected_value = 1.0 * np.product(shape) * num_updates 67 | got_value = prdata.tp[0] + prdata.fp[0] 68 | # They should be at least within 1. 69 | self.assertNear(got_value, expected_value, 1.0) 70 | 71 | if __name__ == '__main__': 72 | test.main() 73 | -------------------------------------------------------------------------------- /slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /slim/nets/mobilenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/nets/mobilenet/__init__.py -------------------------------------------------------------------------------- /slim/nets/mobilenet/g3doc/edgetpu_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/nets/mobilenet/g3doc/edgetpu_latency.png -------------------------------------------------------------------------------- /slim/nets/mobilenet/g3doc/latency_pixel1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/nets/mobilenet/g3doc/latency_pixel1.png -------------------------------------------------------------------------------- /slim/nets/mobilenet/g3doc/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/nets/mobilenet/g3doc/madds_top1_accuracy.png -------------------------------------------------------------------------------- /slim/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /slim/nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/slim/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /slim/nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /slim/nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.compat.v1.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /slim/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2016 The TF-Slim Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | -------------------------------------------------------------------------------- /slim/ops/framework_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Classes and functions used to construct graphs.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | # pylint: disable=g-direct-tensorflow-import 24 | from tensorflow.python.framework import ops 25 | 26 | 27 | __all__ = ['get_graph_from_inputs', 28 | 'get_name_scope'] 29 | 30 | 31 | def get_graph_from_inputs(op_input_list, graph=None): 32 | """Returns the appropriate graph to use for the given inputs. 33 | 34 | 1. If `graph` is provided, we validate that all inputs in `op_input_list` are 35 | from the same graph. 36 | 2. Otherwise, we attempt to select a graph from the first Operation- or 37 | Tensor-valued input in `op_input_list`, and validate that all other 38 | such inputs are in the same graph. 39 | 3. If the graph was not specified and it could not be inferred from 40 | `op_input_list`, we attempt to use the default graph. 41 | 42 | Args: 43 | op_input_list: A list of inputs to an operation, which may include `Tensor`, 44 | `Operation`, and other objects that may be converted to a graph element. 45 | graph: (Optional) The explicit graph to use. 46 | 47 | Raises: 48 | TypeError: If `op_input_list` is not a list or tuple, or if graph is not a 49 | Graph. 50 | ValueError: If a graph is explicitly passed and not all inputs are from it, 51 | or if the inputs are from multiple graphs, or we could not find a graph 52 | and there was no default graph. 53 | 54 | Returns: 55 | The appropriate graph to use for the given inputs. 56 | """ 57 | # pylint: disable=protected-access 58 | return ops._get_graph_from_inputs(op_input_list, graph) 59 | 60 | 61 | def get_name_scope(): 62 | """Returns the current name scope of the default graph. 63 | 64 | For example: 65 | 66 | ```python 67 | with tf.name_scope('scope1'): 68 | with tf.name_scope('scope2'): 69 | print(tf.contrib.framework.get_name_scope()) 70 | ``` 71 | would print the string `scope1/scope2`. 72 | 73 | Returns: 74 | A string representing the current name scope. 75 | """ 76 | return ops.get_default_graph().get_name_scope() 77 | -------------------------------------------------------------------------------- /slim/ops/framework_ops_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """tensor_util tests.""" 17 | 18 | # pylint: disable=unused-import 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | import tensorflow.compat.v1 as tf 23 | 24 | from tf_slim.ops import framework_ops as ops_lib 25 | # pylint: disable=g-direct-tensorflow-import 26 | from tensorflow.python.framework import constant_op 27 | from tensorflow.python.framework import ops 28 | from tensorflow.python.platform import test 29 | 30 | 31 | def setUpModule(): 32 | tf.disable_eager_execution() 33 | 34 | 35 | class OpsTest(test.TestCase): 36 | 37 | def testGetGraphFromEmptyInputs(self): 38 | with ops.Graph().as_default() as g0: 39 | self.assertIs(g0, ops_lib.get_graph_from_inputs([])) 40 | 41 | def testGetGraphFromValidInputs(self): 42 | g0 = ops.Graph() 43 | with g0.as_default(): 44 | values = [constant_op.constant(0.0), constant_op.constant(1.0)] 45 | self.assertIs(g0, ops_lib.get_graph_from_inputs(values)) 46 | self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0)) 47 | with ops.Graph().as_default(): 48 | self.assertIs(g0, ops_lib.get_graph_from_inputs(values)) 49 | self.assertIs(g0, ops_lib.get_graph_from_inputs(values, g0)) 50 | 51 | def testGetGraphFromInvalidInputs(self): 52 | g0 = ops.Graph() 53 | with g0.as_default(): 54 | values = [constant_op.constant(0.0), constant_op.constant(1.0)] 55 | g1 = ops.Graph() 56 | with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"): 57 | ops_lib.get_graph_from_inputs(values, g1) 58 | with g1.as_default(): 59 | values.append(constant_op.constant(2.0)) 60 | with self.assertRaisesRegexp(ValueError, "must be from the same graph"): 61 | ops_lib.get_graph_from_inputs(values) 62 | with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"): 63 | ops_lib.get_graph_from_inputs(values, g0) 64 | with self.assertRaisesRegexp(ValueError, "not from the passed-in graph"): 65 | ops_lib.get_graph_from_inputs(values, g1) 66 | 67 | def testGetNameScope(self): 68 | with ops.name_scope("scope1"): 69 | with ops.name_scope("scope2"): 70 | with ops.name_scope("scope3"): 71 | self.assertEqual("scope1/scope2/scope3", ops_lib.get_name_scope()) 72 | self.assertEqual("scope1/scope2", ops_lib.get_name_scope()) 73 | self.assertEqual("scope1", ops_lib.get_name_scope()) 74 | self.assertEqual("", ops_lib.get_name_scope()) 75 | 76 | 77 | if __name__ == "__main__": 78 | test.main() 79 | -------------------------------------------------------------------------------- /slim/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from tensorflow.contrib import slim as contrib_slim 23 | 24 | slim = contrib_slim 25 | 26 | 27 | def preprocess_image(image, 28 | output_height, 29 | output_width, 30 | is_training, 31 | use_grayscale=False): 32 | """Preprocesses the given image. 33 | 34 | Args: 35 | image: A `Tensor` representing an image of arbitrary size. 36 | output_height: The height of the image after preprocessing. 37 | output_width: The width of the image after preprocessing. 38 | is_training: `True` if we're preprocessing the image for training and 39 | `False` otherwise. 40 | use_grayscale: Whether to convert the image from RGB to grayscale. 41 | 42 | Returns: 43 | A preprocessed image. 44 | """ 45 | del is_training # Unused argument 46 | image = tf.to_float(image) 47 | if use_grayscale: 48 | image = tf.image.rgb_to_grayscale(image) 49 | image = tf.image.resize_image_with_crop_or_pad( 50 | image, output_width, output_height) 51 | image = tf.subtract(image, 128.0) 52 | image = tf.div(image, 128.0) 53 | return image 54 | -------------------------------------------------------------------------------- /slim/queues.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # coding=utf-8 3 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | """Contains a helper context for running queue runners. 18 | 19 | @@NestedQueueRunnerError 20 | @@QueueRunners 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import contextlib 28 | import threading 29 | 30 | # pylint:disable=g-direct-tensorflow-import 31 | from tensorflow.python.framework import ops 32 | from tensorflow.python.training import coordinator 33 | # pylint:enable=g-direct-tensorflow-import 34 | 35 | __all__ = [ 36 | 'NestedQueueRunnerError', 37 | 'QueueRunners', 38 | ] 39 | 40 | _queue_runner_lock = threading.Lock() 41 | 42 | 43 | class NestedQueueRunnerError(Exception): 44 | pass 45 | 46 | 47 | @contextlib.contextmanager 48 | def QueueRunners(session): 49 | """Creates a context manager that handles starting and stopping queue runners. 50 | 51 | Args: 52 | session: the currently running session. 53 | 54 | Yields: 55 | a context in which queues are run. 56 | 57 | Raises: 58 | NestedQueueRunnerError: if a QueueRunners context is nested within another. 59 | """ 60 | if not _queue_runner_lock.acquire(False): 61 | raise NestedQueueRunnerError('QueueRunners cannot be nested') 62 | 63 | coord = coordinator.Coordinator() 64 | threads = [] 65 | for qr in ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 66 | threads.extend( 67 | qr.create_threads( 68 | session, coord=coord, daemon=True, start=True)) 69 | try: 70 | yield 71 | finally: 72 | coord.request_stop() 73 | try: 74 | coord.join(threads, stop_grace_period_secs=120) 75 | except RuntimeError: 76 | session.close() 77 | 78 | _queue_runner_lock.release() 79 | -------------------------------------------------------------------------------- /slim/scripts/train_cifarnet_on_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Cifar10 dataset 19 | # 2. Trains a CifarNet model on the Cifar10 training set. 20 | # 3. Evaluates the model on the Cifar10 testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./scripts/train_cifarnet_on_cifar10.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/cifarnet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/cifar10 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=cifar10 \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=cifar10 \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=cifarnet \ 45 | --preprocessing_name=cifarnet \ 46 | --max_number_of_steps=100000 \ 47 | --batch_size=128 \ 48 | --save_interval_secs=120 \ 49 | --save_summaries_secs=120 \ 50 | --log_every_n_steps=100 \ 51 | --optimizer=sgd \ 52 | --learning_rate=0.1 \ 53 | --learning_rate_decay_factor=0.1 \ 54 | --num_epochs_per_decay=200 \ 55 | --weight_decay=0.004 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=cifar10 \ 62 | --dataset_split_name=test \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=cifarnet 65 | -------------------------------------------------------------------------------- /slim/scripts/train_lenet_on_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the MNIST dataset 19 | # 2. Trains a LeNet model on the MNIST training set. 20 | # 3. Evaluates the model on the MNIST testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/train_lenet_on_mnist.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/lenet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/mnist 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=mnist \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=mnist \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=lenet \ 45 | --preprocessing_name=lenet \ 46 | --max_number_of_steps=20000 \ 47 | --batch_size=50 \ 48 | --learning_rate=0.01 \ 49 | --save_interval_secs=60 \ 50 | --save_summaries_secs=60 \ 51 | --log_every_n_steps=100 \ 52 | --optimizer=sgd \ 53 | --learning_rate_decay_type=fixed \ 54 | --weight_decay=0 55 | 56 | # Run evaluation. 57 | python eval_image_classifier.py \ 58 | --checkpoint_path=${TRAIN_DIR} \ 59 | --eval_dir=${TRAIN_DIR} \ 60 | --dataset_name=mnist \ 61 | --dataset_split_name=test \ 62 | --dataset_dir=${DATASET_DIR} \ 63 | --model_name=lenet 64 | -------------------------------------------------------------------------------- /slim/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup script for slim.""" 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | 21 | setup( 22 | name='slim', 23 | version='0.1', 24 | include_package_data=True, 25 | packages=find_packages(), 26 | description='tf-slim', 27 | ) 28 | -------------------------------------------------------------------------------- /slim/training/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2016 The TF-Slim Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | -------------------------------------------------------------------------------- /tf_datatools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/tf_datatools/__init__.py -------------------------------------------------------------------------------- /tf_datatools/pascal_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: '目标物' 4 | } -------------------------------------------------------------------------------- /tf_datatools/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/tf_datatools/utils/__init__.py -------------------------------------------------------------------------------- /tf_datatools/utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tensorflow as tf 23 | 24 | 25 | def int64_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 27 | 28 | 29 | def int64_list_feature(value): 30 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 31 | 32 | 33 | def bytes_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 35 | 36 | 37 | def bytes_list_feature(value): 38 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 39 | 40 | 41 | def float_list_feature(value): 42 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 43 | 44 | 45 | def read_examples_list(path): 46 | """Read list of training or validation examples. 47 | 48 | The file is assumed to contain a single example per line where the first 49 | token in the line is an identifier that allows us to find the image and 50 | annotation xml for that example. 51 | 52 | For example, the line: 53 | xyz 3 54 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 55 | 56 | Args: 57 | path: absolute path to examples list file. 58 | 59 | Returns: 60 | list of example identifiers (strings). 61 | """ 62 | with tf.compat.v1.gfile.GFile(path) as fid: 63 | lines = fid.readlines() 64 | return [line.strip().split(' ')[0] for line in lines] 65 | 66 | 67 | def recursive_parse_xml_to_dict(xml): 68 | """Recursively parses XML contents to python dict. 69 | 70 | We assume that `object` tags are the only ones that can appear 71 | multiple times at the same level of a tree. 72 | 73 | Args: 74 | xml: xml tree obtained by parsing XML file contents using lxml.etree 75 | 76 | Returns: 77 | Python dictionary holding XML contents. 78 | """ 79 | if not xml: 80 | return {xml.tag: xml.text} 81 | result = {} 82 | for child in xml: 83 | child_result = recursive_parse_xml_to_dict(child) 84 | if child.tag != 'object': 85 | result[child.tag] = child_result[child.tag] 86 | else: 87 | if child.tag not in result: 88 | result[child.tag] = [] 89 | result[child.tag].append(child_result[child.tag]) 90 | return {xml.tag: result} 91 | -------------------------------------------------------------------------------- /tf_datatools/utils/protos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WidgetA/sonar_baseline_with_tensorflow/c15fb1b0344f658f83b3ffdfb50700d51ca0af18/tf_datatools/utils/protos/__init__.py -------------------------------------------------------------------------------- /tf_datatools/utils/protos/bipartite_matcher_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: object_detection/protos/bipartite_matcher.proto 4 | 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor.FileDescriptor( 17 | name='object_detection/protos/bipartite_matcher.proto', 18 | package='object_detection.protos', 19 | syntax='proto2', 20 | serialized_options=None, 21 | serialized_pb=b'\n/object_detection/protos/bipartite_matcher.proto\x12\x17object_detection.protos\"4\n\x10\x42ipartiteMatcher\x12 \n\x11use_matmul_gather\x18\x06 \x01(\x08:\x05\x66\x61lse' 22 | ) 23 | 24 | 25 | 26 | 27 | _BIPARTITEMATCHER = _descriptor.Descriptor( 28 | name='BipartiteMatcher', 29 | full_name='object_detection.protos.BipartiteMatcher', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | containing_type=None, 33 | fields=[ 34 | _descriptor.FieldDescriptor( 35 | name='use_matmul_gather', full_name='object_detection.protos.BipartiteMatcher.use_matmul_gather', index=0, 36 | number=6, type=8, cpp_type=7, label=1, 37 | has_default_value=True, default_value=False, 38 | message_type=None, enum_type=None, containing_type=None, 39 | is_extension=False, extension_scope=None, 40 | serialized_options=None, file=DESCRIPTOR), 41 | ], 42 | extensions=[ 43 | ], 44 | nested_types=[], 45 | enum_types=[ 46 | ], 47 | serialized_options=None, 48 | is_extendable=False, 49 | syntax='proto2', 50 | extension_ranges=[], 51 | oneofs=[ 52 | ], 53 | serialized_start=76, 54 | serialized_end=128, 55 | ) 56 | 57 | DESCRIPTOR.message_types_by_name['BipartiteMatcher'] = _BIPARTITEMATCHER 58 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 59 | 60 | BipartiteMatcher = _reflection.GeneratedProtocolMessageType('BipartiteMatcher', (_message.Message,), { 61 | 'DESCRIPTOR' : _BIPARTITEMATCHER, 62 | '__module__' : 'object_detection.protos.bipartite_matcher_pb2' 63 | # @@protoc_insertion_point(class_scope:object_detection.protos.BipartiteMatcher) 64 | }) 65 | _sym_db.RegisterMessage(BipartiteMatcher) 66 | 67 | 68 | # @@protoc_insertion_point(module_scope) 69 | -------------------------------------------------------------------------------- /tf_datatools/utils/protos/mean_stddev_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: object_detection/protos/mean_stddev_box_coder.proto 4 | 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor.FileDescriptor( 17 | name='object_detection/protos/mean_stddev_box_coder.proto', 18 | package='object_detection.protos', 19 | syntax='proto2', 20 | serialized_options=None, 21 | serialized_pb=b'\n3object_detection/protos/mean_stddev_box_coder.proto\x12\x17object_detection.protos\"*\n\x12MeanStddevBoxCoder\x12\x14\n\x06stddev\x18\x01 \x01(\x02:\x04\x30.01' 22 | ) 23 | 24 | 25 | 26 | 27 | _MEANSTDDEVBOXCODER = _descriptor.Descriptor( 28 | name='MeanStddevBoxCoder', 29 | full_name='object_detection.protos.MeanStddevBoxCoder', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | containing_type=None, 33 | fields=[ 34 | _descriptor.FieldDescriptor( 35 | name='stddev', full_name='object_detection.protos.MeanStddevBoxCoder.stddev', index=0, 36 | number=1, type=2, cpp_type=6, label=1, 37 | has_default_value=True, default_value=float(0.01), 38 | message_type=None, enum_type=None, containing_type=None, 39 | is_extension=False, extension_scope=None, 40 | serialized_options=None, file=DESCRIPTOR), 41 | ], 42 | extensions=[ 43 | ], 44 | nested_types=[], 45 | enum_types=[ 46 | ], 47 | serialized_options=None, 48 | is_extendable=False, 49 | syntax='proto2', 50 | extension_ranges=[], 51 | oneofs=[ 52 | ], 53 | serialized_start=80, 54 | serialized_end=122, 55 | ) 56 | 57 | DESCRIPTOR.message_types_by_name['MeanStddevBoxCoder'] = _MEANSTDDEVBOXCODER 58 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 59 | 60 | MeanStddevBoxCoder = _reflection.GeneratedProtocolMessageType('MeanStddevBoxCoder', (_message.Message,), { 61 | 'DESCRIPTOR' : _MEANSTDDEVBOXCODER, 62 | '__module__' : 'object_detection.protos.mean_stddev_box_coder_pb2' 63 | # @@protoc_insertion_point(class_scope:object_detection.protos.MeanStddevBoxCoder) 64 | }) 65 | _sym_db.RegisterMessage(MeanStddevBoxCoder) 66 | 67 | 68 | # @@protoc_insertion_point(module_scope) 69 | -------------------------------------------------------------------------------- /tf_datatools/utils/protos/square_box_coder_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: object_detection/protos/square_box_coder.proto 4 | 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor.FileDescriptor( 17 | name='object_detection/protos/square_box_coder.proto', 18 | package='object_detection.protos', 19 | syntax='proto2', 20 | serialized_options=None, 21 | serialized_pb=b'\n.object_detection/protos/square_box_coder.proto\x12\x17object_detection.protos\"S\n\x0eSquareBoxCoder\x12\x13\n\x07y_scale\x18\x01 \x01(\x02:\x02\x31\x30\x12\x13\n\x07x_scale\x18\x02 \x01(\x02:\x02\x31\x30\x12\x17\n\x0clength_scale\x18\x03 \x01(\x02:\x01\x35' 22 | ) 23 | 24 | 25 | 26 | 27 | _SQUAREBOXCODER = _descriptor.Descriptor( 28 | name='SquareBoxCoder', 29 | full_name='object_detection.protos.SquareBoxCoder', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | containing_type=None, 33 | fields=[ 34 | _descriptor.FieldDescriptor( 35 | name='y_scale', full_name='object_detection.protos.SquareBoxCoder.y_scale', index=0, 36 | number=1, type=2, cpp_type=6, label=1, 37 | has_default_value=True, default_value=float(10), 38 | message_type=None, enum_type=None, containing_type=None, 39 | is_extension=False, extension_scope=None, 40 | serialized_options=None, file=DESCRIPTOR), 41 | _descriptor.FieldDescriptor( 42 | name='x_scale', full_name='object_detection.protos.SquareBoxCoder.x_scale', index=1, 43 | number=2, type=2, cpp_type=6, label=1, 44 | has_default_value=True, default_value=float(10), 45 | message_type=None, enum_type=None, containing_type=None, 46 | is_extension=False, extension_scope=None, 47 | serialized_options=None, file=DESCRIPTOR), 48 | _descriptor.FieldDescriptor( 49 | name='length_scale', full_name='object_detection.protos.SquareBoxCoder.length_scale', index=2, 50 | number=3, type=2, cpp_type=6, label=1, 51 | has_default_value=True, default_value=float(5), 52 | message_type=None, enum_type=None, containing_type=None, 53 | is_extension=False, extension_scope=None, 54 | serialized_options=None, file=DESCRIPTOR), 55 | ], 56 | extensions=[ 57 | ], 58 | nested_types=[], 59 | enum_types=[ 60 | ], 61 | serialized_options=None, 62 | is_extendable=False, 63 | syntax='proto2', 64 | extension_ranges=[], 65 | oneofs=[ 66 | ], 67 | serialized_start=75, 68 | serialized_end=158, 69 | ) 70 | 71 | DESCRIPTOR.message_types_by_name['SquareBoxCoder'] = _SQUAREBOXCODER 72 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 73 | 74 | SquareBoxCoder = _reflection.GeneratedProtocolMessageType('SquareBoxCoder', (_message.Message,), { 75 | 'DESCRIPTOR' : _SQUAREBOXCODER, 76 | '__module__' : 'object_detection.protos.square_box_coder_pb2' 77 | # @@protoc_insertion_point(class_scope:object_detection.protos.SquareBoxCoder) 78 | }) 79 | _sym_db.RegisterMessage(SquareBoxCoder) 80 | 81 | 82 | # @@protoc_insertion_point(module_scope) 83 | --------------------------------------------------------------------------------