├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── cifar10.py ├── dataset_factory.py ├── dataset_utils.py ├── download_and_convert_cifar10.py ├── download_and_convert_flowers.py ├── download_and_convert_mnist.py ├── flowers.py ├── imagenet.py └── mnist.py ├── deployment ├── __init__.py └── model_deploy.py ├── download_and_convert_data.py ├── eval_image_classifier.py ├── nets ├── __init__.py ├── densenet.py └── nets_factory.py ├── preprocessing ├── __init__.py ├── densenet_preprocessing.py └── preprocessing_factory.py └── train_image_classifier.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow-DenseNet with ImageNet Pretrained Models 2 | 3 | This is an [Tensorflow](https://www.tensorflow.org/) implementation of [DenseNet](https://arxiv.org/pdf/1608.06993.pdf) by G. Huang, Z. Liu, K. Weinberger, and L. van der Maaten with [ImageNet](http://www.image-net.org/) pretrained models. The weights are converted from [DenseNet-Keras Models](https://github.com/flyyufelix/DenseNet-Keras). 4 | 5 | The code are largely borrowed from [TensorFlow-Slim Models](https://github.com/tensorflow/models/tree/master/slim). 6 | 7 | ## Pre-trained Models 8 | 9 | The top-1/5 accuracy rates by using single center crop (crop size: 224x224, image size: 256xN) 10 | 11 | Network|Top-1|Top-5|Checkpoints 12 | :---:|:---:|:---:|:---: 13 | DenseNet 121 (k=32)| 74.91| 92.19| [model](https://drive.google.com/file/d/0B_fUSpodN0t0eW1sVk1aeWREaDA/view?usp=sharing&resourcekey=0-z_rek7lZEjA4nwXzIhqDZg) 14 | DenseNet 169 (k=32)| 76.09| 93.14| [model](https://drive.google.com/file/d/0B_fUSpodN0t0TDB5Ti1PeTZMM2c/view?usp=sharing&resourcekey=0-EJzINUM7lBPKX-fOCCdsog) 15 | DenseNet 161 (k=48)| 77.64| 93.79| [model](https://drive.google.com/file/d/0B_fUSpodN0t0NmZvTnZZa2plaHc/view?usp=sharing&resourcekey=0-l8IZdDxN30rjeq4K4Mfeow) 16 | 17 | ## Usage 18 | Follow the instruction [TensorFlow-Slim Models](https://github.com/tensorflow/models/tree/master/slim). 19 | 20 | ### Step-by-step Example of training on flowers dataset. 21 | #### Downloading ans converting flowers dataset 22 | 23 | ``` 24 | $ DATA_DIR=/tmp/data/flowers 25 | $ python download_and_convert_data.py \ 26 | --dataset_name=flowers \ 27 | --dataset_dir="${DATA_DIR}" 28 | ``` 29 | 30 | #### Training a model from scratch. 31 | 32 | ``` 33 | $ DATASET_DIR=/tmp/data/flowers 34 | $ TRAIN_DIR=/tmp/train_logs 35 | $ python train_image_classifier.py \ 36 | --train_dir=${TRAIN_DIR} \ 37 | --dataset_name=flowers \ 38 | --dataset_split_name=train \ 39 | --dataset_dir=${DATASET_DIR} \ 40 | --model_name=densenet121 41 | ``` 42 | 43 | #### Fine-tuning a model from an existing checkpoint 44 | 45 | ``` 46 | $ DATASET_DIR=/tmp/data/flowers 47 | $ TRAIN_DIR=/tmp/train_logs 48 | $ CHECKPOINT_PATH=/tmp/my_checkpoints/tf-densenet121.ckpt 49 | $ python train_image_classifier.py \ 50 | --train_dir=${TRAIN_DIR} \ 51 | --dataset_name=flowers \ 52 | --dataset_split_name=train \ 53 | --dataset_dir=${DATASET_DIR} \ 54 | --model_name=densenet121 \ 55 | --checkpoint_path=${CHECKPOINT_PATH} \ 56 | --checkpoint_exclude_scopes=global_step,densenet121/logits \ 57 | --trainable_scopes=densenet121/logits 58 | ``` 59 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/datasets/download_and_convert_cifar10.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [32 x 32 x 3] color image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if not reader: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | 26 | datasets_map = { 27 | 'cifar10': cifar10, 28 | 'flowers': flowers, 29 | 'imagenet': imagenet, 30 | 'mnist': mnist, 31 | } 32 | 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | 45 | Returns: 46 | A `Dataset` class. 47 | 48 | Raises: 49 | ValueError: If the dataset `name` is unknown. 50 | """ 51 | if name not in datasets_map: 52 | raise ValueError('Name of dataset unknown %s' % name) 53 | return datasets_map[name].get_split( 54 | split_name, 55 | dataset_dir, 56 | file_pattern, 57 | reader) 58 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(values): 31 | """Returns a TF-Feature of int64s. 32 | 33 | Args: 34 | values: A scalar or list of values. 35 | 36 | Returns: 37 | a TF-Feature. 38 | """ 39 | if not isinstance(values, (tuple, list)): 40 | values = [values] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 42 | 43 | 44 | def bytes_feature(values): 45 | """Returns a TF-Feature of bytes. 46 | 47 | Args: 48 | values: A string. 49 | 50 | Returns: 51 | a TF-Feature. 52 | """ 53 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 54 | 55 | 56 | def image_to_tfexample(image_data, image_format, height, width, class_id): 57 | return tf.train.Example(features=tf.train.Features(feature={ 58 | 'image/encoded': bytes_feature(image_data), 59 | 'image/format': bytes_feature(image_format), 60 | 'image/class/label': int64_feature(class_id), 61 | 'image/height': int64_feature(height), 62 | 'image/width': int64_feature(width), 63 | })) 64 | 65 | 66 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 67 | """Downloads the `tarball_url` and uncompresses it locally. 68 | 69 | Args: 70 | tarball_url: The URL of a tarball file. 71 | dataset_dir: The directory where the temporary files are stored. 72 | """ 73 | filename = tarball_url.split('/')[-1] 74 | filepath = os.path.join(dataset_dir, filename) 75 | 76 | def _progress(count, block_size, total_size): 77 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 78 | filename, float(count * block_size) / float(total_size) * 100.0)) 79 | sys.stdout.flush() 80 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 81 | print() 82 | statinfo = os.stat(filepath) 83 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 84 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 85 | 86 | 87 | def write_label_file(labels_to_class_names, dataset_dir, 88 | filename=LABELS_FILENAME): 89 | """Writes a file with the list of class names. 90 | 91 | Args: 92 | labels_to_class_names: A map of (integer) labels to class names. 93 | dataset_dir: The directory in which the labels file should be written. 94 | filename: The filename where the class names are written. 95 | """ 96 | labels_filename = os.path.join(dataset_dir, filename) 97 | with tf.gfile.Open(labels_filename, 'w') as f: 98 | for label in labels_to_class_names: 99 | class_name = labels_to_class_names[label] 100 | f.write('%d:%s\n' % (label, class_name)) 101 | 102 | 103 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 104 | """Specifies whether or not the dataset directory contains a label map file. 105 | 106 | Args: 107 | dataset_dir: The directory in which the labels file is found. 108 | filename: The filename where the class names are written. 109 | 110 | Returns: 111 | `True` if the labels file exists and `False` otherwise. 112 | """ 113 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 114 | 115 | 116 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 117 | """Reads the labels file and returns a mapping from ID to class name. 118 | 119 | Args: 120 | dataset_dir: The directory in which the labels file is found. 121 | filename: The filename where the class names are written. 122 | 123 | Returns: 124 | A map from a label (integer) to class name. 125 | """ 126 | labels_filename = os.path.join(dataset_dir, filename) 127 | with tf.gfile.Open(labels_filename, 'rb') as f: 128 | lines = f.read().decode() 129 | lines = lines.split('\n') 130 | lines = filter(None, lines) 131 | 132 | labels_to_class_names = {} 133 | for line in lines: 134 | index = line.index(':') 135 | labels_to_class_names[int(line[:index])] = line[index+1:] 136 | return labels_to_class_names 137 | -------------------------------------------------------------------------------- /datasets/download_and_convert_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the cifar10 data, uncompresses it, reads the files 18 | that make up the cifar10 data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take several minutes to run. 23 | 24 | """ 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | from six.moves import cPickle 30 | import os 31 | import sys 32 | import tarfile 33 | 34 | import numpy as np 35 | from six.moves import urllib 36 | import tensorflow as tf 37 | 38 | from datasets import dataset_utils 39 | 40 | # The URL where the CIFAR data can be downloaded. 41 | _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 42 | 43 | # The number of training files. 44 | _NUM_TRAIN_FILES = 5 45 | 46 | # The height and width of each image. 47 | _IMAGE_SIZE = 32 48 | 49 | # The names of the classes. 50 | _CLASS_NAMES = [ 51 | 'airplane', 52 | 'automobile', 53 | 'bird', 54 | 'cat', 55 | 'deer', 56 | 'dog', 57 | 'frog', 58 | 'horse', 59 | 'ship', 60 | 'truck', 61 | ] 62 | 63 | 64 | def _add_to_tfrecord(filename, tfrecord_writer, offset=0): 65 | """Loads data from the cifar10 pickle files and writes files to a TFRecord. 66 | 67 | Args: 68 | filename: The filename of the cifar10 pickle file. 69 | tfrecord_writer: The TFRecord writer to use for writing. 70 | offset: An offset into the absolute number of images previously written. 71 | 72 | Returns: 73 | The new offset. 74 | """ 75 | with tf.gfile.Open(filename, 'rb') as f: 76 | if sys.version_info < (3,): 77 | data = cPickle.load(f) 78 | else: 79 | data = cPickle.load(f, encoding='bytes') 80 | 81 | images = data[b'data'] 82 | num_images = images.shape[0] 83 | 84 | images = images.reshape((num_images, 3, 32, 32)) 85 | labels = data[b'labels'] 86 | 87 | with tf.Graph().as_default(): 88 | image_placeholder = tf.placeholder(dtype=tf.uint8) 89 | encoded_image = tf.image.encode_png(image_placeholder) 90 | 91 | with tf.Session('') as sess: 92 | 93 | for j in range(num_images): 94 | sys.stdout.write('\r>> Reading file [%s] image %d/%d' % ( 95 | filename, offset + j + 1, offset + num_images)) 96 | sys.stdout.flush() 97 | 98 | image = np.squeeze(images[j]).transpose((1, 2, 0)) 99 | label = labels[j] 100 | 101 | png_string = sess.run(encoded_image, 102 | feed_dict={image_placeholder: image}) 103 | 104 | example = dataset_utils.image_to_tfexample( 105 | png_string, b'png', _IMAGE_SIZE, _IMAGE_SIZE, label) 106 | tfrecord_writer.write(example.SerializeToString()) 107 | 108 | return offset + num_images 109 | 110 | 111 | def _get_output_filename(dataset_dir, split_name): 112 | """Creates the output filename. 113 | 114 | Args: 115 | dataset_dir: The dataset directory where the dataset is stored. 116 | split_name: The name of the train/test split. 117 | 118 | Returns: 119 | An absolute file path. 120 | """ 121 | return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name) 122 | 123 | 124 | def _download_and_uncompress_dataset(dataset_dir): 125 | """Downloads cifar10 and uncompresses it locally. 126 | 127 | Args: 128 | dataset_dir: The directory where the temporary files are stored. 129 | """ 130 | filename = _DATA_URL.split('/')[-1] 131 | filepath = os.path.join(dataset_dir, filename) 132 | 133 | if not os.path.exists(filepath): 134 | def _progress(count, block_size, total_size): 135 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 136 | filename, float(count * block_size) / float(total_size) * 100.0)) 137 | sys.stdout.flush() 138 | filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress) 139 | print() 140 | statinfo = os.stat(filepath) 141 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 142 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 143 | 144 | 145 | def _clean_up_temporary_files(dataset_dir): 146 | """Removes temporary files used to create the dataset. 147 | 148 | Args: 149 | dataset_dir: The directory where the temporary files are stored. 150 | """ 151 | filename = _DATA_URL.split('/')[-1] 152 | filepath = os.path.join(dataset_dir, filename) 153 | tf.gfile.Remove(filepath) 154 | 155 | tmp_dir = os.path.join(dataset_dir, 'cifar-10-batches-py') 156 | tf.gfile.DeleteRecursively(tmp_dir) 157 | 158 | 159 | def run(dataset_dir): 160 | """Runs the download and conversion operation. 161 | 162 | Args: 163 | dataset_dir: The dataset directory where the dataset is stored. 164 | """ 165 | if not tf.gfile.Exists(dataset_dir): 166 | tf.gfile.MakeDirs(dataset_dir) 167 | 168 | training_filename = _get_output_filename(dataset_dir, 'train') 169 | testing_filename = _get_output_filename(dataset_dir, 'test') 170 | 171 | if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename): 172 | print('Dataset files already exist. Exiting without re-creating them.') 173 | return 174 | 175 | dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 176 | 177 | # First, process the training data: 178 | with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer: 179 | offset = 0 180 | for i in range(_NUM_TRAIN_FILES): 181 | filename = os.path.join(dataset_dir, 182 | 'cifar-10-batches-py', 183 | 'data_batch_%d' % (i + 1)) # 1-indexed. 184 | offset = _add_to_tfrecord(filename, tfrecord_writer, offset) 185 | 186 | # Next, process the testing data: 187 | with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer: 188 | filename = os.path.join(dataset_dir, 189 | 'cifar-10-batches-py', 190 | 'test_batch') 191 | _add_to_tfrecord(filename, tfrecord_writer) 192 | 193 | # Finally, write the labels file: 194 | labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) 195 | dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 196 | 197 | _clean_up_temporary_files(dataset_dir) 198 | print('\nFinished converting the Cifar10 dataset!') 199 | -------------------------------------------------------------------------------- /datasets/download_and_convert_flowers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts Flowers data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the Flowers data, uncompresses it, reads the files 18 | that make up the Flowers data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take about a minute to run. 23 | 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import math 31 | import os 32 | import random 33 | import sys 34 | 35 | import tensorflow as tf 36 | 37 | from datasets import dataset_utils 38 | 39 | # The URL where the Flowers data can be downloaded. 40 | _DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz' 41 | 42 | # The number of images in the validation set. 43 | _NUM_VALIDATION = 350 44 | 45 | # Seed for repeatability. 46 | _RANDOM_SEED = 0 47 | 48 | # The number of shards per dataset split. 49 | _NUM_SHARDS = 5 50 | 51 | 52 | class ImageReader(object): 53 | """Helper class that provides TensorFlow image coding utilities.""" 54 | 55 | def __init__(self): 56 | # Initializes function that decodes RGB JPEG data. 57 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 58 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 59 | 60 | def read_image_dims(self, sess, image_data): 61 | image = self.decode_jpeg(sess, image_data) 62 | return image.shape[0], image.shape[1] 63 | 64 | def decode_jpeg(self, sess, image_data): 65 | image = sess.run(self._decode_jpeg, 66 | feed_dict={self._decode_jpeg_data: image_data}) 67 | assert len(image.shape) == 3 68 | assert image.shape[2] == 3 69 | return image 70 | 71 | 72 | def _get_filenames_and_classes(dataset_dir): 73 | """Returns a list of filenames and inferred class names. 74 | 75 | Args: 76 | dataset_dir: A directory containing a set of subdirectories representing 77 | class names. Each subdirectory should contain PNG or JPG encoded images. 78 | 79 | Returns: 80 | A list of image file paths, relative to `dataset_dir` and the list of 81 | subdirectories, representing class names. 82 | """ 83 | flower_root = os.path.join(dataset_dir, 'flower_photos') 84 | directories = [] 85 | class_names = [] 86 | for filename in os.listdir(flower_root): 87 | path = os.path.join(flower_root, filename) 88 | if os.path.isdir(path): 89 | directories.append(path) 90 | class_names.append(filename) 91 | 92 | photo_filenames = [] 93 | for directory in directories: 94 | for filename in os.listdir(directory): 95 | path = os.path.join(directory, filename) 96 | photo_filenames.append(path) 97 | 98 | return photo_filenames, sorted(class_names) 99 | 100 | 101 | def _get_dataset_filename(dataset_dir, split_name, shard_id): 102 | output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % ( 103 | split_name, shard_id, _NUM_SHARDS) 104 | return os.path.join(dataset_dir, output_filename) 105 | 106 | 107 | def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): 108 | """Converts the given filenames to a TFRecord dataset. 109 | 110 | Args: 111 | split_name: The name of the dataset, either 'train' or 'validation'. 112 | filenames: A list of absolute paths to png or jpg images. 113 | class_names_to_ids: A dictionary from class names (strings) to ids 114 | (integers). 115 | dataset_dir: The directory where the converted datasets are stored. 116 | """ 117 | assert split_name in ['train', 'validation'] 118 | 119 | num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) 120 | 121 | with tf.Graph().as_default(): 122 | image_reader = ImageReader() 123 | 124 | with tf.Session('') as sess: 125 | 126 | for shard_id in range(_NUM_SHARDS): 127 | output_filename = _get_dataset_filename( 128 | dataset_dir, split_name, shard_id) 129 | 130 | with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: 131 | start_ndx = shard_id * num_per_shard 132 | end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) 133 | for i in range(start_ndx, end_ndx): 134 | sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( 135 | i+1, len(filenames), shard_id)) 136 | sys.stdout.flush() 137 | 138 | # Read the filename: 139 | image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() 140 | height, width = image_reader.read_image_dims(sess, image_data) 141 | 142 | class_name = os.path.basename(os.path.dirname(filenames[i])) 143 | class_id = class_names_to_ids[class_name] 144 | 145 | example = dataset_utils.image_to_tfexample( 146 | image_data, b'jpg', height, width, class_id) 147 | tfrecord_writer.write(example.SerializeToString()) 148 | 149 | sys.stdout.write('\n') 150 | sys.stdout.flush() 151 | 152 | 153 | def _clean_up_temporary_files(dataset_dir): 154 | """Removes temporary files used to create the dataset. 155 | 156 | Args: 157 | dataset_dir: The directory where the temporary files are stored. 158 | """ 159 | filename = _DATA_URL.split('/')[-1] 160 | filepath = os.path.join(dataset_dir, filename) 161 | tf.gfile.Remove(filepath) 162 | 163 | tmp_dir = os.path.join(dataset_dir, 'flower_photos') 164 | tf.gfile.DeleteRecursively(tmp_dir) 165 | 166 | 167 | def _dataset_exists(dataset_dir): 168 | for split_name in ['train', 'validation']: 169 | for shard_id in range(_NUM_SHARDS): 170 | output_filename = _get_dataset_filename( 171 | dataset_dir, split_name, shard_id) 172 | if not tf.gfile.Exists(output_filename): 173 | return False 174 | return True 175 | 176 | 177 | def run(dataset_dir): 178 | """Runs the download and conversion operation. 179 | 180 | Args: 181 | dataset_dir: The dataset directory where the dataset is stored. 182 | """ 183 | if not tf.gfile.Exists(dataset_dir): 184 | tf.gfile.MakeDirs(dataset_dir) 185 | 186 | if _dataset_exists(dataset_dir): 187 | print('Dataset files already exist. Exiting without re-creating them.') 188 | return 189 | 190 | dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 191 | photo_filenames, class_names = _get_filenames_and_classes(dataset_dir) 192 | class_names_to_ids = dict(zip(class_names, range(len(class_names)))) 193 | 194 | # Divide into train and test: 195 | random.seed(_RANDOM_SEED) 196 | random.shuffle(photo_filenames) 197 | training_filenames = photo_filenames[_NUM_VALIDATION:] 198 | validation_filenames = photo_filenames[:_NUM_VALIDATION] 199 | 200 | # First, convert the training and validation sets. 201 | _convert_dataset('train', training_filenames, class_names_to_ids, 202 | dataset_dir) 203 | _convert_dataset('validation', validation_filenames, class_names_to_ids, 204 | dataset_dir) 205 | 206 | # Finally, write the labels file: 207 | labels_to_class_names = dict(zip(range(len(class_names)), class_names)) 208 | dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 209 | 210 | _clean_up_temporary_files(dataset_dir) 211 | print('\nFinished converting the Flowers dataset!') 212 | 213 | -------------------------------------------------------------------------------- /datasets/download_and_convert_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts MNIST data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the MNIST data, uncompresses it, reads the files 18 | that make up the MNIST data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take about a minute to run. 23 | 24 | """ 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import gzip 30 | import os 31 | import sys 32 | 33 | import numpy as np 34 | from six.moves import urllib 35 | import tensorflow as tf 36 | 37 | from datasets import dataset_utils 38 | 39 | # The URLs where the MNIST data can be downloaded. 40 | _DATA_URL = 'http://yann.lecun.com/exdb/mnist/' 41 | _TRAIN_DATA_FILENAME = 'train-images-idx3-ubyte.gz' 42 | _TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz' 43 | _TEST_DATA_FILENAME = 't10k-images-idx3-ubyte.gz' 44 | _TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz' 45 | 46 | _IMAGE_SIZE = 28 47 | _NUM_CHANNELS = 1 48 | 49 | # The names of the classes. 50 | _CLASS_NAMES = [ 51 | 'zero', 52 | 'one', 53 | 'two', 54 | 'three', 55 | 'four', 56 | 'five', 57 | 'size', 58 | 'seven', 59 | 'eight', 60 | 'nine', 61 | ] 62 | 63 | 64 | def _extract_images(filename, num_images): 65 | """Extract the images into a numpy array. 66 | 67 | Args: 68 | filename: The path to an MNIST images file. 69 | num_images: The number of images in the file. 70 | 71 | Returns: 72 | A numpy array of shape [number_of_images, height, width, channels]. 73 | """ 74 | print('Extracting images from: ', filename) 75 | with gzip.open(filename) as bytestream: 76 | bytestream.read(16) 77 | buf = bytestream.read( 78 | _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS) 79 | data = np.frombuffer(buf, dtype=np.uint8) 80 | data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS) 81 | return data 82 | 83 | 84 | def _extract_labels(filename, num_labels): 85 | """Extract the labels into a vector of int64 label IDs. 86 | 87 | Args: 88 | filename: The path to an MNIST labels file. 89 | num_labels: The number of labels in the file. 90 | 91 | Returns: 92 | A numpy array of shape [number_of_labels] 93 | """ 94 | print('Extracting labels from: ', filename) 95 | with gzip.open(filename) as bytestream: 96 | bytestream.read(8) 97 | buf = bytestream.read(1 * num_labels) 98 | labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64) 99 | return labels 100 | 101 | 102 | def _add_to_tfrecord(data_filename, labels_filename, num_images, 103 | tfrecord_writer): 104 | """Loads data from the binary MNIST files and writes files to a TFRecord. 105 | 106 | Args: 107 | data_filename: The filename of the MNIST images. 108 | labels_filename: The filename of the MNIST labels. 109 | num_images: The number of images in the dataset. 110 | tfrecord_writer: The TFRecord writer to use for writing. 111 | """ 112 | images = _extract_images(data_filename, num_images) 113 | labels = _extract_labels(labels_filename, num_images) 114 | 115 | shape = (_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS) 116 | with tf.Graph().as_default(): 117 | image = tf.placeholder(dtype=tf.uint8, shape=shape) 118 | encoded_png = tf.image.encode_png(image) 119 | 120 | with tf.Session('') as sess: 121 | for j in range(num_images): 122 | sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images)) 123 | sys.stdout.flush() 124 | 125 | png_string = sess.run(encoded_png, feed_dict={image: images[j]}) 126 | 127 | example = dataset_utils.image_to_tfexample( 128 | png_string, 'png'.encode(), _IMAGE_SIZE, _IMAGE_SIZE, labels[j]) 129 | tfrecord_writer.write(example.SerializeToString()) 130 | 131 | 132 | def _get_output_filename(dataset_dir, split_name): 133 | """Creates the output filename. 134 | 135 | Args: 136 | dataset_dir: The directory where the temporary files are stored. 137 | split_name: The name of the train/test split. 138 | 139 | Returns: 140 | An absolute file path. 141 | """ 142 | return '%s/mnist_%s.tfrecord' % (dataset_dir, split_name) 143 | 144 | 145 | def _download_dataset(dataset_dir): 146 | """Downloads MNIST locally. 147 | 148 | Args: 149 | dataset_dir: The directory where the temporary files are stored. 150 | """ 151 | for filename in [_TRAIN_DATA_FILENAME, 152 | _TRAIN_LABELS_FILENAME, 153 | _TEST_DATA_FILENAME, 154 | _TEST_LABELS_FILENAME]: 155 | filepath = os.path.join(dataset_dir, filename) 156 | 157 | if not os.path.exists(filepath): 158 | print('Downloading file %s...' % filename) 159 | def _progress(count, block_size, total_size): 160 | sys.stdout.write('\r>> Downloading %.1f%%' % ( 161 | float(count * block_size) / float(total_size) * 100.0)) 162 | sys.stdout.flush() 163 | filepath, _ = urllib.request.urlretrieve(_DATA_URL + filename, 164 | filepath, 165 | _progress) 166 | print() 167 | with tf.gfile.GFile(filepath) as f: 168 | size = f.size() 169 | print('Successfully downloaded', filename, size, 'bytes.') 170 | 171 | 172 | def _clean_up_temporary_files(dataset_dir): 173 | """Removes temporary files used to create the dataset. 174 | 175 | Args: 176 | dataset_dir: The directory where the temporary files are stored. 177 | """ 178 | for filename in [_TRAIN_DATA_FILENAME, 179 | _TRAIN_LABELS_FILENAME, 180 | _TEST_DATA_FILENAME, 181 | _TEST_LABELS_FILENAME]: 182 | filepath = os.path.join(dataset_dir, filename) 183 | tf.gfile.Remove(filepath) 184 | 185 | 186 | def run(dataset_dir): 187 | """Runs the download and conversion operation. 188 | 189 | Args: 190 | dataset_dir: The dataset directory where the dataset is stored. 191 | """ 192 | if not tf.gfile.Exists(dataset_dir): 193 | tf.gfile.MakeDirs(dataset_dir) 194 | 195 | training_filename = _get_output_filename(dataset_dir, 'train') 196 | testing_filename = _get_output_filename(dataset_dir, 'test') 197 | 198 | if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename): 199 | print('Dataset files already exist. Exiting without re-creating them.') 200 | return 201 | 202 | _download_dataset(dataset_dir) 203 | 204 | # First, process the training data: 205 | with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer: 206 | data_filename = os.path.join(dataset_dir, _TRAIN_DATA_FILENAME) 207 | labels_filename = os.path.join(dataset_dir, _TRAIN_LABELS_FILENAME) 208 | _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer) 209 | 210 | # Next, process the testing data: 211 | with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer: 212 | data_filename = os.path.join(dataset_dir, _TEST_DATA_FILENAME) 213 | labels_filename = os.path.join(dataset_dir, _TEST_LABELS_FILENAME) 214 | _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer) 215 | 216 | # Finally, write the labels file: 217 | labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) 218 | dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 219 | 220 | _clean_up_temporary_files(dataset_dir) 221 | print('\nFinished converting the MNIST dataset!') 222 | -------------------------------------------------------------------------------- /datasets/flowers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the flowers dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/datasets/download_and_convert_flowers.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'flowers_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 3320, 'validation': 350} 35 | 36 | _NUM_CLASSES = 5 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading flowers. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the ImageNet ILSVRC 2012 Dataset plus some bounding boxes. 16 | 17 | Some images have one or more bounding boxes associated with the label of the 18 | image. See details here: http://image-net.org/download-bboxes 19 | 20 | ImageNet is based upon WordNet 3.0. To uniquely identify a synset, we use 21 | "WordNet ID" (wnid), which is a concatenation of POS ( i.e. part of speech ) 22 | and SYNSET OFFSET of WordNet. For more information, please refer to the 23 | WordNet documentation[http://wordnet.princeton.edu/wordnet/documentation/]. 24 | 25 | "There are bounding boxes for over 3000 popular synsets available. 26 | For each synset, there are on average 150 images with bounding boxes." 27 | 28 | WARNING: Don't use for object detection, in this case all the bounding boxes 29 | of the image belong to just one class. 30 | """ 31 | from __future__ import absolute_import 32 | from __future__ import division 33 | from __future__ import print_function 34 | 35 | import os 36 | from six.moves import urllib 37 | import tensorflow as tf 38 | 39 | from datasets import dataset_utils 40 | 41 | slim = tf.contrib.slim 42 | 43 | # TODO(nsilberman): Add tfrecord file type once the script is updated. 44 | _FILE_PATTERN = '%s-*' 45 | 46 | _SPLITS_TO_SIZES = { 47 | 'train': 1281167, 48 | 'validation': 50000, 49 | } 50 | 51 | _ITEMS_TO_DESCRIPTIONS = { 52 | 'image': 'A color image of varying height and width.', 53 | 'label': 'The label id of the image, integer between 0 and 999', 54 | 'label_text': 'The text of the label.', 55 | 'object/bbox': 'A list of bounding boxes.', 56 | 'object/label': 'A list of labels, one per each object.', 57 | } 58 | 59 | _NUM_CLASSES = 1001 60 | 61 | 62 | def create_readable_names_for_imagenet_labels(): 63 | """Create a dict mapping label id to human readable string. 64 | 65 | Returns: 66 | labels_to_names: dictionary where keys are integers from to 1000 67 | and values are human-readable names. 68 | 69 | We retrieve a synset file, which contains a list of valid synset labels used 70 | by ILSVRC competition. There is one synset one per line, eg. 71 | # n01440764 72 | # n01443537 73 | We also retrieve a synset_to_human_file, which contains a mapping from synsets 74 | to human-readable names for every synset in Imagenet. These are stored in a 75 | tsv format, as follows: 76 | # n02119247 black fox 77 | # n02119359 silver fox 78 | We assign each synset (in alphabetical order) an integer, starting from 1 79 | (since 0 is reserved for the background class). 80 | 81 | Code is based on 82 | https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463 83 | """ 84 | 85 | # pylint: disable=g-line-too-long 86 | base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/inception/inception/data/' 87 | synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url) 88 | synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url) 89 | 90 | filename, _ = urllib.request.urlretrieve(synset_url) 91 | synset_list = [s.strip() for s in open(filename).readlines()] 92 | num_synsets_in_ilsvrc = len(synset_list) 93 | assert num_synsets_in_ilsvrc == 1000 94 | 95 | filename, _ = urllib.request.urlretrieve(synset_to_human_url) 96 | synset_to_human_list = open(filename).readlines() 97 | num_synsets_in_all_imagenet = len(synset_to_human_list) 98 | assert num_synsets_in_all_imagenet == 21842 99 | 100 | synset_to_human = {} 101 | for s in synset_to_human_list: 102 | parts = s.strip().split('\t') 103 | assert len(parts) == 2 104 | synset = parts[0] 105 | human = parts[1] 106 | synset_to_human[synset] = human 107 | 108 | label_index = 1 109 | labels_to_names = {0: 'background'} 110 | for synset in synset_list: 111 | name = synset_to_human[synset] 112 | labels_to_names[label_index] = name 113 | label_index += 1 114 | 115 | return labels_to_names 116 | 117 | 118 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 119 | """Gets a dataset tuple with instructions for reading ImageNet. 120 | 121 | Args: 122 | split_name: A train/test split name. 123 | dataset_dir: The base directory of the dataset sources. 124 | file_pattern: The file pattern to use when matching the dataset sources. 125 | It is assumed that the pattern contains a '%s' string so that the split 126 | name can be inserted. 127 | reader: The TensorFlow reader type. 128 | 129 | Returns: 130 | A `Dataset` namedtuple. 131 | 132 | Raises: 133 | ValueError: if `split_name` is not a valid train/test split. 134 | """ 135 | if split_name not in _SPLITS_TO_SIZES: 136 | raise ValueError('split name %s was not recognized.' % split_name) 137 | 138 | if not file_pattern: 139 | file_pattern = _FILE_PATTERN 140 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 141 | 142 | # Allowing None in the signature so that dataset_factory can use the default. 143 | if reader is None: 144 | reader = tf.TFRecordReader 145 | 146 | keys_to_features = { 147 | 'image/encoded': tf.FixedLenFeature( 148 | (), tf.string, default_value=''), 149 | 'image/format': tf.FixedLenFeature( 150 | (), tf.string, default_value='jpeg'), 151 | 'image/class/label': tf.FixedLenFeature( 152 | [], dtype=tf.int64, default_value=-1), 153 | 'image/class/text': tf.FixedLenFeature( 154 | [], dtype=tf.string, default_value=''), 155 | 'image/object/bbox/xmin': tf.VarLenFeature( 156 | dtype=tf.float32), 157 | 'image/object/bbox/ymin': tf.VarLenFeature( 158 | dtype=tf.float32), 159 | 'image/object/bbox/xmax': tf.VarLenFeature( 160 | dtype=tf.float32), 161 | 'image/object/bbox/ymax': tf.VarLenFeature( 162 | dtype=tf.float32), 163 | 'image/object/class/label': tf.VarLenFeature( 164 | dtype=tf.int64), 165 | } 166 | 167 | items_to_handlers = { 168 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 169 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 170 | 'label_text': slim.tfexample_decoder.Tensor('image/class/text'), 171 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 172 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 173 | 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'), 174 | } 175 | 176 | decoder = slim.tfexample_decoder.TFExampleDecoder( 177 | keys_to_features, items_to_handlers) 178 | 179 | labels_to_names = None 180 | if dataset_utils.has_labels(dataset_dir): 181 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 182 | else: 183 | labels_to_names = create_readable_names_for_imagenet_labels() 184 | dataset_utils.write_label_file(labels_to_names, dataset_dir) 185 | 186 | return slim.dataset.Dataset( 187 | data_sources=file_pattern, 188 | reader=reader, 189 | decoder=decoder, 190 | num_samples=_SPLITS_TO_SIZES[split_name], 191 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 192 | num_classes=_NUM_CLASSES, 193 | labels_to_names=labels_to_names) 194 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the MNIST dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/slim/datasets/download_and_convert_mnist.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'mnist_%s.tfrecord' 33 | 34 | _SPLITS_TO_SIZES = {'train': 60000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [28 x 28 x 1] grayscale image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading MNIST. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in _SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[28, 28, 1], channels=1), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=_SPLITS_TO_SIZES[split_name], 96 | num_classes=_NUM_CLASSES, 97 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deployment/model_deploy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Deploy Slim models across multiple clones and replicas. 16 | 17 | # TODO(sguada) docstring paragraph by (a) motivating the need for the file and 18 | # (b) defining clones. 19 | 20 | # TODO(sguada) describe the high-level components of model deployment. 21 | # E.g. "each model deployment is composed of several parts: a DeploymentConfig, 22 | # which captures A, B and C, an input_fn which loads data.. etc 23 | 24 | To easily train a model on multiple GPUs or across multiple machines this 25 | module provides a set of helper functions: `create_clones`, 26 | `optimize_clones` and `deploy`. 27 | 28 | Usage: 29 | 30 | g = tf.Graph() 31 | 32 | # Set up DeploymentConfig 33 | config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True) 34 | 35 | # Create the global step on the device storing the variables. 36 | with tf.device(config.variables_device()): 37 | global_step = slim.create_global_step() 38 | 39 | # Define the inputs 40 | with tf.device(config.inputs_device()): 41 | images, labels = LoadData(...) 42 | inputs_queue = slim.data.prefetch_queue((images, labels)) 43 | 44 | # Define the optimizer. 45 | with tf.device(config.optimizer_device()): 46 | optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) 47 | 48 | # Define the model including the loss. 49 | def model_fn(inputs_queue): 50 | images, labels = inputs_queue.dequeue() 51 | predictions = CreateNetwork(images) 52 | slim.losses.log_loss(predictions, labels) 53 | 54 | model_dp = model_deploy.deploy(config, model_fn, [inputs_queue], 55 | optimizer=optimizer) 56 | 57 | # Run training. 58 | slim.learning.train(model_dp.train_op, my_log_dir, 59 | summary_op=model_dp.summary_op) 60 | 61 | The Clone namedtuple holds together the values associated with each call to 62 | model_fn: 63 | * outputs: The return values of the calls to `model_fn()`. 64 | * scope: The scope used to create the clone. 65 | * device: The device used to create the clone. 66 | 67 | DeployedModel namedtuple, holds together the values needed to train multiple 68 | clones: 69 | * train_op: An operation that run the optimizer training op and include 70 | all the update ops created by `model_fn`. Present only if an optimizer 71 | was specified. 72 | * summary_op: An operation that run the summaries created by `model_fn` 73 | and process_gradients. 74 | * total_loss: A `Tensor` that contains the sum of all losses created by 75 | `model_fn` plus the regularization losses. 76 | * clones: List of `Clone` tuples returned by `create_clones()`. 77 | 78 | DeploymentConfig parameters: 79 | * num_clones: Number of model clones to deploy in each replica. 80 | * clone_on_cpu: True if clones should be placed on CPU. 81 | * replica_id: Integer. Index of the replica for which the model is 82 | deployed. Usually 0 for the chief replica. 83 | * num_replicas: Number of replicas to use. 84 | * num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 85 | * worker_job_name: A name for the worker job. 86 | * ps_job_name: A name for the parameter server job. 87 | 88 | TODO(sguada): 89 | - describe side effect to the graph. 90 | - what happens to summaries and update_ops. 91 | - which graph collections are altered. 92 | - write a tutorial on how to use this. 93 | - analyze the possibility of calling deploy more than once. 94 | 95 | 96 | """ 97 | 98 | from __future__ import absolute_import 99 | from __future__ import division 100 | from __future__ import print_function 101 | 102 | import collections 103 | 104 | import tensorflow as tf 105 | 106 | slim = tf.contrib.slim 107 | 108 | 109 | __all__ = ['create_clones', 110 | 'deploy', 111 | 'optimize_clones', 112 | 'DeployedModel', 113 | 'DeploymentConfig', 114 | 'Clone', 115 | ] 116 | 117 | 118 | # Namedtuple used to represent a clone during deployment. 119 | Clone = collections.namedtuple('Clone', 120 | ['outputs', # Whatever model_fn() returned. 121 | 'scope', # The scope used to create it. 122 | 'device', # The device used to create. 123 | ]) 124 | 125 | # Namedtuple used to represent a DeployedModel, returned by deploy(). 126 | DeployedModel = collections.namedtuple('DeployedModel', 127 | ['train_op', # The `train_op` 128 | 'summary_op', # The `summary_op` 129 | 'total_loss', # The loss `Tensor` 130 | 'clones', # A list of `Clones` tuples. 131 | ]) 132 | 133 | # Default parameters for DeploymentConfig 134 | _deployment_params = {'num_clones': 1, 135 | 'clone_on_cpu': False, 136 | 'replica_id': 0, 137 | 'num_replicas': 1, 138 | 'num_ps_tasks': 0, 139 | 'worker_job_name': 'worker', 140 | 'ps_job_name': 'ps'} 141 | 142 | 143 | def create_clones(config, model_fn, args=None, kwargs=None): 144 | """Creates multiple clones according to config using a `model_fn`. 145 | 146 | The returned values of `model_fn(*args, **kwargs)` are collected along with 147 | the scope and device used to created it in a namedtuple 148 | `Clone(outputs, scope, device)` 149 | 150 | Note: it is assumed that any loss created by `model_fn` is collected at 151 | the tf.GraphKeys.LOSSES collection. 152 | 153 | To recover the losses, summaries or update_ops created by the clone use: 154 | ```python 155 | losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 156 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope) 157 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope) 158 | ``` 159 | 160 | The deployment options are specified by the config object and support 161 | deploying one or several clones on different GPUs and one or several replicas 162 | of such clones. 163 | 164 | The argument `model_fn` is called `config.num_clones` times to create the 165 | model clones as `model_fn(*args, **kwargs)`. 166 | 167 | If `config` specifies deployment on multiple replicas then the default 168 | tensorflow device is set appropriatly for each call to `model_fn` and for the 169 | slim variable creation functions: model and global variables will be created 170 | on the `ps` device, the clone operations will be on the `worker` device. 171 | 172 | Args: 173 | config: A DeploymentConfig object. 174 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 175 | args: Optional list of arguments to pass to `model_fn`. 176 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 177 | 178 | Returns: 179 | A list of namedtuples `Clone`. 180 | """ 181 | clones = [] 182 | args = args or [] 183 | kwargs = kwargs or {} 184 | with slim.arg_scope([slim.model_variable, slim.variable], 185 | device=config.variables_device()): 186 | # Create clones. 187 | for i in range(0, config.num_clones): 188 | with tf.name_scope(config.clone_scope(i)) as clone_scope: 189 | clone_device = config.clone_device(i) 190 | with tf.device(clone_device): 191 | with tf.variable_scope(tf.get_variable_scope(), 192 | reuse=True if i > 0 else None): 193 | outputs = model_fn(*args, **kwargs) 194 | clones.append(Clone(outputs, clone_scope, clone_device)) 195 | return clones 196 | 197 | 198 | def _gather_clone_loss(clone, num_clones, regularization_losses): 199 | """Gather the loss for a single clone. 200 | 201 | Args: 202 | clone: A Clone namedtuple. 203 | num_clones: The number of clones being deployed. 204 | regularization_losses: Possibly empty list of regularization_losses 205 | to add to the clone losses. 206 | 207 | Returns: 208 | A tensor for the total loss for the clone. Can be None. 209 | """ 210 | # The return value. 211 | sum_loss = None 212 | # Individual components of the loss that will need summaries. 213 | clone_loss = None 214 | regularization_loss = None 215 | # Compute and aggregate losses on the clone device. 216 | with tf.device(clone.device): 217 | all_losses = [] 218 | clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 219 | if clone_losses: 220 | clone_loss = tf.add_n(clone_losses, name='clone_loss') 221 | if num_clones > 1: 222 | clone_loss = tf.div(clone_loss, 1.0 * num_clones, 223 | name='scaled_clone_loss') 224 | all_losses.append(clone_loss) 225 | if regularization_losses: 226 | regularization_loss = tf.add_n(regularization_losses, 227 | name='regularization_loss') 228 | all_losses.append(regularization_loss) 229 | if all_losses: 230 | sum_loss = tf.add_n(all_losses) 231 | # Add the summaries out of the clone device block. 232 | if clone_loss is not None: 233 | tf.summary.scalar(clone.scope + '/clone_loss', clone_loss) 234 | if regularization_loss is not None: 235 | tf.summary.scalar('regularization_loss', regularization_loss) 236 | return sum_loss 237 | 238 | 239 | def _optimize_clone(optimizer, clone, num_clones, regularization_losses, 240 | **kwargs): 241 | """Compute losses and gradients for a single clone. 242 | 243 | Args: 244 | optimizer: A tf.Optimizer object. 245 | clone: A Clone namedtuple. 246 | num_clones: The number of clones being deployed. 247 | regularization_losses: Possibly empty list of regularization_losses 248 | to add to the clone losses. 249 | **kwargs: Dict of kwarg to pass to compute_gradients(). 250 | 251 | Returns: 252 | A tuple (clone_loss, clone_grads_and_vars). 253 | - clone_loss: A tensor for the total loss for the clone. Can be None. 254 | - clone_grads_and_vars: List of (gradient, variable) for the clone. 255 | Can be empty. 256 | """ 257 | sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses) 258 | clone_grad = None 259 | if sum_loss is not None: 260 | with tf.device(clone.device): 261 | clone_grad = optimizer.compute_gradients(sum_loss, **kwargs) 262 | return sum_loss, clone_grad 263 | 264 | 265 | def optimize_clones(clones, optimizer, 266 | regularization_losses=None, 267 | **kwargs): 268 | """Compute clone losses and gradients for the given list of `Clones`. 269 | 270 | Note: The regularization_losses are added to the first clone losses. 271 | 272 | Args: 273 | clones: List of `Clones` created by `create_clones()`. 274 | optimizer: An `Optimizer` object. 275 | regularization_losses: Optional list of regularization losses. If None it 276 | will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to 277 | exclude them. 278 | **kwargs: Optional list of keyword arguments to pass to `compute_gradients`. 279 | 280 | Returns: 281 | A tuple (total_loss, grads_and_vars). 282 | - total_loss: A Tensor containing the average of the clone losses including 283 | the regularization loss. 284 | - grads_and_vars: A List of tuples (gradient, variable) containing the sum 285 | of the gradients for each variable. 286 | 287 | """ 288 | grads_and_vars = [] 289 | clones_losses = [] 290 | num_clones = len(clones) 291 | if regularization_losses is None: 292 | regularization_losses = tf.get_collection( 293 | tf.GraphKeys.REGULARIZATION_LOSSES) 294 | for clone in clones: 295 | with tf.name_scope(clone.scope): 296 | clone_loss, clone_grad = _optimize_clone( 297 | optimizer, clone, num_clones, regularization_losses, **kwargs) 298 | if clone_loss is not None: 299 | clones_losses.append(clone_loss) 300 | grads_and_vars.append(clone_grad) 301 | # Only use regularization_losses for the first clone 302 | regularization_losses = None 303 | # Compute the total_loss summing all the clones_losses. 304 | total_loss = tf.add_n(clones_losses, name='total_loss') 305 | # Sum the gradients across clones. 306 | grads_and_vars = _sum_clones_gradients(grads_and_vars) 307 | return total_loss, grads_and_vars 308 | 309 | 310 | def deploy(config, 311 | model_fn, 312 | args=None, 313 | kwargs=None, 314 | optimizer=None, 315 | summarize_gradients=False): 316 | """Deploys a Slim-constructed model across multiple clones. 317 | 318 | The deployment options are specified by the config object and support 319 | deploying one or several clones on different GPUs and one or several replicas 320 | of such clones. 321 | 322 | The argument `model_fn` is called `config.num_clones` times to create the 323 | model clones as `model_fn(*args, **kwargs)`. 324 | 325 | The optional argument `optimizer` is an `Optimizer` object. If not `None`, 326 | the deployed model is configured for training with that optimizer. 327 | 328 | If `config` specifies deployment on multiple replicas then the default 329 | tensorflow device is set appropriatly for each call to `model_fn` and for the 330 | slim variable creation functions: model and global variables will be created 331 | on the `ps` device, the clone operations will be on the `worker` device. 332 | 333 | Args: 334 | config: A `DeploymentConfig` object. 335 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 336 | args: Optional list of arguments to pass to `model_fn`. 337 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 338 | optimizer: Optional `Optimizer` object. If passed the model is deployed 339 | for training with that optimizer. 340 | summarize_gradients: Whether or not add summaries to the gradients. 341 | 342 | Returns: 343 | A `DeployedModel` namedtuple. 344 | 345 | """ 346 | # Gather initial summaries. 347 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 348 | 349 | # Create Clones. 350 | clones = create_clones(config, model_fn, args, kwargs) 351 | first_clone = clones[0] 352 | 353 | # Gather update_ops from the first clone. These contain, for example, 354 | # the updates for the batch_norm variables created by model_fn. 355 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope) 356 | 357 | train_op = None 358 | total_loss = None 359 | with tf.device(config.optimizer_device()): 360 | if optimizer: 361 | # Place the global step on the device storing the variables. 362 | with tf.device(config.variables_device()): 363 | global_step = slim.get_or_create_global_step() 364 | 365 | # Compute the gradients for the clones. 366 | total_loss, clones_gradients = optimize_clones(clones, optimizer) 367 | 368 | if clones_gradients: 369 | if summarize_gradients: 370 | # Add summaries to the gradients. 371 | summaries |= set(_add_gradients_summaries(clones_gradients)) 372 | 373 | # Create gradient updates. 374 | grad_updates = optimizer.apply_gradients(clones_gradients, 375 | global_step=global_step) 376 | update_ops.append(grad_updates) 377 | 378 | update_op = tf.group(*update_ops) 379 | with tf.control_dependencies([update_op]): 380 | train_op = tf.identity(total_loss, name='train_op') 381 | else: 382 | clones_losses = [] 383 | regularization_losses = tf.get_collection( 384 | tf.GraphKeys.REGULARIZATION_LOSSES) 385 | for clone in clones: 386 | with tf.name_scope(clone.scope): 387 | clone_loss = _gather_clone_loss(clone, len(clones), 388 | regularization_losses) 389 | if clone_loss is not None: 390 | clones_losses.append(clone_loss) 391 | # Only use regularization_losses for the first clone 392 | regularization_losses = None 393 | if clones_losses: 394 | total_loss = tf.add_n(clones_losses, name='total_loss') 395 | 396 | # Add the summaries from the first clone. These contain the summaries 397 | # created by model_fn and either optimize_clones() or _gather_clone_loss(). 398 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, 399 | first_clone.scope)) 400 | 401 | if total_loss is not None: 402 | # Add total_loss to summary. 403 | summaries.add(tf.summary.scalar('total_loss', total_loss)) 404 | 405 | if summaries: 406 | # Merge all summaries together. 407 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 408 | else: 409 | summary_op = None 410 | 411 | return DeployedModel(train_op, summary_op, total_loss, clones) 412 | 413 | 414 | def _sum_clones_gradients(clone_grads): 415 | """Calculate the sum gradient for each shared variable across all clones. 416 | 417 | This function assumes that the clone_grads has been scaled appropriately by 418 | 1 / num_clones. 419 | 420 | Args: 421 | clone_grads: A List of List of tuples (gradient, variable), one list per 422 | `Clone`. 423 | 424 | Returns: 425 | List of tuples of (gradient, variable) where the gradient has been summed 426 | across all clones. 427 | """ 428 | sum_grads = [] 429 | for grad_and_vars in zip(*clone_grads): 430 | # Note that each grad_and_vars looks like the following: 431 | # ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN)) 432 | grads = [] 433 | var = grad_and_vars[0][1] 434 | for g, v in grad_and_vars: 435 | assert v == var 436 | if g is not None: 437 | grads.append(g) 438 | if grads: 439 | if len(grads) > 1: 440 | sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads') 441 | else: 442 | sum_grad = grads[0] 443 | sum_grads.append((sum_grad, var)) 444 | return sum_grads 445 | 446 | 447 | def _add_gradients_summaries(grads_and_vars): 448 | """Add histogram summaries to gradients. 449 | 450 | Note: The summaries are also added to the SUMMARIES collection. 451 | 452 | Args: 453 | grads_and_vars: A list of gradient to variable pairs (tuples). 454 | 455 | Returns: 456 | The _list_ of the added summaries for grads_and_vars. 457 | """ 458 | summaries = [] 459 | for grad, var in grads_and_vars: 460 | if grad is not None: 461 | if isinstance(grad, tf.IndexedSlices): 462 | grad_values = grad.values 463 | else: 464 | grad_values = grad 465 | summaries.append(tf.summary.histogram(var.op.name + ':gradient', 466 | grad_values)) 467 | summaries.append(tf.summary.histogram(var.op.name + ':gradient_norm', 468 | tf.global_norm([grad_values]))) 469 | else: 470 | tf.logging.info('Var %s has no gradient', var.op.name) 471 | return summaries 472 | 473 | 474 | class DeploymentConfig(object): 475 | """Configuration for deploying a model with `deploy()`. 476 | 477 | You can pass an instance of this class to `deploy()` to specify exactly 478 | how to deploy the model to build. If you do not pass one, an instance built 479 | from the default deployment_hparams will be used. 480 | """ 481 | 482 | def __init__(self, 483 | num_clones=1, 484 | clone_on_cpu=False, 485 | replica_id=0, 486 | num_replicas=1, 487 | num_ps_tasks=0, 488 | worker_job_name='worker', 489 | ps_job_name='ps'): 490 | """Create a DeploymentConfig. 491 | 492 | The config describes how to deploy a model across multiple clones and 493 | replicas. The model will be replicated `num_clones` times in each replica. 494 | If `clone_on_cpu` is True, each clone will placed on CPU. 495 | 496 | If `num_replicas` is 1, the model is deployed via a single process. In that 497 | case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored. 498 | 499 | If `num_replicas` is greater than 1, then `worker_device` and `ps_device` 500 | must specify TensorFlow devices for the `worker` and `ps` jobs and 501 | `num_ps_tasks` must be positive. 502 | 503 | Args: 504 | num_clones: Number of model clones to deploy in each replica. 505 | clone_on_cpu: If True clones would be placed on CPU. 506 | replica_id: Integer. Index of the replica for which the model is 507 | deployed. Usually 0 for the chief replica. 508 | num_replicas: Number of replicas to use. 509 | num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 510 | worker_job_name: A name for the worker job. 511 | ps_job_name: A name for the parameter server job. 512 | 513 | Raises: 514 | ValueError: If the arguments are invalid. 515 | """ 516 | if num_replicas > 1: 517 | if num_ps_tasks < 1: 518 | raise ValueError('When using replicas num_ps_tasks must be positive') 519 | if num_replicas > 1 or num_ps_tasks > 0: 520 | if not worker_job_name: 521 | raise ValueError('Must specify worker_job_name when using replicas') 522 | if not ps_job_name: 523 | raise ValueError('Must specify ps_job_name when using parameter server') 524 | if replica_id >= num_replicas: 525 | raise ValueError('replica_id must be less than num_replicas') 526 | self._num_clones = num_clones 527 | self._clone_on_cpu = clone_on_cpu 528 | self._replica_id = replica_id 529 | self._num_replicas = num_replicas 530 | self._num_ps_tasks = num_ps_tasks 531 | self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else '' 532 | self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else '' 533 | 534 | @property 535 | def num_clones(self): 536 | return self._num_clones 537 | 538 | @property 539 | def clone_on_cpu(self): 540 | return self._clone_on_cpu 541 | 542 | @property 543 | def replica_id(self): 544 | return self._replica_id 545 | 546 | @property 547 | def num_replicas(self): 548 | return self._num_replicas 549 | 550 | @property 551 | def num_ps_tasks(self): 552 | return self._num_ps_tasks 553 | 554 | @property 555 | def ps_device(self): 556 | return self._ps_device 557 | 558 | @property 559 | def worker_device(self): 560 | return self._worker_device 561 | 562 | def caching_device(self): 563 | """Returns the device to use for caching variables. 564 | 565 | Variables are cached on the worker CPU when using replicas. 566 | 567 | Returns: 568 | A device string or None if the variables do not need to be cached. 569 | """ 570 | if self._num_ps_tasks > 0: 571 | return lambda op: op.device 572 | else: 573 | return None 574 | 575 | def clone_device(self, clone_index): 576 | """Device used to create the clone and all the ops inside the clone. 577 | 578 | Args: 579 | clone_index: Int, representing the clone_index. 580 | 581 | Returns: 582 | A value suitable for `tf.device()`. 583 | 584 | Raises: 585 | ValueError: if `clone_index` is greater or equal to the number of clones". 586 | """ 587 | if clone_index >= self._num_clones: 588 | raise ValueError('clone_index must be less than num_clones') 589 | device = '' 590 | if self._num_ps_tasks > 0: 591 | device += self._worker_device 592 | if self._clone_on_cpu: 593 | device += '/device:CPU:0' 594 | else: 595 | device += '/device:GPU:%d' % clone_index 596 | return device 597 | 598 | def clone_scope(self, clone_index): 599 | """Name scope to create the clone. 600 | 601 | Args: 602 | clone_index: Int, representing the clone_index. 603 | 604 | Returns: 605 | A name_scope suitable for `tf.name_scope()`. 606 | 607 | Raises: 608 | ValueError: if `clone_index` is greater or equal to the number of clones". 609 | """ 610 | if clone_index >= self._num_clones: 611 | raise ValueError('clone_index must be less than num_clones') 612 | scope = '' 613 | if self._num_clones > 1: 614 | scope = 'clone_%d' % clone_index 615 | return scope 616 | 617 | def optimizer_device(self): 618 | """Device to use with the optimizer. 619 | 620 | Returns: 621 | A value suitable for `tf.device()`. 622 | """ 623 | if self._num_ps_tasks > 0 or self._num_clones > 0: 624 | return self._worker_device + '/device:CPU:0' 625 | else: 626 | return '' 627 | 628 | def inputs_device(self): 629 | """Device to use to build the inputs. 630 | 631 | Returns: 632 | A value suitable for `tf.device()`. 633 | """ 634 | device = '' 635 | if self._num_ps_tasks > 0: 636 | device += self._worker_device 637 | device += '/device:CPU:0' 638 | return device 639 | 640 | def variables_device(self): 641 | """Returns the device to use for variables created inside the clone. 642 | 643 | Returns: 644 | A value suitable for `tf.device()`. 645 | """ 646 | device = '' 647 | if self._num_ps_tasks > 0: 648 | device += self._ps_device 649 | device += '/device:CPU:0' 650 | 651 | class _PSDeviceChooser(object): 652 | """Slim device chooser for variables when using PS.""" 653 | 654 | def __init__(self, device, tasks): 655 | self._device = device 656 | self._tasks = tasks 657 | self._task = 0 658 | 659 | def choose(self, op): 660 | if op.device: 661 | return op.device 662 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 663 | if node_def.op.startswith('Variable'): 664 | t = self._task 665 | self._task = (self._task + 1) % self._tasks 666 | d = '%s/task:%d' % (self._device, t) 667 | return d 668 | else: 669 | return op.device 670 | 671 | if not self._num_ps_tasks: 672 | return device 673 | else: 674 | chooser = _PSDeviceChooser(device, self._num_ps_tasks) 675 | return chooser.choose 676 | -------------------------------------------------------------------------------- /download_and_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import download_and_convert_cifar10 40 | from datasets import download_and_convert_flowers 41 | from datasets import download_and_convert_mnist 42 | 43 | FLAGS = tf.app.flags.FLAGS 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_name', 47 | None, 48 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 49 | 50 | tf.app.flags.DEFINE_string( 51 | 'dataset_dir', 52 | None, 53 | 'The directory where the output TFRecords and temporary files are saved.') 54 | 55 | 56 | def main(_): 57 | if not FLAGS.dataset_name: 58 | raise ValueError('You must supply the dataset name with --dataset_name') 59 | if not FLAGS.dataset_dir: 60 | raise ValueError('You must supply the dataset directory with --dataset_dir') 61 | 62 | if FLAGS.dataset_name == 'cifar10': 63 | download_and_convert_cifar10.run(FLAGS.dataset_dir) 64 | elif FLAGS.dataset_name == 'flowers': 65 | download_and_convert_flowers.run(FLAGS.dataset_dir) 66 | elif FLAGS.dataset_name == 'mnist': 67 | download_and_convert_mnist.run(FLAGS.dataset_dir) 68 | else: 69 | raise ValueError( 70 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_dir) 71 | 72 | if __name__ == '__main__': 73 | tf.app.run() 74 | 75 | -------------------------------------------------------------------------------- /eval_image_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Generic evaluation script that evaluates a model using a given dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import nets_factory 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | tf.app.flags.DEFINE_integer( 31 | 'batch_size', 100, 'The number of samples in each batch.') 32 | 33 | tf.app.flags.DEFINE_integer( 34 | 'max_num_batches', None, 35 | 'Max number of batches to evaluate by default use all.') 36 | 37 | tf.app.flags.DEFINE_string( 38 | 'master', '', 'The address of the TensorFlow master to use.') 39 | 40 | tf.app.flags.DEFINE_string( 41 | 'checkpoint_path', '/tmp/tfmodel/', 42 | 'The directory where the model was written to or an absolute path to a ' 43 | 'checkpoint file.') 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.') 47 | 48 | tf.app.flags.DEFINE_integer( 49 | 'num_preprocessing_threads', 4, 50 | 'The number of threads used to create the batches.') 51 | 52 | tf.app.flags.DEFINE_string( 53 | 'dataset_name', 'imagenet', 'The name of the dataset to load.') 54 | 55 | tf.app.flags.DEFINE_string( 56 | 'dataset_split_name', 'test', 'The name of the train/test split.') 57 | 58 | tf.app.flags.DEFINE_string( 59 | 'dataset_dir', None, 'The directory where the dataset files are stored.') 60 | 61 | tf.app.flags.DEFINE_integer( 62 | 'labels_offset', 0, 63 | 'An offset for the labels in the dataset. This flag is primarily used to ' 64 | 'evaluate the VGG and ResNet architectures which do not use a background ' 65 | 'class for the ImageNet dataset.') 66 | 67 | tf.app.flags.DEFINE_string( 68 | 'model_name', 'densenet121', 'The name of the architecture to evaluate.') 69 | 70 | tf.app.flags.DEFINE_string( 71 | 'data_format', 'NHWC', 'The structure of the Tensor. NHWC or NCHW.') 72 | 73 | tf.app.flags.DEFINE_string( 74 | 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 75 | 'as `None`, then the model_name flag is used.') 76 | 77 | tf.app.flags.DEFINE_float( 78 | 'moving_average_decay', None, 79 | 'The decay to use for the moving average.' 80 | 'If left as None, then moving averages are not used.') 81 | 82 | tf.app.flags.DEFINE_integer( 83 | 'eval_image_size', None, 'Eval image size') 84 | 85 | FLAGS = tf.app.flags.FLAGS 86 | 87 | 88 | def main(_): 89 | if not FLAGS.dataset_dir: 90 | raise ValueError('You must supply the dataset directory with --dataset_dir') 91 | 92 | tf.logging.set_verbosity(tf.logging.INFO) 93 | with tf.Graph().as_default(): 94 | tf_global_step = slim.get_or_create_global_step() 95 | 96 | ###################### 97 | # Select the dataset # 98 | ###################### 99 | dataset = dataset_factory.get_dataset( 100 | FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 101 | 102 | #################### 103 | # Select the model # 104 | #################### 105 | network_fn = nets_factory.get_network_fn( 106 | FLAGS.model_name, 107 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 108 | data_format=FLAGS.data_format, 109 | is_training=False) 110 | 111 | ############################################################## 112 | # Create a dataset provider that loads data from the dataset # 113 | ############################################################## 114 | provider = slim.dataset_data_provider.DatasetDataProvider( 115 | dataset, 116 | shuffle=False, 117 | common_queue_capacity=2 * FLAGS.batch_size, 118 | common_queue_min=FLAGS.batch_size) 119 | [image, label] = provider.get(['image', 'label']) 120 | label -= FLAGS.labels_offset 121 | 122 | ##################################### 123 | # Select the preprocessing function # 124 | ##################################### 125 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 126 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 127 | preprocessing_name, 128 | is_training=False) 129 | 130 | eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size 131 | 132 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size) 133 | 134 | images, labels = tf.train.batch( 135 | [image, label], 136 | batch_size=FLAGS.batch_size, 137 | num_threads=FLAGS.num_preprocessing_threads, 138 | capacity=5 * FLAGS.batch_size) 139 | 140 | #################### 141 | # Define the model # 142 | #################### 143 | logits, _ = network_fn(images) 144 | 145 | if FLAGS.moving_average_decay: 146 | variable_averages = tf.train.ExponentialMovingAverage( 147 | FLAGS.moving_average_decay, tf_global_step) 148 | variables_to_restore = variable_averages.variables_to_restore( 149 | slim.get_model_variables()) 150 | variables_to_restore[tf_global_step.op.name] = tf_global_step 151 | else: 152 | variables_to_restore = slim.get_variables_to_restore() 153 | 154 | logits = tf.squeeze(logits) 155 | predictions = tf.argmax(logits, 1) 156 | labels = tf.squeeze(labels) 157 | 158 | # Define the metrics: 159 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 160 | 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels), 161 | 'Recall_5': slim.metrics.streaming_recall_at_k( 162 | logits, labels, 5), 163 | }) 164 | 165 | # Print the summaries to screen. 166 | print_ops = [] 167 | for name, value in names_to_values.items(): 168 | summary_name = 'eval/%s' % name 169 | op = tf.summary.scalar(summary_name, value, collections=[]) 170 | op = tf.Print(op, [value], summary_name) 171 | print_ops.append(tf.Print(value, [value], summary_name)) 172 | tf.add_to_collection(tf.GraphKeys.SUMMARIES, op) 173 | 174 | # TODO(sguada) use num_epochs=1 175 | if FLAGS.max_num_batches: 176 | num_batches = FLAGS.max_num_batches 177 | else: 178 | # This ensures that we make a single pass over all of the data. 179 | num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size)) 180 | 181 | if tf.gfile.IsDirectory(FLAGS.checkpoint_path): 182 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 183 | else: 184 | checkpoint_path = FLAGS.checkpoint_path 185 | 186 | tf.logging.info('Evaluating %s' % checkpoint_path) 187 | 188 | slim.evaluation.evaluate_once( 189 | master=FLAGS.master, 190 | checkpoint_path=checkpoint_path, 191 | logdir=FLAGS.eval_dir, 192 | num_evals=num_batches, 193 | eval_op=list(names_to_updates.values()) + print_ops, 194 | variables_to_restore=variables_to_restore) 195 | 196 | 197 | if __name__ == '__main__': 198 | tf.app.run() 199 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pudae/tensorflow-densenet/d664b5ceab05466cbf36b4dd95602437cc61e0a8/nets/__init__.py -------------------------------------------------------------------------------- /nets/densenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 pudae. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition of the DenseNet architecture. 16 | 17 | As described in https://arxiv.org/abs/1608.06993. 18 | 19 | Densely Connected Convolutional Networks 20 | Gao Huang, Zhuang Liu, Kilian Q. Weinberger, Laurens van der Maaten 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | @slim.add_arg_scope 32 | def _global_avg_pool2d(inputs, data_format='NHWC', scope=None, outputs_collections=None): 33 | with tf.variable_scope(scope, 'xx', [inputs]) as sc: 34 | axis = [1, 2] if data_format == 'NHWC' else [2, 3] 35 | net = tf.reduce_mean(inputs, axis=axis, keep_dims=True) 36 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 37 | return net 38 | 39 | 40 | @slim.add_arg_scope 41 | def _conv(inputs, num_filters, kernel_size, stride=1, dropout_rate=None, 42 | scope=None, outputs_collections=None): 43 | with tf.variable_scope(scope, 'xx', [inputs]) as sc: 44 | net = slim.batch_norm(inputs) 45 | net = tf.nn.relu(net) 46 | net = slim.conv2d(net, num_filters, kernel_size) 47 | 48 | if dropout_rate: 49 | net = tf.nn.dropout(net) 50 | 51 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 52 | 53 | return net 54 | 55 | 56 | @slim.add_arg_scope 57 | def _conv_block(inputs, num_filters, data_format='NHWC', scope=None, outputs_collections=None): 58 | with tf.variable_scope(scope, 'conv_blockx', [inputs]) as sc: 59 | net = inputs 60 | net = _conv(net, num_filters*4, 1, scope='x1') 61 | net = _conv(net, num_filters, 3, scope='x2') 62 | if data_format == 'NHWC': 63 | net = tf.concat([inputs, net], axis=3) 64 | else: # "NCHW" 65 | net = tf.concat([inputs, net], axis=1) 66 | 67 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 68 | 69 | return net 70 | 71 | 72 | @slim.add_arg_scope 73 | def _dense_block(inputs, num_layers, num_filters, growth_rate, 74 | grow_num_filters=True, scope=None, outputs_collections=None): 75 | 76 | with tf.variable_scope(scope, 'dense_blockx', [inputs]) as sc: 77 | net = inputs 78 | for i in range(num_layers): 79 | branch = i + 1 80 | net = _conv_block(net, growth_rate, scope='conv_block'+str(branch)) 81 | 82 | if grow_num_filters: 83 | num_filters += growth_rate 84 | 85 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 86 | 87 | return net, num_filters 88 | 89 | 90 | @slim.add_arg_scope 91 | def _transition_block(inputs, num_filters, compression=1.0, 92 | scope=None, outputs_collections=None): 93 | 94 | num_filters = int(num_filters * compression) 95 | with tf.variable_scope(scope, 'transition_blockx', [inputs]) as sc: 96 | net = inputs 97 | net = _conv(net, num_filters, 1, scope='blk') 98 | 99 | net = slim.avg_pool2d(net, 2) 100 | 101 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 102 | 103 | return net, num_filters 104 | 105 | 106 | def densenet(inputs, 107 | num_classes=1000, 108 | reduction=None, 109 | growth_rate=None, 110 | num_filters=None, 111 | num_layers=None, 112 | dropout_rate=None, 113 | data_format='NHWC', 114 | is_training=True, 115 | reuse=None, 116 | scope=None): 117 | assert reduction is not None 118 | assert growth_rate is not None 119 | assert num_filters is not None 120 | assert num_layers is not None 121 | 122 | compression = 1.0 - reduction 123 | num_dense_blocks = len(num_layers) 124 | 125 | if data_format == 'NCHW': 126 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 127 | 128 | with tf.variable_scope(scope, 'densenetxxx', [inputs, num_classes], 129 | reuse=reuse) as sc: 130 | end_points_collection = sc.name + '_end_points' 131 | with slim.arg_scope([slim.batch_norm, slim.dropout], 132 | is_training=is_training), \ 133 | slim.arg_scope([slim.conv2d, _conv, _conv_block, 134 | _dense_block, _transition_block], 135 | outputs_collections=end_points_collection), \ 136 | slim.arg_scope([_conv], dropout_rate=dropout_rate): 137 | net = inputs 138 | 139 | # initial convolution 140 | net = slim.conv2d(net, num_filters, 7, stride=2, scope='conv1') 141 | net = slim.batch_norm(net) 142 | net = tf.nn.relu(net) 143 | net = slim.max_pool2d(net, 3, stride=2, padding='SAME') 144 | 145 | # blocks 146 | for i in range(num_dense_blocks - 1): 147 | # dense blocks 148 | net, num_filters = _dense_block(net, num_layers[i], num_filters, 149 | growth_rate, 150 | scope='dense_block' + str(i+1)) 151 | 152 | # Add transition_block 153 | net, num_filters = _transition_block(net, num_filters, 154 | compression=compression, 155 | scope='transition_block' + str(i+1)) 156 | 157 | net, num_filters = _dense_block( 158 | net, num_layers[-1], num_filters, 159 | growth_rate, 160 | scope='dense_block' + str(num_dense_blocks)) 161 | 162 | # final blocks 163 | with tf.variable_scope('final_block', [inputs]): 164 | net = slim.batch_norm(net) 165 | net = tf.nn.relu(net) 166 | net = _global_avg_pool2d(net, scope='global_avg_pool') 167 | 168 | net = slim.conv2d(net, num_classes, 1, 169 | biases_initializer=tf.zeros_initializer(), 170 | scope='logits') 171 | 172 | end_points = slim.utils.convert_collection_to_dict( 173 | end_points_collection) 174 | 175 | if num_classes is not None: 176 | end_points['predictions'] = slim.softmax(net, scope='predictions') 177 | 178 | return net, end_points 179 | 180 | 181 | def densenet121(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): 182 | return densenet(inputs, 183 | num_classes=num_classes, 184 | reduction=0.5, 185 | growth_rate=32, 186 | num_filters=64, 187 | num_layers=[6,12,24,16], 188 | data_format=data_format, 189 | is_training=is_training, 190 | reuse=reuse, 191 | scope='densenet121') 192 | densenet121.default_image_size = 224 193 | 194 | 195 | def densenet161(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): 196 | return densenet(inputs, 197 | num_classes=num_classes, 198 | reduction=0.5, 199 | growth_rate=48, 200 | num_filters=96, 201 | num_layers=[6,12,36,24], 202 | data_format=data_format, 203 | is_training=is_training, 204 | reuse=reuse, 205 | scope='densenet161') 206 | densenet161.default_image_size = 224 207 | 208 | 209 | def densenet169(inputs, num_classes=1000, data_format='NHWC', is_training=True, reuse=None): 210 | return densenet(inputs, 211 | num_classes=num_classes, 212 | reduction=0.5, 213 | growth_rate=32, 214 | num_filters=64, 215 | num_layers=[6,12,32,32], 216 | data_format=data_format, 217 | is_training=is_training, 218 | reuse=reuse, 219 | scope='densenet169') 220 | densenet169.default_image_size = 224 221 | 222 | 223 | def densenet_arg_scope(weight_decay=1e-4, 224 | batch_norm_decay=0.99, 225 | batch_norm_epsilon=1.1e-5, 226 | data_format='NHWC'): 227 | with slim.arg_scope([slim.conv2d, slim.batch_norm, slim.avg_pool2d, slim.max_pool2d, 228 | _conv_block, _global_avg_pool2d], 229 | data_format=data_format): 230 | with slim.arg_scope([slim.conv2d], 231 | weights_regularizer=slim.l2_regularizer(weight_decay), 232 | activation_fn=None, 233 | biases_initializer=None): 234 | with slim.arg_scope([slim.batch_norm], 235 | scale=True, 236 | decay=batch_norm_decay, 237 | epsilon=batch_norm_epsilon) as scope: 238 | return scope 239 | 240 | 241 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import densenet 25 | 26 | slim = tf.contrib.slim 27 | 28 | networks_map = { 29 | 'densenet121': densenet.densenet121, 30 | 'densenet161': densenet.densenet161, 31 | 'densenet169': densenet.densenet169, 32 | } 33 | 34 | arg_scopes_map = { 35 | 'densenet121': densenet.densenet_arg_scope, 36 | 'densenet161': densenet.densenet_arg_scope, 37 | 'densenet169': densenet.densenet_arg_scope, 38 | } 39 | 40 | 41 | def get_network_fn(name, num_classes, weight_decay=0.0, data_format='NHWC', 42 | is_training=False): 43 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 44 | 45 | Args: 46 | name: The name of the network. 47 | num_classes: The number of classes to use for classification. 48 | weight_decay: The l2 coefficient for the model weights. 49 | is_training: `True` if the model is being used for training and `False` 50 | otherwise. 51 | 52 | Returns: 53 | network_fn: A function that applies the model to a batch of images. It has 54 | the following signature: 55 | logits, end_points = network_fn(images) 56 | Raises: 57 | ValueError: If network `name` is not recognized. 58 | """ 59 | if name not in networks_map: 60 | raise ValueError('Name of network unknown %s' % name) 61 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay, data_format=data_format) 62 | func = networks_map[name] 63 | @functools.wraps(func) 64 | def network_fn(images): 65 | with slim.arg_scope(arg_scope): 66 | return func(images, num_classes, data_format=data_format, is_training=is_training) 67 | if hasattr(func, 'default_image_size'): 68 | network_fn.default_image_size = func.default_image_size 69 | 70 | return network_fn 71 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pudae/tensorflow-densenet/d664b5ceab05466cbf36b4dd95602437cc61e0a8/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/densenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images. 16 | 17 | The preprocessing steps for VGG were introduced in the following technical 18 | report: 19 | 20 | Very Deep Convolutional Networks For Large-Scale Image Recognition 21 | Karen Simonyan and Andrew Zisserman 22 | arXiv technical report, 2015 23 | PDF: http://arxiv.org/pdf/1409.1556.pdf 24 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 25 | CC-BY-4.0 26 | 27 | More information can be obtained from the VGG website: 28 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 29 | """ 30 | 31 | from __future__ import absolute_import 32 | from __future__ import division 33 | from __future__ import print_function 34 | 35 | import tensorflow as tf 36 | 37 | slim = tf.contrib.slim 38 | 39 | _R_MEAN = 123.68 40 | _G_MEAN = 116.78 41 | _B_MEAN = 103.94 42 | 43 | _SCALE_FACTOR = 0.017 44 | 45 | _RESIZE_SIDE_MIN = 256 46 | _RESIZE_SIDE_MAX = 512 47 | 48 | 49 | def _crop(image, offset_height, offset_width, crop_height, crop_width): 50 | """Crops the given image using the provided offsets and sizes. 51 | 52 | Note that the method doesn't assume we know the input image size but it does 53 | assume we know the input image rank. 54 | 55 | Args: 56 | image: an image of shape [height, width, channels]. 57 | offset_height: a scalar tensor indicating the height offset. 58 | offset_width: a scalar tensor indicating the width offset. 59 | crop_height: the height of the cropped image. 60 | crop_width: the width of the cropped image. 61 | 62 | Returns: 63 | the cropped (and resized) image. 64 | 65 | Raises: 66 | InvalidArgumentError: if the rank is not 3 or if the image dimensions are 67 | less than the crop size. 68 | """ 69 | original_shape = tf.shape(image) 70 | 71 | rank_assertion = tf.Assert( 72 | tf.equal(tf.rank(image), 3), 73 | ['Rank of image must be equal to 3.']) 74 | with tf.control_dependencies([rank_assertion]): 75 | cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]]) 76 | 77 | size_assertion = tf.Assert( 78 | tf.logical_and( 79 | tf.greater_equal(original_shape[0], crop_height), 80 | tf.greater_equal(original_shape[1], crop_width)), 81 | ['Crop size greater than the image size.']) 82 | 83 | offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) 84 | 85 | # Use tf.slice instead of crop_to_bounding box as it accepts tensors to 86 | # define the crop size. 87 | with tf.control_dependencies([size_assertion]): 88 | image = tf.slice(image, offsets, cropped_shape) 89 | return tf.reshape(image, cropped_shape) 90 | 91 | 92 | def _random_crop(image_list, crop_height, crop_width): 93 | """Crops the given list of images. 94 | 95 | The function applies the same crop to each image in the list. This can be 96 | effectively applied when there are multiple image inputs of the same 97 | dimension such as: 98 | 99 | image, depths, normals = _random_crop([image, depths, normals], 120, 150) 100 | 101 | Args: 102 | image_list: a list of image tensors of the same dimension but possibly 103 | varying channel. 104 | crop_height: the new height. 105 | crop_width: the new width. 106 | 107 | Returns: 108 | the image_list with cropped images. 109 | 110 | Raises: 111 | ValueError: if there are multiple image inputs provided with different size 112 | or the images are smaller than the crop dimensions. 113 | """ 114 | if not image_list: 115 | raise ValueError('Empty image_list.') 116 | 117 | # Compute the rank assertions. 118 | rank_assertions = [] 119 | for i in range(len(image_list)): 120 | image_rank = tf.rank(image_list[i]) 121 | rank_assert = tf.Assert( 122 | tf.equal(image_rank, 3), 123 | ['Wrong rank for tensor %s [expected] [actual]', 124 | image_list[i].name, 3, image_rank]) 125 | rank_assertions.append(rank_assert) 126 | 127 | with tf.control_dependencies([rank_assertions[0]]): 128 | image_shape = tf.shape(image_list[0]) 129 | image_height = image_shape[0] 130 | image_width = image_shape[1] 131 | crop_size_assert = tf.Assert( 132 | tf.logical_and( 133 | tf.greater_equal(image_height, crop_height), 134 | tf.greater_equal(image_width, crop_width)), 135 | ['Crop size greater than the image size.']) 136 | 137 | asserts = [rank_assertions[0], crop_size_assert] 138 | 139 | for i in range(1, len(image_list)): 140 | image = image_list[i] 141 | asserts.append(rank_assertions[i]) 142 | with tf.control_dependencies([rank_assertions[i]]): 143 | shape = tf.shape(image) 144 | height = shape[0] 145 | width = shape[1] 146 | 147 | height_assert = tf.Assert( 148 | tf.equal(height, image_height), 149 | ['Wrong height for tensor %s [expected][actual]', 150 | image.name, height, image_height]) 151 | width_assert = tf.Assert( 152 | tf.equal(width, image_width), 153 | ['Wrong width for tensor %s [expected][actual]', 154 | image.name, width, image_width]) 155 | asserts.extend([height_assert, width_assert]) 156 | 157 | # Create a random bounding box. 158 | # 159 | # Use tf.random_uniform and not numpy.random.rand as doing the former would 160 | # generate random numbers at graph eval time, unlike the latter which 161 | # generates random numbers at graph definition time. 162 | with tf.control_dependencies(asserts): 163 | max_offset_height = tf.reshape(image_height - crop_height + 1, []) 164 | with tf.control_dependencies(asserts): 165 | max_offset_width = tf.reshape(image_width - crop_width + 1, []) 166 | offset_height = tf.random_uniform( 167 | [], maxval=max_offset_height, dtype=tf.int32) 168 | offset_width = tf.random_uniform( 169 | [], maxval=max_offset_width, dtype=tf.int32) 170 | 171 | return [_crop(image, offset_height, offset_width, 172 | crop_height, crop_width) for image in image_list] 173 | 174 | 175 | def _central_crop(image_list, crop_height, crop_width): 176 | """Performs central crops of the given image list. 177 | 178 | Args: 179 | image_list: a list of image tensors of the same dimension but possibly 180 | varying channel. 181 | crop_height: the height of the image following the crop. 182 | crop_width: the width of the image following the crop. 183 | 184 | Returns: 185 | the list of cropped images. 186 | """ 187 | outputs = [] 188 | for image in image_list: 189 | image_height = tf.shape(image)[0] 190 | image_width = tf.shape(image)[1] 191 | 192 | offset_height = (image_height - crop_height) / 2 193 | offset_width = (image_width - crop_width) / 2 194 | 195 | outputs.append(_crop(image, offset_height, offset_width, 196 | crop_height, crop_width)) 197 | return outputs 198 | 199 | 200 | def _mean_image_subtraction(image, means): 201 | """Subtracts the given means from each image channel. 202 | 203 | For example: 204 | means = [123.68, 116.779, 103.939] 205 | image = _mean_image_subtraction(image, means) 206 | 207 | Note that the rank of `image` must be known. 208 | 209 | Args: 210 | image: a tensor of size [height, width, C]. 211 | means: a C-vector of values to subtract from each channel. 212 | 213 | Returns: 214 | the centered image. 215 | 216 | Raises: 217 | ValueError: If the rank of `image` is unknown, if `image` has a rank other 218 | than three or if the number of channels in `image` doesn't match the 219 | number of values in `means`. 220 | """ 221 | if image.get_shape().ndims != 3: 222 | raise ValueError('Input must be of size [height, width, C>0]') 223 | num_channels = image.get_shape().as_list()[-1] 224 | if len(means) != num_channels: 225 | raise ValueError('len(means) must match the number of channels') 226 | 227 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) 228 | for i in range(num_channels): 229 | channels[i] -= means[i] 230 | return tf.concat(axis=2, values=channels) 231 | 232 | 233 | def _smallest_size_at_least(height, width, smallest_side): 234 | """Computes new shape with the smallest side equal to `smallest_side`. 235 | 236 | Computes new shape with the smallest side equal to `smallest_side` while 237 | preserving the original aspect ratio. 238 | 239 | Args: 240 | height: an int32 scalar tensor indicating the current height. 241 | width: an int32 scalar tensor indicating the current width. 242 | smallest_side: A python integer or scalar `Tensor` indicating the size of 243 | the smallest side after resize. 244 | 245 | Returns: 246 | new_height: an int32 scalar tensor indicating the new height. 247 | new_width: and int32 scalar tensor indicating the new width. 248 | """ 249 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 250 | 251 | height = tf.to_float(height) 252 | width = tf.to_float(width) 253 | smallest_side = tf.to_float(smallest_side) 254 | 255 | scale = tf.cond(tf.greater(height, width), 256 | lambda: smallest_side / width, 257 | lambda: smallest_side / height) 258 | new_height = tf.to_int32(height * scale) 259 | new_width = tf.to_int32(width * scale) 260 | return new_height, new_width 261 | 262 | 263 | def _aspect_preserving_resize(image, smallest_side): 264 | """Resize images preserving the original aspect ratio. 265 | 266 | Args: 267 | image: A 3-D image `Tensor`. 268 | smallest_side: A python integer or scalar `Tensor` indicating the size of 269 | the smallest side after resize. 270 | 271 | Returns: 272 | resized_image: A 3-D tensor containing the resized image. 273 | """ 274 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 275 | 276 | shape = tf.shape(image) 277 | height = shape[0] 278 | width = shape[1] 279 | new_height, new_width = _smallest_size_at_least(height, width, smallest_side) 280 | image = tf.expand_dims(image, 0) 281 | resized_image = tf.image.resize_bilinear(image, [new_height, new_width], 282 | align_corners=False) 283 | resized_image = tf.squeeze(resized_image) 284 | resized_image.set_shape([None, None, 3]) 285 | return resized_image 286 | 287 | 288 | def preprocess_for_train(image, 289 | output_height, 290 | output_width, 291 | resize_side_min=_RESIZE_SIDE_MIN, 292 | resize_side_max=_RESIZE_SIDE_MAX): 293 | """Preprocesses the given image for training. 294 | 295 | Note that the actual resizing scale is sampled from 296 | [`resize_size_min`, `resize_size_max`]. 297 | 298 | Args: 299 | image: A `Tensor` representing an image of arbitrary size. 300 | output_height: The height of the image after preprocessing. 301 | output_width: The width of the image after preprocessing. 302 | resize_side_min: The lower bound for the smallest side of the image for 303 | aspect-preserving resizing. 304 | resize_side_max: The upper bound for the smallest side of the image for 305 | aspect-preserving resizing. 306 | 307 | Returns: 308 | A preprocessed image. 309 | """ 310 | resize_side = tf.random_uniform( 311 | [], minval=resize_side_min, maxval=resize_side_max+1, dtype=tf.int32) 312 | 313 | image = _aspect_preserving_resize(image, resize_side) 314 | image = _random_crop([image], output_height, output_width)[0] 315 | image.set_shape([output_height, output_width, 3]) 316 | image = tf.to_float(image) 317 | image = tf.image.random_flip_left_right(image) 318 | 319 | image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 320 | return image * _SCALE_FACTOR 321 | 322 | 323 | def preprocess_for_eval(image, output_height, output_width, resize_side): 324 | """Preprocesses the given image for evaluation. 325 | 326 | Args: 327 | image: A `Tensor` representing an image of arbitrary size. 328 | output_height: The height of the image after preprocessing. 329 | output_width: The width of the image after preprocessing. 330 | resize_side: The smallest side of the image for aspect-preserving resizing. 331 | 332 | Returns: 333 | A preprocessed image. 334 | """ 335 | image = _aspect_preserving_resize(image, resize_side) 336 | image = _central_crop([image], output_height, output_width)[0] 337 | image.set_shape([output_height, output_width, 3]) 338 | image = tf.to_float(image) 339 | 340 | image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 341 | return image * _SCALE_FACTOR 342 | 343 | 344 | def preprocess_image(image, output_height, output_width, is_training=False, 345 | resize_side_min=_RESIZE_SIDE_MIN, 346 | resize_side_max=_RESIZE_SIDE_MAX): 347 | """Preprocesses the given image. 348 | 349 | Args: 350 | image: A `Tensor` representing an image of arbitrary size. 351 | output_height: The height of the image after preprocessing. 352 | output_width: The width of the image after preprocessing. 353 | is_training: `True` if we're preprocessing the image for training and 354 | `False` otherwise. 355 | resize_side_min: The lower bound for the smallest side of the image for 356 | aspect-preserving resizing. If `is_training` is `False`, then this value 357 | is used for rescaling. 358 | resize_side_max: The upper bound for the smallest side of the image for 359 | aspect-preserving resizing. If `is_training` is `False`, this value is 360 | ignored. Otherwise, the resize side is sampled from 361 | [resize_size_min, resize_size_max]. 362 | 363 | Returns: 364 | A preprocessed image. 365 | """ 366 | if is_training: 367 | return preprocess_for_train(image, output_height, output_width, 368 | resize_side_min, resize_side_max) 369 | else: 370 | return preprocess_for_eval(image, output_height, output_width, 371 | resize_side_min) 372 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import densenet_preprocessing 24 | 25 | slim = tf.contrib.slim 26 | 27 | 28 | def get_preprocessing(name, is_training=False): 29 | """Returns preprocessing_fn(image, height, width, **kwargs). 30 | 31 | Args: 32 | name: The name of the preprocessing function. 33 | is_training: `True` if the model is being used for training and `False` 34 | otherwise. 35 | 36 | Returns: 37 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 38 | It has the following signature: 39 | image = preprocessing_fn(image, output_height, output_width, ...). 40 | 41 | Raises: 42 | ValueError: If Preprocessing `name` is not recognized. 43 | """ 44 | preprocessing_fn_map = { 45 | 'densenet121': densenet_preprocessing, 46 | 'densenet161': densenet_preprocessing, 47 | 'densenet169': densenet_preprocessing, 48 | } 49 | 50 | if name not in preprocessing_fn_map: 51 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 52 | 53 | def preprocessing_fn(image, output_height, output_width, **kwargs): 54 | return preprocessing_fn_map[name].preprocess_image( 55 | image, output_height, output_width, is_training=is_training, **kwargs) 56 | 57 | return preprocessing_fn 58 | -------------------------------------------------------------------------------- /train_image_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Generic training script that trains a model using a given dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from datasets import dataset_factory 24 | from deployment import model_deploy 25 | from nets import nets_factory 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | tf.app.flags.DEFINE_string( 31 | 'master', '', 'The address of the TensorFlow master to use.') 32 | 33 | tf.app.flags.DEFINE_string( 34 | 'train_dir', '/tmp/tfmodel/', 35 | 'Directory where checkpoints and event logs are written to.') 36 | 37 | tf.app.flags.DEFINE_integer('num_clones', 1, 38 | 'Number of model clones to deploy.') 39 | 40 | tf.app.flags.DEFINE_boolean('clone_on_cpu', False, 41 | 'Use CPUs to deploy clones.') 42 | 43 | tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.') 44 | 45 | tf.app.flags.DEFINE_integer( 46 | 'num_ps_tasks', 0, 47 | 'The number of parameter servers. If the value is 0, then the parameters ' 48 | 'are handled locally by the worker.') 49 | 50 | tf.app.flags.DEFINE_integer( 51 | 'num_readers', 4, 52 | 'The number of parallel readers that read data from the dataset.') 53 | 54 | tf.app.flags.DEFINE_integer( 55 | 'num_preprocessing_threads', 4, 56 | 'The number of threads used to create the batches.') 57 | 58 | tf.app.flags.DEFINE_integer( 59 | 'log_every_n_steps', 10, 60 | 'The frequency with which logs are print.') 61 | 62 | tf.app.flags.DEFINE_integer( 63 | 'save_summaries_secs', 600, 64 | 'The frequency with which summaries are saved, in seconds.') 65 | 66 | tf.app.flags.DEFINE_integer( 67 | 'save_interval_secs', 600, 68 | 'The frequency with which the model is saved, in seconds.') 69 | 70 | tf.app.flags.DEFINE_integer( 71 | 'task', 0, 'Task id of the replica running the training.') 72 | 73 | ###################### 74 | # Optimization Flags # 75 | ###################### 76 | 77 | tf.app.flags.DEFINE_float( 78 | 'weight_decay', 0.00004, 'The weight decay on the model weights.') 79 | 80 | tf.app.flags.DEFINE_string( 81 | 'optimizer', 'rmsprop', 82 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",' 83 | '"ftrl", "momentum", "sgd" or "rmsprop".') 84 | 85 | tf.app.flags.DEFINE_float( 86 | 'adadelta_rho', 0.95, 87 | 'The decay rate for adadelta.') 88 | 89 | tf.app.flags.DEFINE_float( 90 | 'adagrad_initial_accumulator_value', 0.1, 91 | 'Starting value for the AdaGrad accumulators.') 92 | 93 | tf.app.flags.DEFINE_float( 94 | 'adam_beta1', 0.9, 95 | 'The exponential decay rate for the 1st moment estimates.') 96 | 97 | tf.app.flags.DEFINE_float( 98 | 'adam_beta2', 0.999, 99 | 'The exponential decay rate for the 2nd moment estimates.') 100 | 101 | tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.') 102 | 103 | tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5, 104 | 'The learning rate power.') 105 | 106 | tf.app.flags.DEFINE_float( 107 | 'ftrl_initial_accumulator_value', 0.1, 108 | 'Starting value for the FTRL accumulators.') 109 | 110 | tf.app.flags.DEFINE_float( 111 | 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.') 112 | 113 | tf.app.flags.DEFINE_float( 114 | 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.') 115 | 116 | tf.app.flags.DEFINE_float( 117 | 'momentum', 0.9, 118 | 'The momentum for the MomentumOptimizer and RMSPropOptimizer.') 119 | 120 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') 121 | 122 | ####################### 123 | # Learning Rate Flags # 124 | ####################### 125 | 126 | tf.app.flags.DEFINE_string( 127 | 'learning_rate_decay_type', 128 | 'exponential', 129 | 'Specifies how the learning rate is decayed. One of "fixed", "exponential",' 130 | ' or "polynomial"') 131 | 132 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 133 | 134 | tf.app.flags.DEFINE_float( 135 | 'end_learning_rate', 0.0001, 136 | 'The minimal end learning rate used by a polynomial decay learning rate.') 137 | 138 | tf.app.flags.DEFINE_float( 139 | 'label_smoothing', 0.0, 'The amount of label smoothing.') 140 | 141 | tf.app.flags.DEFINE_float( 142 | 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.') 143 | 144 | tf.app.flags.DEFINE_float( 145 | 'num_epochs_per_decay', 2.0, 146 | 'Number of epochs after which learning rate decays.') 147 | 148 | tf.app.flags.DEFINE_bool( 149 | 'sync_replicas', False, 150 | 'Whether or not to synchronize the replicas during training.') 151 | 152 | tf.app.flags.DEFINE_integer( 153 | 'replicas_to_aggregate', 1, 154 | 'The Number of gradients to collect before updating params.') 155 | 156 | tf.app.flags.DEFINE_float( 157 | 'moving_average_decay', None, 158 | 'The decay to use for the moving average.' 159 | 'If left as None, then moving averages are not used.') 160 | 161 | ####################### 162 | # Dataset Flags # 163 | ####################### 164 | 165 | tf.app.flags.DEFINE_string( 166 | 'dataset_name', 'imagenet', 'The name of the dataset to load.') 167 | 168 | tf.app.flags.DEFINE_string( 169 | 'dataset_split_name', 'train', 'The name of the train/test split.') 170 | 171 | tf.app.flags.DEFINE_string( 172 | 'dataset_dir', None, 'The directory where the dataset files are stored.') 173 | 174 | tf.app.flags.DEFINE_integer( 175 | 'labels_offset', 0, 176 | 'An offset for the labels in the dataset. This flag is primarily used to ' 177 | 'evaluate the VGG and ResNet architectures which do not use a background ' 178 | 'class for the ImageNet dataset.') 179 | 180 | tf.app.flags.DEFINE_string( 181 | 'model_name', 'densenet121', 'The name of the architecture to train.') 182 | 183 | tf.app.flags.DEFINE_string( 184 | 'data_format', 'NHWC', 'The structure of the Tensor. NHWC or NCHW.') 185 | 186 | tf.app.flags.DEFINE_string( 187 | 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 188 | 'as `None`, then the model_name flag is used.') 189 | 190 | tf.app.flags.DEFINE_integer( 191 | 'batch_size', 32, 'The number of samples in each batch.') 192 | 193 | tf.app.flags.DEFINE_integer( 194 | 'train_image_size', None, 'Train image size') 195 | 196 | tf.app.flags.DEFINE_integer('max_number_of_steps', None, 197 | 'The maximum number of training steps.') 198 | 199 | ##################### 200 | # Fine-Tuning Flags # 201 | ##################### 202 | 203 | tf.app.flags.DEFINE_string( 204 | 'checkpoint_path', None, 205 | 'The path to a checkpoint from which to fine-tune.') 206 | 207 | tf.app.flags.DEFINE_string( 208 | 'checkpoint_exclude_scopes', None, 209 | 'Comma-separated list of scopes of variables to exclude when restoring ' 210 | 'from a checkpoint.') 211 | 212 | tf.app.flags.DEFINE_string( 213 | 'trainable_scopes', None, 214 | 'Comma-separated list of scopes to filter the set of variables to train.' 215 | 'By default, None would train all the variables.') 216 | 217 | tf.app.flags.DEFINE_boolean( 218 | 'ignore_missing_vars', False, 219 | 'When restoring a checkpoint would ignore missing variables.') 220 | 221 | FLAGS = tf.app.flags.FLAGS 222 | 223 | 224 | def _configure_learning_rate(num_samples_per_epoch, global_step): 225 | """Configures the learning rate. 226 | 227 | Args: 228 | num_samples_per_epoch: The number of samples in each epoch of training. 229 | global_step: The global_step tensor. 230 | 231 | Returns: 232 | A `Tensor` representing the learning rate. 233 | 234 | Raises: 235 | ValueError: if 236 | """ 237 | decay_steps = int(num_samples_per_epoch / FLAGS.batch_size * 238 | FLAGS.num_epochs_per_decay) 239 | if FLAGS.sync_replicas: 240 | decay_steps /= FLAGS.replicas_to_aggregate 241 | 242 | if FLAGS.learning_rate_decay_type == 'exponential': 243 | return tf.train.exponential_decay(FLAGS.learning_rate, 244 | global_step, 245 | decay_steps, 246 | FLAGS.learning_rate_decay_factor, 247 | staircase=True, 248 | name='exponential_decay_learning_rate') 249 | elif FLAGS.learning_rate_decay_type == 'fixed': 250 | return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate') 251 | elif FLAGS.learning_rate_decay_type == 'polynomial': 252 | return tf.train.polynomial_decay(FLAGS.learning_rate, 253 | global_step, 254 | decay_steps, 255 | FLAGS.end_learning_rate, 256 | power=1.0, 257 | cycle=False, 258 | name='polynomial_decay_learning_rate') 259 | else: 260 | raise ValueError('learning_rate_decay_type [%s] was not recognized', 261 | FLAGS.learning_rate_decay_type) 262 | 263 | 264 | def _configure_optimizer(learning_rate): 265 | """Configures the optimizer used for training. 266 | 267 | Args: 268 | learning_rate: A scalar or `Tensor` learning rate. 269 | 270 | Returns: 271 | An instance of an optimizer. 272 | 273 | Raises: 274 | ValueError: if FLAGS.optimizer is not recognized. 275 | """ 276 | if FLAGS.optimizer == 'adadelta': 277 | optimizer = tf.train.AdadeltaOptimizer( 278 | learning_rate, 279 | rho=FLAGS.adadelta_rho, 280 | epsilon=FLAGS.opt_epsilon) 281 | elif FLAGS.optimizer == 'adagrad': 282 | optimizer = tf.train.AdagradOptimizer( 283 | learning_rate, 284 | initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value) 285 | elif FLAGS.optimizer == 'adam': 286 | optimizer = tf.train.AdamOptimizer( 287 | learning_rate, 288 | beta1=FLAGS.adam_beta1, 289 | beta2=FLAGS.adam_beta2, 290 | epsilon=FLAGS.opt_epsilon) 291 | elif FLAGS.optimizer == 'ftrl': 292 | optimizer = tf.train.FtrlOptimizer( 293 | learning_rate, 294 | learning_rate_power=FLAGS.ftrl_learning_rate_power, 295 | initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value, 296 | l1_regularization_strength=FLAGS.ftrl_l1, 297 | l2_regularization_strength=FLAGS.ftrl_l2) 298 | elif FLAGS.optimizer == 'momentum': 299 | optimizer = tf.train.MomentumOptimizer( 300 | learning_rate, 301 | momentum=FLAGS.momentum, 302 | name='Momentum') 303 | elif FLAGS.optimizer == 'rmsprop': 304 | optimizer = tf.train.RMSPropOptimizer( 305 | learning_rate, 306 | decay=FLAGS.rmsprop_decay, 307 | momentum=FLAGS.momentum, 308 | epsilon=FLAGS.opt_epsilon) 309 | elif FLAGS.optimizer == 'sgd': 310 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 311 | else: 312 | raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer) 313 | return optimizer 314 | 315 | def _get_init_fn(): 316 | """Returns a function run by the chief worker to warm-start the training. 317 | 318 | Note that the init_fn is only run when initializing the model during the very 319 | first global step. 320 | 321 | Returns: 322 | An init function run by the supervisor. 323 | """ 324 | if FLAGS.checkpoint_path is None: 325 | return None 326 | 327 | # Warn the user if a checkpoint exists in the train_dir. Then we'll be 328 | # ignoring the checkpoint anyway. 329 | if tf.train.latest_checkpoint(FLAGS.train_dir): 330 | tf.logging.info( 331 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s' 332 | % FLAGS.train_dir) 333 | return None 334 | 335 | exclusions = [] 336 | if FLAGS.checkpoint_exclude_scopes: 337 | exclusions = [scope.strip() 338 | for scope in FLAGS.checkpoint_exclude_scopes.split(',')] 339 | 340 | # TODO(sguada) variables.filter_variables() 341 | variables_to_restore = [] 342 | for var in slim.get_model_variables(): 343 | excluded = False 344 | for exclusion in exclusions: 345 | if var.op.name.startswith(exclusion): 346 | excluded = True 347 | break 348 | if not excluded: 349 | variables_to_restore.append(var) 350 | 351 | if tf.gfile.IsDirectory(FLAGS.checkpoint_path): 352 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 353 | else: 354 | checkpoint_path = FLAGS.checkpoint_path 355 | 356 | tf.logging.info('Fine-tuning from %s' % checkpoint_path) 357 | 358 | return slim.assign_from_checkpoint_fn( 359 | checkpoint_path, 360 | variables_to_restore, 361 | ignore_missing_vars=FLAGS.ignore_missing_vars) 362 | 363 | 364 | def _get_variables_to_train(): 365 | """Returns a list of variables to train. 366 | 367 | Returns: 368 | A list of variables to train by the optimizer. 369 | """ 370 | if FLAGS.trainable_scopes is None: 371 | return tf.trainable_variables() 372 | else: 373 | scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')] 374 | 375 | variables_to_train = [] 376 | for scope in scopes: 377 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 378 | variables_to_train.extend(variables) 379 | return variables_to_train 380 | 381 | 382 | def main(_): 383 | if not FLAGS.dataset_dir: 384 | raise ValueError('You must supply the dataset directory with --dataset_dir') 385 | 386 | tf.logging.set_verbosity(tf.logging.INFO) 387 | with tf.Graph().as_default(): 388 | ####################### 389 | # Config model_deploy # 390 | ####################### 391 | deploy_config = model_deploy.DeploymentConfig( 392 | num_clones=FLAGS.num_clones, 393 | clone_on_cpu=FLAGS.clone_on_cpu, 394 | replica_id=FLAGS.task, 395 | num_replicas=FLAGS.worker_replicas, 396 | num_ps_tasks=FLAGS.num_ps_tasks) 397 | 398 | # Create global_step 399 | with tf.device(deploy_config.variables_device()): 400 | global_step = slim.create_global_step() 401 | 402 | ###################### 403 | # Select the dataset # 404 | ###################### 405 | dataset = dataset_factory.get_dataset( 406 | FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 407 | 408 | ###################### 409 | # Select the network # 410 | ###################### 411 | network_fn = nets_factory.get_network_fn( 412 | FLAGS.model_name, 413 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 414 | weight_decay=FLAGS.weight_decay, 415 | data_format=FLAGS.data_format, 416 | is_training=True) 417 | 418 | ##################################### 419 | # Select the preprocessing function # 420 | ##################################### 421 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 422 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 423 | preprocessing_name, 424 | is_training=True) 425 | 426 | ############################################################## 427 | # Create a dataset provider that loads data from the dataset # 428 | ############################################################## 429 | with tf.device(deploy_config.inputs_device()): 430 | provider = slim.dataset_data_provider.DatasetDataProvider( 431 | dataset, 432 | num_readers=FLAGS.num_readers, 433 | common_queue_capacity=20 * FLAGS.batch_size, 434 | common_queue_min=10 * FLAGS.batch_size) 435 | [image, label] = provider.get(['image', 'label']) 436 | label -= FLAGS.labels_offset 437 | 438 | train_image_size = FLAGS.train_image_size or network_fn.default_image_size 439 | 440 | image = image_preprocessing_fn(image, train_image_size, train_image_size) 441 | 442 | images, labels = tf.train.batch( 443 | [image, label], 444 | batch_size=FLAGS.batch_size, 445 | num_threads=FLAGS.num_preprocessing_threads, 446 | capacity=5 * FLAGS.batch_size) 447 | labels = slim.one_hot_encoding( 448 | labels, dataset.num_classes - FLAGS.labels_offset) 449 | batch_queue = slim.prefetch_queue.prefetch_queue( 450 | [images, labels], capacity=2 * deploy_config.num_clones) 451 | 452 | #################### 453 | # Define the model # 454 | #################### 455 | def clone_fn(batch_queue): 456 | """Allows data parallelism by creating multiple clones of network_fn.""" 457 | with tf.device(deploy_config.inputs_device()): 458 | images, labels = batch_queue.dequeue() 459 | logits, end_points = network_fn(images) 460 | logits = tf.squeeze(logits) 461 | 462 | ############################# 463 | # Specify the loss function # 464 | ############################# 465 | if 'AuxLogits' in end_points: 466 | tf.losses.softmax_cross_entropy( 467 | logits=end_points['AuxLogits'], onehot_labels=labels, 468 | label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') 469 | tf.losses.softmax_cross_entropy( 470 | logits=logits, onehot_labels=labels, 471 | label_smoothing=FLAGS.label_smoothing, weights=1.0) 472 | return end_points 473 | 474 | # Gather initial summaries. 475 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 476 | 477 | clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) 478 | first_clone_scope = deploy_config.clone_scope(0) 479 | # Gather update_ops from the first clone. These contain, for example, 480 | # the updates for the batch_norm variables created by network_fn. 481 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) 482 | 483 | # Add summaries for end_points. 484 | end_points = clones[0].outputs 485 | for end_point in end_points: 486 | x = end_points[end_point] 487 | summaries.add(tf.summary.histogram('activations/' + end_point, x)) 488 | summaries.add(tf.summary.scalar('sparsity/' + end_point, 489 | tf.nn.zero_fraction(x))) 490 | 491 | # Add summaries for losses. 492 | for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): 493 | summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) 494 | 495 | # Add summaries for variables. 496 | for variable in slim.get_model_variables(): 497 | summaries.add(tf.summary.histogram(variable.op.name, variable)) 498 | 499 | ################################# 500 | # Configure the moving averages # 501 | ################################# 502 | if FLAGS.moving_average_decay: 503 | moving_average_variables = slim.get_model_variables() 504 | variable_averages = tf.train.ExponentialMovingAverage( 505 | FLAGS.moving_average_decay, global_step) 506 | else: 507 | moving_average_variables, variable_averages = None, None 508 | 509 | ######################################### 510 | # Configure the optimization procedure. # 511 | ######################################### 512 | with tf.device(deploy_config.optimizer_device()): 513 | learning_rate = _configure_learning_rate(dataset.num_samples, global_step) 514 | optimizer = _configure_optimizer(learning_rate) 515 | summaries.add(tf.summary.scalar('learning_rate', learning_rate)) 516 | 517 | if FLAGS.sync_replicas: 518 | # If sync_replicas is enabled, the averaging will be done in the chief 519 | # queue runner. 520 | optimizer = tf.train.SyncReplicasOptimizer( 521 | opt=optimizer, 522 | replicas_to_aggregate=FLAGS.replicas_to_aggregate, 523 | variable_averages=variable_averages, 524 | variables_to_average=moving_average_variables, 525 | replica_id=tf.constant(FLAGS.task, tf.int32, shape=()), 526 | total_num_replicas=FLAGS.worker_replicas) 527 | elif FLAGS.moving_average_decay: 528 | # Update ops executed locally by trainer. 529 | update_ops.append(variable_averages.apply(moving_average_variables)) 530 | 531 | # Variables to train. 532 | variables_to_train = _get_variables_to_train() 533 | 534 | # and returns a train_tensor and summary_op 535 | total_loss, clones_gradients = model_deploy.optimize_clones( 536 | clones, 537 | optimizer, 538 | var_list=variables_to_train) 539 | # Add total_loss to summary. 540 | summaries.add(tf.summary.scalar('total_loss', total_loss)) 541 | 542 | # Create gradient updates. 543 | grad_updates = optimizer.apply_gradients(clones_gradients, 544 | global_step=global_step) 545 | update_ops.append(grad_updates) 546 | 547 | update_op = tf.group(*update_ops) 548 | with tf.control_dependencies([update_op]): 549 | train_tensor = tf.identity(total_loss, name='train_op') 550 | 551 | # Add the summaries from the first clone. These contain the summaries 552 | # created by model_fn and either optimize_clones() or _gather_clone_loss(). 553 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, 554 | first_clone_scope)) 555 | 556 | # Merge all summaries together. 557 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 558 | 559 | 560 | ########################### 561 | # Kicks off the training. # 562 | ########################### 563 | slim.learning.train( 564 | train_tensor, 565 | logdir=FLAGS.train_dir, 566 | master=FLAGS.master, 567 | is_chief=(FLAGS.task == 0), 568 | init_fn=_get_init_fn(), 569 | summary_op=summary_op, 570 | number_of_steps=FLAGS.max_number_of_steps, 571 | log_every_n_steps=FLAGS.log_every_n_steps, 572 | save_summaries_secs=FLAGS.save_summaries_secs, 573 | save_interval_secs=FLAGS.save_interval_secs, 574 | sync_optimizer=optimizer if FLAGS.sync_replicas else None) 575 | 576 | 577 | if __name__ == '__main__': 578 | tf.app.run() 579 | --------------------------------------------------------------------------------