├── .gitignore ├── AUTHORS ├── CODEOWNERS ├── CONTRIBUTING.md ├── ISSUE_TEMPLATE.md ├── LICENSE ├── README.md ├── WORKSPACE ├── doc ├── .gitkeep ├── cvpr19_mtl_ssl_poster.pdf ├── cvpr19_mtl_ssl_poster.pptx ├── cvpr19_poster_v5.pdf ├── mtl-ssl.pdf └── mtl-ssl.pptx ├── fonts └── Ubuntu Mono derivative Powerline.ttf ├── global_utils ├── __init__.py └── custom_utils.py ├── object_detection ├── .gitignore ├── BUILD ├── CONTRIBUTING.md ├── README.md ├── README_org.md ├── __init__.py ├── anchor_generators │ ├── BUILD │ ├── __init__.py │ ├── grid_anchor_generator.py │ ├── grid_anchor_generator_test.py │ ├── multiple_grid_anchor_generator.py │ └── multiple_grid_anchor_generator_test.py ├── box_coders │ ├── BUILD │ ├── __init__.py │ ├── faster_rcnn_box_coder.py │ ├── faster_rcnn_box_coder_test.py │ ├── keypoint_box_coder.py │ ├── keypoint_box_coder_test.py │ ├── mean_stddev_box_coder.py │ ├── mean_stddev_box_coder_test.py │ ├── square_box_coder.py │ └── square_box_coder_test.py ├── builders │ ├── BUILD │ ├── __init__.py │ ├── anchor_generator_builder.py │ ├── anchor_generator_builder_test.py │ ├── box_coder_builder.py │ ├── box_coder_builder_test.py │ ├── box_predictor_builder.py │ ├── box_predictor_builder_test.py │ ├── hyperparams_builder.py │ ├── hyperparams_builder_test.py │ ├── image_resizer_builder.py │ ├── image_resizer_builder_test.py │ ├── initializer_builder.py │ ├── input_reader_builder.py │ ├── input_reader_builder_test.py │ ├── losses_builder.py │ ├── losses_builder_test.py │ ├── mask_predictor_builder.py │ ├── matcher_builder.py │ ├── matcher_builder_test.py │ ├── model_builder.py │ ├── model_builder_test.py │ ├── optimizer_builder.py │ ├── optimizer_builder_test.py │ ├── post_processing_builder.py │ ├── post_processing_builder_test.py │ ├── preprocessor_builder.py │ ├── preprocessor_builder_test.py │ ├── region_similarity_calculator_builder.py │ └── region_similarity_calculator_builder_test.py ├── checkpoints │ └── .gitignore ├── configs │ ├── .gitignore │ ├── faster_rcnn_resnet101_pets.config │ ├── ssd_inception_v2_pets.config │ ├── ssd_mobilenet_v1_pets.config │ ├── ssd_vgg_16_pets.config │ ├── test │ │ ├── .gitignore │ │ ├── model11.config │ │ ├── model12.config │ │ ├── model21.config │ │ ├── model22.config │ │ ├── model31.config │ │ ├── model32.config │ │ ├── model41.config │ │ ├── model42.config │ │ ├── model51.config │ │ ├── model52.config │ │ ├── model61.config │ │ ├── model62.config │ │ ├── model71.config │ │ ├── model72.config │ │ ├── model81.config │ │ ├── model82.config │ │ ├── model91.config │ │ └── model92.config │ └── voc_evaluation.config ├── core │ ├── BUILD │ ├── __init__.py │ ├── anchor_generator.py │ ├── balanced_positive_negative_sampler.py │ ├── balanced_positive_negative_sampler_test.py │ ├── batcher.py │ ├── batcher_test.py │ ├── box_coder.py │ ├── box_coder_test.py │ ├── box_list.py │ ├── box_list_ops.py │ ├── box_list_ops_test.py │ ├── box_list_test.py │ ├── box_ops.py │ ├── box_predictor.py │ ├── box_predictor_test.py │ ├── data_decoder.py │ ├── keypoint_ops.py │ ├── keypoint_ops_test.py │ ├── losses.py │ ├── losses_test.py │ ├── mask_predictor.py │ ├── matcher.py │ ├── matcher_test.py │ ├── minibatch_sampler.py │ ├── minibatch_sampler_test.py │ ├── model.py │ ├── post_processing.py │ ├── post_processing_test.py │ ├── prefetcher.py │ ├── prefetcher_test.py │ ├── preprocessor.py │ ├── preprocessor_test.py │ ├── region_similarity_calculator.py │ ├── region_similarity_calculator_test.py │ ├── standard_fields.py │ ├── target_assigner.py │ └── target_assigner_test.py ├── create_records │ ├── create_all_pascal_tf_records.sh │ ├── create_mscoco_tf_record.py │ ├── create_pascal_tf_record.py │ ├── create_pascal_tf_record_test.py │ └── create_pet_tf_record.py ├── data │ ├── .gitignore │ ├── caltech_label_map.pbtxt │ ├── mobis_label_map.pbtxt │ ├── mscoco │ │ ├── .gitignore │ │ └── README.md │ ├── mscoco_label_map.pbtxt │ ├── pascal_label_map.pbtxt │ ├── pet_label_map.pbtxt │ └── voc │ │ ├── .gitignore │ │ └── README.md ├── data_decoders │ ├── BUILD │ ├── __init__.py │ ├── tf_example_decoder.py │ └── tf_example_decoder_test.py ├── eval.py ├── eval_util.py ├── evaluator.py ├── export_inference_graph.py ├── exporter.py ├── exporter_test.py ├── g3doc │ ├── configuring_jobs.md │ ├── defining_your_own_model.md │ ├── detection_model_zoo.md │ ├── exporting_models.md │ ├── img │ │ ├── dogs_detections_output.jpg │ │ ├── example_cat.jpg │ │ ├── kites_detections_output.jpg │ │ ├── mtl_1.png │ │ ├── mtl_2.png │ │ ├── mtl_ssl_detection.png │ │ ├── ours_1.png │ │ ├── ours_2.png │ │ ├── oxford_pet.png │ │ ├── qr_1.png │ │ ├── qr_2.png │ │ ├── qr_3.png │ │ ├── qr_4.png │ │ ├── qr_5.png │ │ ├── results_1.png │ │ ├── results_2.png │ │ ├── results_3.png │ │ ├── results_4.png │ │ ├── reuse_1.png │ │ ├── reuse_2.png │ │ ├── ssl_1.png │ │ ├── ssl_2.png │ │ ├── tensorboard.png │ │ └── tensorboard2.png │ ├── installation.md │ ├── preparation.md │ ├── preparing_inputs.md │ ├── running_locally.md │ ├── running_notebook.md │ ├── running_on_cloud.md │ ├── running_pets.md │ ├── train_and_eval.md │ └── using_your_own_dataset.md ├── matchers │ ├── BUILD │ ├── __init__.py │ ├── argmax_matcher.py │ ├── argmax_matcher_test.py │ ├── bipartite_matcher.py │ └── bipartite_matcher_test.py ├── meta_architectures │ ├── BUILD │ ├── __init__.py │ ├── faster_rcnn_meta_arch.py │ ├── faster_rcnn_meta_arch_test.py │ ├── faster_rcnn_meta_arch_test_lib.py │ ├── rfcn_meta_arch.py │ ├── rfcn_meta_arch_test.py │ ├── ssd_meta_arch.py │ └── ssd_meta_arch_test.py ├── models │ ├── BUILD │ ├── __init__.py │ ├── faster_rcnn_inception_resnet_v2_feature_extractor.py │ ├── faster_rcnn_inception_resnet_v2_feature_extractor_test.py │ ├── faster_rcnn_mobilenet_v1_feature_extractor.py │ ├── faster_rcnn_mobilenet_v1_feature_extractor_test.py │ ├── faster_rcnn_resnet_v1_feature_extractor.py │ ├── faster_rcnn_resnet_v1_feature_extractor_test.py │ ├── faster_rcnn_vgg_16_feature_extractor.py │ ├── feature_map_generators.py │ ├── feature_map_generators_test.py │ ├── ssd_feature_extractor_test.py │ ├── ssd_inception_v2_feature_extractor.py │ ├── ssd_inception_v2_feature_extractor_test.py │ ├── ssd_mobilenet_v1_feature_extractor.py │ ├── ssd_mobilenet_v1_feature_extractor_test.py │ └── ssd_vgg_16_feature_extractor.py ├── notebooks │ └── object_detection_tutorial.ipynb ├── object_detection_tutorial.ipynb ├── protos │ ├── .gitignore │ ├── BUILD │ ├── anchor_generator.proto │ ├── argmax_matcher.proto │ ├── bipartite_matcher.proto │ ├── box_coder.proto │ ├── box_predictor.proto │ ├── eval.proto │ ├── faster_rcnn.proto │ ├── faster_rcnn_box_coder.proto │ ├── grid_anchor_generator.proto │ ├── hyperparams.proto │ ├── image_resizer.proto │ ├── input_reader.proto │ ├── losses.proto │ ├── mask_predictor.proto │ ├── matcher.proto │ ├── mean_stddev_box_coder.proto │ ├── model.proto │ ├── optimizer.proto │ ├── pipeline.proto │ ├── post_processing.proto │ ├── preprocessor.proto │ ├── region_similarity_calculator.proto │ ├── square_box_coder.proto │ ├── ssd.proto │ ├── ssd_anchor_generator.proto │ ├── string_int_label_map.proto │ └── train.proto ├── samples │ ├── cloud │ │ └── cloud.yml │ └── configs │ │ ├── faster_rcnn_inception_resnet_v2_atrous_pets.config │ │ ├── faster_rcnn_resnet101_pets.config │ │ ├── faster_rcnn_resnet101_voc07.config │ │ ├── faster_rcnn_resnet152_pets.config │ │ ├── faster_rcnn_resnet50_pets.config │ │ ├── rfcn_resnet101_pets.config │ │ ├── ssd_inception_v2_pets.config │ │ └── ssd_mobilenet_v1_pets.config ├── scripts │ ├── .gitignore │ ├── run_eval.sh │ └── run_train.sh ├── test_images │ ├── image1.jpg │ ├── image2.jpg │ └── image_info.txt ├── train.py ├── trainer.py ├── trainer_test.py └── utils │ ├── BUILD │ ├── __init__.py │ ├── category_util.py │ ├── category_util_test.py │ ├── dataset_util.py │ ├── dataset_util_test.py │ ├── debug_utils.py │ ├── kwargs_util.py │ ├── label_map_util.py │ ├── label_map_util_test.py │ ├── learning_schedules.py │ ├── learning_schedules_test.py │ ├── metrics.py │ ├── metrics_test.py │ ├── mtl_util.py │ ├── np_box_list.py │ ├── np_box_list_ops.py │ ├── np_box_list_ops_test.py │ ├── np_box_list_test.py │ ├── np_box_ops.py │ ├── np_box_ops_test.py │ ├── object_detection_evaluation.py │ ├── object_detection_evaluation_test.py │ ├── ops.py │ ├── ops_test.py │ ├── per_image_evaluation.py │ ├── per_image_evaluation_test.py │ ├── shape_utils.py │ ├── shape_utils_test.py │ ├── static_shape.py │ ├── static_shape_test.py │ ├── test_utils.py │ ├── test_utils_test.py │ ├── variables_helper.py │ ├── variables_helper_test.py │ ├── visualization_utils.py │ └── visualization_utils_test.py ├── requirements.txt ├── setup.py └── slim ├── BUILD ├── README.md ├── WORKSPACE ├── __init__.py ├── datasets ├── __init__.py ├── build_imagenet_data.py ├── cifar10.py ├── dataset_factory.py ├── dataset_utils.py ├── download_and_convert_cifar10.py ├── download_and_convert_flowers.py ├── download_and_convert_imagenet.sh ├── download_and_convert_mnist.py ├── download_imagenet.sh ├── flowers.py ├── imagenet.py ├── imagenet_2012_validation_synset_labels.txt ├── imagenet_lsvrc_2015_synsets.txt ├── imagenet_metadata.txt ├── mnist.py ├── preprocess_imagenet_validation_data.py └── process_bounding_boxes.py ├── deployment ├── __init__.py ├── model_deploy.py └── model_deploy_test.py ├── download_and_convert_data.py ├── eval_image_classifier.py ├── export_inference_graph.py ├── export_inference_graph_test.py ├── learning.py ├── nets ├── __init__.py ├── alexnet.py ├── alexnet_test.py ├── cifarnet.py ├── inception.py ├── inception_resnet_v2.py ├── inception_resnet_v2_test.py ├── inception_utils.py ├── inception_v1.py ├── inception_v1_test.py ├── inception_v2.py ├── inception_v2_test.py ├── inception_v3.py ├── inception_v3_test.py ├── inception_v4.py ├── inception_v4_test.py ├── lenet.py ├── mobilenet_v1.md ├── mobilenet_v1.png ├── mobilenet_v1.py ├── mobilenet_v1_test.py ├── nets_factory.py ├── nets_factory_test.py ├── overfeat.py ├── overfeat_test.py ├── resnet_utils.py ├── resnet_v1.py ├── resnet_v1_test.py ├── resnet_v2.py ├── resnet_v2_test.py ├── vgg.py └── vgg_test.py ├── preprocessing ├── __init__.py ├── cifarnet_preprocessing.py ├── inception_preprocessing.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── vgg_preprocessing.py ├── scripts ├── export_mobilenet.sh ├── finetune_inception_resnet_v2_on_flowers.sh ├── finetune_inception_v1_on_flowers.sh ├── finetune_inception_v3_on_flowers.sh ├── finetune_resnet_v1_50_on_flowers.sh ├── train_cifarnet_on_cifar10.sh └── train_lenet_on_mnist.sh ├── setup.py ├── slim_walkthrough.ipynb └── train_image_classifier.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # JetBrains 92 | .idea 93 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of authors for copyright purposes. 2 | # This file is distinct from the CONTRIBUTORS files. 3 | # See the latter for an explanation. 4 | 5 | # Names should be added to this file as: 6 | # Name or Organization 7 | # The email address is not required for organizations. 8 | 9 | Google Inc. 10 | David Dao 11 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | adversarial_crypto/* @dave-andersen 2 | adversarial_text/* @rsepassi 3 | adv_imagenet_models/* @AlexeyKurakin 4 | attention_ocr/* @alexgorban 5 | audioset/* @plakal @dpwe 6 | autoencoders/* @snurkabill 7 | cognitive_mapping_and_planning/* @s-gupta 8 | compression/* @nmjohn 9 | differential_privacy/* @panyx0718 10 | domain_adaptation/* @bousmalis @ddohan 11 | im2txt/* @cshallue 12 | inception/* @shlens @vincentvanhoucke 13 | learning_to_remember_rare_events/* @lukaszkaiser @ofirnachum 14 | lfads/* @jazcollins @susillo 15 | lm_1b/* @oriolvinyals @panyx0718 16 | namignizer/* @knathanieltucker 17 | neural_gpu/* @lukaszkaiser 18 | neural_programmer/* @arvind2505 19 | next_frame_prediction/* @panyx0718 20 | object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon 21 | pcl_rl/* @ofirnachum 22 | ptn/* @xcyan @arkanath @hellojas @honglaklee 23 | real_nvp/* @laurent-dinh 24 | rebar/* @gjtucker 25 | resnet/* @panyx0718 26 | skip_thoughts/* @cshallue 27 | slim/* @sguada @nathansilberman 28 | street/* @theraysmith 29 | swivel/* @waterson 30 | syntaxnet/* @calberti @andorardo 31 | textsum/* @panyx0718 @peterjliu 32 | transformer/* @daviddao 33 | tutorials/embedding/* @zffchen78 @a-dai 34 | tutorials/image/* @sherrym @shlens 35 | tutorials/rnn/* @lukaszkaiser @ebrevdo 36 | video_prediction/* @cbfinn 37 | 38 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | If you have created a model and would like to publish it here, please send us a 4 | pull request. For those just getting started with pull requests, GitHub has a 5 | [howto](https://help.github.com/articles/using-pull-requests/). 6 | 7 | The code for any model in this repository is licensed under the Apache License 8 | 2.0. 9 | 10 | In order to accept our code, we have to make sure that we can publish your code: 11 | You have to sign a Contributor License Agreement (CLA). 12 | 13 | ### Contributor License Agreements 14 | 15 | Please fill out either the individual or corporate Contributor License Agreement (CLA). 16 | 17 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). 18 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 19 | 20 | Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. 21 | 22 | ***NOTE***: Only original source code from you and other people that have signed the CLA can be accepted into the repository. 23 | 24 | -------------------------------------------------------------------------------- /ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Please go to Stack Overflow for help and support: 2 | 3 | http://stackoverflow.com/questions/tagged/tensorflow 4 | 5 | Also, please understand that many of the models included in this repository are experimental and research-style code. If you open a GitHub issue, here is our policy: 6 | 7 | 1. It must be a bug or a feature request. 8 | 2. The form below must be filled out. 9 | 10 | **Here's why we have that policy**: TensorFlow developers respond to issues. We want to focus on work that benefits the whole community, e.g., fixing bugs and adding features. Support only helps individuals. GitHub also notifies thousands of people when issues are filed. We want them to see you communicating an interesting problem, rather than being redirected to Stack Overflow. 11 | 12 | ------------------------ 13 | 14 | ### System information 15 | - **What is the top-level directory of the model you are using**: 16 | - **Have I written custom code (as opposed to using a stock example script provided in TensorFlow)**: 17 | - **OS Platform and Distribution (e.g., Linux Ubuntu 16.04)**: 18 | - **TensorFlow installed from (source or binary)**: 19 | - **TensorFlow version (use command below)**: 20 | - **Bazel version (if compiling from source)**: 21 | - **CUDA/cuDNN version**: 22 | - **GPU model and memory**: 23 | - **Exact command to reproduce**: 24 | 25 | You can collect some of this information using our environment capture script: 26 | 27 | https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh 28 | 29 | You can obtain the TensorFlow version with 30 | 31 | python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 32 | 33 | ### Describe the problem 34 | Describe the problem clearly here. Be sure to convey here why it's a bug in TensorFlow or a feature request. 35 | 36 | ### Source code / logs 37 | Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. Try to provide a reproducible test case that is the bare minimum necessary to generate the problem. 38 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/WORKSPACE -------------------------------------------------------------------------------- /doc/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /doc/cvpr19_mtl_ssl_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/doc/cvpr19_mtl_ssl_poster.pdf -------------------------------------------------------------------------------- /doc/cvpr19_mtl_ssl_poster.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/doc/cvpr19_mtl_ssl_poster.pptx -------------------------------------------------------------------------------- /doc/cvpr19_poster_v5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/doc/cvpr19_poster_v5.pdf -------------------------------------------------------------------------------- /doc/mtl-ssl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/doc/mtl-ssl.pdf -------------------------------------------------------------------------------- /doc/mtl-ssl.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/doc/mtl-ssl.pptx -------------------------------------------------------------------------------- /fonts/Ubuntu Mono derivative Powerline.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/fonts/Ubuntu Mono derivative Powerline.ttf -------------------------------------------------------------------------------- /global_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/global_utils/__init__.py -------------------------------------------------------------------------------- /global_utils/custom_utils.py: -------------------------------------------------------------------------------- 1 | """ Common utilities. """ 2 | 3 | # Logging 4 | # ======= 5 | 6 | import logging 7 | import os, os.path 8 | from colorlog import ColoredFormatter 9 | 10 | ch = logging.StreamHandler() 11 | ch.setLevel(logging.DEBUG) 12 | 13 | formatter = ColoredFormatter( 14 | "%(log_color)s[%(asctime)s] %(message)s", 15 | # datefmt='%H:%M:%S.%f', 16 | datefmt=None, 17 | reset=True, 18 | log_colors={ 19 | 'DEBUG': 'cyan', 20 | 'INFO': 'white,bold', 21 | 'INFOV': 'cyan,bold', 22 | 'WARNING': 'yellow', 23 | 'ERROR': 'red,bold', 24 | 'CRITICAL': 'red,bg_white', 25 | }, 26 | secondary_log_colors={}, 27 | style='%' 28 | ) 29 | ch.setFormatter(formatter) 30 | 31 | log = logging.getLogger('small') 32 | log.setLevel(logging.DEBUG) 33 | log.handlers = [] # No duplicated handlers 34 | log.propagate = False # workaround for duplicated logs in ipython 35 | log.addHandler(ch) 36 | 37 | logging.addLevelName(logging.INFO + 1, 'INFOV') 38 | def _infov(self, msg, *args, **kwargs): 39 | self.log(logging.INFO + 1, msg, *args, **kwargs) 40 | logging.Logger.infov = _infov 41 | 42 | 43 | # Etc 44 | # === 45 | 46 | def get_tempdir(): 47 | import getpass, tempfile 48 | user = getpass.getuser() 49 | 50 | for t in ('/data1/' + user, 51 | '/data/' + user, 52 | tempfile.gettempdir()): 53 | if os.path.exists(t): 54 | return mkdir_p(t + '/small.tmp') 55 | return None 56 | 57 | 58 | def get_specific_dir(name): 59 | import getpass 60 | 61 | assert name is not None, 'Need to specify directory name.' 62 | user = getpass.getuser() 63 | 64 | for t in ('/data2/' + user, 65 | '/data1/' + user, 66 | '/data/' + user): 67 | if os.path.exists(t): 68 | return mkdir_p(t + '/' + name) 69 | return None 70 | 71 | def mkdir_p(path): 72 | if not os.path.exists(path): 73 | os.mkdir(path) 74 | return path 75 | 76 | 77 | __all__ = ( 78 | 'log', 'get_tempdir', 'get_specific_dir', 'mkdir_p', 79 | ) 80 | -------------------------------------------------------------------------------- /object_detection/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/.gitignore -------------------------------------------------------------------------------- /object_detection/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to the Tensorflow Object Detection API 2 | 3 | Patches to Tensorflow Object Detection API are welcome! 4 | 5 | We require contributors to fill out either the individual or corporate 6 | Contributor License Agreement (CLA). 7 | 8 | * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). 9 | * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 10 | 11 | Please follow the 12 | [Tensorflow contributing guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md) 13 | when submitting pull requests. 14 | -------------------------------------------------------------------------------- /object_detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/__init__.py -------------------------------------------------------------------------------- /object_detection/anchor_generators/BUILD: -------------------------------------------------------------------------------- 1 | # Tensorflow Object Detection API: Anchor Generator implementations. 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | licenses(["notice"]) 8 | 9 | # Apache 2.0 10 | py_library( 11 | name = "grid_anchor_generator", 12 | srcs = [ 13 | "grid_anchor_generator.py", 14 | ], 15 | deps = [ 16 | "//tensorflow", 17 | "//tensorflow_models/object_detection/core:anchor_generator", 18 | "//tensorflow_models/object_detection/core:box_list", 19 | "//tensorflow_models/object_detection/utils:ops", 20 | ], 21 | ) 22 | 23 | py_test( 24 | name = "grid_anchor_generator_test", 25 | srcs = [ 26 | "grid_anchor_generator_test.py", 27 | ], 28 | deps = [ 29 | ":grid_anchor_generator", 30 | "//tensorflow", 31 | ], 32 | ) 33 | 34 | py_library( 35 | name = "multiple_grid_anchor_generator", 36 | srcs = [ 37 | "multiple_grid_anchor_generator.py", 38 | ], 39 | deps = [ 40 | ":grid_anchor_generator", 41 | "//tensorflow", 42 | "//tensorflow_models/object_detection/core:anchor_generator", 43 | "//tensorflow_models/object_detection/core:box_list_ops", 44 | ], 45 | ) 46 | 47 | py_test( 48 | name = "multiple_grid_anchor_generator_test", 49 | srcs = [ 50 | "multiple_grid_anchor_generator_test.py", 51 | ], 52 | deps = [ 53 | ":multiple_grid_anchor_generator", 54 | "//third_party/py/numpy", 55 | ], 56 | ) 57 | -------------------------------------------------------------------------------- /object_detection/anchor_generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/anchor_generators/__init__.py -------------------------------------------------------------------------------- /object_detection/anchor_generators/grid_anchor_generator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.grid_anchor_generator.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.anchor_generators import grid_anchor_generator 21 | 22 | 23 | class GridAnchorGeneratorTest(tf.test.TestCase): 24 | 25 | def test_construct_single_anchor(self): 26 | """Builds a 1x1 anchor grid to test the size of the output boxes.""" 27 | scales = [0.5, 1.0, 2.0] 28 | aspect_ratios = [0.25, 1.0, 4.0] 29 | anchor_offset = [7, -3] 30 | exp_anchor_corners = [[-121, -35, 135, 29], [-249, -67, 263, 61], 31 | [-505, -131, 519, 125], [-57, -67, 71, 61], 32 | [-121, -131, 135, 125], [-249, -259, 263, 253], 33 | [-25, -131, 39, 125], [-57, -259, 71, 253], 34 | [-121, -515, 135, 509]] 35 | 36 | anchor_generator = grid_anchor_generator.GridAnchorGenerator( 37 | scales, aspect_ratios, 38 | anchor_offset=anchor_offset) 39 | anchors = anchor_generator.generate(feature_map_shape_list=[(1, 1)]) 40 | anchor_corners = anchors.get() 41 | 42 | with self.test_session(): 43 | anchor_corners_out = anchor_corners.eval() 44 | self.assertAllClose(anchor_corners_out, exp_anchor_corners) 45 | 46 | def test_construct_anchor_grid(self): 47 | base_anchor_size = [10, 10] 48 | anchor_stride = [19, 19] 49 | anchor_offset = [0, 0] 50 | scales = [0.5, 1.0, 2.0] 51 | aspect_ratios = [1.0] 52 | 53 | exp_anchor_corners = [[-2.5, -2.5, 2.5, 2.5], [-5., -5., 5., 5.], 54 | [-10., -10., 10., 10.], [-2.5, 16.5, 2.5, 21.5], 55 | [-5., 14., 5, 24], [-10., 9., 10, 29], 56 | [16.5, -2.5, 21.5, 2.5], [14., -5., 24, 5], 57 | [9., -10., 29, 10], [16.5, 16.5, 21.5, 21.5], 58 | [14., 14., 24, 24], [9., 9., 29, 29]] 59 | 60 | anchor_generator = grid_anchor_generator.GridAnchorGenerator( 61 | scales, 62 | aspect_ratios, 63 | base_anchor_size=base_anchor_size, 64 | anchor_stride=anchor_stride, 65 | anchor_offset=anchor_offset) 66 | 67 | anchors = anchor_generator.generate(feature_map_shape_list=[(2, 2)]) 68 | anchor_corners = anchors.get() 69 | 70 | with self.test_session(): 71 | anchor_corners_out = anchor_corners.eval() 72 | self.assertAllClose(anchor_corners_out, exp_anchor_corners) 73 | 74 | 75 | if __name__ == '__main__': 76 | tf.test.main() 77 | -------------------------------------------------------------------------------- /object_detection/box_coders/BUILD: -------------------------------------------------------------------------------- 1 | # Tensorflow Object Detection API: Box Coder implementations. 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | licenses(["notice"]) 8 | 9 | # Apache 2.0 10 | py_library( 11 | name = "faster_rcnn_box_coder", 12 | srcs = [ 13 | "faster_rcnn_box_coder.py", 14 | ], 15 | deps = [ 16 | "//tensorflow_models/object_detection/core:box_coder", 17 | "//tensorflow_models/object_detection/core:box_list", 18 | ], 19 | ) 20 | 21 | py_test( 22 | name = "faster_rcnn_box_coder_test", 23 | srcs = [ 24 | "faster_rcnn_box_coder_test.py", 25 | ], 26 | deps = [ 27 | ":faster_rcnn_box_coder", 28 | "//tensorflow", 29 | "//tensorflow_models/object_detection/core:box_list", 30 | ], 31 | ) 32 | 33 | py_library( 34 | name = "keypoint_box_coder", 35 | srcs = [ 36 | "keypoint_box_coder.py", 37 | ], 38 | deps = [ 39 | "//tensorflow_models/object_detection/core:box_coder", 40 | "//tensorflow_models/object_detection/core:box_list", 41 | "//tensorflow_models/object_detection/core:standard_fields", 42 | ], 43 | ) 44 | 45 | py_test( 46 | name = "keypoint_box_coder_test", 47 | srcs = [ 48 | "keypoint_box_coder_test.py", 49 | ], 50 | deps = [ 51 | ":keypoint_box_coder", 52 | "//tensorflow", 53 | "//tensorflow_models/object_detection/core:box_list", 54 | "//tensorflow_models/object_detection/core:standard_fields", 55 | ], 56 | ) 57 | 58 | py_library( 59 | name = "mean_stddev_box_coder", 60 | srcs = [ 61 | "mean_stddev_box_coder.py", 62 | ], 63 | deps = [ 64 | "//tensorflow_models/object_detection/core:box_coder", 65 | "//tensorflow_models/object_detection/core:box_list", 66 | ], 67 | ) 68 | 69 | py_test( 70 | name = "mean_stddev_box_coder_test", 71 | srcs = [ 72 | "mean_stddev_box_coder_test.py", 73 | ], 74 | deps = [ 75 | ":mean_stddev_box_coder", 76 | "//tensorflow", 77 | "//tensorflow_models/object_detection/core:box_list", 78 | ], 79 | ) 80 | 81 | py_library( 82 | name = "square_box_coder", 83 | srcs = [ 84 | "square_box_coder.py", 85 | ], 86 | deps = [ 87 | "//tensorflow_models/object_detection/core:box_coder", 88 | "//tensorflow_models/object_detection/core:box_list", 89 | ], 90 | ) 91 | 92 | py_test( 93 | name = "square_box_coder_test", 94 | srcs = [ 95 | "square_box_coder_test.py", 96 | ], 97 | deps = [ 98 | ":square_box_coder", 99 | "//tensorflow", 100 | "//tensorflow_models/object_detection/core:box_list", 101 | ], 102 | ) 103 | -------------------------------------------------------------------------------- /object_detection/box_coders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/box_coders/__init__.py -------------------------------------------------------------------------------- /object_detection/box_coders/mean_stddev_box_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Mean stddev box coder. 17 | 18 | This box coder use the following coding schema to encode boxes: 19 | rel_code = (box_corner - anchor_corner_mean) / anchor_corner_stddev. 20 | """ 21 | from object_detection.core import box_coder 22 | from object_detection.core import box_list 23 | 24 | 25 | class MeanStddevBoxCoder(box_coder.BoxCoder): 26 | """Mean stddev box coder.""" 27 | 28 | @property 29 | def code_size(self): 30 | return 4 31 | 32 | def _encode(self, boxes, anchors): 33 | """Encode a box collection with respect to anchor collection. 34 | 35 | Args: 36 | boxes: BoxList holding N boxes to be encoded. 37 | anchors: BoxList of N anchors. We assume that anchors has an associated 38 | stddev field. 39 | 40 | Returns: 41 | a tensor representing N anchor-encoded boxes 42 | Raises: 43 | ValueError: if the anchors BoxList does not have a stddev field 44 | """ 45 | if not anchors.has_field('stddev'): 46 | raise ValueError('anchors must have a stddev field') 47 | box_corners = boxes.get() 48 | means = anchors.get() 49 | stddev = anchors.get_field('stddev') 50 | return (box_corners - means) / stddev 51 | 52 | def _decode(self, rel_codes, anchors): 53 | """Decode. 54 | 55 | Args: 56 | rel_codes: a tensor representing N anchor-encoded boxes. 57 | anchors: BoxList of anchors. We assume that anchors has an associated 58 | stddev field. 59 | 60 | Returns: 61 | boxes: BoxList holding N bounding boxes 62 | Raises: 63 | ValueError: if the anchors BoxList does not have a stddev field 64 | """ 65 | if not anchors.has_field('stddev'): 66 | raise ValueError('anchors must have a stddev field') 67 | means = anchors.get() 68 | stddevs = anchors.get_field('stddev') 69 | box_corners = rel_codes * stddevs + means 70 | return box_list.BoxList(box_corners) 71 | -------------------------------------------------------------------------------- /object_detection/box_coders/mean_stddev_box_coder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.box_coder.mean_stddev_boxcoder.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.box_coders import mean_stddev_box_coder 21 | from object_detection.core import box_list 22 | 23 | 24 | class MeanStddevBoxCoderTest(tf.test.TestCase): 25 | 26 | def testGetCorrectRelativeCodesAfterEncoding(self): 27 | box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]] 28 | boxes = box_list.BoxList(tf.constant(box_corners)) 29 | expected_rel_codes = [[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]] 30 | prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]]) 31 | prior_stddevs = tf.constant(2 * [4 * [.1]]) 32 | priors = box_list.BoxList(prior_means) 33 | priors.add_field('stddev', prior_stddevs) 34 | 35 | coder = mean_stddev_box_coder.MeanStddevBoxCoder() 36 | rel_codes = coder.encode(boxes, priors) 37 | with self.test_session() as sess: 38 | rel_codes_out = sess.run(rel_codes) 39 | self.assertAllClose(rel_codes_out, expected_rel_codes) 40 | 41 | def testGetCorrectBoxesAfterDecoding(self): 42 | rel_codes = tf.constant([[0.0, 0.0, 0.0, 0.0], [-5.0, -5.0, -5.0, -3.0]]) 43 | expected_box_corners = [[0.0, 0.0, 0.5, 0.5], [0.0, 0.0, 0.5, 0.5]] 44 | prior_means = tf.constant([[0.0, 0.0, 0.5, 0.5], [0.5, 0.5, 1.0, 0.8]]) 45 | prior_stddevs = tf.constant(2 * [4 * [.1]]) 46 | priors = box_list.BoxList(prior_means) 47 | priors.add_field('stddev', prior_stddevs) 48 | 49 | coder = mean_stddev_box_coder.MeanStddevBoxCoder() 50 | decoded_boxes = coder.decode(rel_codes, priors) 51 | decoded_box_corners = decoded_boxes.get() 52 | with self.test_session() as sess: 53 | decoded_out = sess.run(decoded_box_corners) 54 | self.assertAllClose(decoded_out, expected_box_corners) 55 | 56 | 57 | if __name__ == '__main__': 58 | tf.test.main() 59 | -------------------------------------------------------------------------------- /object_detection/builders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/builders/__init__.py -------------------------------------------------------------------------------- /object_detection/builders/box_coder_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A function to build an object detection box coder from configuration.""" 17 | from object_detection.box_coders import faster_rcnn_box_coder 18 | from object_detection.box_coders import mean_stddev_box_coder 19 | from object_detection.box_coders import square_box_coder 20 | from object_detection.protos import box_coder_pb2 21 | 22 | 23 | def build(box_coder_config): 24 | """Builds a box coder object based on the box coder config. 25 | 26 | Args: 27 | box_coder_config: A box_coder.proto object containing the config for the 28 | desired box coder. 29 | 30 | Returns: 31 | BoxCoder based on the config. 32 | 33 | Raises: 34 | ValueError: On empty box coder proto. 35 | """ 36 | if not isinstance(box_coder_config, box_coder_pb2.BoxCoder): 37 | raise ValueError('box_coder_config not of type box_coder_pb2.BoxCoder.') 38 | 39 | if box_coder_config.WhichOneof('box_coder_oneof') == 'faster_rcnn_box_coder': 40 | return faster_rcnn_box_coder.FasterRcnnBoxCoder(scale_factors=[ 41 | box_coder_config.faster_rcnn_box_coder.y_scale, 42 | box_coder_config.faster_rcnn_box_coder.x_scale, 43 | box_coder_config.faster_rcnn_box_coder.height_scale, 44 | box_coder_config.faster_rcnn_box_coder.width_scale 45 | ]) 46 | if (box_coder_config.WhichOneof('box_coder_oneof') == 47 | 'mean_stddev_box_coder'): 48 | return mean_stddev_box_coder.MeanStddevBoxCoder() 49 | if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder': 50 | return square_box_coder.SquareBoxCoder(scale_factors=[ 51 | box_coder_config.square_box_coder.y_scale, 52 | box_coder_config.square_box_coder.x_scale, 53 | box_coder_config.square_box_coder.length_scale 54 | ]) 55 | raise ValueError('Empty box coder.') 56 | -------------------------------------------------------------------------------- /object_detection/builders/image_resizer_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builder function for image resizing operations.""" 17 | import functools 18 | 19 | from object_detection.core import preprocessor 20 | from object_detection.protos import image_resizer_pb2 21 | 22 | 23 | def build(image_resizer_config): 24 | """Builds callable for image resizing operations. 25 | 26 | Args: 27 | image_resizer_config: image_resizer.proto object containing parameters for 28 | an image resizing operation. 29 | 30 | Returns: 31 | image_resizer_fn: Callable for image resizing. This callable always takes 32 | a rank-3 image tensor (corresponding to a single image) and returns a 33 | rank-3 image tensor, possibly with new spatial dimensions. 34 | 35 | Raises: 36 | ValueError: if `image_resizer_config` is of incorrect type. 37 | ValueError: if `image_resizer_config.image_resizer_oneof` is of expected 38 | type. 39 | ValueError: if min_dimension > max_dimension when keep_aspect_ratio_resizer 40 | is used. 41 | """ 42 | if not isinstance(image_resizer_config, image_resizer_pb2.ImageResizer): 43 | raise ValueError('image_resizer_config not of type ' 44 | 'image_resizer_pb2.ImageResizer.') 45 | 46 | if image_resizer_config.WhichOneof( 47 | 'image_resizer_oneof') == 'keep_aspect_ratio_resizer': 48 | keep_aspect_ratio_config = image_resizer_config.keep_aspect_ratio_resizer 49 | if not (keep_aspect_ratio_config.min_dimension 50 | <= keep_aspect_ratio_config.max_dimension): 51 | raise ValueError('min_dimension > max_dimension') 52 | return functools.partial( 53 | preprocessor.resize_to_range, 54 | min_dimension=keep_aspect_ratio_config.min_dimension, 55 | max_dimension=keep_aspect_ratio_config.max_dimension) 56 | if image_resizer_config.WhichOneof( 57 | 'image_resizer_oneof') == 'fixed_shape_resizer': 58 | fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer 59 | return functools.partial(preprocessor.resize_image, 60 | new_height=fixed_shape_resizer_config.height, 61 | new_width=fixed_shape_resizer_config.width) 62 | raise ValueError('Invalid image resizer option.') 63 | -------------------------------------------------------------------------------- /object_detection/builders/image_resizer_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.builders.image_resizer_builder.""" 17 | import tensorflow as tf 18 | from google.protobuf import text_format 19 | from object_detection.builders import image_resizer_builder 20 | from object_detection.protos import image_resizer_pb2 21 | 22 | 23 | class ImageResizerBuilderTest(tf.test.TestCase): 24 | 25 | def _shape_of_resized_random_image_given_text_proto( 26 | self, input_shape, text_proto): 27 | image_resizer_config = image_resizer_pb2.ImageResizer() 28 | text_format.Merge(text_proto, image_resizer_config) 29 | image_resizer_fn = image_resizer_builder.build(image_resizer_config) 30 | images = tf.to_float(tf.random_uniform( 31 | input_shape, minval=0, maxval=255, dtype=tf.int32)) 32 | resized_images = image_resizer_fn(images) 33 | with self.test_session() as sess: 34 | return sess.run(resized_images).shape 35 | 36 | def test_built_keep_aspect_ratio_resizer_returns_expected_shape(self): 37 | image_resizer_text_proto = """ 38 | keep_aspect_ratio_resizer { 39 | min_dimension: 10 40 | max_dimension: 20 41 | } 42 | """ 43 | input_shape = (50, 25, 3) 44 | expected_output_shape = (20, 10, 3) 45 | output_shape = self._shape_of_resized_random_image_given_text_proto( 46 | input_shape, image_resizer_text_proto) 47 | self.assertEqual(output_shape, expected_output_shape) 48 | 49 | def test_built_fixed_shape_resizer_returns_expected_shape(self): 50 | image_resizer_text_proto = """ 51 | fixed_shape_resizer { 52 | height: 10 53 | width: 20 54 | } 55 | """ 56 | input_shape = (50, 25, 3) 57 | expected_output_shape = (10, 20, 3) 58 | output_shape = self._shape_of_resized_random_image_given_text_proto( 59 | input_shape, image_resizer_text_proto) 60 | self.assertEqual(output_shape, expected_output_shape) 61 | 62 | def test_raises_error_on_invalid_input(self): 63 | invalid_input = 'invalid_input' 64 | with self.assertRaises(ValueError): 65 | image_resizer_builder.build(invalid_input) 66 | 67 | 68 | if __name__ == '__main__': 69 | tf.test.main() 70 | 71 | -------------------------------------------------------------------------------- /object_detection/builders/initializer_builder.py: -------------------------------------------------------------------------------- 1 | import joblib 2 | import numpy as np 3 | import tensorflow as tf 4 | from global_utils.custom_utils import log 5 | 6 | def build(init_file): 7 | """Build initializers based on the init file""" 8 | initial_values = joblib.load(init_file) 9 | 10 | initializers = {} 11 | for name, values_dict in initial_values.iteritems(): 12 | log.infov('Build initializer for layer [%s].', name) 13 | initializers[name] = build_layer_initializers(values_dict) 14 | return initializers 15 | 16 | def build_layer_initializers(values_dict): 17 | layer_initializers = {} 18 | for k, v in values_dict.iteritems(): 19 | if isinstance(v, np.ndarray): 20 | layer_initializers[k] = tf.constant_initializer(v) 21 | elif isinstance(v, dict): 22 | layer_initializers[k] = build_layer_initializers(v) 23 | else: 24 | raise ValueError('Cannot change type of [%s]: %s.' % 25 | (k, type(v))) 26 | return layer_initializers 27 | -------------------------------------------------------------------------------- /object_detection/builders/input_reader_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Input reader builder. 17 | 18 | Creates data sources for DetectionModels from an InputReader config. See 19 | input_reader.proto for options. 20 | 21 | Note: If users wishes to also use their own InputReaders with the Object 22 | Detection configuration framework, they should define their own builder function 23 | that wraps the build function. 24 | """ 25 | 26 | import tensorflow as tf 27 | 28 | from object_detection.data_decoders import tf_example_decoder 29 | from object_detection.protos import input_reader_pb2 30 | 31 | parallel_reader = tf.contrib.slim.parallel_reader 32 | 33 | 34 | def build(input_reader_config): 35 | """Builds a tensor dictionary based on the InputReader config. 36 | 37 | Args: 38 | input_reader_config: A input_reader_pb2.InputReader object. 39 | 40 | Returns: 41 | A tensor dict based on the input_reader_config. 42 | 43 | Raises: 44 | ValueError: On invalid input reader proto. 45 | """ 46 | if not isinstance(input_reader_config, input_reader_pb2.InputReader): 47 | raise ValueError('input_reader_config not of type ' 48 | 'input_reader_pb2.InputReader.') 49 | 50 | if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': 51 | config = input_reader_config.tf_record_input_reader 52 | _, string_tensor = parallel_reader.parallel_read( 53 | config.input_path, 54 | reader_class=tf.TFRecordReader, 55 | num_epochs=(input_reader_config.num_epochs 56 | if input_reader_config.num_epochs else None), 57 | num_readers=input_reader_config.num_readers, 58 | shuffle=input_reader_config.shuffle, 59 | dtypes=[tf.string, tf.string], 60 | capacity=input_reader_config.queue_capacity, 61 | min_after_dequeue=input_reader_config.min_after_dequeue) 62 | 63 | return tf_example_decoder.TfExampleDecoder().decode(string_tensor) 64 | 65 | raise ValueError('Unsupported input_reader_config.') 66 | -------------------------------------------------------------------------------- /object_detection/builders/mask_predictor_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Function to build box predictor from configuration.""" 17 | 18 | from object_detection.builders import hyperparams_builder 19 | from object_detection.core import mask_predictor 20 | from object_detection.protos import box_predictor_pb2 21 | 22 | 23 | def build(argscope_fn, mask_predictor_config, is_training, num_classes, reuse_weights=None, channels=1): 24 | """Builds box predictor based on the configuration. 25 | 26 | Builds box predictor based on the configuration. See box_predictor.proto for 27 | configurable options. Also, see box_predictor.py for more details. 28 | 29 | Args: 30 | argscope_fn: A function that takes the following inputs: 31 | * hyperparams_pb2.Hyperparams proto 32 | * a boolean indicating if the model is in training mode. 33 | and returns a tf slim argscope for Conv and FC hyperparameters. 34 | box_predictor_config: box_predictor_pb2.BoxPredictor proto containing 35 | configuration. 36 | is_training: Whether the models is in training mode. 37 | num_classes: Number of classes to predict. 38 | 39 | Returns: 40 | box_predictor: box_predictor.BoxPredictor object. 41 | 42 | Raises: 43 | ValueError: On unknown box predictor. 44 | """ 45 | 46 | conv_hyperparams = argscope_fn(mask_predictor_config.conv_hyperparams, is_training) 47 | 48 | box_predictor_object = mask_predictor.MaskPredictor( 49 | conv_hyperparams=conv_hyperparams, 50 | is_training=is_training, 51 | num_classes=num_classes, 52 | kernel_size=mask_predictor_config.kernel_size, 53 | reuse_weights=reuse_weights, 54 | channels=channels) 55 | return box_predictor_object 56 | 57 | -------------------------------------------------------------------------------- /object_detection/builders/matcher_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A function to build an object detection matcher from configuration.""" 17 | 18 | from object_detection.matchers import argmax_matcher 19 | from object_detection.matchers import bipartite_matcher 20 | from object_detection.protos import matcher_pb2 21 | 22 | 23 | def build(matcher_config): 24 | """Builds a matcher object based on the matcher config. 25 | 26 | Args: 27 | matcher_config: A matcher.proto object containing the config for the desired 28 | Matcher. 29 | 30 | Returns: 31 | Matcher based on the config. 32 | 33 | Raises: 34 | ValueError: On empty matcher proto. 35 | """ 36 | if not isinstance(matcher_config, matcher_pb2.Matcher): 37 | raise ValueError('matcher_config not of type matcher_pb2.Matcher.') 38 | if matcher_config.WhichOneof('matcher_oneof') == 'argmax_matcher': 39 | matcher = matcher_config.argmax_matcher 40 | matched_threshold = unmatched_threshold = None 41 | if not matcher.ignore_thresholds: 42 | matched_threshold = matcher.matched_threshold 43 | unmatched_threshold = matcher.unmatched_threshold 44 | return argmax_matcher.ArgMaxMatcher( 45 | matched_threshold=matched_threshold, 46 | unmatched_threshold=unmatched_threshold, 47 | negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched, 48 | force_match_for_each_row=matcher.force_match_for_each_row) 49 | if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher': 50 | return bipartite_matcher.GreedyBipartiteMatcher() 51 | raise ValueError('Empty matcher.') 52 | -------------------------------------------------------------------------------- /object_detection/builders/post_processing_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for post_processing_builder.""" 17 | 18 | import tensorflow as tf 19 | from google.protobuf import text_format 20 | from object_detection.builders import post_processing_builder 21 | from object_detection.protos import post_processing_pb2 22 | 23 | 24 | class PostProcessingBuilderTest(tf.test.TestCase): 25 | 26 | def test_build_non_max_suppressor_with_correct_parameters(self): 27 | post_processing_text_proto = """ 28 | batch_non_max_suppression { 29 | score_threshold: 0.7 30 | iou_threshold: 0.6 31 | max_detections_per_class: 100 32 | max_total_detections: 300 33 | } 34 | """ 35 | post_processing_config = post_processing_pb2.PostProcessing() 36 | text_format.Merge(post_processing_text_proto, post_processing_config) 37 | non_max_suppressor, _ = post_processing_builder.build( 38 | post_processing_config) 39 | self.assertEqual(non_max_suppressor.keywords['max_size_per_class'], 100) 40 | self.assertEqual(non_max_suppressor.keywords['max_total_size'], 300) 41 | self.assertAlmostEqual(non_max_suppressor.keywords['score_thresh'], 0.7) 42 | self.assertAlmostEqual(non_max_suppressor.keywords['iou_thresh'], 0.6) 43 | 44 | def test_build_identity_score_converter(self): 45 | post_processing_text_proto = """ 46 | score_converter: IDENTITY 47 | """ 48 | post_processing_config = post_processing_pb2.PostProcessing() 49 | text_format.Merge(post_processing_text_proto, post_processing_config) 50 | _, score_converter = post_processing_builder.build(post_processing_config) 51 | self.assertEqual(score_converter, tf.identity) 52 | 53 | def test_build_sigmoid_score_converter(self): 54 | post_processing_text_proto = """ 55 | score_converter: SIGMOID 56 | """ 57 | post_processing_config = post_processing_pb2.PostProcessing() 58 | text_format.Merge(post_processing_text_proto, post_processing_config) 59 | _, score_converter = post_processing_builder.build(post_processing_config) 60 | self.assertEqual(score_converter, tf.sigmoid) 61 | 62 | def test_build_softmax_score_converter(self): 63 | post_processing_text_proto = """ 64 | score_converter: SOFTMAX 65 | """ 66 | post_processing_config = post_processing_pb2.PostProcessing() 67 | text_format.Merge(post_processing_text_proto, post_processing_config) 68 | _, score_converter = post_processing_builder.build(post_processing_config) 69 | self.assertEqual(score_converter, tf.nn.softmax) 70 | 71 | 72 | if __name__ == '__main__': 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /object_detection/builders/region_similarity_calculator_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builder for region similarity calculators.""" 17 | 18 | from object_detection.core import region_similarity_calculator 19 | from object_detection.protos import region_similarity_calculator_pb2 20 | 21 | 22 | def build(region_similarity_calculator_config): 23 | """Builds region similarity calculator based on the configuration. 24 | 25 | Builds one of [IouSimilarity, IoaSimilarity, NegSqDistSimilarity] objects. See 26 | core/region_similarity_calculator.proto for details. 27 | 28 | Args: 29 | region_similarity_calculator_config: RegionSimilarityCalculator 30 | configuration proto. 31 | 32 | Returns: 33 | region_similarity_calculator: RegionSimilarityCalculator object. 34 | 35 | Raises: 36 | ValueError: On unknown region similarity calculator. 37 | """ 38 | 39 | if not isinstance( 40 | region_similarity_calculator_config, 41 | region_similarity_calculator_pb2.RegionSimilarityCalculator): 42 | raise ValueError( 43 | 'region_similarity_calculator_config not of type ' 44 | 'region_similarity_calculator_pb2.RegionsSimilarityCalculator') 45 | 46 | similarity_calculator = region_similarity_calculator_config.WhichOneof( 47 | 'region_similarity') 48 | if similarity_calculator == 'iou_similarity': 49 | return region_similarity_calculator.IouSimilarity() 50 | if similarity_calculator == 'ioa_similarity': 51 | return region_similarity_calculator.IoaSimilarity() 52 | if similarity_calculator == 'neg_sq_dist_similarity': 53 | return region_similarity_calculator.NegSqDistSimilarity() 54 | 55 | raise ValueError('Unknown region similarity calculator.') 56 | 57 | -------------------------------------------------------------------------------- /object_detection/builders/region_similarity_calculator_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for region_similarity_calculator_builder.""" 17 | 18 | import tensorflow as tf 19 | 20 | from google.protobuf import text_format 21 | from object_detection.builders import region_similarity_calculator_builder 22 | from object_detection.core import region_similarity_calculator 23 | from object_detection.protos import region_similarity_calculator_pb2 as sim_calc_pb2 24 | 25 | 26 | class RegionSimilarityCalculatorBuilderTest(tf.test.TestCase): 27 | 28 | def testBuildIoaSimilarityCalculator(self): 29 | similarity_calc_text_proto = """ 30 | ioa_similarity { 31 | } 32 | """ 33 | similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator() 34 | text_format.Merge(similarity_calc_text_proto, similarity_calc_proto) 35 | similarity_calc = region_similarity_calculator_builder.build( 36 | similarity_calc_proto) 37 | self.assertTrue(isinstance(similarity_calc, 38 | region_similarity_calculator.IoaSimilarity)) 39 | 40 | def testBuildIouSimilarityCalculator(self): 41 | similarity_calc_text_proto = """ 42 | iou_similarity { 43 | } 44 | """ 45 | similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator() 46 | text_format.Merge(similarity_calc_text_proto, similarity_calc_proto) 47 | similarity_calc = region_similarity_calculator_builder.build( 48 | similarity_calc_proto) 49 | self.assertTrue(isinstance(similarity_calc, 50 | region_similarity_calculator.IouSimilarity)) 51 | 52 | def testBuildNegSqDistSimilarityCalculator(self): 53 | similarity_calc_text_proto = """ 54 | neg_sq_dist_similarity { 55 | } 56 | """ 57 | similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator() 58 | text_format.Merge(similarity_calc_text_proto, similarity_calc_proto) 59 | similarity_calc = region_similarity_calculator_builder.build( 60 | similarity_calc_proto) 61 | self.assertTrue(isinstance(similarity_calc, 62 | region_similarity_calculator. 63 | NegSqDistSimilarity)) 64 | 65 | 66 | if __name__ == '__main__': 67 | tf.test.main() 68 | -------------------------------------------------------------------------------- /object_detection/checkpoints/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/checkpoints/.gitignore -------------------------------------------------------------------------------- /object_detection/configs/.gitignore: -------------------------------------------------------------------------------- 1 | bak 2 | -------------------------------------------------------------------------------- /object_detection/configs/test/.gitignore: -------------------------------------------------------------------------------- 1 | bak 2 | -------------------------------------------------------------------------------- /object_detection/configs/voc_evaluation.config: -------------------------------------------------------------------------------- 1 | eval_config: { 2 | num_examples: 50000 3 | metrics_set: "pascal_voc_metrics" 4 | } 5 | 6 | eval_input_reader: { 7 | tf_record_input_reader { 8 | input_path: "../data/voc/voc0712_val.record" 9 | } 10 | label_map_path: "../data/pascal_label_map.pbtxt" 11 | shuffle: false 12 | num_readers: 1 13 | } 14 | -------------------------------------------------------------------------------- /object_detection/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/core/__init__.py -------------------------------------------------------------------------------- /object_detection/core/balanced_positive_negative_sampler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.core.balanced_positive_negative_sampler.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.core import balanced_positive_negative_sampler 22 | 23 | 24 | class BalancedPositiveNegativeSamplerTest(tf.test.TestCase): 25 | 26 | def test_subsample_all_examples(self): 27 | numpy_labels = np.random.permutation(300) 28 | indicator = tf.constant(np.ones(300) == 1) 29 | numpy_labels = (numpy_labels - 200) > 0 30 | 31 | labels = tf.constant(numpy_labels) 32 | 33 | sampler = (balanced_positive_negative_sampler. 34 | BalancedPositiveNegativeSampler()) 35 | is_sampled = sampler.subsample(indicator, 64, labels) 36 | with self.test_session() as sess: 37 | is_sampled = sess.run(is_sampled) 38 | self.assertTrue(sum(is_sampled) == 64) 39 | self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 32) 40 | self.assertTrue(sum(np.logical_and( 41 | np.logical_not(numpy_labels), is_sampled)) == 32) 42 | 43 | def test_subsample_selection(self): 44 | # Test random sampling when only some examples can be sampled: 45 | # 100 samples, 20 positives, 10 positives cannot be sampled 46 | numpy_labels = np.arange(100) 47 | numpy_indicator = numpy_labels < 90 48 | indicator = tf.constant(numpy_indicator) 49 | numpy_labels = (numpy_labels - 80) >= 0 50 | 51 | labels = tf.constant(numpy_labels) 52 | 53 | sampler = (balanced_positive_negative_sampler. 54 | BalancedPositiveNegativeSampler()) 55 | is_sampled = sampler.subsample(indicator, 64, labels) 56 | with self.test_session() as sess: 57 | is_sampled = sess.run(is_sampled) 58 | self.assertTrue(sum(is_sampled) == 64) 59 | self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10) 60 | self.assertTrue(sum(np.logical_and( 61 | np.logical_not(numpy_labels), is_sampled)) == 54) 62 | self.assertAllEqual(is_sampled, np.logical_and(is_sampled, 63 | numpy_indicator)) 64 | 65 | def test_raises_error_with_incorrect_label_shape(self): 66 | labels = tf.constant([[True, False, False]]) 67 | indicator = tf.constant([True, False, True]) 68 | sampler = (balanced_positive_negative_sampler. 69 | BalancedPositiveNegativeSampler()) 70 | with self.assertRaises(ValueError): 71 | sampler.subsample(indicator, 64, labels) 72 | 73 | def test_raises_error_with_incorrect_indicator_shape(self): 74 | labels = tf.constant([True, False, False]) 75 | indicator = tf.constant([[True, False, True]]) 76 | sampler = (balanced_positive_negative_sampler. 77 | BalancedPositiveNegativeSampler()) 78 | with self.assertRaises(ValueError): 79 | sampler.subsample(indicator, 64, labels) 80 | 81 | 82 | if __name__ == '__main__': 83 | tf.test.main() 84 | -------------------------------------------------------------------------------- /object_detection/core/box_coder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.core.box_coder.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.core import box_coder 21 | from object_detection.core import box_list 22 | 23 | 24 | class MockBoxCoder(box_coder.BoxCoder): 25 | """Test BoxCoder that encodes/decodes using the multiply-by-two function.""" 26 | 27 | def code_size(self): 28 | return 4 29 | 30 | def _encode(self, boxes, anchors): 31 | return 2.0 * boxes.get() 32 | 33 | def _decode(self, rel_codes, anchors): 34 | return box_list.BoxList(rel_codes / 2.0) 35 | 36 | 37 | class BoxCoderTest(tf.test.TestCase): 38 | 39 | def test_batch_decode(self): 40 | mock_anchor_corners = tf.constant( 41 | [[0, 0.1, 0.2, 0.3], [0.2, 0.4, 0.4, 0.6]], tf.float32) 42 | mock_anchors = box_list.BoxList(mock_anchor_corners) 43 | mock_box_coder = MockBoxCoder() 44 | 45 | expected_boxes = [[[0.0, 0.1, 0.5, 0.6], [0.5, 0.6, 0.7, 0.8]], 46 | [[0.1, 0.2, 0.3, 0.4], [0.7, 0.8, 0.9, 1.0]]] 47 | 48 | encoded_boxes_list = [mock_box_coder.encode( 49 | box_list.BoxList(tf.constant(boxes)), mock_anchors) 50 | for boxes in expected_boxes] 51 | encoded_boxes = tf.stack(encoded_boxes_list) 52 | decoded_boxes = box_coder.batch_decode( 53 | encoded_boxes, mock_box_coder, mock_anchors) 54 | 55 | with self.test_session() as sess: 56 | decoded_boxes_result = sess.run(decoded_boxes) 57 | self.assertAllClose(expected_boxes, decoded_boxes_result) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /object_detection/core/data_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Interface for data decoders. 17 | 18 | Data decoders decode the input data and return a dictionary of tensors keyed by 19 | the entries in core.reader.Fields. 20 | """ 21 | from abc import ABCMeta 22 | from abc import abstractmethod 23 | 24 | 25 | class DataDecoder(object): 26 | """Interface for data decoders.""" 27 | __metaclass__ = ABCMeta 28 | 29 | @abstractmethod 30 | def decode(self, data): 31 | """Return a single image and associated labels. 32 | 33 | Args: 34 | data: a string tensor holding a serialized protocol buffer corresponding 35 | to data for a single image. 36 | 37 | Returns: 38 | tensor_dict: a dictionary containing tensors. Possible keys are defined in 39 | reader.Fields. 40 | """ 41 | pass 42 | -------------------------------------------------------------------------------- /object_detection/core/minibatch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Base minibatch sampler module. 17 | 18 | The job of the minibatch_sampler is to subsample a minibatch based on some 19 | criterion. 20 | 21 | The main function call is: 22 | subsample(indicator, batch_size, **params). 23 | Indicator is a 1d boolean tensor where True denotes which examples can be 24 | sampled. It returns a boolean indicator where True denotes an example has been 25 | sampled.. 26 | 27 | Subclasses should implement the Subsample function and can make use of the 28 | @staticmethod SubsampleIndicator. 29 | """ 30 | 31 | from abc import ABCMeta 32 | from abc import abstractmethod 33 | 34 | import tensorflow as tf 35 | 36 | from object_detection.utils import ops 37 | 38 | 39 | class MinibatchSampler(object): 40 | """Abstract base class for subsampling minibatches.""" 41 | __metaclass__ = ABCMeta 42 | 43 | def __init__(self): 44 | """Constructs a minibatch sampler.""" 45 | pass 46 | 47 | @abstractmethod 48 | def subsample(self, indicator, batch_size, **params): 49 | """Returns subsample of entries in indicator. 50 | 51 | Args: 52 | indicator: boolean tensor of shape [N] whose True entries can be sampled. 53 | batch_size: desired batch size. 54 | **params: additional keyword arguments for specific implementations of 55 | the MinibatchSampler. 56 | 57 | Returns: 58 | sample_indicator: boolean tensor of shape [N] whose True entries have been 59 | sampled. If sum(indicator) >= batch_size, sum(is_sampled) = batch_size 60 | """ 61 | pass 62 | 63 | @staticmethod 64 | def subsample_indicator(indicator, num_samples): 65 | """Subsample indicator vector. 66 | 67 | Given a boolean indicator vector with M elements set to `True`, the function 68 | assigns all but `num_samples` of these previously `True` elements to 69 | `False`. If `num_samples` is greater than M, the original indicator vector 70 | is returned. 71 | 72 | Args: 73 | indicator: a 1-dimensional boolean tensor indicating which elements 74 | are allowed to be sampled and which are not. 75 | num_samples: int32 scalar tensor 76 | 77 | Returns: 78 | a boolean tensor with the same shape as input (indicator) tensor 79 | """ 80 | indices = tf.where(indicator) 81 | indices = tf.random_shuffle(indices) 82 | indices = tf.reshape(indices, [-1]) 83 | 84 | num_samples = tf.minimum(tf.size(indices), num_samples) 85 | selected_indices = tf.slice(indices, [0], tf.reshape(num_samples, [1])) 86 | 87 | selected_indicator = ops.indices_to_dense_vector(selected_indices, 88 | tf.shape(indicator)[0]) 89 | 90 | return tf.equal(selected_indicator, 1) 91 | -------------------------------------------------------------------------------- /object_detection/core/prefetcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Provides functions to prefetch tensors to feed into models.""" 17 | import tensorflow as tf 18 | 19 | 20 | def prefetch(tensor_dict, capacity): 21 | """Creates a prefetch queue for tensors. 22 | 23 | Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a 24 | dequeue op that evaluates to a tensor_dict. This function is useful in 25 | prefetching preprocessed tensors so that the data is readily available for 26 | consumers. 27 | 28 | Example input pipeline when you don't need batching: 29 | ---------------------------------------------------- 30 | key, string_tensor = slim.parallel_reader.parallel_read(...) 31 | tensor_dict = decoder.decode(string_tensor) 32 | tensor_dict = preprocessor.preprocess(tensor_dict, ...) 33 | prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20) 34 | tensor_dict = prefetch_queue.dequeue() 35 | outputs = Model(tensor_dict) 36 | ... 37 | ---------------------------------------------------- 38 | 39 | For input pipelines with batching, refer to core/batcher.py 40 | 41 | Args: 42 | tensor_dict: a dictionary of tensors to prefetch. 43 | capacity: the size of the prefetch queue. 44 | 45 | Returns: 46 | a FIFO prefetcher queue 47 | """ 48 | names = list(tensor_dict.keys()) 49 | dtypes = [t.dtype for t in tensor_dict.values()] 50 | shapes = [t.get_shape() for t in tensor_dict.values()] 51 | prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes, 52 | shapes=shapes, 53 | names=names, 54 | name='prefetch_queue') 55 | enqueue_op = prefetch_queue.enqueue(tensor_dict) 56 | tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner( 57 | prefetch_queue, [enqueue_op])) 58 | tf.summary.scalar('queue/%s/fraction_of_%d_full' % (prefetch_queue.name, 59 | capacity), 60 | tf.to_float(prefetch_queue.size()) * (1. / capacity)) 61 | return prefetch_queue 62 | -------------------------------------------------------------------------------- /object_detection/create_records/create_all_pascal_tf_records.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | export LD_PRELOAD="/usr/lib/libtcmalloc.so" 3 | 4 | #DATA_DIR=$1 5 | #OUTPUT_DIR=$2 6 | 7 | #python create_pascal_tf_record.py \ 8 | # --data_dir=../data/voc/VOCdevkit \ 9 | # --year=VOC2007 \ 10 | # --set=train \ 11 | # --output_path=../data/voc/voc2007_train.record \ 12 | # --label_map_path=../data/pascal_label_map.pbtxt 13 | 14 | #python create_pascal_tf_record.py \ 15 | # --data_dir=../data/voc/VOCdevkit \ 16 | # --year=VOC2007 \ 17 | # --set=val \ 18 | # --output_path=../data/voc/voc2007_val.record \ 19 | # --label_map_path=../data/pascal_label_map.pbtxt 20 | 21 | python create_pascal_tf_record.py \ 22 | --data_dir=../data/voc/VOCdevkit \ 23 | --year=VOC2007 \ 24 | --set=trainval \ 25 | --output_path=../data/voc/voc2007_trainval.record \ 26 | --label_map_path=../data/pascal_label_map.pbtxt 27 | 28 | python create_pascal_tf_record.py \ 29 | --data_dir=../data/voc/VOCdevkit \ 30 | --year=VOC2007 \ 31 | --set=test \ 32 | --output_path=../data/voc/voc2007_test.record \ 33 | --label_map_path=../data/pascal_label_map.pbtxt 34 | 35 | #python create_pascal_tf_record.py \ 36 | # --data_dir=../data/voc/VOCdevkit \ 37 | # --year=VOC2012 \ 38 | # --set=train \ 39 | # --output_path=../data/voc/voc2012_train.record \ 40 | # --label_map_path=../data/pascal_label_map.pbtxt 41 | 42 | #python create_pascal_tf_record.py \ 43 | # --data_dir=../data/voc/VOCdevkit \ 44 | # --year=VOC2012 \ 45 | # --set=val \ 46 | # --output_path=../data/voc/voc2012_val.record \ 47 | # --label_map_path=../data/pascal_label_map.pbtxt 48 | 49 | #python create_pascal_tf_record.py \ 50 | # --data_dir=../data/voc/VOCdevkit \ 51 | # --year=VOC2012 \ 52 | # --set=trainval \ 53 | # --output_path=../data/voc/voc2012_trainval.record \ 54 | # --label_map_path=../data/pascal_label_map.pbtxt 55 | 56 | #python create_pascal_tf_record.py \ 57 | # --data_dir=../data/voc/VOCdevkit \ 58 | # --year=merged \ 59 | # --set=train \ 60 | # --output_path=../data/voc/voc0712_train.record \ 61 | # --label_map_path=../data/pascal_label_map.pbtxt 62 | 63 | #python create_pascal_tf_record.py \ 64 | # --data_dir=../data/voc/VOCdevkit \ 65 | # --year=merged \ 66 | # --set=val \ 67 | # --output_path=../data/voc/voc0712_val.record \ 68 | # --label_map_path=../data/pascal_label_map.pbtxt 69 | 70 | python create_pascal_tf_record.py \ 71 | --data_dir=../data/voc/VOCdevkit \ 72 | --year=merged \ 73 | --set=trainval \ 74 | --output_path=../data/voc/voc0712_trainval.record \ 75 | --label_map_path=../data/pascal_label_map.pbtxt 76 | 77 | #python create_pascal_tf_record.py \ 78 | # --data_dir=../data/voc/VOCdevkit \ 79 | # --year=VOC2012 \ 80 | # --set=test \ 81 | # --output_path=../data/voc/voc2012_test.record \ 82 | # --label_map_path=../data/pascal_label_map.pbtxt -------------------------------------------------------------------------------- /object_detection/data/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/data/.gitignore -------------------------------------------------------------------------------- /object_detection/data/caltech_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | 2 | item { 3 | id: 1 4 | name: 'person' 5 | } 6 | -------------------------------------------------------------------------------- /object_detection/data/mobis_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'person' 4 | } 5 | -------------------------------------------------------------------------------- /object_detection/data/mscoco/.gitignore: -------------------------------------------------------------------------------- 1 | annotations 2 | cocoapi 3 | images 4 | *.record 5 | *.zip 6 | 7 | -------------------------------------------------------------------------------- /object_detection/data/mscoco/README.md: -------------------------------------------------------------------------------- 1 | ## MS COCO dataset 2 | 3 | dataset homepage: http://cocodataset.org/ 4 | 5 | 6 | Download and unzip MS COCO images 7 | ``` bash 8 | # from mtl-ssl-detection/object_detection/data/mscoco/ 9 | mkdir images 10 | cd images 11 | wget http://images.cocodataset.org/zips/train2017.zip 12 | wget http://images.cocodataset.org/zips/val2017.zip 13 | wget http://images.cocodataset.org/zips/test2017.zip 14 | unzip train2017.zip 15 | unzip val2017.zip 16 | unzip test2017.zip 17 | 18 | ``` 19 | 20 | Download and unzip MS COCO annotations 21 | 22 | ``` bash 23 | # from mtl-ssl-detection/object_detection/data/mscoco/ 24 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 25 | wget http://images.cocodataset.org/annotations/image_info_test2017.zip 26 | unzip annotations_trainval2017.zip 27 | unzip image_info_test2017.zip 28 | ``` 29 | 30 | 31 | Install [COCO API](https://github.com/cocodataset/cocoapi) 32 | ``` bash 33 | # from mtl-ssl-detection/object_detection/data/mscoco/ 34 | git clone https://github.com/cocodataset/cocoapi.git 35 | cd cocoapi/PythonAPI 36 | make install 37 | ``` 38 | 39 | Generating the MS COCO TFRecord files 40 | ``` bash 41 | # from mtl-ssl-detection/object_detection/create_records/ 42 | python create_mscoco_tf_record.py 43 | ``` -------------------------------------------------------------------------------- /object_detection/data/pascal_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'aeroplane' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'bicycle' 9 | } 10 | 11 | item { 12 | id: 3 13 | name: 'bird' 14 | } 15 | 16 | item { 17 | id: 4 18 | name: 'boat' 19 | } 20 | 21 | item { 22 | id: 5 23 | name: 'bottle' 24 | } 25 | 26 | item { 27 | id: 6 28 | name: 'bus' 29 | } 30 | 31 | item { 32 | id: 7 33 | name: 'car' 34 | } 35 | 36 | item { 37 | id: 8 38 | name: 'cat' 39 | } 40 | 41 | item { 42 | id: 9 43 | name: 'chair' 44 | } 45 | 46 | item { 47 | id: 10 48 | name: 'cow' 49 | } 50 | 51 | item { 52 | id: 11 53 | name: 'diningtable' 54 | } 55 | 56 | item { 57 | id: 12 58 | name: 'dog' 59 | } 60 | 61 | item { 62 | id: 13 63 | name: 'horse' 64 | } 65 | 66 | item { 67 | id: 14 68 | name: 'motorbike' 69 | } 70 | 71 | item { 72 | id: 15 73 | name: 'person' 74 | } 75 | 76 | item { 77 | id: 16 78 | name: 'pottedplant' 79 | } 80 | 81 | item { 82 | id: 17 83 | name: 'sheep' 84 | } 85 | 86 | item { 87 | id: 18 88 | name: 'sofa' 89 | } 90 | 91 | item { 92 | id: 19 93 | name: 'train' 94 | } 95 | 96 | item { 97 | id: 20 98 | name: 'tvmonitor' 99 | } 100 | -------------------------------------------------------------------------------- /object_detection/data/pet_label_map.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: 'Abyssinian' 4 | } 5 | 6 | item { 7 | id: 2 8 | name: 'american_bulldog' 9 | } 10 | 11 | item { 12 | id: 3 13 | name: 'american_pit_bull_terrier' 14 | } 15 | 16 | item { 17 | id: 4 18 | name: 'basset_hound' 19 | } 20 | 21 | item { 22 | id: 5 23 | name: 'beagle' 24 | } 25 | 26 | item { 27 | id: 6 28 | name: 'Bengal' 29 | } 30 | 31 | item { 32 | id: 7 33 | name: 'Birman' 34 | } 35 | 36 | item { 37 | id: 8 38 | name: 'Bombay' 39 | } 40 | 41 | item { 42 | id: 9 43 | name: 'boxer' 44 | } 45 | 46 | item { 47 | id: 10 48 | name: 'British_Shorthair' 49 | } 50 | 51 | item { 52 | id: 11 53 | name: 'chihuahua' 54 | } 55 | 56 | item { 57 | id: 12 58 | name: 'Egyptian_Mau' 59 | } 60 | 61 | item { 62 | id: 13 63 | name: 'english_cocker_spaniel' 64 | } 65 | 66 | item { 67 | id: 14 68 | name: 'english_setter' 69 | } 70 | 71 | item { 72 | id: 15 73 | name: 'german_shorthaired' 74 | } 75 | 76 | item { 77 | id: 16 78 | name: 'great_pyrenees' 79 | } 80 | 81 | item { 82 | id: 17 83 | name: 'havanese' 84 | } 85 | 86 | item { 87 | id: 18 88 | name: 'japanese_chin' 89 | } 90 | 91 | item { 92 | id: 19 93 | name: 'keeshond' 94 | } 95 | 96 | item { 97 | id: 20 98 | name: 'leonberger' 99 | } 100 | 101 | item { 102 | id: 21 103 | name: 'Maine_Coon' 104 | } 105 | 106 | item { 107 | id: 22 108 | name: 'miniature_pinscher' 109 | } 110 | 111 | item { 112 | id: 23 113 | name: 'newfoundland' 114 | } 115 | 116 | item { 117 | id: 24 118 | name: 'Persian' 119 | } 120 | 121 | item { 122 | id: 25 123 | name: 'pomeranian' 124 | } 125 | 126 | item { 127 | id: 26 128 | name: 'pug' 129 | } 130 | 131 | item { 132 | id: 27 133 | name: 'Ragdoll' 134 | } 135 | 136 | item { 137 | id: 28 138 | name: 'Russian_Blue' 139 | } 140 | 141 | item { 142 | id: 29 143 | name: 'saint_bernard' 144 | } 145 | 146 | item { 147 | id: 30 148 | name: 'samoyed' 149 | } 150 | 151 | item { 152 | id: 31 153 | name: 'scottish_terrier' 154 | } 155 | 156 | item { 157 | id: 32 158 | name: 'shiba_inu' 159 | } 160 | 161 | item { 162 | id: 33 163 | name: 'Siamese' 164 | } 165 | 166 | item { 167 | id: 34 168 | name: 'Sphynx' 169 | } 170 | 171 | item { 172 | id: 35 173 | name: 'staffordshire_bull_terrier' 174 | } 175 | 176 | item { 177 | id: 36 178 | name: 'wheaten_terrier' 179 | } 180 | 181 | item { 182 | id: 37 183 | name: 'yorkshire_terrier' 184 | } 185 | -------------------------------------------------------------------------------- /object_detection/data/voc/.gitignore: -------------------------------------------------------------------------------- 1 | VOCdevkit 2 | *.record 3 | *.zip 4 | *.tar 5 | 6 | -------------------------------------------------------------------------------- /object_detection/data/voc/README.md: -------------------------------------------------------------------------------- 1 | ## PASCAL VOC dataset 2 | 3 | dataset homepage: http://host.robots.ox.ac.uk/pascal/VOC/ 4 | 5 | 6 | Download and unzip PASCAL VOC dataset 7 | ``` bash 8 | # from mtl-ssl-detection/object_detection/data/voc/ 9 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 10 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 11 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 12 | tar -xvf VOCtrainval_11-May-2012.tar 13 | tar -xvf VOCtrainval_06-Nov-2007.tar 14 | tar -xvf VOCtest_06-Nov-2007.tar 15 | ``` 16 | 17 | Generating the PASCAL VOC TFRecord files 18 | ``` bash 19 | # from mtl-ssl-detection/object_detection/create_records/ 20 | bash create_all_pascal_tf_records.sh 21 | ``` -------------------------------------------------------------------------------- /object_detection/data_decoders/BUILD: -------------------------------------------------------------------------------- 1 | # Tensorflow Object Detection API: data decoders. 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | licenses(["notice"]) 8 | # Apache 2.0 9 | 10 | py_library( 11 | name = "tf_example_decoder", 12 | srcs = ["tf_example_decoder.py"], 13 | deps = [ 14 | "//tensorflow", 15 | "//tensorflow_models/object_detection/core:data_decoder", 16 | "//tensorflow_models/object_detection/core:standard_fields", 17 | ], 18 | ) 19 | 20 | py_test( 21 | name = "tf_example_decoder_test", 22 | srcs = ["tf_example_decoder_test.py"], 23 | deps = [ 24 | ":tf_example_decoder", 25 | "//tensorflow", 26 | "//tensorflow_models/object_detection/core:standard_fields", 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /object_detection/data_decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/data_decoders/__init__.py -------------------------------------------------------------------------------- /object_detection/g3doc/detection_model_zoo.md: -------------------------------------------------------------------------------- 1 | # Tensorflow detection model zoo 2 | 3 | We provide a collection of detection models pre-trained on the 4 | [COCO dataset](http://mscoco.org). 5 | These models can be useful for out-of-the-box inference if you are interested 6 | in categories already in COCO (e.g., humans, cars, etc). 7 | They are also useful for initializing your models when training on novel 8 | datasets. 9 | 10 | In the table below, we list each such pre-trained model including: 11 | 12 | * a model name that corresponds to a config file that was used to train this 13 | model in the `samples/configs` directory, 14 | * a download link to a tar.gz file containing the pre-trained model, 15 | * model speed (one of {slow, medium, fast}), 16 | * detector performance on COCO data as measured by the COCO mAP measure. 17 | Here, higher is better, and we only report bounding box mAP rounded to the 18 | nearest integer. 19 | * Output types (currently only `Boxes`) 20 | 21 | You can un-tar each tar.gz file via, e.g.,: 22 | 23 | ``` 24 | tar -xzvf ssd_mobilenet_v1_coco.tar.gz 25 | ``` 26 | 27 | Inside the un-tar'ed directory, you will find: 28 | 29 | * a graph proto (`graph.pbtxt`) 30 | * a checkpoint 31 | (`model.ckpt.data-00000-of-00001`, `model.ckpt.index`, `model.ckpt.meta`) 32 | * a frozen graph proto with weights baked into the graph as constants 33 | (`frozen_inference_graph.pb`) to be used for out of the box inference 34 | (try this out in the Jupyter notebook!) 35 | 36 | | Model name | Speed | COCO mAP | Outputs | 37 | | ------------ | :--------------: | :--------------: | :-------------: | 38 | | [ssd_mobilenet_v1_coco](http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz) | fast | 21 | Boxes | 39 | | [ssd_inception_v2_coco](http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_11_06_2017.tar.gz) | fast | 24 | Boxes | 40 | | [rfcn_resnet101_coco](http://download.tensorflow.org/models/object_detection/rfcn_resnet101_coco_11_06_2017.tar.gz) | medium | 30 | Boxes | 41 | | [faster_rcnn_resnet101_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_11_06_2017.tar.gz) | medium | 32 | Boxes | 42 | | [faster_rcnn_inception_resnet_v2_atrous_coco](http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_coco_11_06_2017.tar.gz) | slow | 37 | Boxes | 43 | -------------------------------------------------------------------------------- /object_detection/g3doc/exporting_models.md: -------------------------------------------------------------------------------- 1 | # Exporting a trained model for inference 2 | 3 | After your model has been trained, you should export it to a Tensorflow 4 | graph proto. A checkpoint will typically consist of three files: 5 | 6 | * model.ckpt-${CHECKPOINT_NUMBER}.data-00000-of-00001, 7 | * model.ckpt-${CHECKPOINT_NUMBER}.index 8 | * model.ckpt-${CHECKPOINT_NUMBER}.meta 9 | 10 | After you've identified a candidate checkpoint to export, run the following 11 | command from tensorflow/models/object_detection: 12 | 13 | ``` bash 14 | # from mtl-ssl-detection/object_detection 15 | python export_inference_graph.py \ 16 | --input_type image_tensor \ 17 | --pipeline_config_path ${PIPELINE_CONFIG_PATH} \ 18 | --trained_checkpoint_prefix ${CHECKPOINT_FILE_PATH} \ 19 | --output_directory ${OUTPUT_DIR_PATH} 20 | ``` 21 | (e.g.) 22 | ``` bash 23 | # from mtl-ssl-detection/object_detection 24 | python export_inference_graph.py \ 25 | --input_type image_tensor \ 26 | --pipeline_config_path ./configs/test/model11.config \ 27 | --trained_checkpoint_prefix ./checkpoints/train/model11/model.ckpt \ 28 | --output_directory ./checkpoints/fronzen/model11 29 | ``` 30 | 31 | Afterwards, you should see a graph named output_inference_graph.pb. 32 | -------------------------------------------------------------------------------- /object_detection/g3doc/img/dogs_detections_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/dogs_detections_output.jpg -------------------------------------------------------------------------------- /object_detection/g3doc/img/example_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/example_cat.jpg -------------------------------------------------------------------------------- /object_detection/g3doc/img/kites_detections_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/kites_detections_output.jpg -------------------------------------------------------------------------------- /object_detection/g3doc/img/mtl_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/mtl_1.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/mtl_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/mtl_2.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/mtl_ssl_detection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/mtl_ssl_detection.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/ours_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/ours_1.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/ours_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/ours_2.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/oxford_pet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/oxford_pet.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/qr_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/qr_1.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/qr_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/qr_2.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/qr_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/qr_3.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/qr_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/qr_4.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/qr_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/qr_5.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/results_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/results_1.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/results_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/results_2.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/results_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/results_3.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/results_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/results_4.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/reuse_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/reuse_1.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/reuse_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/reuse_2.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/ssl_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/ssl_1.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/ssl_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/ssl_2.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/tensorboard.png -------------------------------------------------------------------------------- /object_detection/g3doc/img/tensorboard2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/g3doc/img/tensorboard2.png -------------------------------------------------------------------------------- /object_detection/g3doc/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Dependencies 4 | 5 | Tensorflow Object Detection API depends on the following libraries: 6 | 7 | * Protobuf 2.6 8 | * Pillow 1.0 9 | * lxml 10 | * tf Slim (which is included in the "tensorflow/models" checkout) 11 | * Jupyter notebook 12 | * Matplotlib 13 | * Tensorflow 14 | 15 | For detailed steps to install Tensorflow, follow the 16 | [Tensorflow installation instructions](https://www.tensorflow.org/install/). 17 | A typically user can install Tensorflow using one of the following commands: 18 | 19 | ``` bash 20 | # For CPU 21 | pip install tensorflow 22 | # For GPU 23 | pip install tensorflow-gpu 24 | ``` 25 | 26 | The remaining libraries can be installed on Ubuntu 16.04 using via apt-get: 27 | 28 | ``` bash 29 | sudo apt-get install protobuf-compiler python-pil python-lxml 30 | sudo pip install jupyter 31 | sudo pip install matplotlib 32 | ``` 33 | 34 | Alternatively, users can install dependencies using pip: 35 | 36 | ``` bash 37 | sudo pip install pillow 38 | sudo pip install lxml 39 | sudo pip install jupyter 40 | sudo pip install matplotlib 41 | ``` 42 | 43 | ## Protobuf Compilation 44 | 45 | The Tensorflow Object Detection API uses Protobufs to configure model and 46 | training parameters. Before the framework can be used, the Protobuf libraries 47 | must be compiled. This should be done by running the following command from 48 | the tensorflow/models directory: 49 | 50 | 51 | ``` bash 52 | # From tensorflow/models/ 53 | protoc object_detection/protos/*.proto --python_out=. 54 | ``` 55 | 56 | ## Add Libraries to PYTHONPATH 57 | 58 | When running locally, the tensorflow/models/ and slim directories should be 59 | appended to PYTHONPATH. This can be done by running the following from 60 | tensorflow/models/: 61 | 62 | 63 | ``` bash 64 | # From tensorflow/models/ 65 | export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim 66 | ``` 67 | 68 | Note: This command needs to run from every new terminal you start. If you wish 69 | to avoid running this manually, you can add it as a new line to the end of your 70 | ~/.bashrc file. 71 | 72 | # Testing the Installation 73 | 74 | You can test that you have correctly installed the Tensorflow Object Detection\ 75 | API by running the following command: 76 | 77 | ```bash 78 | python object_detection/builders/model_builder_test.py 79 | ``` 80 | -------------------------------------------------------------------------------- /object_detection/g3doc/preparation.md: -------------------------------------------------------------------------------- 1 | # Preparation 2 | 3 | ## Installation 4 | This has been tested on Ubuntu 16.04 python 2.7 environment. 5 | 6 | Source code 7 | ``` bash 8 | git clone https://github.com/wonheeML/mtl-ssl-detection.git 9 | ``` 10 | 11 | 12 | Requirements 13 | ``` bash 14 | # from mtl-ssl-detection/ 15 | pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.7.0-cp27-none-linux_x86_64.whl 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | 20 | Check the protoc version >= 3.5.1 [[protoc](http://google.github.io/proto-lens/installing-protoc.html)] 21 | ``` bash 22 | protoc --version 23 | 24 | # if < 3.5.1, 25 | PROTOC_ZIP=protoc-3.5.1-linux-x86_64.zip 26 | curl -OL https://github.com/google/protobuf/releases/download/v3.5.1/$PROTOC_ZIP 27 | sudo unzip -o $PROTOC_ZIP -d /usr/local bin/protoc 28 | rm -f $PROTOC_ZIP 29 | sudo chmod +rx /usr/local/bin/protoc 30 | 31 | ``` 32 | 33 | 34 | Protobuf Compilation 35 | ``` bash 36 | # from from mtl-ssl-detection/ 37 | protoc object_detection/protos/*.proto --python_out=. 38 | ``` 39 | 40 | 41 | Add Libraries to PYTHONPATH 42 | ``` bash 43 | # from from mtl-ssl-detection/ 44 | export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim 45 | ``` 46 | 47 | Note: This command needs to run from every new terminal you start. If you wish 48 | to avoid running this manually, you can add it as a new line to the end of your 49 | ~/.bashrc file. 50 | 51 | 52 | ## Testing the Installation 53 | 54 | You can test that you have correctly installed the Tensorflow Object Detection\ 55 | API by running the following command: 56 | 57 | ```bash 58 | # from from mtl-ssl-detection/ 59 | python object_detection/builders/model_builder_test.py 60 | ``` 61 | 62 | 63 | # Dataset 64 | * MS COCO
65 | 66 | * PASCAL VOC
67 | 68 | 69 | # Model Zoo 70 | [tensorflow object detection API](https://github.com/tensorflow/models/tree/master/research/slim) 71 | provide a collection of detection models pre-trained on the ImageNet. 72 | The following models were tested. 73 | 74 | ``` bash 75 | # from from mtl-ssl-detection/object_detection/checkpoints/detection_model_zoo/ 76 | wget http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz 77 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 78 | wget http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz 79 | tar -xf resnet_v1_101_2016_08_28.tar.gz 80 | tar -xf inception_resnet_v2_2016_08_30.tar.gz 81 | mkdir mobilenet_v1_1.0_224 82 | tar -xvf mobilenet_v1_1.0_224.tgz -C mobilenet_v1_1.0_224 83 | ``` -------------------------------------------------------------------------------- /object_detection/g3doc/preparing_inputs.md: -------------------------------------------------------------------------------- 1 | # Preparing Inputs 2 | 3 | Tensorflow Object Detection API reads data using the TFRecord file format. Two 4 | sample scripts (`create_pascal_tf_record.py` and `create_pet_tf_record.py`) are 5 | provided to convert from the PASCAL VOC dataset and Oxford-IIIT Pet dataset to 6 | TFRecords. 7 | 8 | ## Generating the PASCAL VOC TFRecord files. 9 | 10 | The raw 2012 PASCAL VOC data set is located 11 | [here](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar). 12 | To download, extract and convert it to TFRecords, run the following commands 13 | below: 14 | 15 | ```bash 16 | # From tensorflow/models 17 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 18 | tar -xvf VOCtrainval_11-May-2012.tar 19 | python object_detection/create_pascal_tf_record.py \ 20 | --label_map_path=object_detection/data/pascal_label_map.pbtxt \ 21 | --data_dir=VOCdevkit --year=VOC2012 --set=train \ 22 | --output_path=pascal_train.record 23 | python object_detection/create_pascal_tf_record.py \ 24 | --label_map_path=object_detection/data/pascal_label_map.pbtxt \ 25 | --data_dir=VOCdevkit --year=VOC2012 --set=val \ 26 | --output_path=pascal_val.record 27 | ``` 28 | 29 | You should end up with two TFRecord files named `pascal_train.record` and 30 | `pascal_val.record` in the `tensorflow/models` directory. 31 | 32 | The label map for the PASCAL VOC data set can be found at 33 | `object_detection/data/pascal_label_map.pbtxt`. 34 | 35 | ## Generating the Oxford-IIIT Pet TFRecord files. 36 | 37 | The Oxford-IIIT Pet data set is located 38 | [here](http://www.robots.ox.ac.uk/~vgg/data/pets/). To download, extract and 39 | convert it to TFRecrods, run the following commands below: 40 | 41 | ```bash 42 | # From tensorflow/models 43 | wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz 44 | wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz 45 | tar -xvf annotations.tar.gz 46 | tar -xvf images.tar.gz 47 | python object_detection/create_pet_tf_record.py \ 48 | --label_map_path=object_detection/data/pet_label_map.pbtxt \ 49 | --data_dir=`pwd` \ 50 | --output_dir=`pwd` 51 | ``` 52 | 53 | You should end up with two TFRecord files named `pet_train.record` and 54 | `pet_val.record` in the `tensorflow/models` directory. 55 | 56 | The label map for the Pet dataset can be found at 57 | `object_detection/data/pet_label_map.pbtxt`. 58 | -------------------------------------------------------------------------------- /object_detection/g3doc/running_locally.md: -------------------------------------------------------------------------------- 1 | # Running Locally 2 | 3 | This page walks through the steps required to train an object detection model 4 | on a local machine. It assumes the reader has completed the 5 | following prerequisites: 6 | 7 | 1. The Tensorflow Object Detection API has been installed as documented in the 8 | [installation instructions](installation.md). This includes installing library 9 | dependencies, compiling the configuration protobufs and setting up the Python 10 | environment. 11 | 2. A valid data set has been created. See [this page](preparing_inputs.md) for 12 | instructions on how to generate a dataset for the PASCAL VOC challenge or the 13 | Oxford-IIIT Pet dataset. 14 | 3. A Object Detection pipeline configuration has been written. See 15 | [this page](configuring_jobs.md) for details on how to write a pipeline configuration. 16 | 17 | ## Recommended Directory Structure for Training and Evaluation 18 | 19 | ``` 20 | +data 21 | -label_map file 22 | -train TFRecord file 23 | -eval TFRecord file 24 | +models 25 | + model 26 | -pipeline config file 27 | +train 28 | +eval 29 | ``` 30 | 31 | ## Running the Training Job 32 | 33 | A local training job can be run with the following command: 34 | 35 | ```bash 36 | # From the tensorflow/models/ directory 37 | python object_detection/train.py \ 38 | --logtostderr \ 39 | --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 40 | --train_dir=${PATH_TO_TRAIN_DIR} 41 | ``` 42 | 43 | where `${PATH_TO_YOUR_PIPELINE_CONFIG}` points to the pipeline config and 44 | `${PATH_TO_TRAIN_DIR}` points to the directory in which training checkpoints 45 | and events will be written to. By default, the training job will 46 | run indefinitely until the user kills it. 47 | 48 | ## Running the Evaluation Job 49 | 50 | Evaluation is run as a separate job. The eval job will periodically poll the 51 | train directory for new checkpoints and evaluate them on a test dataset. The 52 | job can be run using the following command: 53 | 54 | ```bash 55 | # From the tensorflow/models/ directory 56 | python object_detection/eval.py \ 57 | --logtostderr \ 58 | --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 59 | --checkpoint_dir=${PATH_TO_TRAIN_DIR} \ 60 | --eval_dir=${PATH_TO_EVAL_DIR} 61 | ``` 62 | 63 | where `${PATH_TO_YOUR_PIPELINE_CONFIG}` points to the pipeline config, 64 | `${PATH_TO_TRAIN_DIR}` points to the directory in which training checkpoints 65 | were saved (same as the training job) and `${PATH_TO_EVAL_DIR}` points to the 66 | directory in which evaluation events will be saved. As with the training job, 67 | the eval job run until terminated by default. 68 | 69 | ## Running Tensorboard 70 | 71 | Progress for training and eval jobs can be inspected using Tensorboard. If 72 | using the recommended directory structure, Tensorboard can be run using the 73 | following command: 74 | 75 | ```bash 76 | tensorboard --logdir=${PATH_TO_MODEL_DIRECTORY} 77 | ``` 78 | 79 | where `${PATH_TO_MODEL_DIRECTORY}` points to the directory that contains the 80 | train and eval directories. Please note it may take Tensorboard a couple minutes 81 | to populate with data. 82 | -------------------------------------------------------------------------------- /object_detection/g3doc/running_notebook.md: -------------------------------------------------------------------------------- 1 | # Quick Start: Jupyter notebook for off-the-shelf inference 2 | 3 | If you'd like to hit the ground running and run detection on a few example 4 | images right out of the box, we recommend trying out the Jupyter notebook demo. 5 | To run the Jupyter notebook, run the following command from 6 | `tensorflow/models/object_detection`: 7 | 8 | ``` 9 | # From tensorflow/models/object_detection 10 | jupyter notebook 11 | ``` 12 | 13 | The notebook should open in your favorite web browser. Click the 14 | [`object_detection_tutorial.ipynb`](../object_detection_tutorial.ipynb) link to 15 | open the demo. 16 | -------------------------------------------------------------------------------- /object_detection/g3doc/train_and_eval.md: -------------------------------------------------------------------------------- 1 | # Training and Evaluation 2 | 3 | ## Directory Structure for Training and Evaluation 4 | 5 | ``` 6 | # from mtl-ssl-detection/objct_detection 7 | 8 | +configs 9 | +test 10 | -pipeline configuration files 11 | +data 12 | -label_map file 13 | +mscoco 14 | -train TFRecord file 15 | -eval TFRecord file 16 | +voc 17 | -train TFRecord file 18 | -eval TFRecord file 19 | +checkpoints 20 | +train 21 | +model 22 | -checkpoints file 23 | +eval 24 | +model 25 | -tensorflow summary file / detection results 26 | ``` 27 | 28 | ## Running the Training Job 29 | 30 | A local training job can be run with the following command: 31 | 32 | ```bash 33 | from mtl-ssl-detection/objct_detection/scripts 34 | bash run_train.sh ${MODEL_NAME} 35 | ``` 36 | 37 | where `${MODEL_NAME}` points to the model name (e.g. model11). 38 | 39 | 40 | ## Running the Evaluation Job 41 | 42 | Evaluation is run as a separate job. The eval job will periodically poll the 43 | train directory for new checkpoints and evaluate them on a test dataset. The 44 | job can be run using the following command: 45 | 46 | ```bash 47 | from mtl-ssl-detection/objct_detection/scripts 48 | bash run_eval.sh ${MODEL_NAME} 49 | ``` 50 | 51 | 52 | 53 | ## Running Tensorboard 54 | 55 | Progress for training and eval jobs can be inspected using Tensorboard. If 56 | using the recommended directory structure, Tensorboard can be run using the 57 | following command: 58 | 59 | ```bash 60 | tensorboard --logdir=${PATH_TO_MODEL_DIRECTORY} 61 | ``` 62 | 63 | where `${PATH_TO_MODEL_DIRECTORY}` points to the directory that contains the 64 | train and eval directories. Please note it may take Tensorboard a couple minutes 65 | to populate with data. 66 | -------------------------------------------------------------------------------- /object_detection/matchers/BUILD: -------------------------------------------------------------------------------- 1 | # Tensorflow Object Detection API: Matcher implementations. 2 | 3 | package( 4 | default_visibility = ["//visibility:public"], 5 | ) 6 | 7 | licenses(["notice"]) 8 | 9 | # Apache 2.0 10 | py_library( 11 | name = "argmax_matcher", 12 | srcs = [ 13 | "argmax_matcher.py", 14 | ], 15 | deps = [ 16 | "//tensorflow", 17 | "//tensorflow_models/object_detection/core:matcher", 18 | ], 19 | ) 20 | 21 | py_test( 22 | name = "argmax_matcher_test", 23 | srcs = ["argmax_matcher_test.py"], 24 | deps = [ 25 | ":argmax_matcher", 26 | "//tensorflow", 27 | ], 28 | ) 29 | 30 | py_library( 31 | name = "bipartite_matcher", 32 | srcs = [ 33 | "bipartite_matcher.py", 34 | ], 35 | deps = [ 36 | "//tensorflow", 37 | "//tensorflow/contrib/image:image_py", 38 | "//tensorflow_models/object_detection/core:matcher", 39 | ], 40 | ) 41 | 42 | py_test( 43 | name = "bipartite_matcher_test", 44 | srcs = [ 45 | "bipartite_matcher_test.py", 46 | ], 47 | deps = [ 48 | ":bipartite_matcher", 49 | "//tensorflow", 50 | ], 51 | ) 52 | -------------------------------------------------------------------------------- /object_detection/matchers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/matchers/__init__.py -------------------------------------------------------------------------------- /object_detection/matchers/bipartite_matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Bipartite matcher implementation.""" 17 | 18 | import tensorflow as tf 19 | 20 | from tensorflow.contrib.image.python.ops import image_ops 21 | from object_detection.core import matcher 22 | 23 | 24 | class GreedyBipartiteMatcher(matcher.Matcher): 25 | """Wraps a Tensorflow greedy bipartite matcher.""" 26 | 27 | def _match(self, similarity_matrix, num_valid_rows=-1): 28 | """Bipartite matches a collection rows and columns. A greedy bi-partite. 29 | 30 | TODO: Add num_valid_columns options to match only that many columns with 31 | all the rows. 32 | 33 | Args: 34 | similarity_matrix: Float tensor of shape [N, M] with pairwise similarity 35 | where higher values mean more similar. 36 | num_valid_rows: A scalar or a 1-D tensor with one element describing the 37 | number of valid rows of similarity_matrix to consider for the bipartite 38 | matching. If set to be negative, then all rows from similarity_matrix 39 | are used. 40 | 41 | Returns: 42 | match_results: int32 tensor of shape [M] with match_results[i]=-1 43 | meaning that column i is not matched and otherwise that it is matched to 44 | row match_results[i]. 45 | """ 46 | # Convert similarity matrix to distance matrix as tf.image.bipartite tries 47 | # to find minimum distance matches. 48 | distance_matrix = -1 * similarity_matrix 49 | _, match_results = image_ops.bipartite_match( 50 | distance_matrix, num_valid_rows) 51 | match_results = tf.reshape(match_results, [-1]) 52 | match_results = tf.cast(match_results, tf.int32) 53 | return match_results 54 | -------------------------------------------------------------------------------- /object_detection/matchers/bipartite_matcher_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.core.bipartite_matcher.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.matchers import bipartite_matcher 21 | 22 | 23 | class GreedyBipartiteMatcherTest(tf.test.TestCase): 24 | 25 | def test_get_expected_matches_when_all_rows_are_valid(self): 26 | similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]]) 27 | num_valid_rows = 2 28 | expected_match_results = [-1, 1, 0] 29 | 30 | matcher = bipartite_matcher.GreedyBipartiteMatcher() 31 | match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows) 32 | with self.test_session() as sess: 33 | match_results_out = sess.run(match._match_results) 34 | self.assertAllEqual(match_results_out, expected_match_results) 35 | 36 | def test_get_expected_matches_with_valid_rows_set_to_minus_one(self): 37 | similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]]) 38 | num_valid_rows = -1 39 | expected_match_results = [-1, 1, 0] 40 | 41 | matcher = bipartite_matcher.GreedyBipartiteMatcher() 42 | match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows) 43 | with self.test_session() as sess: 44 | match_results_out = sess.run(match._match_results) 45 | self.assertAllEqual(match_results_out, expected_match_results) 46 | 47 | def test_get_no_matches_with_zero_valid_rows(self): 48 | similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]]) 49 | num_valid_rows = 0 50 | expected_match_results = [-1, -1, -1] 51 | 52 | matcher = bipartite_matcher.GreedyBipartiteMatcher() 53 | match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows) 54 | with self.test_session() as sess: 55 | match_results_out = sess.run(match._match_results) 56 | self.assertAllEqual(match_results_out, expected_match_results) 57 | 58 | def test_get_expected_matches_with_only_one_valid_row(self): 59 | similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]]) 60 | num_valid_rows = 1 61 | expected_match_results = [-1, -1, 0] 62 | 63 | matcher = bipartite_matcher.GreedyBipartiteMatcher() 64 | match = matcher.match(similarity_matrix, num_valid_rows=num_valid_rows) 65 | with self.test_session() as sess: 66 | match_results_out = sess.run(match._match_results) 67 | self.assertAllEqual(match_results_out, expected_match_results) 68 | 69 | 70 | if __name__ == '__main__': 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /object_detection/meta_architectures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/meta_architectures/__init__.py -------------------------------------------------------------------------------- /object_detection/meta_architectures/rfcn_meta_arch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.meta_architectures.rfcn_meta_arch.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib 21 | from object_detection.meta_architectures import rfcn_meta_arch 22 | 23 | 24 | class RFCNMetaArchTest( 25 | faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase): 26 | 27 | def _get_second_stage_box_predictor_text_proto(self): 28 | box_predictor_text_proto = """ 29 | rfcn_box_predictor { 30 | conv_hyperparams { 31 | op: CONV 32 | activation: NONE 33 | regularizer { 34 | l2_regularizer { 35 | weight: 0.0005 36 | } 37 | } 38 | initializer { 39 | variance_scaling_initializer { 40 | factor: 1.0 41 | uniform: true 42 | mode: FAN_AVG 43 | } 44 | } 45 | } 46 | } 47 | """ 48 | return box_predictor_text_proto 49 | 50 | def _get_model(self, box_predictor, **common_kwargs): 51 | return rfcn_meta_arch.RFCNMetaArch( 52 | second_stage_rfcn_box_predictor=box_predictor, **common_kwargs) 53 | 54 | 55 | if __name__ == '__main__': 56 | tf.test.main() 57 | -------------------------------------------------------------------------------- /object_detection/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/models/__init__.py -------------------------------------------------------------------------------- /object_detection/models/ssd_vgg_16_feature_extractor.py: -------------------------------------------------------------------------------- 1 | """SSDFeatureExtractor for VGG16 features.""" 2 | import tensorflow as tf 3 | 4 | from object_detection.meta_architectures import ssd_meta_arch 5 | from object_detection.models import feature_map_generators 6 | from nets import vgg 7 | 8 | slim = tf.contrib.slim 9 | 10 | 11 | class SSDVgg16FeatureExtractor( 12 | ssd_meta_arch.SSDFeatureExtractor): 13 | 14 | def __init__(self, 15 | depth_multiplier, 16 | min_depth, 17 | conv_hyperparams, 18 | reuse_weights=None): 19 | """VGG16 Feature Extractor for SSD Models.""" 20 | super(SSDVgg16FeatureExtractor, self).__init__( 21 | depth_multiplier, min_depth, conv_hyperparams, reuse_weights) 22 | 23 | def preprocess(self, resized_inputs): 24 | """SSD preprocessing""" 25 | # TODO: Subtract RGB mean instead 26 | return (2.0 / 255.0) * resized_inputs - 1.0 27 | 28 | def extract_features(self, preprocessed_inputs): 29 | """Extract features from preprocessed inputs. 30 | 31 | Args: 32 | preprocessed_inputs: a [batch, height, width, channels] float tensor 33 | representing a batch of images. 34 | 35 | Returns: 36 | feature_maps: a list of tensors where the ith tensor has shape 37 | [batch, height_i, width_i, depth_i] 38 | """ 39 | preprocessed_inputs.get_shape().assert_has_rank(4) 40 | shape_assert = tf.Assert( 41 | tf.logical_and(tf.greater_equal(tf.shape(preprocessed_inputs)[1], 33), 42 | tf.greater_equal(tf.shape(preprocessed_inputs)[2], 33)), 43 | ['image size must at least be 33 in both height and width.']) 44 | 45 | feature_map_layout = { 46 | 'from_layer': ['conv4', '', '', '', '', '', ''], 47 | 'layer_depth': [-1, 1024, 1024, 512, 256, 256, 256], 48 | } 49 | 50 | with tf.control_dependencies([shape_assert]): 51 | with slim.arg_scope(self._conv_hyperparams): 52 | with tf.variable_scope('vgg_16', 53 | reuse=self._reuse_weights) as scope: 54 | net, image_features = vgg.vgg_16_base( 55 | preprocessed_inputs, 56 | final_endpoint='pool5', 57 | trainable=False, 58 | scope=scope) 59 | feature_maps = feature_map_generators.multi_resolution_feature_maps( 60 | feature_map_layout=feature_map_layout, 61 | depth_multiplier=self._depth_multiplier, 62 | min_depth=self._min_depth, 63 | insert_1x1_conv=True, 64 | image_features=image_features) 65 | 66 | return feature_maps.values() 67 | -------------------------------------------------------------------------------- /object_detection/object_detection_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | } 10 | ], 11 | "metadata": {}, 12 | "nbformat": 4, 13 | "nbformat_minor": 0 14 | } 15 | -------------------------------------------------------------------------------- /object_detection/protos/.gitignore: -------------------------------------------------------------------------------- 1 | *.py 2 | -------------------------------------------------------------------------------- /object_detection/protos/anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/grid_anchor_generator.proto"; 6 | import "object_detection/protos/ssd_anchor_generator.proto"; 7 | 8 | // Configuration proto for the anchor generator to use in the object detection 9 | // pipeline. See core/anchor_generator.py for details. 10 | message AnchorGenerator { 11 | oneof anchor_generator_oneof { 12 | GridAnchorGenerator grid_anchor_generator = 1; 13 | SsdAnchorGenerator ssd_anchor_generator = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /object_detection/protos/argmax_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for ArgMaxMatcher. See 6 | // matchers/argmax_matcher.py for details. 7 | message ArgMaxMatcher { 8 | // Threshold for positive matches. 9 | optional float matched_threshold = 1 [default = 0.5]; 10 | 11 | // Threshold for negative matches. 12 | optional float unmatched_threshold = 2 [default = 0.5]; 13 | 14 | // Whether to construct ArgMaxMatcher without thresholds. 15 | optional bool ignore_thresholds = 3 [default = false]; 16 | 17 | // If True then negative matches are the ones below the unmatched_threshold, 18 | // whereas ignored matches are in between the matched and umatched 19 | // threshold. If False, then negative matches are in between the matched 20 | // and unmatched threshold, and everything lower than unmatched is ignored. 21 | optional bool negatives_lower_than_unmatched = 4 [default = true]; 22 | 23 | // Whether to ensure each row is matched to at least one column. 24 | optional bool force_match_for_each_row = 5 [default = false]; 25 | } 26 | -------------------------------------------------------------------------------- /object_detection/protos/bipartite_matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for bipartite matcher. See 6 | // matchers/bipartite_matcher.py for details. 7 | message BipartiteMatcher { 8 | } 9 | -------------------------------------------------------------------------------- /object_detection/protos/box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn_box_coder.proto"; 6 | import "object_detection/protos/mean_stddev_box_coder.proto"; 7 | import "object_detection/protos/square_box_coder.proto"; 8 | 9 | // Configuration proto for the box coder to be used in the object detection 10 | // pipeline. See core/box_coder.py for details. 11 | message BoxCoder { 12 | oneof box_coder_oneof { 13 | FasterRcnnBoxCoder faster_rcnn_box_coder = 1; 14 | MeanStddevBoxCoder mean_stddev_box_coder = 2; 15 | SquareBoxCoder square_box_coder = 3; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /object_detection/protos/eval.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Message for configuring DetectionModel evaluation jobs (eval.py). 6 | message EvalConfig { 7 | // Number of visualization images to generate. 8 | optional uint32 num_visualizations = 1 [default=10]; 9 | 10 | // Number of examples to process of evaluation. 11 | optional uint32 num_examples = 2 [default=5000]; 12 | 13 | // How often to run evaluation. 14 | optional uint32 eval_interval_secs = 3 [default=120]; 15 | 16 | // Maximum number of times to run evaluation. If set to 0, will run forever. 17 | optional uint32 max_evals = 4 [default=0]; 18 | 19 | // Whether the TensorFlow graph used for evaluation should be saved to disk. 20 | optional bool save_graph = 5 [default=false]; 21 | 22 | // Path to directory to store visualizations in. If empty, visualization 23 | // images are not exported (only shown on Tensorboard). 24 | optional string visualization_export_dir = 6 [default=""]; 25 | 26 | // BNS name of the TensorFlow master. 27 | optional string eval_master = 7 [default=""]; 28 | 29 | // Type of metrics to use for evaluation. Currently supports only Pascal VOC 30 | // detection metrics. 31 | optional string metrics_set = 8 [default="pascal_voc_metrics"]; 32 | 33 | // Path to export detections to COCO compatible JSON format. 34 | optional string export_path = 9 [default='']; 35 | 36 | // Option to not read groundtruth labels and only export detections to 37 | // COCO-compatible JSON file. 38 | optional bool ignore_groundtruth = 10 [default=false]; 39 | 40 | // Use exponential moving averages of variables for evaluation. 41 | // TODO: When this is false make sure the model is constructed 42 | // without moving averages in restore_fn. 43 | optional bool use_moving_averages = 11 [default=false]; 44 | 45 | // Whether to evaluate instance masks. 46 | optional bool eval_instance_masks = 12 [default=false]; 47 | 48 | // Float determining the IoU threshold at which a box is considered correct. 49 | optional float iou_threshold = 13 [default=0.5]; 50 | 51 | // Type of NMS (standard|soft-linear|soft-gaussian) 52 | optional string nms_type = 14 [default="standard"]; 53 | 54 | // NMS IoU threshold. 55 | optional float nms_threshold = 15 [default=1.0]; 56 | 57 | // Soft NMS sigma. 58 | optional float soft_nms_sigma = 16 [default=0.5]; 59 | 60 | // coco eval options 61 | optional CocoEvalOptions coco_eval_options = 17; 62 | 63 | // main subset 64 | optional string main_subset = 18 [default='']; 65 | 66 | optional bool submission_format_output = 19 [default=false]; 67 | 68 | optional bool calc_loss = 20 [default=false]; 69 | } 70 | 71 | message CocoEvalOptions { 72 | // Eval metrics : referr to http://cocodataset.org/#detections-eval 73 | repeated int32 eval_metric_index = 1; 74 | 75 | // 76 | optional int32 eval_class_type = 2 [default=0]; 77 | 78 | 79 | optional string eval_ann_filename = 3 [default='../data/mscoco/annotations/instances_val2017.json']; 80 | } 81 | 82 | -------------------------------------------------------------------------------- /object_detection/protos/faster_rcnn_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for FasterRCNNBoxCoder. See 6 | // box_coders/faster_rcnn_box_coder.py for details. 7 | message FasterRcnnBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box height. 13 | optional float height_scale = 3 [default = 5.0]; 14 | 15 | // Scale factor for anchor encoded box width. 16 | optional float width_scale = 4 [default = 5.0]; 17 | } 18 | -------------------------------------------------------------------------------- /object_detection/protos/grid_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for GridAnchorGenerator. See 6 | // anchor_generators/grid_anchor_generator.py for details. 7 | message GridAnchorGenerator { 8 | // Anchor height in pixels. 9 | optional int32 height = 1 [default = 256]; 10 | 11 | // Anchor width in pixels. 12 | optional int32 width = 2 [default = 256]; 13 | 14 | // Anchor stride in height dimension in pixels. 15 | optional int32 height_stride = 3 [default = 16]; 16 | 17 | // Anchor stride in width dimension in pixels. 18 | optional int32 width_stride = 4 [default = 16]; 19 | 20 | // Anchor height offset in pixels. 21 | optional int32 height_offset = 5 [default = 0]; 22 | 23 | // Anchor width offset in pixels. 24 | optional int32 width_offset = 6 [default = 0]; 25 | 26 | // At any given location, len(scales) * len(aspect_ratios) anchors are 27 | // generated with all possible combinations of scales and aspect ratios. 28 | 29 | // List of scales for the anchors. 30 | repeated float scales = 7; 31 | 32 | // List of aspect ratios for the anchors. 33 | repeated float aspect_ratios = 8; 34 | 35 | // Whether to use scales of height/width directly 36 | // If false, use scales and aspect_ratios 37 | optional bool use_hw_scales = 9 [default = false]; 38 | 39 | // List of scales for the heights of anchors. 40 | repeated float height_scales = 10; 41 | 42 | // List of scales for the widths of anchors. 43 | repeated float width_scales = 11; 44 | 45 | // Whether to align anchors on the left-top corner 46 | // If false, on the center point 47 | optional bool align_lefttop = 12 [default = false]; 48 | } 49 | -------------------------------------------------------------------------------- /object_detection/protos/image_resizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for image resizing operations. 6 | // See builders/image_resizer_builder.py for details. 7 | message ImageResizer { 8 | oneof image_resizer_oneof { 9 | KeepAspectRatioResizer keep_aspect_ratio_resizer = 1; 10 | FixedShapeResizer fixed_shape_resizer = 2; 11 | } 12 | } 13 | 14 | 15 | // Configuration proto for image resizer that keeps aspect ratio. 16 | message KeepAspectRatioResizer { 17 | // Desired size of the smaller image dimension in pixels. 18 | optional int32 min_dimension = 1 [default = 600]; 19 | 20 | // Desired size of the larger image dimension in pixels. 21 | optional int32 max_dimension = 2 [default = 1024]; 22 | } 23 | 24 | 25 | // Configuration proto for image resizer that resizes to a fixed shape. 26 | message FixedShapeResizer { 27 | // Desired height of image in pixels. 28 | optional int32 height = 1 [default = 300]; 29 | 30 | // Desired width of image in pixels. 31 | optional int32 width = 2 [default = 300]; 32 | } 33 | -------------------------------------------------------------------------------- /object_detection/protos/input_reader.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for defining input readers that generate Object Detection 6 | // Examples from input sources. Input readers are expected to generate a 7 | // dictionary of tensors, with the following fields populated: 8 | // 9 | // 'image': an [image_height, image_width, channels] image tensor that detection 10 | // will be run on. 11 | // 'groundtruth_classes': a [num_boxes] int32 tensor storing the class 12 | // labels of detected boxes in the image. 13 | // 'groundtruth_boxes': a [num_boxes, 4] float tensor storing the coordinates of 14 | // detected boxes in the image. 15 | // 'groundtruth_instance_masks': (Optional), a [num_boxes, image_height, 16 | // image_width] float tensor storing binary mask of the objects in boxes. 17 | 18 | message InputReader { 19 | // Path to StringIntLabelMap pbtxt file specifying the mapping from string 20 | // labels to integer ids. 21 | optional string label_map_path = 1 [default=""]; 22 | 23 | // Whether data should be processed in the order they are read in, or 24 | // shuffled randomly. 25 | optional bool shuffle = 2 [default=true]; 26 | 27 | // Maximum number of records to keep in reader queue. 28 | optional uint32 queue_capacity = 3 [default=2000]; 29 | 30 | // Minimum number of records to keep in reader queue. A large value is needed 31 | // to generate a good random shuffle. 32 | optional uint32 min_after_dequeue = 4 [default=1000]; 33 | 34 | // The number of times a data source is read. If set to zero, the data source 35 | // will be reused indefinitely. 36 | optional uint32 num_epochs = 5 [default=0]; 37 | 38 | // Number of reader instances to create. 39 | optional uint32 num_readers = 6 [default=8]; 40 | 41 | // Whether to load groundtruth instance masks. 42 | optional bool load_instance_masks = 7 [default = false]; 43 | 44 | oneof input_reader { 45 | TFRecordInputReader tf_record_input_reader = 8; 46 | ExternalInputReader external_input_reader = 9; 47 | } 48 | } 49 | 50 | // An input reader that reads TF Example protos from local TFRecord files. 51 | message TFRecordInputReader { 52 | // Path to TFRecordFile. 53 | optional string input_path = 1 [default=""]; 54 | } 55 | 56 | // An externally defined input reader. Users may define an extension to this 57 | // proto to interface their own input readers. 58 | message ExternalInputReader { 59 | extensions 1 to 999; 60 | } 61 | -------------------------------------------------------------------------------- /object_detection/protos/mask_predictor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/hyperparams.proto"; 6 | 7 | message MaskPredictor { 8 | optional bool trainable = 1 [default=true]; 9 | optional int32 kernel_size = 2 [default=3]; 10 | optional Hyperparams conv_hyperparams = 3; 11 | } 12 | -------------------------------------------------------------------------------- /object_detection/protos/matcher.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/argmax_matcher.proto"; 6 | import "object_detection/protos/bipartite_matcher.proto"; 7 | 8 | // Configuration proto for the matcher to be used in the object detection 9 | // pipeline. See core/matcher.py for details. 10 | message Matcher { 11 | oneof matcher_oneof { 12 | ArgMaxMatcher argmax_matcher = 1; 13 | BipartiteMatcher bipartite_matcher = 2; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /object_detection/protos/mean_stddev_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for MeanStddevBoxCoder. See 6 | // box_coders/mean_stddev_box_coder.py for details. 7 | message MeanStddevBoxCoder { 8 | } 9 | -------------------------------------------------------------------------------- /object_detection/protos/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/faster_rcnn.proto"; 6 | import "object_detection/protos/ssd.proto"; 7 | 8 | import "object_detection/protos/hyperparams.proto"; 9 | import "object_detection/protos/box_predictor.proto"; 10 | import "object_detection/protos/mask_predictor.proto"; 11 | 12 | // Top level configuration for DetectionModels. 13 | message DetectionModel { 14 | oneof model { 15 | FasterRcnn faster_rcnn = 1; 16 | Ssd ssd = 2; 17 | } 18 | 19 | // Path to file which has initial values of Variables. 20 | // It is similar to train.fine_tune_checkpoint, but not ckpt file. 21 | optional string init_file = 3 [default=""]; 22 | 23 | // Whether to use multi-task learning 24 | optional MTL mtl = 4; 25 | } 26 | 27 | message MTL { 28 | // Use 29 | optional bool refine = 1 [default=false]; 30 | optional bool window = 2 [default=false]; 31 | optional bool closeness = 4 [default=false]; 32 | optional bool edgemask = 5 [default=false]; 33 | 34 | // Refine 35 | optional int32 refine_num_fc_layers = 6 [default=0]; 36 | optional Hyperparams refiner_fc_hyperparams = 7; 37 | optional bool refine_residue = 8 [default=false]; 38 | optional float refine_dropout_rate = 9 [default=1.0]; 39 | 40 | // loss weight 41 | optional float window_class_loss_weight = 10 [default=0.0]; 42 | optional float closeness_loss_weight = 11 [default=0.0]; 43 | optional float edgemask_loss_weight = 12 [default=0.0]; 44 | optional float refined_classification_loss_weight = 13 [default=0.0]; 45 | 46 | // Predictor 47 | optional BoxPredictor window_box_predictor = 14; 48 | optional BoxPredictor closeness_box_predictor = 15; 49 | optional MaskPredictor edgemask_predictor = 16; 50 | 51 | // ETC 52 | optional string shared_feature = 17 [default="proposal_feature_maps"]; 53 | optional bool stop_gradient_for_aux_tasks = 18 [default=false]; 54 | optional bool share_second_stage_init = 19 [default=true]; 55 | optional bool stop_gradient_for_prediction_org = 20 [default=false]; 56 | optional bool global_closeness = 21 [default=true]; 57 | optional bool edgemask_weighted = 22 [default=true]; 58 | } 59 | -------------------------------------------------------------------------------- /object_detection/protos/optimizer.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Messages for configuring the optimizing strategy for training object 6 | // detection models. 7 | 8 | // Top level optimizer message. 9 | message Optimizer { 10 | oneof optimizer { 11 | RMSPropOptimizer rms_prop_optimizer = 1; 12 | MomentumOptimizer momentum_optimizer = 2; 13 | AdamOptimizer adam_optimizer = 3; 14 | } 15 | optional bool use_moving_average = 4 [default=true]; 16 | optional float moving_average_decay = 5 [default=0.9999]; 17 | } 18 | 19 | // Configuration message for the RMSPropOptimizer 20 | // See: https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer 21 | message RMSPropOptimizer { 22 | optional LearningRate learning_rate = 1; 23 | optional float momentum_optimizer_value = 2 [default=0.9]; 24 | optional float decay = 3 [default=0.9]; 25 | optional float epsilon = 4 [default=1.0]; 26 | } 27 | 28 | // Configuration message for the MomentumOptimizer 29 | // See: https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer 30 | message MomentumOptimizer { 31 | optional LearningRate learning_rate = 1; 32 | optional float momentum_optimizer_value = 2 [default=0.9]; 33 | } 34 | 35 | // Configuration message for the AdamOptimizer 36 | // See: https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 37 | message AdamOptimizer { 38 | optional LearningRate learning_rate = 1; 39 | optional float beta1 = 2 [default=0.9]; 40 | optional float beta2 = 3 [default=0.999]; 41 | optional float epsilon = 4 [default=0.00000001]; 42 | } 43 | 44 | // Configuration message for optimizer learning rate. 45 | message LearningRate { 46 | oneof learning_rate { 47 | ConstantLearningRate constant_learning_rate = 1; 48 | ExponentialDecayLearningRate exponential_decay_learning_rate = 2; 49 | ManualStepLearningRate manual_step_learning_rate = 3; 50 | } 51 | } 52 | 53 | // Configuration message for a constant learning rate. 54 | message ConstantLearningRate { 55 | optional float learning_rate = 1 [default=0.002]; 56 | } 57 | 58 | // Configuration message for an exponentially decaying learning rate. 59 | // See https://www.tensorflow.org/versions/master/api_docs/python/train/ \ 60 | // decaying_the_learning_rate#exponential_decay 61 | message ExponentialDecayLearningRate { 62 | optional float initial_learning_rate = 1 [default=0.002]; 63 | optional uint32 decay_steps = 2 [default=4000000]; 64 | optional float decay_factor = 3 [default=0.95]; 65 | optional bool staircase = 4 [default=true]; 66 | } 67 | 68 | // Configuration message for a manually defined learning rate schedule. 69 | message ManualStepLearningRate { 70 | optional float initial_learning_rate = 1 [default=0.002]; 71 | message LearningRateSchedule { 72 | optional uint32 step = 1; 73 | optional float learning_rate = 2 [default=0.002]; 74 | } 75 | repeated LearningRateSchedule schedule = 2; 76 | } 77 | -------------------------------------------------------------------------------- /object_detection/protos/pipeline.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | import "object_detection/protos/eval.proto"; 6 | import "object_detection/protos/input_reader.proto"; 7 | import "object_detection/protos/model.proto"; 8 | import "object_detection/protos/train.proto"; 9 | 10 | // Convenience message for configuring a training and eval pipeline. Allows all 11 | // of the pipeline parameters to be configured from one file. 12 | message TrainEvalPipelineConfig { 13 | optional DetectionModel model = 1; 14 | optional TrainConfig train_config = 2; 15 | optional InputReader train_input_reader = 3; 16 | optional EvalConfig eval_config = 4; 17 | optional InputReader eval_input_reader = 5; 18 | } 19 | -------------------------------------------------------------------------------- /object_detection/protos/post_processing.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for non-max-suppression operation on a batch of 6 | // detections. 7 | message BatchNonMaxSuppression { 8 | // Scalar threshold for score (low scoring boxes are removed). 9 | optional float score_threshold = 1 [default = 0.0]; 10 | 11 | // Scalar threshold for IOU (boxes that have high IOU overlap 12 | // with previously selected boxes are removed). 13 | optional float iou_threshold = 2 [default = 0.6]; 14 | 15 | // Maximum number of detections to retain per class. 16 | optional int32 max_detections_per_class = 3 [default = 100]; 17 | 18 | // Maximum number of detections to retain across all classes. 19 | optional int32 max_total_detections = 4 [default = 100]; 20 | } 21 | 22 | // Configuration proto for post-processing predicted boxes and 23 | // scores. 24 | message PostProcessing { 25 | // Non max suppression parameters. 26 | optional BatchNonMaxSuppression batch_non_max_suppression = 1; 27 | 28 | // Enum to specify how to convert the detection scores. 29 | enum ScoreConverter { 30 | // Input scores equals output scores. 31 | IDENTITY = 0; 32 | 33 | // Applies a sigmoid on input scores. 34 | SIGMOID = 1; 35 | 36 | // Applies a softmax on input scores 37 | SOFTMAX = 2; 38 | } 39 | 40 | // Score converter for classification scores. 41 | optional ScoreConverter score_converter = 2 [default = IDENTITY]; 42 | } 43 | -------------------------------------------------------------------------------- /object_detection/protos/region_similarity_calculator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for region similarity calculators. See 6 | // core/region_similarity_calculator.py for details. 7 | message RegionSimilarityCalculator { 8 | oneof region_similarity { 9 | NegSqDistSimilarity neg_sq_dist_similarity = 1; 10 | IouSimilarity iou_similarity = 2; 11 | IoaSimilarity ioa_similarity = 3; 12 | } 13 | } 14 | 15 | // Configuration for negative squared distance similarity calculator. 16 | message NegSqDistSimilarity { 17 | } 18 | 19 | // Configuration for intersection-over-union (IOU) similarity calculator. 20 | message IouSimilarity { 21 | } 22 | 23 | // Configuration for intersection-over-area (IOA) similarity calculator. 24 | message IoaSimilarity { 25 | } 26 | -------------------------------------------------------------------------------- /object_detection/protos/square_box_coder.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SquareBoxCoder. See 6 | // box_coders/square_box_coder.py for details. 7 | message SquareBoxCoder { 8 | // Scale factor for anchor encoded box center. 9 | optional float y_scale = 1 [default = 10.0]; 10 | optional float x_scale = 2 [default = 10.0]; 11 | 12 | // Scale factor for anchor encoded box length. 13 | optional float length_scale = 3 [default = 5.0]; 14 | } 15 | -------------------------------------------------------------------------------- /object_detection/protos/ssd.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | package object_detection.protos; 3 | 4 | import "object_detection/protos/anchor_generator.proto"; 5 | import "object_detection/protos/box_coder.proto"; 6 | import "object_detection/protos/box_predictor.proto"; 7 | import "object_detection/protos/hyperparams.proto"; 8 | import "object_detection/protos/image_resizer.proto"; 9 | import "object_detection/protos/matcher.proto"; 10 | import "object_detection/protos/losses.proto"; 11 | import "object_detection/protos/post_processing.proto"; 12 | import "object_detection/protos/region_similarity_calculator.proto"; 13 | 14 | // Configuration for Single Shot Detection (SSD) models. 15 | message Ssd { 16 | 17 | // Number of classes to predict. 18 | optional int32 num_classes = 1; 19 | 20 | // Image resizer for preprocessing the input image. 21 | optional ImageResizer image_resizer = 2; 22 | 23 | // Feature extractor config. 24 | optional SsdFeatureExtractor feature_extractor = 3; 25 | 26 | // Box coder to encode the boxes. 27 | optional BoxCoder box_coder = 4; 28 | 29 | // Matcher to match groundtruth with anchors. 30 | optional Matcher matcher = 5; 31 | 32 | // Region similarity calculator to compute similarity of boxes. 33 | optional RegionSimilarityCalculator similarity_calculator = 6; 34 | 35 | // Box predictor to attach to the features. 36 | optional BoxPredictor box_predictor = 7; 37 | 38 | // Anchor generator to compute anchors. 39 | optional AnchorGenerator anchor_generator = 8; 40 | 41 | // Post processing to apply on the predictions. 42 | optional PostProcessing post_processing = 9; 43 | 44 | // Whether to normalize the loss by number of groundtruth boxes that match to 45 | // the anchors. 46 | optional bool normalize_loss_by_num_matches = 10 [default=true]; 47 | 48 | // Loss configuration for training. 49 | optional Loss loss = 11; 50 | } 51 | 52 | 53 | message SsdFeatureExtractor { 54 | // Type of ssd feature extractor. 55 | optional string type = 1; 56 | 57 | // The factor to alter the depth of the channels in the feature extractor. 58 | optional float depth_multiplier = 2 [default=1.0]; 59 | 60 | // Minimum number of the channels in the feature extractor. 61 | optional int32 min_depth = 3 [default=16]; 62 | 63 | // Hyperparameters for the feature extractor. 64 | optional Hyperparams conv_hyperparams = 4; 65 | 66 | // Whether to train or not (freeze). 67 | optional bool trainable = 5 [default=true]; 68 | } 69 | -------------------------------------------------------------------------------- /object_detection/protos/ssd_anchor_generator.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto2"; 2 | 3 | package object_detection.protos; 4 | 5 | // Configuration proto for SSD anchor generator described in 6 | // https://arxiv.org/abs/1512.02325. See 7 | // anchor_generators/multiple_grid_anchor_generator.py for details. 8 | message SsdAnchorGenerator { 9 | // Number of grid layers to create anchors for. 10 | optional int32 num_layers = 1 [default = 6]; 11 | 12 | // Scale of anchors corresponding to finest resolution. 13 | optional float min_scale = 2 [default = 0.2]; 14 | 15 | // Scale of anchors corresponding to coarsest resolution 16 | optional float max_scale = 3 [default = 0.95]; 17 | 18 | // Aspect ratios for anchors at each grid point. 19 | repeated float aspect_ratios = 4; 20 | 21 | // Whether to use the following aspect ratio and scale combination for the 22 | // layer with the finest resolution : (scale=0.1, aspect_ratio=1.0), 23 | // (scale=min_scale, aspect_ration=2.0), (scale=min_scale, aspect_ratio=0.5). 24 | optional bool reduce_boxes_in_lowest_layer = 5 [default = true]; 25 | } 26 | -------------------------------------------------------------------------------- /object_detection/protos/string_int_label_map.proto: -------------------------------------------------------------------------------- 1 | // Message to store the mapping from class label strings to class id. Datasets 2 | // use string labels to represent classes while the object detection framework 3 | // works with class ids. This message maps them so they can be converted back 4 | // and forth as needed. 5 | syntax = "proto2"; 6 | 7 | package object_detection.protos; 8 | 9 | message StringIntLabelMapItem { 10 | // String name. The most common practice is to set this to a MID or synsets 11 | // id. 12 | optional string name = 1; 13 | 14 | // Integer id that maps to the string name above. Label ids should start from 15 | // 1. 16 | optional int32 id = 2; 17 | 18 | // Human readable string label. 19 | optional string display_name = 3; 20 | }; 21 | 22 | message StringIntLabelMap { 23 | repeated StringIntLabelMapItem item = 1; 24 | }; 25 | -------------------------------------------------------------------------------- /object_detection/samples/cloud/cloud.yml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | runtimeVersion: "1.0" 3 | scaleTier: CUSTOM 4 | masterType: standard_gpu 5 | workerCount: 5 6 | workerType: standard_gpu 7 | parameterServerCount: 3 8 | parameterServerType: standard 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /object_detection/scripts/.gitignore: -------------------------------------------------------------------------------- 1 | *.pickle 2 | -------------------------------------------------------------------------------- /object_detection/scripts/run_eval.sh: -------------------------------------------------------------------------------- 1 | eval_label="$1" 2 | 3 | python ../eval.py \ 4 | --eval_label=$1 5 | -------------------------------------------------------------------------------- /object_detection/scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | train_label="$1" 2 | 3 | python ../train.py \ 4 | --train_label=$1 \ 5 | --num_clones=1 6 | -------------------------------------------------------------------------------- /object_detection/test_images/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/test_images/image1.jpg -------------------------------------------------------------------------------- /object_detection/test_images/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/test_images/image2.jpg -------------------------------------------------------------------------------- /object_detection/test_images/image_info.txt: -------------------------------------------------------------------------------- 1 | 2 | Image provenance: 3 | image1.jpg: https://commons.wikimedia.org/wiki/File:Baegle_dwa.jpg 4 | image2.jpg: Michael Miley, 5 | https://www.flickr.com/photos/mike_miley/4678754542/in/photolist-88rQHL-88oBVp-88oC2B-88rS6J-88rSqm-88oBLv-88oBC4 -------------------------------------------------------------------------------- /object_detection/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/object_detection/utils/__init__.py -------------------------------------------------------------------------------- /object_detection/utils/category_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for importing/exporting Object Detection categories.""" 17 | import csv 18 | 19 | import tensorflow as tf 20 | 21 | 22 | def load_categories_from_csv_file(csv_path): 23 | """Loads categories from a csv file. 24 | 25 | The CSV file should have one comma delimited numeric category id and string 26 | category name pair per line. For example: 27 | 28 | 0,"cat" 29 | 1,"dog" 30 | 2,"bird" 31 | ... 32 | 33 | Args: 34 | csv_path: Path to the csv file to be parsed into categories. 35 | Returns: 36 | categories: A list of dictionaries representing all possible categories. 37 | The categories will contain an integer 'id' field and a string 38 | 'name' field. 39 | Raises: 40 | ValueError: If the csv file is incorrectly formatted. 41 | """ 42 | categories = [] 43 | 44 | with tf.gfile.Open(csv_path, 'r') as csvfile: 45 | reader = csv.reader(csvfile, delimiter=',', quotechar='"') 46 | for row in reader: 47 | if not row: 48 | continue 49 | 50 | if len(row) != 2: 51 | raise ValueError('Expected 2 fields per row in csv: %s' % ','.join(row)) 52 | 53 | category_id = int(row[0]) 54 | category_name = row[1] 55 | categories.append({'id': category_id, 'name': category_name}) 56 | 57 | return categories 58 | 59 | 60 | def save_categories_to_csv_file(categories, csv_path): 61 | """Saves categories to a csv file. 62 | 63 | Args: 64 | categories: A list of dictionaries representing categories to save to file. 65 | Each category must contain an 'id' and 'name' field. 66 | csv_path: Path to the csv file to be parsed into categories. 67 | """ 68 | categories.sort(key=lambda x: x['id']) 69 | with tf.gfile.Open(csv_path, 'w') as csvfile: 70 | writer = csv.writer(csvfile, delimiter=',', quotechar='"') 71 | for category in categories: 72 | writer.writerow([category['id'], category['name']]) 73 | -------------------------------------------------------------------------------- /object_detection/utils/category_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.category_util.""" 17 | import os 18 | 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import category_util 22 | 23 | 24 | class EvalUtilTest(tf.test.TestCase): 25 | 26 | def test_load_categories_from_csv_file(self): 27 | csv_data = """ 28 | 0,"cat" 29 | 1,"dog" 30 | 2,"bird" 31 | """.strip(' ') 32 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 33 | with tf.gfile.Open(csv_path, 'wb') as f: 34 | f.write(csv_data) 35 | 36 | categories = category_util.load_categories_from_csv_file(csv_path) 37 | self.assertTrue({'id': 0, 'name': 'cat'} in categories) 38 | self.assertTrue({'id': 1, 'name': 'dog'} in categories) 39 | self.assertTrue({'id': 2, 'name': 'bird'} in categories) 40 | 41 | def test_save_categories_to_csv_file(self): 42 | categories = [ 43 | {'id': 0, 'name': 'cat'}, 44 | {'id': 1, 'name': 'dog'}, 45 | {'id': 2, 'name': 'bird'}, 46 | ] 47 | csv_path = os.path.join(self.get_temp_dir(), 'test.csv') 48 | category_util.save_categories_to_csv_file(categories, csv_path) 49 | saved_categories = category_util.load_categories_from_csv_file(csv_path) 50 | self.assertEqual(saved_categories, categories) 51 | 52 | 53 | if __name__ == '__main__': 54 | tf.test.main() 55 | -------------------------------------------------------------------------------- /object_detection/utils/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def int64_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 23 | 24 | 25 | def int64_list_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 27 | 28 | 29 | def bytes_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | 33 | def bytes_list_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | def read_examples_list(path): 42 | """Read list of training or validation examples. 43 | 44 | The file is assumed to contain a single example per line where the first 45 | token in the line is an identifier that allows us to find the image and 46 | annotation xml for that example. 47 | 48 | For example, the line: 49 | xyz 3 50 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 51 | 52 | Args: 53 | path: absolute path to examples list file. 54 | 55 | Returns: 56 | list of example identifiers (strings). 57 | """ 58 | with tf.gfile.GFile(path) as fid: 59 | lines = fid.readlines() 60 | return [line.strip().split(' ')[0] for line in lines] 61 | 62 | 63 | def recursive_parse_xml_to_dict(xml): 64 | """Recursively parses XML contents to python dict. 65 | 66 | We assume that `object` tags are the only ones that can appear 67 | multiple times at the same level of a tree. 68 | 69 | Args: 70 | xml: xml tree obtained by parsing XML file contents using lxml.etree 71 | 72 | Returns: 73 | Python dictionary holding XML contents. 74 | """ 75 | if not xml: 76 | return {xml.tag: xml.text} 77 | result = {} 78 | for child in xml: 79 | child_result = recursive_parse_xml_to_dict(child) 80 | if child.tag != 'object': 81 | result[child.tag] = child_result[child.tag] 82 | else: 83 | if child.tag not in result: 84 | result[child.tag] = [] 85 | result[child.tag].append(child_result[child.tag]) 86 | return {xml.tag: result} 87 | -------------------------------------------------------------------------------- /object_detection/utils/dataset_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.dataset_util.""" 17 | 18 | import os 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import dataset_util 22 | 23 | 24 | class DatasetUtilTest(tf.test.TestCase): 25 | 26 | def test_read_examples_list(self): 27 | example_list_data = """example1 1\nexample2 2""" 28 | example_list_path = os.path.join(self.get_temp_dir(), 'examples.txt') 29 | with tf.gfile.Open(example_list_path, 'wb') as f: 30 | f.write(example_list_data) 31 | 32 | examples = dataset_util.read_examples_list(example_list_path) 33 | self.assertListEqual(['example1', 'example2'], examples) 34 | 35 | 36 | if __name__ == '__main__': 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /object_detection/utils/debug_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | LOG = 'log_collection' 8 | IMAGE = 'image_collection' 9 | 10 | 11 | def make_print_tensor(t, message=None, first_n=None, summarize=None): 12 | print_tensor = tf.Print(t, [t], message=message, first_n=first_n, summarize=summarize) 13 | tf.add_to_collection(LOG, print_tensor) 14 | -------------------------------------------------------------------------------- /object_detection/utils/kwargs_util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | def get_layer_kwargs(scope, is_training=True, freeze_layer='', batch_norm=True, 7 | initializers=None): 8 | kwargs = {'scope': scope} 9 | if is_training and freeze_layer: 10 | # currently assume all freezed layers are 'conv' with 'bn'. 11 | freeze_no = int(freeze_layer[4:]) 12 | trainable = (int(scope[4:]) > freeze_no) 13 | kwargs['trainable'] = trainable # conv 14 | if batch_norm: 15 | kwargs['normalizer_params'] = { # bn 16 | 'trainable': trainable, 17 | 'is_training': trainable 18 | } 19 | if initializers is not None: 20 | if 'weights' in initializers[scope]: 21 | kwargs['weights_initializer'] = initializers[scope]['weights'] 22 | if 'biases' in initializers[scope]: 23 | kwargs['biases_initializer'] = initializers[scope]['biases'] 24 | if 'BatchNorm' in initializers[scope]: 25 | if 'normalizer_params' in kwargs: 26 | kwargs['normalizer_params']['param_initializers'] = initializers[scope]['BatchNorm'] 27 | else: 28 | kwargs['normalizer_params'] = { 29 | 'param_initializers': initializers[scope]['BatchNorm'] 30 | } 31 | return kwargs 32 | -------------------------------------------------------------------------------- /object_detection/utils/learning_schedules_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.learning_schedules.""" 17 | import tensorflow as tf 18 | 19 | from object_detection.utils import learning_schedules 20 | 21 | 22 | class LearningSchedulesTest(tf.test.TestCase): 23 | 24 | def testExponentialDecayWithBurnin(self): 25 | global_step = tf.placeholder(tf.int32, []) 26 | learning_rate_base = 1.0 27 | learning_rate_decay_steps = 3 28 | learning_rate_decay_factor = .1 29 | burnin_learning_rate = .5 30 | burnin_steps = 2 31 | exp_rates = [.5, .5, 1, .1, .1, .1, .01, .01] 32 | learning_rate = learning_schedules.exponential_decay_with_burnin( 33 | global_step, learning_rate_base, learning_rate_decay_steps, 34 | learning_rate_decay_factor, burnin_learning_rate, burnin_steps) 35 | with self.test_session() as sess: 36 | output_rates = [] 37 | for input_global_step in range(8): 38 | output_rate = sess.run(learning_rate, 39 | feed_dict={global_step: input_global_step}) 40 | output_rates.append(output_rate) 41 | self.assertAllClose(output_rates, exp_rates) 42 | 43 | def testManualStepping(self): 44 | global_step = tf.placeholder(tf.int64, []) 45 | boundaries = [2, 3, 7] 46 | rates = [1.0, 2.0, 3.0, 4.0] 47 | exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0] 48 | learning_rate = learning_schedules.manual_stepping(global_step, boundaries, 49 | rates) 50 | with self.test_session() as sess: 51 | output_rates = [] 52 | for input_global_step in range(10): 53 | output_rate = sess.run(learning_rate, 54 | feed_dict={global_step: input_global_step}) 55 | output_rates.append(output_rate) 56 | self.assertAllClose(output_rates, exp_rates) 57 | 58 | if __name__ == '__main__': 59 | tf.test.main() 60 | -------------------------------------------------------------------------------- /object_detection/utils/np_box_ops_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.np_box_ops.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import np_box_ops 22 | 23 | 24 | class BoxOpsTests(tf.test.TestCase): 25 | 26 | def setUp(self): 27 | boxes1 = np.array([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], 28 | dtype=float) 29 | boxes2 = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], 30 | [0.0, 0.0, 20.0, 20.0]], 31 | dtype=float) 32 | self.boxes1 = boxes1 33 | self.boxes2 = boxes2 34 | 35 | def testArea(self): 36 | areas = np_box_ops.area(self.boxes1) 37 | expected_areas = np.array([6.0, 5.0], dtype=float) 38 | self.assertAllClose(expected_areas, areas) 39 | 40 | def testIntersection(self): 41 | intersection = np_box_ops.intersection(self.boxes1, self.boxes2) 42 | expected_intersection = np.array([[2.0, 0.0, 6.0], [1.0, 0.0, 5.0]], 43 | dtype=float) 44 | self.assertAllClose(intersection, expected_intersection) 45 | 46 | def testIOU(self): 47 | iou = np_box_ops.iou(self.boxes1, self.boxes2) 48 | expected_iou = np.array([[2.0 / 16.0, 0.0, 6.0 / 400.0], 49 | [1.0 / 16.0, 0.0, 5.0 / 400.0]], 50 | dtype=float) 51 | self.assertAllClose(iou, expected_iou) 52 | 53 | def testIOA(self): 54 | boxes1 = np.array([[0.25, 0.25, 0.75, 0.75], 55 | [0.0, 0.0, 0.5, 0.75]], 56 | dtype=np.float32) 57 | boxes2 = np.array([[0.5, 0.25, 1.0, 1.0], 58 | [0.0, 0.0, 1.0, 1.0]], 59 | dtype=np.float32) 60 | ioa21 = np_box_ops.ioa(boxes2, boxes1) 61 | expected_ioa21 = np.array([[0.5, 0.0], 62 | [1.0, 1.0]], 63 | dtype=np.float32) 64 | self.assertAllClose(ioa21, expected_ioa21) 65 | 66 | 67 | if __name__ == '__main__': 68 | tf.test.main() 69 | -------------------------------------------------------------------------------- /object_detection/utils/static_shape.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions to access TensorShape values. 17 | 18 | The rank 4 tensor_shape must be of the form [batch_size, height, width, depth]. 19 | """ 20 | 21 | 22 | def get_batch_size(tensor_shape): 23 | """Returns batch size from the tensor shape. 24 | 25 | Args: 26 | tensor_shape: A rank 4 TensorShape. 27 | 28 | Returns: 29 | An integer representing the batch size of the tensor. 30 | """ 31 | tensor_shape.assert_has_rank(rank=4) 32 | return tensor_shape[0].value 33 | 34 | 35 | def get_height(tensor_shape): 36 | """Returns height from the tensor shape. 37 | 38 | Args: 39 | tensor_shape: A rank 4 TensorShape. 40 | 41 | Returns: 42 | An integer representing the height of the tensor. 43 | """ 44 | tensor_shape.assert_has_rank(rank=4) 45 | return tensor_shape[1].value 46 | 47 | 48 | def get_width(tensor_shape): 49 | """Returns width from the tensor shape. 50 | 51 | Args: 52 | tensor_shape: A rank 4 TensorShape. 53 | 54 | Returns: 55 | An integer representing the width of the tensor. 56 | """ 57 | tensor_shape.assert_has_rank(rank=4) 58 | return tensor_shape[2].value 59 | 60 | 61 | def get_depth(tensor_shape): 62 | """Returns depth from the tensor shape. 63 | 64 | Args: 65 | tensor_shape: A rank 4 TensorShape. 66 | 67 | Returns: 68 | An integer representing the depth of the tensor. 69 | """ 70 | tensor_shape.assert_has_rank(rank=4) 71 | return tensor_shape[3].value 72 | -------------------------------------------------------------------------------- /object_detection/utils/static_shape_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.static_shape.""" 17 | 18 | import tensorflow as tf 19 | 20 | from object_detection.utils import static_shape 21 | 22 | 23 | class StaticShapeTest(tf.test.TestCase): 24 | 25 | def test_return_correct_batchSize(self): 26 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 27 | self.assertEqual(32, static_shape.get_batch_size(tensor_shape)) 28 | 29 | def test_return_correct_height(self): 30 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 31 | self.assertEqual(299, static_shape.get_height(tensor_shape)) 32 | 33 | def test_return_correct_width(self): 34 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 35 | self.assertEqual(384, static_shape.get_width(tensor_shape)) 36 | 37 | def test_return_correct_depth(self): 38 | tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3]) 39 | self.assertEqual(3, static_shape.get_depth(tensor_shape)) 40 | 41 | def test_die_on_tensor_shape_with_rank_three(self): 42 | tensor_shape = tf.TensorShape(dims=[32, 299, 384]) 43 | with self.assertRaises(ValueError): 44 | static_shape.get_batch_size(tensor_shape) 45 | static_shape.get_height(tensor_shape) 46 | static_shape.get_width(tensor_shape) 47 | static_shape.get_depth(tensor_shape) 48 | 49 | if __name__ == '__main__': 50 | tf.test.main() 51 | -------------------------------------------------------------------------------- /object_detection/utils/test_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for object_detection.utils.test_utils.""" 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | 21 | from object_detection.utils import test_utils 22 | 23 | 24 | class TestUtilsTest(tf.test.TestCase): 25 | 26 | def test_diagonal_gradient_image(self): 27 | """Tests if a good pyramid image is created.""" 28 | pyramid_image = test_utils.create_diagonal_gradient_image(3, 4, 2) 29 | 30 | # Test which is easy to understand. 31 | expected_first_channel = np.array([[3, 2, 1, 0], 32 | [4, 3, 2, 1], 33 | [5, 4, 3, 2]], dtype=np.float32) 34 | self.assertAllEqual(np.squeeze(pyramid_image[:, :, 0]), 35 | expected_first_channel) 36 | 37 | # Actual test. 38 | expected_image = np.array([[[3, 30], 39 | [2, 20], 40 | [1, 10], 41 | [0, 0]], 42 | [[4, 40], 43 | [3, 30], 44 | [2, 20], 45 | [1, 10]], 46 | [[5, 50], 47 | [4, 40], 48 | [3, 30], 49 | [2, 20]]], dtype=np.float32) 50 | 51 | self.assertAllEqual(pyramid_image, expected_image) 52 | 53 | def test_random_boxes(self): 54 | """Tests if valid random boxes are created.""" 55 | num_boxes = 1000 56 | max_height = 3 57 | max_width = 5 58 | boxes = test_utils.create_random_boxes(num_boxes, 59 | max_height, 60 | max_width) 61 | 62 | true_column = np.ones(shape=(num_boxes)) == 1 63 | self.assertAllEqual(boxes[:, 0] < boxes[:, 2], true_column) 64 | self.assertAllEqual(boxes[:, 1] < boxes[:, 3], true_column) 65 | 66 | self.assertTrue(boxes[:, 0].min() >= 0) 67 | self.assertTrue(boxes[:, 1].min() >= 0) 68 | self.assertTrue(boxes[:, 2].max() <= max_height) 69 | self.assertTrue(boxes[:, 3].max() <= max_width) 70 | 71 | 72 | if __name__ == '__main__': 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter>=1.0.0 2 | ipython>=5.0.0 3 | numpy>=1.11.1 4 | #scipy>=0.17.1 5 | matplotlib>=2.2.2 6 | easydict>=1.6 7 | pillow>=3.2.0 8 | scikit-learn>=0.17 9 | scikit-image>=0.11.3 10 | pandas>=0.17.1 11 | colorlog>=2.6.0 12 | colored>=1.2.2 13 | keras>=0.3.1 14 | h5py>=2.5.0 15 | Cython>=0.25.0 16 | nltk>=3.2 17 | tqdm>=4.8.2 18 | lxml>=3.8.0 19 | protobuf-to-dict>=0.1.0 20 | joblib>=0.11 21 | bleach==1.5.0 22 | 23 | # TensorFlow 1.7 linux cpu 24 | # https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.7.0-cp27-none-linux_x86_64.whl 25 | # TensorFlow 1.7 linux gpu 26 | https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.7.0-cp27-none-linux_x86_64.whl 27 | # TensorFlow 1.7 win cpu 28 | # https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.4.0-cp27-none-win_amd64.whl 29 | 30 | # Dev 31 | nose2>=0.6.5 32 | green>=2.5.1 33 | ipdb>=0.10.1 34 | pudb>=2016.2 35 | 36 | # vim: set ft=config: 37 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for object_detection.""" 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | 6 | 7 | REQUIRED_PACKAGES = ['Pillow>=1.0'] 8 | 9 | setup( 10 | name='object_detection', 11 | version='0.1', 12 | install_requires=REQUIRED_PACKAGES, 13 | include_package_data=True, 14 | packages=[p for p in find_packages() if p.startswith('object_detection')], 15 | description='Tensorflow Object Detection Library', 16 | ) 17 | -------------------------------------------------------------------------------- /slim/WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/slim/WORKSPACE -------------------------------------------------------------------------------- /slim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/slim/__init__.py -------------------------------------------------------------------------------- /slim/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/datasets/download_and_convert_cifar10.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [32 x 32 x 3] color image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if not reader: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /slim/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | 26 | datasets_map = { 27 | 'cifar10': cifar10, 28 | 'flowers': flowers, 29 | 'imagenet': imagenet, 30 | 'mnist': mnist, 31 | } 32 | 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | 45 | Returns: 46 | A `Dataset` class. 47 | 48 | Raises: 49 | ValueError: If the dataset `name` is unknown. 50 | """ 51 | if name not in datasets_map: 52 | raise ValueError('Name of dataset unknown %s' % name) 53 | return datasets_map[name].get_split( 54 | split_name, 55 | dataset_dir, 56 | file_pattern, 57 | reader) 58 | -------------------------------------------------------------------------------- /slim/datasets/flowers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the flowers dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/datasets/download_and_convert_flowers.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'flowers_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 3320, 'validation': 350} 35 | 36 | _NUM_CLASSES = 5 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading flowers. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /slim/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the MNIST dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/datasets/download_and_convert_mnist.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'mnist_%s.tfrecord' 33 | 34 | _SPLITS_TO_SIZES = {'train': 60000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [28 x 28 x 1] grayscale image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading MNIST. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in _SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[28, 28, 1], channels=1), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=_SPLITS_TO_SIZES[split_name], 96 | num_classes=_NUM_CLASSES, 97 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /slim/datasets/preprocess_imagenet_validation_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Process the ImageNet Challenge bounding boxes for TensorFlow model training. 17 | 18 | Associate the ImageNet 2012 Challenge validation data set with labels. 19 | 20 | The raw ImageNet validation data set is expected to reside in JPEG files 21 | located in the following directory structure. 22 | 23 | data_dir/ILSVRC2012_val_00000001.JPEG 24 | data_dir/ILSVRC2012_val_00000002.JPEG 25 | ... 26 | data_dir/ILSVRC2012_val_00050000.JPEG 27 | 28 | This script moves the files into a directory structure like such: 29 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 30 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 31 | ... 32 | where 'n01440764' is the unique synset label associated with 33 | these images. 34 | 35 | This directory reorganization requires a mapping from validation image 36 | number (i.e. suffix of the original file) to the associated label. This 37 | is provided in the ImageNet development kit via a Matlab file. 38 | 39 | In order to make life easier and divorce ourselves from Matlab, we instead 40 | supply a custom text file that provides this mapping for us. 41 | 42 | Sample usage: 43 | ./preprocess_imagenet_validation_data.py ILSVRC2012_img_val \ 44 | imagenet_2012_validation_synset_labels.txt 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | 51 | import os 52 | import os.path 53 | import sys 54 | 55 | 56 | if __name__ == '__main__': 57 | if len(sys.argv) < 3: 58 | print('Invalid usage\n' 59 | 'usage: preprocess_imagenet_validation_data.py ' 60 | ' ') 61 | sys.exit(-1) 62 | data_dir = sys.argv[1] 63 | validation_labels_file = sys.argv[2] 64 | 65 | # Read in the 50000 synsets associated with the validation data set. 66 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 67 | unique_labels = set(labels) 68 | 69 | # Make all sub-directories in the validation data dir. 70 | for label in unique_labels: 71 | labeled_data_dir = os.path.join(data_dir, label) 72 | os.makedirs(labeled_data_dir) 73 | 74 | # Move all of the image to the appropriate sub-directory. 75 | for i in xrange(len(labels)): 76 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 77 | original_filename = os.path.join(data_dir, basename) 78 | if not os.path.exists(original_filename): 79 | print('Failed to find: ' % original_filename) 80 | sys.exit(-1) 81 | new_filename = os.path.join(data_dir, labels[i], basename) 82 | os.rename(original_filename, new_filename) 83 | -------------------------------------------------------------------------------- /slim/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/download_and_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import download_and_convert_cifar10 40 | from datasets import download_and_convert_flowers 41 | from datasets import download_and_convert_mnist 42 | 43 | FLAGS = tf.app.flags.FLAGS 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_name', 47 | None, 48 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 49 | 50 | tf.app.flags.DEFINE_string( 51 | 'dataset_dir', 52 | None, 53 | 'The directory where the output TFRecords and temporary files are saved.') 54 | 55 | 56 | def main(_): 57 | if not FLAGS.dataset_name: 58 | raise ValueError('You must supply the dataset name with --dataset_name') 59 | if not FLAGS.dataset_dir: 60 | raise ValueError('You must supply the dataset directory with --dataset_dir') 61 | 62 | if FLAGS.dataset_name == 'cifar10': 63 | download_and_convert_cifar10.run(FLAGS.dataset_dir) 64 | elif FLAGS.dataset_name == 'flowers': 65 | download_and_convert_flowers.run(FLAGS.dataset_dir) 66 | elif FLAGS.dataset_name == 'mnist': 67 | download_and_convert_mnist.run(FLAGS.dataset_dir) 68 | else: 69 | raise ValueError( 70 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) 71 | 72 | if __name__ == '__main__': 73 | tf.app.run() 74 | -------------------------------------------------------------------------------- /slim/export_inference_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for export_inference_graph.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | 25 | import tensorflow as tf 26 | 27 | from tensorflow.python.platform import gfile 28 | import export_inference_graph 29 | 30 | 31 | class ExportInferenceGraphTest(tf.test.TestCase): 32 | 33 | def testExportInferenceGraph(self): 34 | tmpdir = self.get_temp_dir() 35 | output_file = os.path.join(tmpdir, 'inception_v3.pb') 36 | flags = tf.app.flags.FLAGS 37 | flags.output_file = output_file 38 | flags.model_name = 'inception_v3' 39 | flags.dataset_dir = tmpdir 40 | export_inference_graph.main(None) 41 | self.assertTrue(gfile.Exists(output_file)) 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /slim/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /slim/nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001): 36 | """Defines the default arg scope for inception models. 37 | 38 | Args: 39 | weight_decay: The weight decay to use for regularizing the model. 40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 41 | batch_norm_decay: Decay for batch norm moving average. 42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 43 | in batch norm. 44 | 45 | Returns: 46 | An `arg_scope` to use for the inception models. 47 | """ 48 | batch_norm_params = { 49 | # Decay for the moving averages. 50 | 'decay': batch_norm_decay, 51 | # epsilon to prevent 0s in variance. 52 | 'epsilon': batch_norm_epsilon, 53 | # collection containing update_ops. 54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 55 | } 56 | if use_batch_norm: 57 | normalizer_fn = slim.batch_norm 58 | normalizer_params = batch_norm_params 59 | else: 60 | normalizer_fn = None 61 | normalizer_params = {} 62 | # Set weight_decay for weights in Conv and FC layers. 63 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 64 | weights_regularizer=slim.l2_regularizer(weight_decay)): 65 | with slim.arg_scope( 66 | [slim.conv2d], 67 | weights_initializer=slim.variance_scaling_initializer(), 68 | activation_fn=tf.nn.relu, 69 | normalizer_fn=normalizer_fn, 70 | normalizer_params=normalizer_params) as sc: 71 | return sc 72 | -------------------------------------------------------------------------------- /slim/nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wonheeML/mtl-ssl/d3e9d5b60cb274eff0890aae8b4528f2cb82e20d/slim/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /slim/nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map.keys()[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in nets_factory.networks_map.keys()[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /slim/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /slim/preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /slim/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'resnet_v1_50': vgg_preprocessing, 58 | 'resnet_v1_101': vgg_preprocessing, 59 | 'resnet_v1_152': vgg_preprocessing, 60 | 'resnet_v1_200': vgg_preprocessing, 61 | 'resnet_v2_50': vgg_preprocessing, 62 | 'resnet_v2_101': vgg_preprocessing, 63 | 'resnet_v2_152': vgg_preprocessing, 64 | 'resnet_v2_200': vgg_preprocessing, 65 | 'vgg': vgg_preprocessing, 66 | 'vgg_a': vgg_preprocessing, 67 | 'vgg_16': vgg_preprocessing, 68 | 'vgg_19': vgg_preprocessing, 69 | } 70 | 71 | if name not in preprocessing_fn_map: 72 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 73 | 74 | def preprocessing_fn(image, output_height, output_width, **kwargs): 75 | return preprocessing_fn_map[name].preprocess_image( 76 | image, output_height, output_width, is_training=is_training, **kwargs) 77 | 78 | return preprocessing_fn 79 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_resnet_v2_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Flowers dataset 5 | # 2. Fine-tunes an Inception Resnet V2 model on the Flowers training set. 6 | # 3. Evaluates the model on the Flowers validation set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/finetune_inception_resnet_v2_on_flowers.sh 11 | set -e 12 | 13 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 14 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 15 | 16 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 17 | MODEL_NAME=inception_resnet_v2 18 | 19 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 20 | TRAIN_DIR=/tmp/flowers-models/${MODEL_NAME} 21 | 22 | # Where the dataset is saved to. 23 | DATASET_DIR=/tmp/flowers 24 | 25 | # Download the pre-trained checkpoint. 26 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 27 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 28 | fi 29 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt ]; then 30 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 31 | tar -xvf inception_resnet_v2_2016_08_30.tar.gz 32 | mv inception_resnet_v2.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt 33 | rm inception_resnet_v2_2016_08_30.tar.gz 34 | fi 35 | 36 | # Download the dataset 37 | python download_and_convert_data.py \ 38 | --dataset_name=flowers \ 39 | --dataset_dir=${DATASET_DIR} 40 | 41 | # Fine-tune only the new layers for 1000 steps. 42 | python train_image_classifier.py \ 43 | --train_dir=${TRAIN_DIR} \ 44 | --dataset_name=flowers \ 45 | --dataset_split_name=train \ 46 | --dataset_dir=${DATASET_DIR} \ 47 | --model_name=${MODEL_NAME} \ 48 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \ 49 | --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 50 | --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 51 | --max_number_of_steps=1000 \ 52 | --batch_size=32 \ 53 | --learning_rate=0.01 \ 54 | --learning_rate_decay_type=fixed \ 55 | --save_interval_secs=60 \ 56 | --save_summaries_secs=60 \ 57 | --log_every_n_steps=10 \ 58 | --optimizer=rmsprop \ 59 | --weight_decay=0.00004 60 | 61 | # Run evaluation. 62 | python eval_image_classifier.py \ 63 | --checkpoint_path=${TRAIN_DIR} \ 64 | --eval_dir=${TRAIN_DIR} \ 65 | --dataset_name=flowers \ 66 | --dataset_split_name=validation \ 67 | --dataset_dir=${DATASET_DIR} \ 68 | --model_name=${MODEL_NAME} 69 | 70 | # Fine-tune all the new layers for 500 steps. 71 | python train_image_classifier.py \ 72 | --train_dir=${TRAIN_DIR}/all \ 73 | --dataset_name=flowers \ 74 | --dataset_split_name=train \ 75 | --dataset_dir=${DATASET_DIR} \ 76 | --model_name=${MODEL_NAME} \ 77 | --checkpoint_path=${TRAIN_DIR} \ 78 | --max_number_of_steps=500 \ 79 | --batch_size=32 \ 80 | --learning_rate=0.0001 \ 81 | --learning_rate_decay_type=fixed \ 82 | --save_interval_secs=60 \ 83 | --save_summaries_secs=60 \ 84 | --log_every_n_steps=10 \ 85 | --optimizer=rmsprop \ 86 | --weight_decay=0.00004 87 | 88 | # Run evaluation. 89 | python eval_image_classifier.py \ 90 | --checkpoint_path=${TRAIN_DIR}/all \ 91 | --eval_dir=${TRAIN_DIR}/all \ 92 | --dataset_name=flowers \ 93 | --dataset_split_name=validation \ 94 | --dataset_dir=${DATASET_DIR} \ 95 | --model_name=${MODEL_NAME} 96 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_v1_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Flowers dataset 5 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 6 | # 3. Evaluates the model on the Flowers validation set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 11 | set -e 12 | 13 | # Where the pre-trained InceptionV1 checkpoint is saved to. 14 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 15 | 16 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 17 | TRAIN_DIR=/tmp/flowers-models/inception_v1 18 | 19 | # Where the dataset is saved to. 20 | DATASET_DIR=/tmp/flowers 21 | 22 | # Download the pre-trained checkpoint. 23 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 24 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 25 | fi 26 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt ]; then 27 | wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 28 | tar -xvf inception_v1_2016_08_28.tar.gz 29 | mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 30 | rm inception_v1_2016_08_28.tar.gz 31 | fi 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=flowers \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Fine-tune only the new layers for 2000 steps. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=flowers \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=inception_v1 \ 45 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt \ 46 | --checkpoint_exclude_scopes=InceptionV1/Logits \ 47 | --trainable_scopes=InceptionV1/Logits \ 48 | --max_number_of_steps=3000 \ 49 | --batch_size=32 \ 50 | --learning_rate=0.01 \ 51 | --save_interval_secs=60 \ 52 | --save_summaries_secs=60 \ 53 | --log_every_n_steps=100 \ 54 | --optimizer=rmsprop \ 55 | --weight_decay=0.00004 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=flowers \ 62 | --dataset_split_name=validation \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=inception_v1 65 | 66 | # Fine-tune all the new layers for 1000 steps. 67 | python train_image_classifier.py \ 68 | --train_dir=${TRAIN_DIR}/all \ 69 | --dataset_name=flowers \ 70 | --dataset_split_name=train \ 71 | --dataset_dir=${DATASET_DIR} \ 72 | --checkpoint_path=${TRAIN_DIR} \ 73 | --model_name=inception_v1 \ 74 | --max_number_of_steps=1000 \ 75 | --batch_size=32 \ 76 | --learning_rate=0.001 \ 77 | --save_interval_secs=60 \ 78 | --save_summaries_secs=60 \ 79 | --log_every_n_steps=100 \ 80 | --optimizer=rmsprop \ 81 | --weight_decay=0.00004 82 | 83 | # Run evaluation. 84 | python eval_image_classifier.py \ 85 | --checkpoint_path=${TRAIN_DIR}/all \ 86 | --eval_dir=${TRAIN_DIR}/all \ 87 | --dataset_name=flowers \ 88 | --dataset_split_name=validation \ 89 | --dataset_dir=${DATASET_DIR} \ 90 | --model_name=inception_v1 91 | -------------------------------------------------------------------------------- /slim/scripts/finetune_inception_v3_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Flowers dataset 5 | # 2. Fine-tunes an InceptionV3 model on the Flowers training set. 6 | # 3. Evaluates the model on the Flowers validation set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/finetune_inception_v3_on_flowers.sh 11 | set -e 12 | 13 | # Where the pre-trained InceptionV3 checkpoint is saved to. 14 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 15 | 16 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 17 | TRAIN_DIR=/tmp/flowers-models/inception_v3 18 | 19 | # Where the dataset is saved to. 20 | DATASET_DIR=/tmp/flowers 21 | 22 | # Download the pre-trained checkpoint. 23 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 24 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 25 | fi 26 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt ]; then 27 | wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 28 | tar -xvf inception_v3_2016_08_28.tar.gz 29 | mv inception_v3.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt 30 | rm inception_v3_2016_08_28.tar.gz 31 | fi 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=flowers \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Fine-tune only the new layers for 1000 steps. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=flowers \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=inception_v3 \ 45 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \ 46 | --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 47 | --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 48 | --max_number_of_steps=1000 \ 49 | --batch_size=32 \ 50 | --learning_rate=0.01 \ 51 | --learning_rate_decay_type=fixed \ 52 | --save_interval_secs=60 \ 53 | --save_summaries_secs=60 \ 54 | --log_every_n_steps=100 \ 55 | --optimizer=rmsprop \ 56 | --weight_decay=0.00004 57 | 58 | # Run evaluation. 59 | python eval_image_classifier.py \ 60 | --checkpoint_path=${TRAIN_DIR} \ 61 | --eval_dir=${TRAIN_DIR} \ 62 | --dataset_name=flowers \ 63 | --dataset_split_name=validation \ 64 | --dataset_dir=${DATASET_DIR} \ 65 | --model_name=inception_v3 66 | 67 | # Fine-tune all the new layers for 500 steps. 68 | python train_image_classifier.py \ 69 | --train_dir=${TRAIN_DIR}/all \ 70 | --dataset_name=flowers \ 71 | --dataset_split_name=train \ 72 | --dataset_dir=${DATASET_DIR} \ 73 | --model_name=inception_v3 \ 74 | --checkpoint_path=${TRAIN_DIR} \ 75 | --max_number_of_steps=500 \ 76 | --batch_size=32 \ 77 | --learning_rate=0.0001 \ 78 | --learning_rate_decay_type=fixed \ 79 | --save_interval_secs=60 \ 80 | --save_summaries_secs=60 \ 81 | --log_every_n_steps=10 \ 82 | --optimizer=rmsprop \ 83 | --weight_decay=0.00004 84 | 85 | # Run evaluation. 86 | python eval_image_classifier.py \ 87 | --checkpoint_path=${TRAIN_DIR}/all \ 88 | --eval_dir=${TRAIN_DIR}/all \ 89 | --dataset_name=flowers \ 90 | --dataset_split_name=validation \ 91 | --dataset_dir=${DATASET_DIR} \ 92 | --model_name=inception_v3 93 | -------------------------------------------------------------------------------- /slim/scripts/finetune_resnet_v1_50_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Flowers dataset 5 | # 2. Fine-tunes a ResNetV1-50 model on the Flowers training set. 6 | # 3. Evaluates the model on the Flowers validation set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh 11 | set -e 12 | 13 | # Where the pre-trained ResNetV1-50 checkpoint is saved to. 14 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 15 | 16 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 17 | TRAIN_DIR=/tmp/flowers-models/resnet_v1_50 18 | 19 | # Where the dataset is saved to. 20 | DATASET_DIR=/tmp/flowers 21 | 22 | # Download the pre-trained checkpoint. 23 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 24 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 25 | fi 26 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then 27 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 28 | tar -xvf resnet_v1_50_2016_08_28.tar.gz 29 | mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt 30 | rm resnet_v1_50_2016_08_28.tar.gz 31 | fi 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=flowers \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Fine-tune only the new layers for 3000 steps. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=flowers \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=resnet_v1_50 \ 45 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \ 46 | --checkpoint_exclude_scopes=resnet_v1_50/logits \ 47 | --trainable_scopes=resnet_v1_50/logits \ 48 | --max_number_of_steps=3000 \ 49 | --batch_size=32 \ 50 | --learning_rate=0.01 \ 51 | --save_interval_secs=60 \ 52 | --save_summaries_secs=60 \ 53 | --log_every_n_steps=100 \ 54 | --optimizer=rmsprop \ 55 | --weight_decay=0.00004 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=flowers \ 62 | --dataset_split_name=validation \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=resnet_v1_50 65 | 66 | # Fine-tune all the new layers for 1000 steps. 67 | python train_image_classifier.py \ 68 | --train_dir=${TRAIN_DIR}/all \ 69 | --dataset_name=flowers \ 70 | --dataset_split_name=train \ 71 | --dataset_dir=${DATASET_DIR} \ 72 | --checkpoint_path=${TRAIN_DIR} \ 73 | --model_name=resnet_v1_50 \ 74 | --max_number_of_steps=1000 \ 75 | --batch_size=32 \ 76 | --learning_rate=0.001 \ 77 | --save_interval_secs=60 \ 78 | --save_summaries_secs=60 \ 79 | --log_every_n_steps=100 \ 80 | --optimizer=rmsprop \ 81 | --weight_decay=0.00004 82 | 83 | # Run evaluation. 84 | python eval_image_classifier.py \ 85 | --checkpoint_path=${TRAIN_DIR}/all \ 86 | --eval_dir=${TRAIN_DIR}/all \ 87 | --dataset_name=flowers \ 88 | --dataset_split_name=validation \ 89 | --dataset_dir=${DATASET_DIR} \ 90 | --model_name=resnet_v1_50 91 | -------------------------------------------------------------------------------- /slim/scripts/train_cifarnet_on_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the Cifar10 dataset 5 | # 2. Trains a CifarNet model on the Cifar10 training set. 6 | # 3. Evaluates the model on the Cifar10 testing set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./scripts/train_cifarnet_on_cifar10.sh 11 | set -e 12 | 13 | # Where the checkpoint and logs will be saved to. 14 | TRAIN_DIR=/tmp/cifarnet-model 15 | 16 | # Where the dataset is saved to. 17 | DATASET_DIR=/tmp/cifar10 18 | 19 | # Download the dataset 20 | python download_and_convert_data.py \ 21 | --dataset_name=cifar10 \ 22 | --dataset_dir=${DATASET_DIR} 23 | 24 | # Run training. 25 | python train_image_classifier.py \ 26 | --train_dir=${TRAIN_DIR} \ 27 | --dataset_name=cifar10 \ 28 | --dataset_split_name=train \ 29 | --dataset_dir=${DATASET_DIR} \ 30 | --model_name=cifarnet \ 31 | --preprocessing_name=cifarnet \ 32 | --max_number_of_steps=100000 \ 33 | --batch_size=128 \ 34 | --save_interval_secs=120 \ 35 | --save_summaries_secs=120 \ 36 | --log_every_n_steps=100 \ 37 | --optimizer=sgd \ 38 | --learning_rate=0.1 \ 39 | --learning_rate_decay_factor=0.1 \ 40 | --num_epochs_per_decay=200 \ 41 | --weight_decay=0.004 42 | 43 | # Run evaluation. 44 | python eval_image_classifier.py \ 45 | --checkpoint_path=${TRAIN_DIR} \ 46 | --eval_dir=${TRAIN_DIR} \ 47 | --dataset_name=cifar10 \ 48 | --dataset_split_name=test \ 49 | --dataset_dir=${DATASET_DIR} \ 50 | --model_name=cifarnet 51 | -------------------------------------------------------------------------------- /slim/scripts/train_lenet_on_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Downloads the MNIST dataset 5 | # 2. Trains a LeNet model on the MNIST training set. 6 | # 3. Evaluates the model on the MNIST testing set. 7 | # 8 | # Usage: 9 | # cd slim 10 | # ./slim/scripts/train_lenet_on_mnist.sh 11 | set -e 12 | 13 | # Where the checkpoint and logs will be saved to. 14 | TRAIN_DIR=/tmp/lenet-model 15 | 16 | # Where the dataset is saved to. 17 | DATASET_DIR=/tmp/mnist 18 | 19 | # Download the dataset 20 | python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=${DATASET_DIR} 23 | 24 | # Run training. 25 | python train_image_classifier.py \ 26 | --train_dir=${TRAIN_DIR} \ 27 | --dataset_name=mnist \ 28 | --dataset_split_name=train \ 29 | --dataset_dir=${DATASET_DIR} \ 30 | --model_name=lenet \ 31 | --preprocessing_name=lenet \ 32 | --max_number_of_steps=20000 \ 33 | --batch_size=50 \ 34 | --learning_rate=0.01 \ 35 | --save_interval_secs=60 \ 36 | --save_summaries_secs=60 \ 37 | --log_every_n_steps=100 \ 38 | --optimizer=sgd \ 39 | --learning_rate_decay_type=fixed \ 40 | --weight_decay=0 41 | 42 | # Run evaluation. 43 | python eval_image_classifier.py \ 44 | --checkpoint_path=${TRAIN_DIR} \ 45 | --eval_dir=${TRAIN_DIR} \ 46 | --dataset_name=mnist \ 47 | --dataset_split_name=test \ 48 | --dataset_dir=${DATASET_DIR} \ 49 | --model_name=lenet 50 | -------------------------------------------------------------------------------- /slim/setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for slim.""" 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | 6 | 7 | setup( 8 | name='slim', 9 | version='0.1', 10 | include_package_data=True, 11 | packages=find_packages(), 12 | description='tf-slim', 13 | ) 14 | --------------------------------------------------------------------------------