├── .gitignore ├── .pylintrc ├── LICENSES.md ├── README.md ├── datasets ├── __init__.py ├── dataset_factory.py ├── dataset_utils.py ├── download_and_convert_pascal.py └── pascal.py ├── deployment ├── __init__.py ├── model_deploy.py └── model_deploy_test.py ├── eval.py ├── nets ├── __init__.py ├── fcn.py ├── layers.py └── nets_factory.py ├── prepare_data.py ├── preprocessing ├── __init__.py ├── cifarnet_preprocessing.py ├── fcn_preprocessing.py ├── inception_preprocessing.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── vgg_preprocessing.py ├── scripts ├── eval.sh ├── train.sh └── unet-pascal.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | bazel*/ 91 | *~ 92 | .*.swp 93 | /.classpath 94 | /.factorypath 95 | /.idea/ 96 | /.project 97 | /.settings 98 | /WORKSPACE.user.bzl 99 | /bazel-bazel 100 | /bazel-bin 101 | /bazel-genfiles 102 | /bazel-io_bazel 103 | /bazel-out 104 | /bazel-testlogs 105 | /bazel.iml 106 | /output/ -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | #init-hook= 9 | 10 | # Profiled execution. 11 | profile=no 12 | 13 | # Add files or directories to the blacklist. They should be base names, not 14 | # paths. 15 | ignore=CVS 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=yes 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | 25 | [MESSAGES CONTROL] 26 | 27 | # Enable the message, report, category or checker with the given id(s). You can 28 | # either give multiple identifier separated by comma (,) or put this option 29 | # multiple time. See also the "--disable" option for examples. 30 | enable=old-raise-syntax 31 | 32 | # Disable the message, report, category or checker with the given id(s). You 33 | # can either give multiple identifiers separated by comma (,) or put this 34 | # option multiple times (only on the command line, not in the configuration 35 | # file where it should appear only once).You can also use "--disable=all" to 36 | # disable everything first and then reenable specific checks. For example, if 37 | # you want to run only the similarities checker, you can use "--disable=all 38 | # --enable=similarities". If you want to run only the classes checker, but have 39 | # no Warning level messages displayed, use"--disable=all --enable=classes 40 | # --disable=W" 41 | disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable 42 | 43 | 44 | 45 | [REPORTS] 46 | 47 | # Set the output format. Available formats are text, parseable, colorized, msvs 48 | # (visual studio) and html. You can also give a reporter class, eg 49 | # mypackage.mymodule.MyReporterClass. 50 | output-format=text 51 | 52 | # Put messages in a separate file for each module / package specified on the 53 | # command line instead of printing them on stdout. Reports (if any) will be 54 | # written in a file name "pylint_global.[txt|html]". 55 | files-output=no 56 | 57 | # Tells whether to display a full report or only the messages 58 | reports=no 59 | 60 | # Python expression which should return a note less than 10 (10 is the highest 61 | # note). You have access to the variables errors warning, statement which 62 | # respectively contain the number of errors / warnings messages and the total 63 | # number of statements analyzed. This is used by the global evaluation report 64 | # (RP0004). 65 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 66 | 67 | # Add a comment according to your evaluation note. This is used by the global 68 | # evaluation report (RP0004). 69 | comment=no 70 | 71 | # Template used to display messages. This is a python new-style format string 72 | # used to format the message information. See doc for all details 73 | #msg-template= 74 | 75 | 76 | [TYPECHECK] 77 | 78 | # Tells whether missing members accessed in mixin class should be ignored. A 79 | # mixin class is detected if its name ends with "mixin" (case insensitive). 80 | ignore-mixin-members=yes 81 | 82 | # List of classes names for which member attributes should not be checked 83 | # (useful for classes with attributes dynamically set). 84 | ignored-classes=SQLObject 85 | 86 | # When zope mode is activated, add a predefined set of Zope acquired attributes 87 | # to generated-members. 88 | zope=no 89 | 90 | # List of members which are set dynamically and missed by pylint inference 91 | # system, and so shouldn't trigger E0201 when accessed. Python regular 92 | # expressions are accepted. 93 | generated-members=REQUEST,acl_users,aq_parent 94 | 95 | 96 | [VARIABLES] 97 | 98 | # Tells whether we should check for unused import in __init__ files. 99 | init-import=no 100 | 101 | # A regular expression matching the beginning of the name of dummy variables 102 | # (i.e. not used). 103 | dummy-variables-rgx=_$|dummy 104 | 105 | # List of additional names supposed to be defined in builtins. Remember that 106 | # you should avoid to define new builtins when possible. 107 | additional-builtins= 108 | 109 | 110 | [BASIC] 111 | 112 | # Required attributes for module, separated by a comma 113 | required-attributes= 114 | 115 | # List of builtins function names that should not be used, separated by a comma 116 | bad-functions=apply,input,reduce 117 | 118 | # Regular expression which should only match correct module names 119 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 120 | 121 | # Regular expression which should only match correct module level names 122 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 123 | 124 | # Regular expression which should only match correct class names 125 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 126 | 127 | # Regular expression which should only match correct function names 128 | function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 129 | 130 | # Regular expression which should only match correct method names 131 | method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 132 | 133 | # Regular expression which should only match correct instance attribute names 134 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 135 | 136 | # Regular expression which should only match correct argument names 137 | argument-rgx=^[a-z][a-z0-9_]*$ 138 | 139 | # Regular expression which should only match correct variable names 140 | variable-rgx=^[a-z][a-z0-9_]*$ 141 | 142 | # Regular expression which should only match correct attribute names in class 143 | # bodies 144 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 145 | 146 | # Regular expression which should only match correct list comprehension / 147 | # generator expression variable names 148 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 149 | 150 | # Good variable names which should always be accepted, separated by a comma 151 | good-names=main,_ 152 | 153 | # Bad variable names which should always be refused, separated by a comma 154 | bad-names= 155 | 156 | # Regular expression which should only match function or class names that do 157 | # not require a docstring. 158 | no-docstring-rgx=(__.*__|main) 159 | 160 | # Minimum line length for functions/classes that require docstrings, shorter 161 | # ones are exempt. 162 | docstring-min-length=10 163 | 164 | 165 | [FORMAT] 166 | 167 | # Maximum number of characters on a single line. 168 | max-line-length=80 169 | 170 | # Regexp for a line that is allowed to be longer than the limit. 171 | ignore-long-lines=^\s*(# )??$ 172 | 173 | # Allow the body of an if to be on the same line as the test if there is no 174 | # else. 175 | single-line-if-stmt=y 176 | 177 | # List of optional constructs for which whitespace checking is disabled 178 | no-space-check= 179 | 180 | # Maximum number of lines in a module 181 | max-module-lines=99999 182 | 183 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 184 | # tab). 185 | indent-string=' ' 186 | 187 | 188 | [SIMILARITIES] 189 | 190 | # Minimum lines number of a similarity. 191 | min-similarity-lines=4 192 | 193 | # Ignore comments when computing similarities. 194 | ignore-comments=yes 195 | 196 | # Ignore docstrings when computing similarities. 197 | ignore-docstrings=yes 198 | 199 | # Ignore imports when computing similarities. 200 | ignore-imports=no 201 | 202 | 203 | [MISCELLANEOUS] 204 | 205 | # List of note tags to take in consideration, separated by a comma. 206 | notes= 207 | 208 | 209 | [IMPORTS] 210 | 211 | # Deprecated modules which should not be used, separated by a comma 212 | deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets 213 | 214 | # Create a graph of every (i.e. internal and external) dependencies in the 215 | # given file (report RP0402 must not be disabled) 216 | import-graph= 217 | 218 | # Create a graph of external dependencies in the given file (report RP0402 must 219 | # not be disabled) 220 | ext-import-graph= 221 | 222 | # Create a graph of internal dependencies in the given file (report RP0402 must 223 | # not be disabled) 224 | int-import-graph= 225 | 226 | 227 | [CLASSES] 228 | 229 | # List of interface methods to ignore, separated by a comma. This is used for 230 | # instance to not check methods defines in Zope's Interface base class. 231 | ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by 232 | 233 | # List of method names used to declare (i.e. assign) instance attributes. 234 | defining-attr-methods=__init__,__new__,setUp 235 | 236 | # List of valid names for the first argument in a class method. 237 | valid-classmethod-first-arg=cls,class_ 238 | 239 | # List of valid names for the first argument in a metaclass class method. 240 | valid-metaclass-classmethod-first-arg=mcs 241 | 242 | 243 | [DESIGN] 244 | 245 | # Maximum number of arguments for function / method 246 | max-args=5 247 | 248 | # Argument names that match this expression will be ignored. Default to name 249 | # with leading underscore 250 | ignored-argument-names=_.* 251 | 252 | # Maximum number of locals for function / method body 253 | max-locals=15 254 | 255 | # Maximum number of return / yield for function / method body 256 | max-returns=6 257 | 258 | # Maximum number of branch for function / method body 259 | max-branches=12 260 | 261 | # Maximum number of statements in function / method body 262 | max-statements=50 263 | 264 | # Maximum number of parents for a class (see R0901). 265 | max-parents=7 266 | 267 | # Maximum number of attributes for a class (see R0902). 268 | max-attributes=7 269 | 270 | # Minimum number of public methods for a class (see R0903). 271 | min-public-methods=2 272 | 273 | # Maximum number of public methods for a class (see R0904). 274 | max-public-methods=20 275 | 276 | 277 | [EXCEPTIONS] 278 | 279 | # Exceptions that will emit a warning when being caught. Defaults to 280 | # "Exception" 281 | overgeneral-exceptions=Exception,StandardError,BaseException -------------------------------------------------------------------------------- /LICENSES.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Bobby D. DeSimone 2 | 3 | Contains modified code from 4 | [TensorFlow](https://www.tensorflow.org/), [LICENSE](https://github.com/tensorflow/tensorflow/blob/master/LICENSE) 5 | ``` 6 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 7 | 8 | Licensed under the Apache License, Version 2.0 (the "License"); 9 | you may not use this file except in compliance with the License. 10 | You may obtain a copy of the License at 11 | 12 | http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | Unless required by applicable law or agreed to in writing, software 15 | distributed under the License is distributed on an "AS IS" BASIS, 16 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | See the License for the specific language governing permissions and 18 | limitations under the License. 19 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Summary 2 | 3 | ⚠️ Work in progress ⚠️ 4 | 5 | A collection of semantic image segmentation models implemented in TensorFlow. Contains data-loaders for the generic and medical benchmark datasets. 6 | 7 | Hopefully this project will enable researchers to spend less time scaffolding and more time building. 8 | 9 | 10 | ## Datasets & Benchmarks 11 | 12 | Generic 13 | 14 | - [ ] [ADE20K Scene Parsing](https://groups.csail.mit.edu/vision/datasets/ADE20K/) : [paper](https://arxiv.org/pdf/1608.05442.pdf) 15 | - [ ] [Microsoft COCO: Common Objects in Context](http://mscoco.org/home/) : [paper](https://arxiv.org/abs/1405.0312) 16 | - [ ] [Cityscapes](https://www.cityscapes-dataset.com/) : [paper](https://arxiv.org/abs/1604.01685) 17 | - [ ] [PASCAL Visual Object Classes](http://host.robots.ox.ac.uk/pascal/VOC/) : [paper](https://link.springer.com/article/10.1007/s11263-014-0733-5) 18 | - [ ] [SUN RGB-D Scene Understanding Benchmark Suite](http://rgbd.cs.princeton.edu/) : [paper](http://rgbd.cs.princeton.edu/paper.pdf) 19 | 20 | Medical 21 | 22 | - [ ] MICCAI - Brain Tumor Image Segmentation Challenge (BRATS) 23 | - [ ] MICCAI - Ischemic Stroke Lesion Segmentation (ISLES) 24 | 25 | ## Networks & Models 26 | 27 | Generic 28 | 29 | - [ ] [DeepLab v2](http://arxiv.org/abs/1412.7062) : [project](http://liangchiehchen.com/projects/DeepLab.html) : [C++ code](https://bitbucket.org/deeplab/deeplab-public/) 30 | - [ ] [RefineNet](https://arxiv.org/abs/1611.06612) : [MATLAB code](https://github.com/guosheng/refinenet) 31 | - [ ] [I-FCN](https://arxiv.org/abs/1611.08986) 32 | - [ ] [FC-DenseNet](https://arxiv.org/abs/1611.09326) : [theano, lasagne code](https://github.com/SimJeg/FC-DenseNet) 33 | - [ ] [PixelNet](https://arxiv.org/abs/1609.06694) : [cafffe code](https://github.com/endernewton/PixelNet) 34 | - [ ] [FCN](http://arxiv.org/abs/1411.4038) : [slides](https://docs.google.com/presentation/d/1VeWFMpZ8XN7OC3URZP4WdXvOGYckoFWGVN7hApoXVnc) 35 | - [ ] [SegNet](http://arxiv.org/abs/1505.07293) : [caffe code](https://github.com/alexgkendall/caffe-segnet) 36 | 37 | Medical 38 | 39 | - [ ] [U-Net](http://lmb.informatik.uni-freiburg.de/resources/opensource/unet.en.html) 40 | 41 | ## Usage 42 | 43 | See `./scipts/` 44 | 45 | ## Requirements 46 | 47 | - Python 2.7 48 | - [TensorFlow](https://www.tensorflow.org/get_started/os_setup) `0.12+` 49 | 50 | ## Resources 51 | 52 | Learn 53 | 54 | 1. [TensorFlow Deep Learning Course](https://www.udacity.com/course/deep-learning--ud730) Get hands on right away with tensorflow and deep learning. 55 | 2. [Machine Learning, Andrew Ng](https://www.coursera.org/learn/machine-learning) Deeper dive into basics, less hands . 56 | 3. [Stanford CS231n](https://cs231n.github.io/) [videos](https://www.youtube.com/playlist?list=PLkt2uSq6rBVctENoVBg1TpCC7OQi31AlC) I can't overstate how fantastic the notes, and videos are. 57 | 4. [Deep Learning : Book](http://www.deeplearningbook.org/) Helpful reference for filling in gaps. 58 | 5. Above papers, starting with [Fully Convolutional Networks for Semantic Segmentation](https://arxiv.org/abs/1411.4038) and [video](http://techtalks.tv/talks/fully-convolutional-networks-for-semantic-segmentation/61606/) 59 | 60 | Code 61 | 62 | - [TF-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) 63 | - [TF-Slim : Classification Networks](https://github.com/tensorflow/models/tree/master/slim) 64 | - [imagenet-multiGPU.torch](https://github.com/soumith/imagenet-multiGPU.torch) 65 | - [pixel-cnn++](https://github.com/openai/pixel-cnn) 66 | - NVIDIA Digits [Semantic Segmentaiton Example](https://github.com/NVIDIA/DIGITS/tree/master/examples/semantic-segmentation) [Medical Imaging Example](https://github.com/NVIDIA/DIGITS/tree/master/examples/medical-imaging) 67 | 68 | ## Contributing 69 | 70 | Please do. [PEP-8](https://www.python.org/dev/peps/pep-0008/), [google style](https://google.github.io/styleguide/pyguide.html) with 2 space idents [🤦️](https://www.tensorflow.org/how_tos/style_guide). -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desimone/segmentation-models/7f9e5a182891d20b2110dc572b57251c0111988c/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """A factory-pattern class which returns classification image/label pairs.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | from datasets import pascal 6 | 7 | datasets_map = {'pascal': pascal} 8 | 9 | 10 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 11 | """Given a dataset name and a split_name returns a Dataset. """ 12 | if name not in datasets_map: 13 | raise ValueError('Name of dataset unknown %s' % name) 14 | return datasets_map[name].get_split(split_name, dataset_dir, file_pattern, 15 | reader) 16 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """Contains utilities for downloading and converting datasets.""" 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import hashlib 5 | import os 6 | import sys 7 | import tarfile 8 | 9 | import glob2 10 | import tensorflow as tf 11 | 12 | from six.moves import urllib 13 | 14 | _RANDOM_SEED = 0 15 | 16 | 17 | def parse_glob(path): 18 | """ returns a file path, or an empty string """ 19 | try: 20 | return glob2.glob(path)[0] 21 | except IndexError: 22 | return "" 23 | 24 | 25 | def _int64_feature(value): 26 | """Wrapper for inserting int64 features into Example proto.""" 27 | if not isinstance(value, list): 28 | value = [value] 29 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 30 | 31 | 32 | def _bytes_feature(value): 33 | """Wrapper for inserting bytes features into Example proto.""" 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 35 | 36 | 37 | def tfrecord(image, mask, height, width, channels): 38 | """Build an Example proto for an example. 39 | 40 | Args: 41 | filename: string, path to an image file, e.g., '/path/to/example.JPG' 42 | image: string, JPEG encoding of RGB image 43 | height: integer, image height in pixels 44 | width: integer, image width in pixels 45 | mask: string, PNG encoding of ground truth image 46 | 47 | Returns: 48 | Example proto 49 | """ 50 | 51 | example = tf.train.Example(features=tf.train.Features(feature={ 52 | 'image/height': _int64_feature(height), 53 | 'image/width': _int64_feature(width), 54 | 'image/channels': _int64_feature(channels), 55 | 'image/encoded': _bytes_feature(image), 56 | 'image/mask/encoded': _bytes_feature(mask) 57 | })) 58 | return example 59 | 60 | 61 | class ImageReader(object): 62 | """Helper class that provides TensorFlow image coding utilities.""" 63 | 64 | def __init__(self): 65 | # Initializes function that decodes RGB JPEG data. 66 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 67 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 68 | 69 | def decode_image(self, sess, image): 70 | image = sess.run(self._decode_jpeg, 71 | feed_dict={self._decode_jpeg_data: image}) 72 | height = image.shape[0] 73 | width = image.shape[1] 74 | channels = image.shape[2] 75 | # some sanity checking 76 | assert len(image.shape) == 3 77 | assert channels == 3 # TODO(BDD) : Support other sets of channels 78 | assert width != 0 79 | assert height != 0 80 | return height, width, channels 81 | 82 | 83 | def get_filenames(dataset_dir, split_name, shard_id): 84 | output_filename = '%s_%05d.tfrecord' % (split_name, shard_id) 85 | return os.path.join(dataset_dir, output_filename) 86 | 87 | 88 | def is_png(filename): 89 | """Determine if a file contains a PNG format image.""" 90 | return '.png' in filename 91 | 92 | 93 | def is_jpg(filename): 94 | """Determine if a file contains a JPG format image.""" 95 | return '.jpg' in filename 96 | 97 | 98 | def file_matches(file_name, file_hash): 99 | if not tf.gfile.Exists(file_name): 100 | return False 101 | with open(file_name) as file_to_check: 102 | md5_returned = hashlib.md5(file_to_check.read()).hexdigest() 103 | return md5_returned == file_hash 104 | 105 | 106 | def dataset_exists(dataset_dir): 107 | for split_name in ['train', 'validation']: 108 | for shard_id in range(_NUM_SHARDS): 109 | out_filename = get_filenames(dataset_dir, split_name, shard_id) 110 | if not tf.gfile.Exists(out_filename): 111 | return False 112 | return True 113 | 114 | 115 | def cleanup_directory(dataset_dir): 116 | """Removes temporary files used to create the dataset.""" 117 | filename = _DATA_URL.split('/')[-1] 118 | filepath = os.path.join(dataset_dir, filename) 119 | tf.gfile.Remove(filepath) 120 | tmp_dir = os.path.join(dataset_dir, _VOC_ROOT) 121 | tf.gfile.DeleteRecursively(tmp_dir) 122 | 123 | 124 | def download(dataset_url, dataset_hash, dataset_dir): 125 | filename = dataset_url.split('/')[-1] 126 | filepath = os.path.join(dataset_dir, filename) 127 | 128 | def _progress(count, block_size, total_size): 129 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 130 | float(count * block_size) / 131 | float(total_size) * 100.0)) 132 | sys.stdout.flush() 133 | 134 | if not file_matches(filepath, dataset_hash): 135 | print("%s not found, downloading it!" % filepath) 136 | filepath, _ = urllib.request.urlretrieve(dataset_url, filepath, _progress) 137 | statinfo = os.stat(filepath) 138 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 139 | tarfile.open(filepath, 'r').extractall(dataset_dir) 140 | -------------------------------------------------------------------------------- /datasets/download_and_convert_pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import math 4 | import os 5 | import random 6 | import sys 7 | 8 | import tensorflow as tf 9 | 10 | from datasets import dataset_utils 11 | 12 | _DATA_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' 13 | _DATA_MD5 = '6cd6e144f989b92b3379bac3b3de84fd' 14 | _NUM_VALIDATION = 350 15 | _RANDOM_SEED = 0 16 | _NUM_SHARDS = 1 17 | _VOC_ROOT = 'VOCdevkit/VOC2012' 18 | 19 | 20 | def get_images_and_masks(dataset_dir): 21 | """Returns a list of mask and image file names. """ 22 | voc_root = os.path.join(dataset_dir, _VOC_ROOT) 23 | mask_root = os.path.join(voc_root, 'SegmentationClass') 24 | img_root = os.path.join(voc_root, 'JPEGImages') 25 | 26 | print('Root:%s\nMasks:%s\nImages:%s' % (voc_root, img_root, mask_root)) 27 | images = [] 28 | masks = [] 29 | # Each jpg image has a corresponding PNG segmentation mask 30 | for filename in os.listdir(mask_root): 31 | masks.append(os.path.join(mask_root, filename)) 32 | # pop the .png extension, and grab the source .jpg 33 | images.append(os.path.join(img_root, filename.strip('.png') + '.jpg')) 34 | 35 | print('found\n\t%d images\n\t %d masks' % (len(images), len(masks))) 36 | return images, masks 37 | 38 | 39 | def _convert_dataset(split_name, img_fns, mask_fns, dataset_dir): 40 | """Converts the given filenames to a TFRecord dataset.""" 41 | assert split_name in ['train', 'validation'] 42 | 43 | num_per_shard = int(math.ceil(len(img_fns) / float(_NUM_SHARDS))) 44 | 45 | with tf.Graph().as_default(): 46 | reader = dataset_utils.ImageReader() 47 | 48 | with tf.Session('') as sess: 49 | 50 | for shard_id in range(_NUM_SHARDS): 51 | 52 | out_fns = dataset_utils.get_filenames(dataset_dir, split_name, shard_id) 53 | 54 | with tf.python_io.TFRecordWriter(out_fns) as tfr_writer: 55 | start_ndx = shard_id * num_per_shard 56 | end_ndx = min((shard_id + 1) * num_per_shard, len(img_fns)) 57 | for i in range(start_ndx, end_ndx): 58 | sys.stdout.write('\r>> Converting image %d/%d shard %d' % 59 | (i + 1, len(img_fns), shard_id)) 60 | sys.stdout.flush() 61 | # read raw mask and image as bytes 62 | image = tf.gfile.FastGFile(img_fns[i], 'r').read() 63 | mask = tf.gfile.FastGFile(mask_fns[i], 'r').read() 64 | # use image reader to grab some properties about the image 65 | width, height, chans = reader.decode_image(sess, image) 66 | # Prepare a tf-record example proto 67 | example = dataset_utils.tfrecord(image, mask, height, width, chans) 68 | tfr_writer.write(example.SerializeToString()) 69 | 70 | sys.stdout.write('\n') 71 | sys.stdout.flush() 72 | 73 | 74 | def _dataset_exists(dataset_dir): 75 | for split_name in ['train', 'validation']: 76 | for shard_id in range(_NUM_SHARDS): 77 | out_fns = dataset_utils.get_filenames(dataset_dir, split_name, shard_id) 78 | if not tf.gfile.Exists(out_fns): 79 | return False 80 | return True 81 | 82 | 83 | def run(dataset_dir): 84 | """Runs the download and conversion operation.""" 85 | if not tf.gfile.Exists(dataset_dir): 86 | tf.gfile.MakeDirs(dataset_dir) 87 | 88 | if _dataset_exists(dataset_dir): 89 | print('Dataset files already exist. Exiting without re-creating them.') 90 | return 91 | 92 | dataset_utils.download(_DATA_URL, _DATA_MD5, dataset_dir) 93 | images, masks = get_images_and_masks(dataset_dir) 94 | 95 | # Divide into train and test: 96 | # 97 | # TODO(BDD): replace with precut train/val/test tests in folder 98 | # 99 | random.seed(_RANDOM_SEED) 100 | random.shuffle(images) 101 | random.shuffle(masks) 102 | 103 | train_imgs = images[_NUM_VALIDATION:] 104 | val_imgs = images[:_NUM_VALIDATION] 105 | train_masks = masks[_NUM_VALIDATION:] 106 | val_masks = masks[:_NUM_VALIDATION] 107 | 108 | # First, convert the training and validation sets. 109 | _convert_dataset('train', train_imgs, train_masks, dataset_dir) 110 | _convert_dataset('validation', val_imgs, val_masks, dataset_dir) 111 | 112 | # TODO(BDD) : renable cleanup when working 113 | #_clean_up_temporary_files(dataset_dir) 114 | print('\nFinished converting the PASCAL-VOC dataset!') 115 | -------------------------------------------------------------------------------- /datasets/pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | from datasets import dataset_utils 8 | 9 | slim = tf.contrib.slim 10 | 11 | _FILE_PATTERN = '%s_*.tfrecord' 12 | _CLASS_NAMES = [ 13 | 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 14 | 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 15 | 'person', 'potted-plant', 'sheep', 'sofa', 'train', 'tv/monitor', 16 | 'ambigious' 17 | ] 18 | 19 | SPLITS_TO_SIZES = {'train': 3320, 'validation': 350} 20 | 21 | _NUM_CLASSES = 21 22 | 23 | _ITEMS_TO_DESCRIPTIONS = { 24 | 'image': 'A color image of varying size.', 25 | 'label': 'Ground truth segmentation mask', 26 | } 27 | 28 | 29 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 30 | if split_name not in SPLITS_TO_SIZES: 31 | raise ValueError('split name %s was not recognized.' % split_name) 32 | 33 | if not file_pattern: 34 | file_pattern = _FILE_PATTERN 35 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 36 | print(file_pattern) 37 | # Allowing None in the signature so that dataset_factory can use the 38 | # default. 39 | if reader is None: 40 | reader = tf.TFRecordReader 41 | 42 | keys_to_features = { 43 | 'image/encoded': 44 | tf.FixedLenFeature( 45 | (), tf.string, default_value=''), 46 | 'image/format': 47 | tf.FixedLenFeature( 48 | (), tf.string, default_value='jpeg'), 49 | 'image/mask/encoded': 50 | tf.FixedLenFeature( 51 | (), tf.string, default_value=''), 52 | 'image/mask/format': 53 | tf.FixedLenFeature( 54 | (), tf.string, default_value='png'), 55 | } 56 | 57 | items_to_handlers = { 58 | 'image': 59 | slim.tfexample_decoder.Image(), 60 | 'label': 61 | slim.tfexample_decoder.Image( 62 | 'image/mask/encoded', 'image/mask/format', channels=1), 63 | } 64 | 65 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, 66 | items_to_handlers) 67 | 68 | # TODO(bdd): name labels 69 | labels_to_names = None 70 | 71 | return slim.dataset.Dataset( 72 | data_sources=file_pattern, 73 | reader=reader, 74 | decoder=decoder, 75 | num_samples=SPLITS_TO_SIZES[split_name], 76 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 77 | num_classes=_NUM_CLASSES, 78 | labels_to_names=labels_to_names) 79 | -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desimone/segmentation-models/7f9e5a182891d20b2110dc572b57251c0111988c/deployment/__init__.py -------------------------------------------------------------------------------- /deployment/model_deploy.py: -------------------------------------------------------------------------------- 1 | """Deploy Slim models across multiple clones and replicas. 2 | 3 | # TODO(sguada) docstring paragraph by (a) motivating the need for the file and 4 | # (b) defining clones. 5 | 6 | # TODO(sguada) describe the high-level components of model deployment. 7 | # E.g. "each model deployment is composed of several parts: a DeploymentConfig, 8 | # which captures A, B and C, an input_fn which loads data.. etc 9 | 10 | To easily train a model on multiple GPUs or across multiple machines this 11 | module provides a set of helper functions: `create_clones`, 12 | `optimize_clones` and `deploy`. 13 | 14 | Usage: 15 | 16 | g = tf.Graph() 17 | 18 | # Set up DeploymentConfig 19 | config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True) 20 | 21 | # Create the global step on the device storing the variables. 22 | with tf.device(config.variables_device()): 23 | global_step = slim.create_global_step() 24 | 25 | # Define the inputs 26 | with tf.device(config.inputs_device()): 27 | images, labels = LoadData(...) 28 | inputs_queue = slim.data.prefetch_queue((images, labels)) 29 | 30 | # Define the optimizer. 31 | with tf.device(config.optimizer_device()): 32 | optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) 33 | 34 | # Define the model including the loss. 35 | def model_fn(inputs_queue): 36 | images, labels = inputs_queue.dequeue() 37 | predictions = CreateNetwork(images) 38 | slim.losses.log_loss(predictions, labels) 39 | 40 | model_dp = model_deploy.deploy(config, model_fn, [inputs_queue], 41 | optimizer=optimizer) 42 | 43 | # Run training. 44 | slim.learning.train(model_dp.train_op, my_log_dir, 45 | summary_op=model_dp.summary_op) 46 | 47 | The Clone namedtuple holds together the values associated with each call to 48 | model_fn: 49 | * outputs: The return values of the calls to `model_fn()`. 50 | * scope: The scope used to create the clone. 51 | * device: The device used to create the clone. 52 | 53 | DeployedModel namedtuple, holds together the values needed to train multiple 54 | clones: 55 | * train_op: An operation that run the optimizer training op and include 56 | all the update ops created by `model_fn`. Present only if an optimizer 57 | was specified. 58 | * summary_op: An operation that run the summaries created by `model_fn` 59 | and process_gradients. 60 | * total_loss: A `Tensor` that contains the sum of all losses created by 61 | `model_fn` plus the regularization losses. 62 | * clones: List of `Clone` tuples returned by `create_clones()`. 63 | 64 | DeploymentConfig parameters: 65 | * num_clones: Number of model clones to deploy in each replica. 66 | * clone_on_cpu: True if clones should be placed on CPU. 67 | * replica_id: Integer. Index of the replica for which the model is 68 | deployed. Usually 0 for the chief replica. 69 | * num_replicas: Number of replicas to use. 70 | * num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 71 | * worker_job_name: A name for the worker job. 72 | * ps_job_name: A name for the parameter server job. 73 | 74 | TODO(sguada): 75 | - describe side effect to the graph. 76 | - what happens to summaries and update_ops. 77 | - which graph collections are altered. 78 | - write a tutorial on how to use this. 79 | - analyze the possibility of calling deploy more than once. 80 | 81 | 82 | """ 83 | 84 | from __future__ import absolute_import, division, print_function 85 | 86 | import collections 87 | 88 | import tensorflow as tf 89 | from tensorflow.python.ops import control_flow_ops 90 | 91 | slim = tf.contrib.slim 92 | 93 | __all__ = [ 94 | 'create_clones', 95 | 'deploy', 96 | 'optimize_clones', 97 | 'DeployedModel', 98 | 'DeploymentConfig', 99 | 'Clone', 100 | ] 101 | 102 | # Namedtuple used to represent a clone during deployment. 103 | Clone = collections.namedtuple( 104 | 'Clone', 105 | [ 106 | 'outputs', # Whatever model_fn() returned. 107 | 'scope', # The scope used to create it. 108 | 'device', # The device used to create. 109 | ]) 110 | 111 | # Namedtuple used to represent a DeployedModel, returned by deploy(). 112 | DeployedModel = collections.namedtuple( 113 | 'DeployedModel', 114 | [ 115 | 'train_op', # The `train_op` 116 | 'summary_op', # The `summary_op` 117 | 'total_loss', # The loss `Tensor` 118 | 'clones', # A list of `Clones` tuples. 119 | ]) 120 | 121 | # Default parameters for DeploymentConfig 122 | _deployment_params = { 123 | 'num_clones': 1, 124 | 'clone_on_cpu': False, 125 | 'replica_id': 0, 126 | 'num_replicas': 1, 127 | 'num_ps_tasks': 0, 128 | 'worker_job_name': 'worker', 129 | 'ps_job_name': 'ps' 130 | } 131 | 132 | 133 | def create_clones(config, model_fn, args=None, kwargs=None): 134 | """Creates multiple clones according to config using a `model_fn`. 135 | 136 | The returned values of `model_fn(*args, **kwargs)` are collected along with 137 | the scope and device used to created it in a namedtuple 138 | `Clone(outputs, scope, device)` 139 | 140 | Note: it is assumed that any loss created by `model_fn` is collected at 141 | the tf.GraphKeys.LOSSES collection. 142 | 143 | To recover the losses, summaries or update_ops created by the clone use: 144 | ```python 145 | losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 146 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope) 147 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope) 148 | ``` 149 | 150 | The deployment options are specified by the config object and support 151 | deploying one or several clones on different GPUs and one or several replicas 152 | of such clones. 153 | 154 | The argument `model_fn` is called `config.num_clones` times to create the 155 | model clones as `model_fn(*args, **kwargs)`. 156 | 157 | If `config` specifies deployment on multiple replicas then the default 158 | tensorflow device is set appropriatly for each call to `model_fn` and for the 159 | slim variable creation functions: model and global variables will be created 160 | on the `ps` device, the clone operations will be on the `worker` device. 161 | 162 | Args: 163 | config: A DeploymentConfig object. 164 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 165 | args: Optional list of arguments to pass to `model_fn`. 166 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 167 | 168 | Returns: 169 | A list of namedtuples `Clone`. 170 | """ 171 | clones = [] 172 | args = args or [] 173 | kwargs = kwargs or {} 174 | with slim.arg_scope( 175 | [slim.model_variable, slim.variable], device=config.variables_device()): 176 | # Create clones. 177 | for i in range(0, config.num_clones): 178 | with tf.name_scope(config.clone_scope(i)) as clone_scope: 179 | clone_device = config.clone_device(i) 180 | with tf.device(clone_device): 181 | with tf.variable_scope( 182 | tf.get_variable_scope(), reuse=True if i > 0 else 183 | None): 184 | outputs = model_fn(*args, **kwargs) 185 | clones.append(Clone(outputs, clone_scope, clone_device)) 186 | return clones 187 | 188 | 189 | def _gather_clone_loss(clone, num_clones, regularization_losses): 190 | """Gather the loss for a single clone. 191 | 192 | Args: 193 | clone: A Clone namedtuple. 194 | num_clones: The number of clones being deployed. 195 | regularization_losses: Possibly empty list of regularization_losses 196 | to add to the clone losses. 197 | 198 | Returns: 199 | A tensor for the total loss for the clone. Can be None. 200 | """ 201 | # The return value. 202 | sum_loss = None 203 | # Individual components of the loss that will need summaries. 204 | clone_loss = None 205 | regularization_loss = None 206 | # Compute and aggregate losses on the clone device. 207 | with tf.device(clone.device): 208 | all_losses = [] 209 | clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 210 | if clone_losses: 211 | clone_loss = tf.add_n(clone_losses, name='clone_loss') 212 | if num_clones > 1: 213 | clone_loss = tf.div(clone_loss, 214 | 1.0 * num_clones, 215 | name='scaled_clone_loss') 216 | all_losses.append(clone_loss) 217 | if regularization_losses: 218 | regularization_loss = tf.add_n( 219 | regularization_losses, name='regularization_loss') 220 | all_losses.append(regularization_loss) 221 | if all_losses: 222 | sum_loss = tf.add_n(all_losses) 223 | # Add the summaries out of the clone device block. 224 | if clone_loss is not None: 225 | tf.summary.scalar(clone.scope + '/clone_loss', clone_loss) 226 | if regularization_loss is not None: 227 | tf.summary.scalar('regularization_loss', regularization_loss) 228 | return sum_loss 229 | 230 | 231 | def _optimize_clone(optimizer, clone, num_clones, regularization_losses, 232 | **kwargs): 233 | """Compute losses and gradients for a single clone. 234 | 235 | Args: 236 | optimizer: A tf.Optimizer object. 237 | clone: A Clone namedtuple. 238 | num_clones: The number of clones being deployed. 239 | regularization_losses: Possibly empty list of regularization_losses 240 | to add to the clone losses. 241 | **kwargs: Dict of kwarg to pass to compute_gradients(). 242 | 243 | Returns: 244 | A tuple (clone_loss, clone_grads_and_vars). 245 | - clone_loss: A tensor for the total loss for the clone. Can be None. 246 | - clone_grads_and_vars: List of (gradient, variable) for the clone. 247 | Can be empty. 248 | """ 249 | sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses) 250 | clone_grad = None 251 | if sum_loss is not None: 252 | with tf.device(clone.device): 253 | clone_grad = optimizer.compute_gradients(sum_loss, **kwargs) 254 | return sum_loss, clone_grad 255 | 256 | 257 | def optimize_clones(clones, optimizer, regularization_losses=None, **kwargs): 258 | """Compute clone losses and gradients for the given list of `Clones`. 259 | 260 | Note: The regularization_losses are added to the first clone losses. 261 | 262 | Args: 263 | clones: List of `Clones` created by `create_clones()`. 264 | optimizer: An `Optimizer` object. 265 | regularization_losses: Optional list of regularization losses. If None it 266 | will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to 267 | exclude them. 268 | **kwargs: Optional list of keyword arguments to pass to `compute_gradients`. 269 | 270 | Returns: 271 | A tuple (total_loss, grads_and_vars). 272 | - total_loss: A Tensor containing the average of the clone losses including 273 | the regularization loss. 274 | - grads_and_vars: A List of tuples (gradient, variable) containing the sum 275 | of the gradients for each variable. 276 | 277 | """ 278 | grads_and_vars = [] 279 | clones_losses = [] 280 | num_clones = len(clones) 281 | if regularization_losses is None: 282 | regularization_losses = tf.get_collection( 283 | tf.GraphKeys.REGULARIZATION_LOSSES) 284 | for clone in clones: 285 | with tf.name_scope(clone.scope): 286 | clone_loss, clone_grad = _optimize_clone(optimizer, clone, 287 | num_clones, 288 | regularization_losses, 289 | **kwargs) 290 | if clone_loss is not None: 291 | clones_losses.append(clone_loss) 292 | grads_and_vars.append(clone_grad) 293 | # Only use regularization_losses for the first clone 294 | regularization_losses = None 295 | # Compute the total_loss summing all the clones_losses. 296 | total_loss = tf.add_n(clones_losses, name='total_loss') 297 | # Sum the gradients accross clones. 298 | grads_and_vars = _sum_clones_gradients(grads_and_vars) 299 | return total_loss, grads_and_vars 300 | 301 | 302 | def deploy(config, 303 | model_fn, 304 | args=None, 305 | kwargs=None, 306 | optimizer=None, 307 | summarize_gradients=False): 308 | """Deploys a Slim-constructed model across multiple clones. 309 | 310 | The deployment options are specified by the config object and support 311 | deploying one or several clones on different GPUs and one or several replicas 312 | of such clones. 313 | 314 | The argument `model_fn` is called `config.num_clones` times to create the 315 | model clones as `model_fn(*args, **kwargs)`. 316 | 317 | The optional argument `optimizer` is an `Optimizer` object. If not `None`, 318 | the deployed model is configured for training with that optimizer. 319 | 320 | If `config` specifies deployment on multiple replicas then the default 321 | tensorflow device is set appropriatly for each call to `model_fn` and for the 322 | slim variable creation functions: model and global variables will be created 323 | on the `ps` device, the clone operations will be on the `worker` device. 324 | 325 | Args: 326 | config: A `DeploymentConfig` object. 327 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 328 | args: Optional list of arguments to pass to `model_fn`. 329 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 330 | optimizer: Optional `Optimizer` object. If passed the model is deployed 331 | for training with that optimizer. 332 | summarize_gradients: Whether or not add summaries to the gradients. 333 | 334 | Returns: 335 | A `DeployedModel` namedtuple. 336 | 337 | """ 338 | # Gather initial summaries. 339 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 340 | 341 | # Create Clones. 342 | clones = create_clones(config, model_fn, args, kwargs) 343 | first_clone = clones[0] 344 | 345 | # Gather update_ops from the first clone. These contain, for example, 346 | # the updates for the batch_norm variables created by model_fn. 347 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope) 348 | 349 | train_op = None 350 | total_loss = None 351 | with tf.device(config.optimizer_device()): 352 | if optimizer: 353 | # Place the global step on the device storing the variables. 354 | with tf.device(config.variables_device()): 355 | global_step = slim.get_or_create_global_step() 356 | 357 | # Compute the gradients for the clones. 358 | total_loss, clones_gradients = optimize_clones(clones, optimizer) 359 | 360 | if clones_gradients: 361 | if summarize_gradients: 362 | # Add summaries to the gradients. 363 | summaries |= set(_add_gradients_summaries(clones_gradients)) 364 | 365 | # Create gradient updates. 366 | grad_updates = optimizer.apply_gradients( 367 | clones_gradients, global_step=global_step) 368 | update_ops.append(grad_updates) 369 | 370 | update_op = tf.group(*update_ops) 371 | train_op = control_flow_ops.with_dependencies( 372 | [update_op], total_loss, name='train_op') 373 | else: 374 | clones_losses = [] 375 | regularization_losses = tf.get_collection( 376 | tf.GraphKeys.REGULARIZATION_LOSSES) 377 | for clone in clones: 378 | with tf.name_scope(clone.scope): 379 | clone_loss = _gather_clone_loss(clone, 380 | len(clones), 381 | regularization_losses) 382 | if clone_loss is not None: 383 | clones_losses.append(clone_loss) 384 | # Only use regularization_losses for the first clone 385 | regularization_losses = None 386 | if clones_losses: 387 | total_loss = tf.add_n(clones_losses, name='total_loss') 388 | 389 | # Add the summaries from the first clone. These contain the summaries 390 | # created by model_fn and either optimize_clones() or 391 | # _gather_clone_loss(). 392 | summaries |= set( 393 | tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone.scope)) 394 | 395 | if total_loss is not None: 396 | # Add total_loss to summary. 397 | summaries.add(tf.summary.scalar('total_loss', total_loss)) 398 | 399 | if summaries: 400 | # Merge all summaries together. 401 | summary_op = tf.summary.merge(list(summaries)) 402 | else: 403 | summary_op = None 404 | 405 | return DeployedModel(train_op, summary_op, total_loss, clones) 406 | 407 | 408 | def _sum_clones_gradients(clone_grads): 409 | """Calculate the sum gradient for each shared variable across all clones. 410 | 411 | This function assumes that the clone_grads has been scaled appropriately by 412 | 1 / num_clones. 413 | 414 | Args: 415 | clone_grads: A List of List of tuples (gradient, variable), one list per 416 | `Clone`. 417 | 418 | Returns: 419 | List of tuples of (gradient, variable) where the gradient has been summed 420 | across all clones. 421 | """ 422 | sum_grads = [] 423 | for grad_and_vars in zip(*clone_grads): 424 | # Note that each grad_and_vars looks like the following: 425 | # ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN)) 426 | grads = [] 427 | var = grad_and_vars[0][1] 428 | for g, v in grad_and_vars: 429 | assert v == var 430 | if g is not None: 431 | grads.append(g) 432 | if grads: 433 | if len(grads) > 1: 434 | sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads') 435 | else: 436 | sum_grad = grads[0] 437 | sum_grads.append((sum_grad, var)) 438 | return sum_grads 439 | 440 | 441 | def _add_gradients_summaries(grads_and_vars): 442 | """Add histogram summaries to gradients. 443 | 444 | Note: The summaries are also added to the SUMMARIES collection. 445 | 446 | Args: 447 | grads_and_vars: A list of gradient to variable pairs (tuples). 448 | 449 | Returns: 450 | The _list_ of the added summaries for grads_and_vars. 451 | """ 452 | summaries = [] 453 | for grad, var in grads_and_vars: 454 | if grad is not None: 455 | if isinstance(grad, tf.IndexedSlices): 456 | grad_values = grad.values 457 | else: 458 | grad_values = grad 459 | summaries.append( 460 | tf.histogram_summary(var.op.name + ':gradient', grad_values)) 461 | summaries.append( 462 | tf.histogram_summary(var.op.name + ':gradient_norm', 463 | tf.global_norm([grad_values]))) 464 | else: 465 | tf.logging.info('Var %s has no gradient', var.op.name) 466 | return summaries 467 | 468 | 469 | class DeploymentConfig(object): 470 | """Configuration for deploying a model with `deploy()`. 471 | 472 | You can pass an instance of this class to `deploy()` to specify exactly 473 | how to deploy the model to build. If you do not pass one, an instance built 474 | from the default deployment_hparams will be used. 475 | """ 476 | 477 | def __init__(self, 478 | num_clones=1, 479 | clone_on_cpu=False, 480 | replica_id=0, 481 | num_replicas=1, 482 | num_ps_tasks=0, 483 | worker_job_name='worker', 484 | ps_job_name='ps'): 485 | """Create a DeploymentConfig. 486 | 487 | The config describes how to deploy a model across multiple clones and 488 | replicas. The model will be replicated `num_clones` times in each replica. 489 | If `clone_on_cpu` is True, each clone will placed on CPU. 490 | 491 | If `num_replicas` is 1, the model is deployed via a single process. In that 492 | case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored. 493 | 494 | If `num_replicas` is greater than 1, then `worker_device` and `ps_device` 495 | must specify TensorFlow devices for the `worker` and `ps` jobs and 496 | `num_ps_tasks` must be positive. 497 | 498 | Args: 499 | num_clones: Number of model clones to deploy in each replica. 500 | clone_on_cpu: If True clones would be placed on CPU. 501 | replica_id: Integer. Index of the replica for which the model is 502 | deployed. Usually 0 for the chief replica. 503 | num_replicas: Number of replicas to use. 504 | num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 505 | worker_job_name: A name for the worker job. 506 | ps_job_name: A name for the parameter server job. 507 | 508 | Raises: 509 | ValueError: If the arguments are invalid. 510 | """ 511 | if num_replicas > 1: 512 | if num_ps_tasks < 1: 513 | raise ValueError( 514 | 'When using replicas num_ps_tasks must be positive') 515 | if num_replicas > 1 or num_ps_tasks > 0: 516 | if not worker_job_name: 517 | raise ValueError( 518 | 'Must specify worker_job_name when using replicas') 519 | if not ps_job_name: 520 | raise ValueError( 521 | 'Must specify ps_job_name when using parameter server') 522 | if replica_id >= num_replicas: 523 | raise ValueError('replica_id must be less than num_replicas') 524 | self._num_clones = num_clones 525 | self._clone_on_cpu = clone_on_cpu 526 | self._replica_id = replica_id 527 | self._num_replicas = num_replicas 528 | self._num_ps_tasks = num_ps_tasks 529 | self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else '' 530 | self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else '' 531 | 532 | @property 533 | def num_clones(self): 534 | return self._num_clones 535 | 536 | @property 537 | def clone_on_cpu(self): 538 | return self._clone_on_cpu 539 | 540 | @property 541 | def replica_id(self): 542 | return self._replica_id 543 | 544 | @property 545 | def num_replicas(self): 546 | return self._num_replicas 547 | 548 | @property 549 | def num_ps_tasks(self): 550 | return self._num_ps_tasks 551 | 552 | @property 553 | def ps_device(self): 554 | return self._ps_device 555 | 556 | @property 557 | def worker_device(self): 558 | return self._worker_device 559 | 560 | def caching_device(self): 561 | """Returns the device to use for caching variables. 562 | 563 | Variables are cached on the worker CPU when using replicas. 564 | 565 | Returns: 566 | A device string or None if the variables do not need to be cached. 567 | """ 568 | if self._num_ps_tasks > 0: 569 | return lambda op: op.device 570 | else: 571 | return None 572 | 573 | def clone_device(self, clone_index): 574 | """Device used to create the clone and all the ops inside the clone. 575 | 576 | Args: 577 | clone_index: Int, representing the clone_index. 578 | 579 | Returns: 580 | A value suitable for `tf.device()`. 581 | 582 | Raises: 583 | ValueError: if `clone_index` is greater or equal to the number of clones". 584 | """ 585 | if clone_index >= self._num_clones: 586 | raise ValueError('clone_index must be less than num_clones') 587 | device = '' 588 | if self._num_ps_tasks > 0: 589 | device += self._worker_device 590 | if self._clone_on_cpu: 591 | device += '/device:CPU:0' 592 | else: 593 | if self._num_clones > 1: 594 | device += '/device:GPU:%d' % clone_index 595 | return device 596 | 597 | def clone_scope(self, clone_index): 598 | """Name scope to create the clone. 599 | 600 | Args: 601 | clone_index: Int, representing the clone_index. 602 | 603 | Returns: 604 | A name_scope suitable for `tf.name_scope()`. 605 | 606 | Raises: 607 | ValueError: if `clone_index` is greater or equal to the number of clones". 608 | """ 609 | if clone_index >= self._num_clones: 610 | raise ValueError('clone_index must be less than num_clones') 611 | scope = '' 612 | if self._num_clones > 1: 613 | scope = 'clone_%d' % clone_index 614 | return scope 615 | 616 | def optimizer_device(self): 617 | """Device to use with the optimizer. 618 | 619 | Returns: 620 | A value suitable for `tf.device()`. 621 | """ 622 | if self._num_ps_tasks > 0 or self._num_clones > 0: 623 | return self._worker_device + '/device:CPU:0' 624 | else: 625 | return '' 626 | 627 | def inputs_device(self): 628 | """Device to use to build the inputs. 629 | 630 | Returns: 631 | A value suitable for `tf.device()`. 632 | """ 633 | device = '' 634 | if self._num_ps_tasks > 0: 635 | device += self._worker_device 636 | device += '/device:CPU:0' 637 | return device 638 | 639 | def variables_device(self): 640 | """Returns the device to use for variables created inside the clone. 641 | 642 | Returns: 643 | A value suitable for `tf.device()`. 644 | """ 645 | device = '' 646 | if self._num_ps_tasks > 0: 647 | device += self._ps_device 648 | device += '/device:CPU:0' 649 | 650 | class _PSDeviceChooser(object): 651 | """Slim device chooser for variables when using PS.""" 652 | 653 | def __init__(self, device, tasks): 654 | self._device = device 655 | self._tasks = tasks 656 | self._task = 0 657 | 658 | def choose(self, op): 659 | if op.device: 660 | return op.device 661 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 662 | if node_def.op == 'Variable': 663 | t = self._task 664 | self._task = (self._task + 1) % self._tasks 665 | d = '%s/task:%d' % (self._device, t) 666 | return d 667 | else: 668 | return op.device 669 | 670 | if not self._num_ps_tasks: 671 | return device 672 | else: 673 | chooser = _PSDeviceChooser(device, self._num_ps_tasks) 674 | return chooser.choose 675 | -------------------------------------------------------------------------------- /deployment/model_deploy_test.py: -------------------------------------------------------------------------------- 1 | """Tests for model_deploy.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from deployment import model_deploy 11 | 12 | slim = tf.contrib.slim 13 | 14 | 15 | class DeploymentConfigTest(tf.test.TestCase): 16 | 17 | def testDefaults(self): 18 | deploy_config = model_deploy.DeploymentConfig() 19 | 20 | self.assertEqual(slim.get_variables(), []) 21 | self.assertEqual(deploy_config.caching_device(), None) 22 | self.assertDeviceEqual(deploy_config.clone_device(0), '') 23 | self.assertEqual(deploy_config.clone_scope(0), '') 24 | self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0') 25 | self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0') 26 | self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0') 27 | 28 | def testCPUonly(self): 29 | deploy_config = model_deploy.DeploymentConfig(clone_on_cpu=True) 30 | 31 | self.assertEqual(deploy_config.caching_device(), None) 32 | self.assertDeviceEqual(deploy_config.clone_device(0), 'CPU:0') 33 | self.assertEqual(deploy_config.clone_scope(0), '') 34 | self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0') 35 | self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0') 36 | self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0') 37 | 38 | def testMultiGPU(self): 39 | deploy_config = model_deploy.DeploymentConfig(num_clones=2) 40 | 41 | self.assertEqual(deploy_config.caching_device(), None) 42 | self.assertDeviceEqual(deploy_config.clone_device(0), 'GPU:0') 43 | self.assertDeviceEqual(deploy_config.clone_device(1), 'GPU:1') 44 | self.assertEqual(deploy_config.clone_scope(0), 'clone_0') 45 | self.assertEqual(deploy_config.clone_scope(1), 'clone_1') 46 | self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0') 47 | self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0') 48 | self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0') 49 | 50 | def testPS(self): 51 | deploy_config = model_deploy.DeploymentConfig( 52 | num_clones=1, num_ps_tasks=1) 53 | 54 | self.assertDeviceEqual(deploy_config.clone_device(0), '/job:worker') 55 | self.assertEqual(deploy_config.clone_scope(0), '') 56 | self.assertDeviceEqual(deploy_config.optimizer_device(), 57 | '/job:worker/device:CPU:0') 58 | self.assertDeviceEqual(deploy_config.inputs_device(), 59 | '/job:worker/device:CPU:0') 60 | with tf.device(deploy_config.variables_device()): 61 | a = tf.Variable(0) 62 | b = tf.Variable(0) 63 | c = tf.no_op() 64 | d = slim.variable( 65 | 'a', [], caching_device=deploy_config.caching_device()) 66 | self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0') 67 | self.assertDeviceEqual(a.device, a.value().device) 68 | self.assertDeviceEqual(b.device, '/job:ps/task:0/device:CPU:0') 69 | self.assertDeviceEqual(b.device, b.value().device) 70 | self.assertDeviceEqual(c.device, '') 71 | self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0') 72 | self.assertDeviceEqual(d.value().device, '') 73 | 74 | def testMultiGPUPS(self): 75 | deploy_config = model_deploy.DeploymentConfig( 76 | num_clones=2, num_ps_tasks=1) 77 | 78 | self.assertEqual(deploy_config.caching_device()(tf.no_op()), '') 79 | self.assertDeviceEqual( 80 | deploy_config.clone_device(0), '/job:worker/device:GPU:0') 81 | self.assertDeviceEqual( 82 | deploy_config.clone_device(1), '/job:worker/device:GPU:1') 83 | self.assertEqual(deploy_config.clone_scope(0), 'clone_0') 84 | self.assertEqual(deploy_config.clone_scope(1), 'clone_1') 85 | self.assertDeviceEqual(deploy_config.optimizer_device(), 86 | '/job:worker/device:CPU:0') 87 | self.assertDeviceEqual(deploy_config.inputs_device(), 88 | '/job:worker/device:CPU:0') 89 | 90 | def testReplicasPS(self): 91 | deploy_config = model_deploy.DeploymentConfig( 92 | num_replicas=2, num_ps_tasks=2) 93 | 94 | self.assertDeviceEqual(deploy_config.clone_device(0), '/job:worker') 95 | self.assertEqual(deploy_config.clone_scope(0), '') 96 | self.assertDeviceEqual(deploy_config.optimizer_device(), 97 | '/job:worker/device:CPU:0') 98 | self.assertDeviceEqual(deploy_config.inputs_device(), 99 | '/job:worker/device:CPU:0') 100 | 101 | def testReplicasMultiGPUPS(self): 102 | deploy_config = model_deploy.DeploymentConfig( 103 | num_replicas=2, num_clones=2, num_ps_tasks=2) 104 | self.assertDeviceEqual( 105 | deploy_config.clone_device(0), '/job:worker/device:GPU:0') 106 | self.assertDeviceEqual( 107 | deploy_config.clone_device(1), '/job:worker/device:GPU:1') 108 | self.assertEqual(deploy_config.clone_scope(0), 'clone_0') 109 | self.assertEqual(deploy_config.clone_scope(1), 'clone_1') 110 | self.assertDeviceEqual(deploy_config.optimizer_device(), 111 | '/job:worker/device:CPU:0') 112 | self.assertDeviceEqual(deploy_config.inputs_device(), 113 | '/job:worker/device:CPU:0') 114 | 115 | def testVariablesPS(self): 116 | deploy_config = model_deploy.DeploymentConfig(num_ps_tasks=2) 117 | 118 | with tf.device(deploy_config.variables_device()): 119 | a = tf.Variable(0) 120 | b = tf.Variable(0) 121 | c = tf.no_op() 122 | d = slim.variable( 123 | 'a', [], caching_device=deploy_config.caching_device()) 124 | 125 | self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0') 126 | self.assertDeviceEqual(a.device, a.value().device) 127 | self.assertDeviceEqual(b.device, '/job:ps/task:1/device:CPU:0') 128 | self.assertDeviceEqual(b.device, b.value().device) 129 | self.assertDeviceEqual(c.device, '') 130 | self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0') 131 | self.assertDeviceEqual(d.value().device, '') 132 | 133 | 134 | def LogisticClassifier(inputs, labels, scope=None, reuse=None): 135 | with tf.variable_scope( 136 | scope, 'LogisticClassifier', [inputs, labels], reuse=reuse): 137 | predictions = slim.fully_connected( 138 | inputs, 1, activation_fn=tf.sigmoid, scope='fully_connected') 139 | slim.losses.log_loss(predictions, labels) 140 | return predictions 141 | 142 | 143 | def BatchNormClassifier(inputs, labels, scope=None, reuse=None): 144 | with tf.variable_scope( 145 | scope, 'BatchNormClassifier', [inputs, labels], reuse=reuse): 146 | inputs = slim.batch_norm(inputs, decay=0.1) 147 | predictions = slim.fully_connected( 148 | inputs, 1, activation_fn=tf.sigmoid, scope='fully_connected') 149 | slim.losses.log_loss(predictions, labels) 150 | return predictions 151 | 152 | 153 | class CreatecloneTest(tf.test.TestCase): 154 | 155 | def setUp(self): 156 | # Create an easy training set: 157 | np.random.seed(0) 158 | 159 | self._inputs = np.zeros((16, 4)) 160 | self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 161 | self._logdir = self.get_temp_dir() 162 | 163 | for i in range(16): 164 | j = int(2 * self._labels[i] + np.random.randint(0, 2)) 165 | self._inputs[i, j] = 1 166 | 167 | def testCreateLogisticClassifier(self): 168 | g = tf.Graph() 169 | with g.as_default(): 170 | tf.set_random_seed(0) 171 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 172 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 173 | 174 | model_fn = LogisticClassifier 175 | clone_args = (tf_inputs, tf_labels) 176 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 177 | 178 | self.assertEqual(slim.get_variables(), []) 179 | clones = model_deploy.create_clones( 180 | deploy_config, model_fn, clone_args) 181 | clone = clones[0] 182 | self.assertEqual(len(slim.get_variables()), 2) 183 | for v in slim.get_variables(): 184 | self.assertDeviceEqual(v.device, 'CPU:0') 185 | self.assertDeviceEqual(v.value().device, 'CPU:0') 186 | self.assertEqual(clone.outputs.op.name, 187 | 'LogisticClassifier/fully_connected/Sigmoid') 188 | self.assertEqual(clone.scope, '') 189 | self.assertDeviceEqual(clone.device, '') 190 | self.assertEqual(len(slim.losses.get_losses()), 1) 191 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 192 | self.assertEqual(update_ops, []) 193 | 194 | def testCreateSingleclone(self): 195 | g = tf.Graph() 196 | with g.as_default(): 197 | tf.set_random_seed(0) 198 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 199 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 200 | 201 | model_fn = BatchNormClassifier 202 | clone_args = (tf_inputs, tf_labels) 203 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 204 | 205 | self.assertEqual(slim.get_variables(), []) 206 | clones = model_deploy.create_clones( 207 | deploy_config, model_fn, clone_args) 208 | clone = clones[0] 209 | self.assertEqual(len(slim.get_variables()), 5) 210 | for v in slim.get_variables(): 211 | self.assertDeviceEqual(v.device, 'CPU:0') 212 | self.assertDeviceEqual(v.value().device, 'CPU:0') 213 | self.assertEqual(clone.outputs.op.name, 214 | 'BatchNormClassifier/fully_connected/Sigmoid') 215 | self.assertEqual(clone.scope, '') 216 | self.assertDeviceEqual(clone.device, '') 217 | self.assertEqual(len(slim.losses.get_losses()), 1) 218 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 219 | self.assertEqual(len(update_ops), 2) 220 | 221 | def testCreateMulticlone(self): 222 | g = tf.Graph() 223 | with g.as_default(): 224 | tf.set_random_seed(0) 225 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 226 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 227 | 228 | model_fn = BatchNormClassifier 229 | clone_args = (tf_inputs, tf_labels) 230 | num_clones = 4 231 | deploy_config = model_deploy.DeploymentConfig( 232 | num_clones=num_clones) 233 | 234 | self.assertEqual(slim.get_variables(), []) 235 | clones = model_deploy.create_clones( 236 | deploy_config, model_fn, clone_args) 237 | self.assertEqual(len(slim.get_variables()), 5) 238 | for v in slim.get_variables(): 239 | self.assertDeviceEqual(v.device, 'CPU:0') 240 | self.assertDeviceEqual(v.value().device, 'CPU:0') 241 | self.assertEqual(len(clones), num_clones) 242 | for i, clone in enumerate(clones): 243 | self.assertEqual( 244 | clone.outputs.op.name, 245 | 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i) 246 | update_ops = tf.get_collection( 247 | tf.GraphKeys.UPDATE_OPS, clone.scope) 248 | self.assertEqual(len(update_ops), 2) 249 | self.assertEqual(clone.scope, 'clone_%d/' % i) 250 | self.assertDeviceEqual(clone.device, 'GPU:%d' % i) 251 | 252 | def testCreateOnecloneWithPS(self): 253 | g = tf.Graph() 254 | with g.as_default(): 255 | tf.set_random_seed(0) 256 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 257 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 258 | 259 | model_fn = BatchNormClassifier 260 | clone_args = (tf_inputs, tf_labels) 261 | deploy_config = model_deploy.DeploymentConfig( 262 | num_clones=1, num_ps_tasks=1) 263 | 264 | self.assertEqual(slim.get_variables(), []) 265 | clones = model_deploy.create_clones( 266 | deploy_config, model_fn, clone_args) 267 | self.assertEqual(len(clones), 1) 268 | clone = clones[0] 269 | self.assertEqual(clone.outputs.op.name, 270 | 'BatchNormClassifier/fully_connected/Sigmoid') 271 | self.assertDeviceEqual(clone.device, '/job:worker') 272 | self.assertEqual(clone.scope, '') 273 | self.assertEqual(len(slim.get_variables()), 5) 274 | for v in slim.get_variables(): 275 | self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0') 276 | self.assertDeviceEqual(v.device, v.value().device) 277 | 278 | def testCreateMulticloneWithPS(self): 279 | g = tf.Graph() 280 | with g.as_default(): 281 | tf.set_random_seed(0) 282 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 283 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 284 | 285 | model_fn = BatchNormClassifier 286 | clone_args = (tf_inputs, tf_labels) 287 | deploy_config = model_deploy.DeploymentConfig( 288 | num_clones=2, num_ps_tasks=2) 289 | 290 | self.assertEqual(slim.get_variables(), []) 291 | clones = model_deploy.create_clones( 292 | deploy_config, model_fn, clone_args) 293 | self.assertEqual(len(slim.get_variables()), 5) 294 | for i, v in enumerate(slim.get_variables()): 295 | t = i % 2 296 | self.assertDeviceEqual( 297 | v.device, '/job:ps/task:%d/device:CPU:0' % t) 298 | self.assertDeviceEqual(v.device, v.value().device) 299 | self.assertEqual(len(clones), 2) 300 | for i, clone in enumerate(clones): 301 | self.assertEqual( 302 | clone.outputs.op.name, 303 | 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i) 304 | self.assertEqual(clone.scope, 'clone_%d/' % i) 305 | self.assertDeviceEqual( 306 | clone.device, '/job:worker/device:GPU:%d' % i) 307 | 308 | 309 | class OptimizeclonesTest(tf.test.TestCase): 310 | 311 | def setUp(self): 312 | # Create an easy training set: 313 | np.random.seed(0) 314 | 315 | self._inputs = np.zeros((16, 4)) 316 | self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 317 | self._logdir = self.get_temp_dir() 318 | 319 | for i in range(16): 320 | j = int(2 * self._labels[i] + np.random.randint(0, 2)) 321 | self._inputs[i, j] = 1 322 | 323 | def testCreateLogisticClassifier(self): 324 | g = tf.Graph() 325 | with g.as_default(): 326 | tf.set_random_seed(0) 327 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 328 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 329 | 330 | model_fn = LogisticClassifier 331 | clone_args = (tf_inputs, tf_labels) 332 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 333 | 334 | self.assertEqual(slim.get_variables(), []) 335 | clones = model_deploy.create_clones( 336 | deploy_config, model_fn, clone_args) 337 | self.assertEqual(len(slim.get_variables()), 2) 338 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 339 | self.assertEqual(update_ops, []) 340 | 341 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 342 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 343 | optimizer) 344 | self.assertEqual(len(grads_and_vars), 345 | len(tf.trainable_variables())) 346 | self.assertEqual(total_loss.op.name, 'total_loss') 347 | for g, v in grads_and_vars: 348 | self.assertDeviceEqual(g.device, '') 349 | self.assertDeviceEqual(v.device, 'CPU:0') 350 | 351 | def testCreateSingleclone(self): 352 | g = tf.Graph() 353 | with g.as_default(): 354 | tf.set_random_seed(0) 355 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 356 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 357 | 358 | model_fn = BatchNormClassifier 359 | clone_args = (tf_inputs, tf_labels) 360 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 361 | 362 | self.assertEqual(slim.get_variables(), []) 363 | clones = model_deploy.create_clones( 364 | deploy_config, model_fn, clone_args) 365 | self.assertEqual(len(slim.get_variables()), 5) 366 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 367 | self.assertEqual(len(update_ops), 2) 368 | 369 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 370 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 371 | optimizer) 372 | self.assertEqual(len(grads_and_vars), 373 | len(tf.trainable_variables())) 374 | self.assertEqual(total_loss.op.name, 'total_loss') 375 | for g, v in grads_and_vars: 376 | self.assertDeviceEqual(g.device, '') 377 | self.assertDeviceEqual(v.device, 'CPU:0') 378 | 379 | def testCreateMulticlone(self): 380 | g = tf.Graph() 381 | with g.as_default(): 382 | tf.set_random_seed(0) 383 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 384 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 385 | 386 | model_fn = BatchNormClassifier 387 | clone_args = (tf_inputs, tf_labels) 388 | num_clones = 4 389 | deploy_config = model_deploy.DeploymentConfig( 390 | num_clones=num_clones) 391 | 392 | self.assertEqual(slim.get_variables(), []) 393 | clones = model_deploy.create_clones( 394 | deploy_config, model_fn, clone_args) 395 | self.assertEqual(len(slim.get_variables()), 5) 396 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 397 | self.assertEqual(len(update_ops), num_clones * 2) 398 | 399 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 400 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 401 | optimizer) 402 | self.assertEqual(len(grads_and_vars), 403 | len(tf.trainable_variables())) 404 | self.assertEqual(total_loss.op.name, 'total_loss') 405 | for g, v in grads_and_vars: 406 | self.assertDeviceEqual(g.device, '') 407 | self.assertDeviceEqual(v.device, 'CPU:0') 408 | 409 | def testCreateMulticloneCPU(self): 410 | g = tf.Graph() 411 | with g.as_default(): 412 | tf.set_random_seed(0) 413 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 414 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 415 | 416 | model_fn = BatchNormClassifier 417 | model_args = (tf_inputs, tf_labels) 418 | num_clones = 4 419 | deploy_config = model_deploy.DeploymentConfig( 420 | num_clones=num_clones, clone_on_cpu=True) 421 | 422 | self.assertEqual(slim.get_variables(), []) 423 | clones = model_deploy.create_clones( 424 | deploy_config, model_fn, model_args) 425 | self.assertEqual(len(slim.get_variables()), 5) 426 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 427 | self.assertEqual(len(update_ops), num_clones * 2) 428 | 429 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 430 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 431 | optimizer) 432 | self.assertEqual(len(grads_and_vars), 433 | len(tf.trainable_variables())) 434 | self.assertEqual(total_loss.op.name, 'total_loss') 435 | for g, v in grads_and_vars: 436 | self.assertDeviceEqual(g.device, '') 437 | self.assertDeviceEqual(v.device, 'CPU:0') 438 | 439 | def testCreateOnecloneWithPS(self): 440 | g = tf.Graph() 441 | with g.as_default(): 442 | tf.set_random_seed(0) 443 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 444 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 445 | 446 | model_fn = BatchNormClassifier 447 | model_args = (tf_inputs, tf_labels) 448 | deploy_config = model_deploy.DeploymentConfig( 449 | num_clones=1, num_ps_tasks=1) 450 | 451 | self.assertEqual(slim.get_variables(), []) 452 | clones = model_deploy.create_clones( 453 | deploy_config, model_fn, model_args) 454 | self.assertEqual(len(slim.get_variables()), 5) 455 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 456 | self.assertEqual(len(update_ops), 2) 457 | 458 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 459 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 460 | optimizer) 461 | self.assertEqual(len(grads_and_vars), 462 | len(tf.trainable_variables())) 463 | self.assertEqual(total_loss.op.name, 'total_loss') 464 | for g, v in grads_and_vars: 465 | self.assertDeviceEqual(g.device, '/job:worker') 466 | self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0') 467 | 468 | 469 | class DeployTest(tf.test.TestCase): 470 | 471 | def setUp(self): 472 | # Create an easy training set: 473 | np.random.seed(0) 474 | 475 | self._inputs = np.zeros((16, 4)) 476 | self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 477 | self._logdir = self.get_temp_dir() 478 | 479 | for i in range(16): 480 | j = int(2 * self._labels[i] + np.random.randint(0, 2)) 481 | self._inputs[i, j] = 1 482 | 483 | def testLocalTrainOp(self): 484 | g = tf.Graph() 485 | with g.as_default(): 486 | tf.set_random_seed(0) 487 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 488 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 489 | 490 | model_fn = BatchNormClassifier 491 | model_args = (tf_inputs, tf_labels) 492 | deploy_config = model_deploy.DeploymentConfig( 493 | num_clones=2, clone_on_cpu=True) 494 | 495 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 496 | 497 | self.assertEqual(slim.get_variables(), []) 498 | model = model_deploy.deploy( 499 | deploy_config, model_fn, model_args, optimizer=optimizer) 500 | 501 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 502 | self.assertEqual(len(update_ops), 4) 503 | self.assertEqual(len(model.clones), 2) 504 | self.assertEqual(model.total_loss.op.name, 'total_loss') 505 | self.assertEqual(model.summary_op.op.name, 'summary_op/summary_op') 506 | self.assertEqual(model.train_op.op.name, 'train_op') 507 | 508 | with tf.Session() as sess: 509 | sess.run(tf.initialize_all_variables()) 510 | moving_mean = tf.contrib.framework.get_variables_by_name('moving_mean')[ 511 | 0] 512 | moving_variance = tf.contrib.framework.get_variables_by_name( 513 | 'moving_variance')[0] 514 | initial_loss = sess.run(model.total_loss) 515 | initial_mean, initial_variance = sess.run( 516 | [moving_mean, moving_variance]) 517 | self.assertAllClose(initial_mean, [0.0, 0.0, 0.0, 0.0]) 518 | self.assertAllClose(initial_variance, [1.0, 1.0, 1.0, 1.0]) 519 | for _ in range(10): 520 | sess.run(model.train_op) 521 | final_loss = sess.run(model.total_loss) 522 | self.assertLess(final_loss, initial_loss / 10.0) 523 | 524 | final_mean, final_variance = sess.run( 525 | [moving_mean, moving_variance]) 526 | self.assertAllClose(final_mean, [0.125, 0.25, 0.375, 0.25]) 527 | self.assertAllClose(final_variance, 528 | [0.109375, 0.1875, 0.234375, 0.1875]) 529 | 530 | def testNoSummariesOnGPU(self): 531 | with tf.Graph().as_default(): 532 | deploy_config = model_deploy.DeploymentConfig(num_clones=2) 533 | 534 | # clone function creates a fully_connected layer with a regularizer 535 | # loss. 536 | def ModelFn(): 537 | inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32) 538 | reg = tf.contrib.layers.l2_regularizer(0.001) 539 | tf.contrib.layers.fully_connected( 540 | inputs, 30, weights_regularizer=reg) 541 | 542 | model = model_deploy.deploy( 543 | deploy_config, 544 | ModelFn, 545 | optimizer=tf.train.GradientDescentOptimizer(1.0)) 546 | # The model summary op should have a few summary inputs and all of them 547 | # should be on the CPU. 548 | self.assertTrue(model.summary_op.op.inputs) 549 | for inp in model.summary_op.op.inputs: 550 | self.assertEqual('/device:CPU:0', inp.device) 551 | 552 | def testNoSummariesOnGPUForEvals(self): 553 | with tf.Graph().as_default(): 554 | deploy_config = model_deploy.DeploymentConfig(num_clones=2) 555 | 556 | # clone function creates a fully_connected layer with a regularizer 557 | # loss. 558 | def ModelFn(): 559 | inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32) 560 | reg = tf.contrib.layers.l2_regularizer(0.001) 561 | tf.contrib.layers.fully_connected( 562 | inputs, 30, weights_regularizer=reg) 563 | 564 | # No optimizer here, it's an eval. 565 | model = model_deploy.deploy(deploy_config, ModelFn) 566 | # The model summary op should have a few summary inputs and all of them 567 | # should be on the CPU. 568 | self.assertTrue(model.summary_op.op.inputs) 569 | for inp in model.summary_op.op.inputs: 570 | self.assertEqual('/device:CPU:0', inp.device) 571 | 572 | 573 | if __name__ == '__main__': 574 | tf.test.main() 575 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """Generic evaluation script that evaluates a model using a given dataset.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import math 6 | import tensorflow as tf 7 | from datasets import dataset_factory 8 | from nets import nets_factory 9 | from preprocessing import preprocessing_factory 10 | 11 | slim = tf.contrib.slim 12 | flags = tf.app.flags 13 | 14 | flags.DEFINE_integer('batch_size', 100, 'samples in each batch') 15 | flags.DEFINE_integer('max_num_batches', None, 'Max batches; default is all') 16 | flags.DEFINE_integer('num_preprocessing_threads', 4, 'threads used for batches') 17 | flags.DEFINE_integer('eval_image_size', None, 'Eval image size') 18 | flags.DEFINE_integer('labels_offset', 0, 'Labels offset; used in VGG/ResNet') 19 | flags.DEFINE_string('master', '', 'address of the TensorFlow master') 20 | flags.DEFINE_string('checkpoint_path', '/tmp/tfmodel/', 'checkpoint dir') 21 | flags.DEFINE_string('eval_dir', '/tmp/tfmodel/', 'results are saved to') 22 | flags.DEFINE_string('dataset_name', 'imagenet', 'dataset to load') 23 | flags.DEFINE_string('dataset_split_name', 'test', 'train/test split') 24 | flags.DEFINE_string('dataset_dir', None, 'dataset files') 25 | flags.DEFINE_string('model_name', 'inception_v3', 'architecture to evaluate') 26 | flags.DEFINE_string('preprocessing_name', None, 'if None, model_name is used') 27 | flags.DEFINE_float('moving_average_decay', None, 'decay for the moving average') 28 | 29 | FLAGS = tf.app.flags.FLAGS 30 | 31 | 32 | def main(_): 33 | if not FLAGS.dataset_dir: 34 | raise ValueError('You must supply the dataset directory with --dataset_dir') 35 | 36 | tf.logging.set_verbosity(tf.logging.INFO) 37 | with tf.Graph().as_default(): 38 | tf_global_step = slim.get_or_create_global_step() 39 | 40 | # Select dataset 41 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 42 | FLAGS.dataset_split_name, 43 | FLAGS.dataset_dir) 44 | 45 | # Select model 46 | network_fn = nets_factory.get_network_fn( 47 | FLAGS.model_name, 48 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 49 | is_training=False) 50 | 51 | # Create dataset provider to load dataset 52 | provider = slim.dataset_data_provider.DatasetDataProvider( 53 | dataset, 54 | shuffle=False, 55 | common_queue_capacity=2 * FLAGS.batch_size, 56 | common_queue_min=FLAGS.batch_size) 57 | [image, label] = provider.get(['image', 'label']) 58 | # label -= FLAGS.labels_offset 59 | 60 | # Select preprocessing function 61 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 62 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 63 | preprocessing_name, is_training=False) 64 | 65 | eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size 66 | 67 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size) 68 | 69 | images, labels = tf.train.batch( 70 | [image, label], 71 | batch_size=FLAGS.batch_size, 72 | num_threads=FLAGS.num_preprocessing_threads, 73 | capacity=5 * FLAGS.batch_size) 74 | 75 | # Define model 76 | logits, _ = network_fn(images) 77 | 78 | if FLAGS.moving_average_decay: 79 | variable_averages = tf.train.ExponentialMovingAverage( 80 | FLAGS.moving_average_decay, tf_global_step) 81 | variables_to_restore = variable_averages.variables_to_restore( 82 | slim.get_model_variables()) 83 | variables_to_restore[tf_global_step.op.name] = tf_global_step 84 | else: 85 | variables_to_restore = slim.get_variables_to_restore() 86 | 87 | predictions = tf.argmax(logits, 3) 88 | labels = tf.squeeze(labels) 89 | print("=" * 40 + ">labels<" + "=" * 40) 90 | print(labels) 91 | print("=" * 90) 92 | print("=" * 40 + ">logits<" + "=" * 40) 93 | print(logits) 94 | print("=" * 90) 95 | print("=" * 40 + ">predictions<" + "=" * 40) 96 | print(predictions) 97 | print("=" * 90) 98 | labels = tf.to_int64(labels) 99 | # Define the metrics 100 | metrics = slim.metrics 101 | names_to_values, names_to_updates = metrics.aggregate_metric_map({ 102 | 'miou': 103 | metrics.streaming_mean_iou(predictions, labels, 104 | dataset.num_classes), 105 | 'accuracy': 106 | metrics.streaming_accuracy(predictions, labels), 107 | 'precision': 108 | metrics.streaming_precision(predictions, labels), 109 | }) 110 | 111 | # Print the summaries to screen. 112 | for name, value in names_to_values.iteritems(): 113 | summary_name = 'eval_%s' % name 114 | op = tf.summary.scalar(summary_name, value, collections=[]) 115 | op = tf.Print(op, [value], summary_name) 116 | tf.add_to_collection(tf.GraphKeys.SUMMARIES, op) 117 | 118 | if FLAGS.max_num_batches: 119 | num_batches = FLAGS.max_num_batches 120 | else: 121 | # This ensures that we make a single pass over all of the data. 122 | num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size)) 123 | 124 | if tf.gfile.IsDirectory(FLAGS.checkpoint_path): 125 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 126 | else: 127 | checkpoint_path = FLAGS.checkpoint_path 128 | 129 | tf.logging.info('Evaluating %s' % checkpoint_path) 130 | 131 | slim.evaluation.evaluate_once( 132 | master=FLAGS.master, 133 | checkpoint_path=checkpoint_path, 134 | logdir=FLAGS.eval_dir, 135 | num_evals=num_batches, 136 | eval_op=names_to_updates.values(), 137 | variables_to_restore=variables_to_restore) 138 | 139 | 140 | if __name__ == '__main__': 141 | tf.app.run() 142 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desimone/segmentation-models/7f9e5a182891d20b2110dc572b57251c0111988c/nets/__init__.py -------------------------------------------------------------------------------- /nets/fcn.py: -------------------------------------------------------------------------------- 1 | """ Fully Convolutional Models for Semantic Segmentation 2 | arXiv:1605.06211 3 | https://github.com/shelhamer/fcn.berkeleyvision.org 4 | """ 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import tensorflow as tf 8 | 9 | slim = tf.contrib.slim 10 | 11 | 12 | def fcn_arg_scope(weight_decay=0.0005): 13 | with slim.arg_scope( 14 | [slim.conv2d, slim.fully_connected, slim.conv2d_transpose], 15 | activation_fn=tf.nn.relu, 16 | weights_regularizer=slim.l2_regularizer(weight_decay), 17 | biases_initializer=tf.zeros_initializer): 18 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 19 | return arg_sc 20 | 21 | 22 | def fcn_32(inputs, 23 | num_classes=21, 24 | is_training=True, 25 | dropout_prob=0.5, 26 | scope='fcn_32'): 27 | with tf.variable_scope(scope, 'fcn_32', [inputs]) as sc: 28 | end_points_collection = sc.name + '_end_points' 29 | # Collect outputs for conv2d, fully_connected, conv2d_transpose and max_pool2d. 30 | with slim.arg_scope( 31 | [ 32 | slim.conv2d, slim.fully_connected, slim.max_pool2d, 33 | slim.conv2d_transpose 34 | ], 35 | outputs_collections=end_points_collection): 36 | 37 | # Contracting portion is VGG-16 https://goo.gl/dM7PWe 38 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 39 | pool1 = slim.max_pool2d(net, [2, 2], scope='pool1') 40 | net = slim.repeat(pool1, 2, slim.conv2d, 128, [3, 3], scope='conv2') 41 | pool2 = slim.max_pool2d(net, [2, 2], scope='pool2') 42 | net = slim.repeat(pool2, 3, slim.conv2d, 256, [3, 3], scope='conv3') 43 | pool3 = slim.max_pool2d(net, [2, 2], scope='pool3') 44 | net = slim.repeat(pool3, 3, slim.conv2d, 512, [3, 3], scope='conv4') 45 | pool4 = slim.max_pool2d(net, [2, 2], scope='pool4') 46 | net = slim.repeat(pool4, 3, slim.conv2d, 512, [3, 3], scope='conv5') 47 | pool5 = slim.max_pool2d(net, [2, 2], scope='pool5') 48 | # Fully connected layers (in reference but not really sure if we need) 49 | net = slim.fully_connected(pool5, 4096, scope='fc6') 50 | net = slim.dropout( 51 | net, dropout_prob, is_training=is_training, scope='drop6') 52 | net = slim.fully_connected(net, 4096, scope='fc7') 53 | net = slim.dropout( 54 | net, dropout_prob, is_training=is_training, scope='drop7') 55 | net = slim.fully_connected(net, num_classes, scope='fc8') 56 | 57 | # Expanding : Upscore : https://goo.gl/wchbCq 58 | 59 | # n.score_fr = L.Convolution(n.drop7, num_output=21, kernel_size=1, pad=0, 60 | # param=[dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)]) 61 | # n.upscore = L.Deconvolution(n.score_fr, 62 | # convolution_param=dict(num_output=21, kernel_size=64, stride=32, 63 | # bias_term=False), 64 | # param=[dict(lr_mult=0)]) 65 | # n.score = crop(n.upscore, n.data) 66 | net = slim.conv2d_transpose( 67 | net, 64, [2, 2], stride=32, padding='VALID', scope='up1') 68 | 69 | net = slim.conv2d(net, num_classes, [1, 1], scope='score') 70 | net = tf.argmax(net, dimension=3, name="prediction") 71 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 72 | return net, end_points 73 | 74 | 75 | fcn_32.default_image_size = 448 76 | -------------------------------------------------------------------------------- /nets/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_kernel_size(factor): 5 | """Find the kernel size given the desired factor of upsampling.""" 6 | return 2 * factor - factor % 2 7 | 8 | 9 | def upsample_filt(size): 10 | """Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size.""" 11 | factor = (size + 1) // 2 12 | if size % 2 == 1: 13 | center = factor - 1 14 | else: 15 | center = factor - 0.5 16 | og = np.ogrid[:size, :size] 17 | return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 18 | 19 | 20 | def bilinear_upsample_weights(factor, number_of_classes): 21 | """Create weights matrix for transposed convolution with bilinear filter initialization.""" 22 | 23 | filter_size = get_kernel_size(factor) 24 | 25 | weights = np.zeros( 26 | (filter_size, filter_size, number_of_classes, number_of_classes), 27 | dtype=np.float32) 28 | 29 | upsample_kernel = upsample_filt(filter_size) 30 | 31 | for i in xrange(number_of_classes): 32 | weights[:, :, i, i] = upsample_kernel 33 | 34 | return weights 35 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | """Contains a factory for building various models.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import functools 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | # from tensorflow.contrib.slim.nets import (alexnet, inception, overfeat, resnet_v1, resnet_v2, vgg) 10 | from nets import fcn 11 | 12 | slim = tf.contrib.slim 13 | 14 | networks_map = { 15 | 'fcn_32': 16 | fcn.fcn_32, 17 | # included in slim contrib 18 | # 'alexnet_v2': alexnet.alexnet_v2, 19 | # 'overfeat': overfeat.overfeat, 20 | # 'vgg_a': vgg.vgg_a, 21 | # 'vgg_16': vgg.vgg_16, 22 | # 'vgg_19': vgg.vgg_19, 23 | # 'inception_v1': inception.inception_v1, 24 | # 'inception_v2': inception.inception_v2, 25 | # 'inception_v3': inception.inception_v3, 26 | # 'resnet_v1_50': resnet_v1.resnet_v1_50, 27 | # 'resnet_v1_101': resnet_v1.resnet_v1_101, 28 | # 'resnet_v1_152': resnet_v1.resnet_v1_152, 29 | # 'resnet_v1_200': resnet_v1.resnet_v1_200, 30 | # 'resnet_v2_50': resnet_v2.resnet_v2_50, 31 | # 'resnet_v2_101': resnet_v2.resnet_v2_101, 32 | # 'resnet_v2_152': resnet_v2.resnet_v2_152, 33 | # 'resnet_v2_200': resnet_v2.resnet_v2_200, 34 | } 35 | 36 | arg_scopes_map = { 37 | # custom 38 | 'fcn_32': 39 | fcn.fcn_arg_scope, 40 | # included in slim contrib 41 | # 'alexnet_v2': alexnet.alexnet_v2_arg_scope, 42 | # 'overfeat': overfeat.overfeat_arg_scope, 43 | # 'vgg_a': vgg.vgg_arg_scope, 44 | # 'vgg_16': vgg.vgg_arg_scope, 45 | # 'vgg_19': vgg.vgg_arg_scope, 46 | # 'inception_v1': inception.inception_v3_arg_scope, 47 | # 'inception_v2': inception.inception_v3_arg_scope, 48 | # 'inception_v3': inception.inception_v3_arg_scope, 49 | # 'resnet_v1_50': resnet_v1.resnet_arg_scope, 50 | # 'resnet_v1_101': resnet_v1.resnet_arg_scope, 51 | # 'resnet_v1_152': resnet_v1.resnet_arg_scope, 52 | # 'resnet_v1_200': resnet_v1.resnet_arg_scope, 53 | # 'resnet_v2_50': resnet_v2.resnet_arg_scope, 54 | # 'resnet_v2_101': resnet_v2.resnet_arg_scope, 55 | # 'resnet_v2_152': resnet_v2.resnet_arg_scope, 56 | # 'resnet_v2_200': resnet_v2.resnet_arg_scope, 57 | } 58 | 59 | 60 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 61 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 62 | 63 | Args: 64 | name: The name of the network. 65 | num_classes: The number of classes to use for classification. 66 | weight_decay: The l2 coefficient for the model weights. 67 | is_training: `True` if the model is being used for training and `False` 68 | otherwise. 69 | 70 | Returns: 71 | network_fn: A function that applies the model to a batch of images. It has 72 | the following signature: 73 | logits, end_points = network_fn(images) 74 | Raises: 75 | ValueError: If network `name` is not recognized. 76 | """ 77 | if name not in networks_map: 78 | raise ValueError('Name of network unknown %s' % name) 79 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 80 | func = networks_map[name] 81 | 82 | @functools.wraps(func) 83 | def network_fn(images): 84 | with slim.arg_scope(arg_scope): 85 | return func(images, num_classes, is_training=is_training) 86 | 87 | if hasattr(func, 'default_image_size'): 88 | network_fn.default_image_size = func.default_image_size 89 | 90 | return network_fn 91 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | """Downloads and converts datatsets.""" 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import tensorflow as tf 5 | 6 | from datasets import (download_and_convert_pascal) 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | flags = tf.app.flags 10 | flags.DEFINE_string('dataset_name', None, '"pascal"') 11 | flags.DEFINE_string('dataset_dir', None, 'where to put TFRecords ') 12 | flags.DEFINE_string('dataset_archive', None, 'where to find archives') 13 | 14 | 15 | def main(_): 16 | if not FLAGS.dataset_name: 17 | raise ValueError('Must set --dataset_name') 18 | if not FLAGS.dataset_dir: 19 | raise ValueError('Must set --dataset_dir') 20 | elif FLAGS.dataset_name == 'pascal': 21 | download_and_convert_pascal.run(FLAGS.dataset_dir) 22 | 23 | else: 24 | raise ValueError('dataset_name [%s] not recognized.' % FLAGS.dataset_dir) 25 | 26 | 27 | if __name__ == '__main__': 28 | tf.app.run() 29 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/desimone/segmentation-models/7f9e5a182891d20b2110dc572b57251c0111988c/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Provides utilities to preprocess images in CIFAR-10.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import tensorflow as tf 6 | 7 | _PADDING = 4 8 | 9 | slim = tf.contrib.slim 10 | 11 | 12 | def preprocess_for_train(image, output_height, output_width, padding=_PADDING): 13 | """Preprocesses the given image for training. 14 | 15 | Note that the actual resizing scale is sampled from 16 | [`resize_size_min`, `resize_size_max`]. 17 | 18 | Args: 19 | image: A `Tensor` representing an image of arbitrary size. 20 | output_height: The height of the image after preprocessing. 21 | output_width: The width of the image after preprocessing. 22 | padding: The amound of padding before and after each dimension of the image. 23 | 24 | Returns: 25 | A preprocessed image. 26 | """ 27 | tf.summary.image('image', tf.expand_dims(image, 0)) 28 | 29 | # Transform the image to floats. 30 | image = tf.to_float(image) 31 | if padding > 0: 32 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 33 | # Randomly crop a [height, width] section of the image. 34 | distorted_image = tf.random_crop(image, [output_height, output_width, 3]) 35 | 36 | # Randomly flip the image horizontally. 37 | distorted_image = tf.image.random_flip_left_right(distorted_image) 38 | 39 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 40 | 41 | # Because these operations are not commutative, consider randomizing 42 | # the order their operation. 43 | distorted_image = tf.image.random_brightness(distorted_image, max_delta=63) 44 | distorted_image = tf.image.random_contrast( 45 | distorted_image, lower=0.2, upper=1.8) 46 | # Subtract off the mean and divide by the variance of the pixels. 47 | return tf.image.per_image_standardization(distorted_image) 48 | 49 | 50 | def preprocess_for_eval(image, output_height, output_width): 51 | """Preprocesses the given image for evaluation. 52 | 53 | Args: 54 | image: A `Tensor` representing an image of arbitrary size. 55 | output_height: The height of the image after preprocessing. 56 | output_width: The width of the image after preprocessing. 57 | 58 | Returns: 59 | A preprocessed image. 60 | """ 61 | tf.summary.image('image', tf.expand_dims(image, 0)) 62 | # Transform the image to floats. 63 | image = tf.to_float(image) 64 | 65 | # Resize and crop if needed. 66 | resized_image = tf.image.resize_image_with_crop_or_pad(image, output_width, 67 | output_height) 68 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 69 | 70 | # Subtract off the mean and divide by the variance of the pixels. 71 | return tf.image.per_image_standardization(resized_image) 72 | 73 | 74 | def preprocess_image(image, output_height, output_width, is_training=False): 75 | """Preprocesses the given image. 76 | 77 | Args: 78 | image: A `Tensor` representing an image of arbitrary size. 79 | output_height: The height of the image after preprocessing. 80 | output_width: The width of the image after preprocessing. 81 | is_training: `True` if we're preprocessing the image for training and 82 | `False` otherwise. 83 | 84 | Returns: 85 | A preprocessed image. 86 | """ 87 | if is_training: 88 | return preprocess_for_train(image, output_height, output_width) 89 | else: 90 | return preprocess_for_eval(image, output_height, output_width) 91 | -------------------------------------------------------------------------------- /preprocessing/fcn_preprocessing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.ops import control_flow_ops 5 | from PIL import Image 6 | 7 | slim = tf.contrib.slim 8 | 9 | # vgg network 10 | _R_MEAN = 123.68 11 | _G_MEAN = 116.78 12 | _B_MEAN = 103.94 13 | 14 | 15 | def _mean_image_subtraction(image, means): 16 | if image.get_shape().ndims != 3: 17 | raise ValueError('Input must be of size [height, width, C>0]') 18 | num_channels = image.get_shape().as_list()[-1] 19 | if len(means) != num_channels: 20 | raise ValueError('len(means) must match the number of channels') 21 | 22 | channels = tf.split(2, num_channels, image) 23 | for i in range(num_channels): 24 | channels[i] -= means[i] 25 | return tf.concat(2, channels) 26 | 27 | 28 | def preprocess_image(image, label, output_height, output_width, is_training): 29 | label = tf.image.rgb_to_grayscale(label) 30 | if image.dtype != tf.float32: 31 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 32 | image = tf.image.resize_image_with_crop_or_pad(image, output_width, 33 | output_height) 34 | label = tf.image.resize_image_with_crop_or_pad(label, output_width, 35 | output_height) 36 | image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 37 | return image, label 38 | -------------------------------------------------------------------------------- /preprocessing/inception_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Provides utilities to preprocess images for the Inception networks.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python.ops import control_flow_ops 7 | 8 | 9 | def apply_with_random_selector(x, func, num_cases): 10 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 11 | 12 | Args: 13 | x: input Tensor. 14 | func: Python function to apply. 15 | num_cases: Python int32, number of cases to sample sel from. 16 | 17 | Returns: 18 | The result of func(x, sel), where func receives the value of the 19 | selector as a python integer, but sel is sampled dynamically. 20 | """ 21 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 22 | # Pass the real x only to one of the func calls. 23 | return control_flow_ops.merge([ 24 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 25 | for case in range(num_cases) 26 | ])[0] 27 | 28 | 29 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 30 | """Distort the color of a Tensor image. 31 | 32 | Each color distortion is non-commutative and thus ordering of the color ops 33 | matters. Ideally we would randomly permute the ordering of the color ops. 34 | Rather then adding that level of complication, we select a distinct ordering 35 | of color ops for each preprocessing thread. 36 | 37 | Args: 38 | image: 3-D Tensor containing single image in [0, 1]. 39 | color_ordering: Python int, a type of distortion (valid values: 0-3). 40 | fast_mode: Avoids slower ops (random_hue and random_contrast) 41 | scope: Optional scope for name_scope. 42 | Returns: 43 | 3-D Tensor color-distorted image on range [0, 1] 44 | Raises: 45 | ValueError: if color_ordering not in [0, 3] 46 | """ 47 | with tf.name_scope(scope, 'distort_color', [image]): 48 | if fast_mode: 49 | if color_ordering == 0: 50 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 51 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 52 | else: 53 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 54 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 55 | else: 56 | if color_ordering == 0: 57 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 58 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 59 | image = tf.image.random_hue(image, max_delta=0.2) 60 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 61 | elif color_ordering == 1: 62 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 63 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 64 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 65 | image = tf.image.random_hue(image, max_delta=0.2) 66 | elif color_ordering == 2: 67 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 68 | image = tf.image.random_hue(image, max_delta=0.2) 69 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 70 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 71 | elif color_ordering == 3: 72 | image = tf.image.random_hue(image, max_delta=0.2) 73 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 74 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 75 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 76 | else: 77 | raise ValueError('color_ordering must be in [0, 3]') 78 | 79 | # The random_* ops do not necessarily clamp. 80 | return tf.clip_by_value(image, 0.0, 1.0) 81 | 82 | 83 | def distorted_bounding_box_crop(image, 84 | bbox, 85 | min_object_covered=0.1, 86 | aspect_ratio_range=(0.75, 1.33), 87 | area_range=(0.05, 1.0), 88 | max_attempts=100, 89 | scope=None): 90 | """Generates cropped_image using a one of the bboxes randomly distorted. 91 | 92 | See `tf.image.sample_distorted_bounding_box` for more documentation. 93 | 94 | Args: 95 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 96 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 97 | where each coordinate is [0, 1) and the coordinates are arranged 98 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 99 | image. 100 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 101 | area of the image must contain at least this fraction of any bounding box 102 | supplied. 103 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 104 | image must have an aspect ratio = width / height within this range. 105 | area_range: An optional list of `floats`. The cropped area of the image 106 | must contain a fraction of the supplied image within in this range. 107 | max_attempts: An optional `int`. Number of attempts at generating a cropped 108 | region of the image of the specified constraints. After `max_attempts` 109 | failures, return the entire image. 110 | scope: Optional scope for name_scope. 111 | Returns: 112 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 113 | """ 114 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 115 | # Each bounding box has shape [1, num_boxes, box coords] and 116 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 117 | 118 | # A large fraction of image datasets contain a human-annotated bounding 119 | # box delineating the region of the image containing the object of interest. 120 | # We choose to create a new bounding box for the object which is a randomly 121 | # distorted version of the human-annotated bounding box that obeys an 122 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 123 | # bounding box. If no box is supplied, then we assume the bounding box is 124 | # the entire image. 125 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 126 | tf.shape(image), 127 | bounding_boxes=bbox, 128 | min_object_covered=min_object_covered, 129 | aspect_ratio_range=aspect_ratio_range, 130 | area_range=area_range, 131 | max_attempts=max_attempts, 132 | use_image_if_no_bounding_boxes=True) 133 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 134 | 135 | # Crop the image to the specified bounding box. 136 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 137 | return cropped_image, distort_bbox 138 | 139 | 140 | def preprocess_for_train(image, height, width, bbox, fast_mode=True, 141 | scope=None): 142 | """Distort one image for training a network. 143 | 144 | Distorting images provides a useful technique for augmenting the data 145 | set during training in order to make the network invariant to aspects 146 | of the image that do not effect the label. 147 | 148 | Additionally it would create image_summaries to display the different 149 | transformations applied to the image. 150 | 151 | Args: 152 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 153 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 154 | is [0, MAX], where MAX is largest positive representable number for 155 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 156 | height: integer 157 | width: integer 158 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 159 | where each coordinate is [0, 1) and the coordinates are arranged 160 | as [ymin, xmin, ymax, xmax]. 161 | fast_mode: Optional boolean, if True avoids slower transformations (i.e. 162 | bi-cubic resizing, random_hue or random_contrast). 163 | scope: Optional scope for name_scope. 164 | Returns: 165 | 3-D float Tensor of distorted image used for training with range [-1, 1]. 166 | """ 167 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): 168 | if bbox is None: 169 | bbox = tf.constant( 170 | [0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 171 | if image.dtype != tf.float32: 172 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 173 | # Each bounding box has shape [1, num_boxes, box coords] and 174 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 175 | image_with_box = tf.image.draw_bounding_boxes( 176 | tf.expand_dims(image, 0), bbox) 177 | tf.image_summary('image_with_bounding_boxes', image_with_box) 178 | 179 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) 180 | # Restore the shape since the dynamic slice based upon the bbox_size loses 181 | # the third dimension. 182 | distorted_image.set_shape([None, None, 3]) 183 | image_with_distorted_box = tf.image.draw_bounding_boxes( 184 | tf.expand_dims(image, 0), distorted_bbox) 185 | tf.image_summary('images_with_distorted_bounding_box', 186 | image_with_distorted_box) 187 | 188 | # This resizing operation may distort the images because the aspect 189 | # ratio is not respected. We select a resize method in a round robin 190 | # fashion based on the thread number. 191 | # Note that ResizeMethod contains 4 enumerated resizing methods. 192 | 193 | # We select only 1 case for fast_mode bilinear. 194 | num_resize_cases = 1 if fast_mode else 4 195 | distorted_image = apply_with_random_selector( 196 | distorted_image, 197 | lambda x, method: tf.image.resize_images( 198 | x, [height, width], method=method), 199 | num_cases=num_resize_cases) 200 | 201 | tf.image_summary('cropped_resized_image', 202 | tf.expand_dims(distorted_image, 0)) 203 | 204 | # Randomly flip the image horizontally. 205 | distorted_image = tf.image.random_flip_left_right(distorted_image) 206 | 207 | # Randomly distort the colors. There are 4 ways to do it. 208 | distorted_image = apply_with_random_selector( 209 | distorted_image, 210 | lambda x, ordering: distort_color(x, ordering, fast_mode), 211 | num_cases=4) 212 | 213 | tf.image_summary('final_distorted_image', 214 | tf.expand_dims(distorted_image, 0)) 215 | distorted_image = tf.sub(distorted_image, 0.5) 216 | distorted_image = tf.mul(distorted_image, 2.0) 217 | return distorted_image 218 | 219 | 220 | def preprocess_for_eval(image, 221 | height, 222 | width, 223 | central_fraction=0.875, 224 | scope=None): 225 | """Prepare one image for evaluation. 226 | 227 | If height and width are specified it would output an image with that size by 228 | applying resize_bilinear. 229 | 230 | If central_fraction is specified it would cropt the central fraction of the 231 | input image. 232 | 233 | Args: 234 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 235 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 236 | is [0, MAX], where MAX is largest positive representable number for 237 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details) 238 | height: integer 239 | width: integer 240 | central_fraction: Optional Float, fraction of the image to crop. 241 | scope: Optional scope for name_scope. 242 | Returns: 243 | 3-D float Tensor of prepared image. 244 | """ 245 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 246 | if image.dtype != tf.float32: 247 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 248 | # Crop the central region of the image with an area containing 87.5% of 249 | # the original image. 250 | if central_fraction: 251 | image = tf.image.central_crop(image, central_fraction=central_fraction) 252 | 253 | if height and width: 254 | # Resize the image to the specified height and width. 255 | image = tf.expand_dims(image, 0) 256 | image = tf.image.resize_bilinear( 257 | image, [height, width], align_corners=False) 258 | image = tf.squeeze(image, [0]) 259 | image = tf.sub(image, 0.5) 260 | image = tf.mul(image, 2.0) 261 | return image 262 | 263 | 264 | def preprocess_image(image, 265 | height, 266 | width, 267 | is_training=False, 268 | bbox=None, 269 | fast_mode=True): 270 | """Pre-process one image for training or evaluation. 271 | 272 | Args: 273 | image: 3-D Tensor [height, width, channels] with the image. 274 | height: integer, image expected height. 275 | width: integer, image expected width. 276 | is_training: Boolean. If true it would transform an image for train, 277 | otherwise it would transform it for evaluation. 278 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 279 | where each coordinate is [0, 1) and the coordinates are arranged as 280 | [ymin, xmin, ymax, xmax]. 281 | fast_mode: Optional boolean, if True avoids slower transformations. 282 | 283 | Returns: 284 | 3-D float Tensor containing an appropriately scaled image 285 | 286 | Raises: 287 | ValueError: if user does not provide bounding box 288 | """ 289 | if is_training: 290 | return preprocess_for_train(image, height, width, bbox, fast_mode) 291 | else: 292 | return preprocess_for_eval(image, height, width) 293 | -------------------------------------------------------------------------------- /preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Provides utilities for preprocessing.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import tensorflow as tf 6 | 7 | slim = tf.contrib.slim 8 | 9 | 10 | def preprocess_image(image, output_height, output_width, is_training): 11 | """Preprocesses the given image. 12 | 13 | Args: 14 | image: A `Tensor` representing an image of arbitrary size. 15 | output_height: The height of the image after preprocessing. 16 | output_width: The width of the image after preprocessing. 17 | is_training: `True` if we're preprocessing the image for training and 18 | `False` otherwise. 19 | 20 | Returns: 21 | A preprocessed image. 22 | """ 23 | image = tf.to_float(image) 24 | image = tf.image.resize_image_with_crop_or_pad(image, output_width, 25 | output_height) 26 | image = tf.sub(image, 128.0) 27 | image = tf.div(image, 128.0) 28 | return image 29 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | """Contains a factory for building various models.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import tensorflow as tf 6 | 7 | from preprocessing import fcn_preprocessing 8 | 9 | 10 | def get_preprocessing(name, is_training=False): 11 | preprocessing_fn_map = {'fcn_32': fcn_preprocessing,} 12 | 13 | if name not in preprocessing_fn_map: 14 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 15 | 16 | def preprocessing_fn(image, label, output_height, output_width, **kwargs): 17 | return preprocessing_fn_map[name].preprocess_image( 18 | image, 19 | label, 20 | output_height, 21 | output_width, 22 | is_training=is_training, 23 | **kwargs) 24 | 25 | return preprocessing_fn 26 | -------------------------------------------------------------------------------- /preprocessing/vgg_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Provides utilities to preprocess images. 2 | 3 | The preprocessing steps for VGG were introduced in the following technical 4 | report: 5 | 6 | Very Deep Convolutional Networks For Large-Scale Image Recognition 7 | Karen Simonyan and Andrew Zisserman 8 | arXiv technical report, 2015 9 | PDF: http://arxiv.org/pdf/1409.1556.pdf 10 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 11 | CC-BY-4.0 12 | 13 | More information can be obtained from the VGG website: 14 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 15 | """ 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | 25 | slim = tf.contrib.slim 26 | 27 | _R_MEAN = 123.68 28 | _G_MEAN = 116.78 29 | _B_MEAN = 103.94 30 | 31 | _RESIZE_SIDE_MIN = 256 32 | _RESIZE_SIDE_MAX = 512 33 | 34 | 35 | def _crop(image, offset_height, offset_width, crop_height, crop_width): 36 | """Crops the given image using the provided offsets and sizes. 37 | 38 | Note that the method doesn't assume we know the input image size but it does 39 | assume we know the input image rank. 40 | 41 | Args: 42 | image: an image of shape [height, width, channels]. 43 | offset_height: a scalar tensor indicating the height offset. 44 | offset_width: a scalar tensor indicating the width offset. 45 | crop_height: the height of the cropped image. 46 | crop_width: the width of the cropped image. 47 | 48 | Returns: 49 | the cropped (and resized) image. 50 | 51 | Raises: 52 | InvalidArgumentError: if the rank is not 3 or if the image dimensions are 53 | less than the crop size. 54 | """ 55 | original_shape = tf.shape(image) 56 | 57 | rank_assertion = tf.Assert( 58 | tf.equal(tf.rank(image), 3), ['Rank of image must be equal to 3.']) 59 | cropped_shape = control_flow_ops.with_dependencies( 60 | [rank_assertion], tf.pack([crop_height, crop_width, original_shape[2]])) 61 | 62 | size_assertion = tf.Assert( 63 | tf.logical_and( 64 | tf.greater_equal(original_shape[0], crop_height), 65 | tf.greater_equal(original_shape[1], crop_width)), 66 | ['Crop size greater than the image size.']) 67 | 68 | offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0])) 69 | 70 | # Use tf.slice instead of crop_to_bounding box as it accepts tensors to 71 | # define the crop size. 72 | image = control_flow_ops.with_dependencies( 73 | [size_assertion], tf.slice(image, offsets, cropped_shape)) 74 | return tf.reshape(image, cropped_shape) 75 | 76 | 77 | def _random_crop(image_list, crop_height, crop_width): 78 | """Crops the given list of images. 79 | 80 | The function applies the same crop to each image in the list. This can be 81 | effectively applied when there are multiple image inputs of the same 82 | dimension such as: 83 | 84 | image, depths, normals = _random_crop([image, depths, normals], 120, 150) 85 | 86 | Args: 87 | image_list: a list of image tensors of the same dimension but possibly 88 | varying channel. 89 | crop_height: the new height. 90 | crop_width: the new width. 91 | 92 | Returns: 93 | the image_list with cropped images. 94 | 95 | Raises: 96 | ValueError: if there are multiple image inputs provided with different size 97 | or the images are smaller than the crop dimensions. 98 | """ 99 | if not image_list: 100 | raise ValueError('Empty image_list.') 101 | 102 | # Compute the rank assertions. 103 | rank_assertions = [] 104 | for i in range(len(image_list)): 105 | image_rank = tf.rank(image_list[i]) 106 | rank_assert = tf.Assert( 107 | tf.equal(image_rank, 3), [ 108 | 'Wrong rank for tensor %s [expected] [actual]', image_list[i].name, 109 | 3, image_rank 110 | ]) 111 | rank_assertions.append(rank_assert) 112 | 113 | image_shape = control_flow_ops.with_dependencies([rank_assertions[0]], 114 | tf.shape(image_list[0])) 115 | image_height = image_shape[0] 116 | image_width = image_shape[1] 117 | crop_size_assert = tf.Assert( 118 | tf.logical_and( 119 | tf.greater_equal(image_height, crop_height), 120 | tf.greater_equal(image_width, crop_width)), 121 | ['Crop size greater than the image size.']) 122 | 123 | asserts = [rank_assertions[0], crop_size_assert] 124 | 125 | for i in range(1, len(image_list)): 126 | image = image_list[i] 127 | asserts.append(rank_assertions[i]) 128 | shape = control_flow_ops.with_dependencies([rank_assertions[i]], 129 | tf.shape(image)) 130 | height = shape[0] 131 | width = shape[1] 132 | 133 | height_assert = tf.Assert( 134 | tf.equal(height, image_height), [ 135 | 'Wrong height for tensor %s [expected][actual]', image.name, height, 136 | image_height 137 | ]) 138 | width_assert = tf.Assert( 139 | tf.equal(width, image_width), [ 140 | 'Wrong width for tensor %s [expected][actual]', image.name, width, 141 | image_width 142 | ]) 143 | asserts.extend([height_assert, width_assert]) 144 | 145 | # Create a random bounding box. 146 | # 147 | # Use tf.random_uniform and not numpy.random.rand as doing the former would 148 | # generate random numbers at graph eval time, unlike the latter which 149 | # generates random numbers at graph definition time. 150 | max_offset_height = control_flow_ops.with_dependencies( 151 | asserts, tf.reshape(image_height - crop_height + 1, [])) 152 | max_offset_width = control_flow_ops.with_dependencies( 153 | asserts, tf.reshape(image_width - crop_width + 1, [])) 154 | offset_height = tf.random_uniform( 155 | [], maxval=max_offset_height, dtype=tf.int32) 156 | offset_width = tf.random_uniform([], maxval=max_offset_width, dtype=tf.int32) 157 | 158 | return [ 159 | _crop(image, offset_height, offset_width, crop_height, crop_width) 160 | for image in image_list 161 | ] 162 | 163 | 164 | def _central_crop(image_list, crop_height, crop_width): 165 | """Performs central crops of the given image list. 166 | 167 | Args: 168 | image_list: a list of image tensors of the same dimension but possibly 169 | varying channel. 170 | crop_height: the height of the image following the crop. 171 | crop_width: the width of the image following the crop. 172 | 173 | Returns: 174 | the list of cropped images. 175 | """ 176 | outputs = [] 177 | for image in image_list: 178 | image_height = tf.shape(image)[0] 179 | image_width = tf.shape(image)[1] 180 | 181 | offset_height = (image_height - crop_height) / 2 182 | offset_width = (image_width - crop_width) / 2 183 | 184 | outputs.append( 185 | _crop(image, offset_height, offset_width, crop_height, crop_width)) 186 | return outputs 187 | 188 | 189 | def _mean_image_subtraction(image, means): 190 | """Subtracts the given means from each image channel. 191 | 192 | For example: 193 | means = [123.68, 116.779, 103.939] 194 | image = _mean_image_subtraction(image, means) 195 | 196 | Note that the rank of `image` must be known. 197 | 198 | Args: 199 | image: a tensor of size [height, width, C]. 200 | means: a C-vector of values to subtract from each channel. 201 | 202 | Returns: 203 | the centered image. 204 | 205 | Raises: 206 | ValueError: If the rank of `image` is unknown, if `image` has a rank other 207 | than three or if the number of channels in `image` doesn't match the 208 | number of values in `means`. 209 | """ 210 | if image.get_shape().ndims != 3: 211 | raise ValueError('Input must be of size [height, width, C>0]') 212 | num_channels = image.get_shape().as_list()[-1] 213 | if len(means) != num_channels: 214 | raise ValueError('len(means) must match the number of channels') 215 | 216 | channels = tf.split(2, num_channels, image) 217 | for i in range(num_channels): 218 | channels[i] -= means[i] 219 | return tf.concat(2, channels) 220 | 221 | 222 | def _smallest_size_at_least(height, width, smallest_side): 223 | """Computes new shape with the smallest side equal to `smallest_side`. 224 | 225 | Computes new shape with the smallest side equal to `smallest_side` while 226 | preserving the original aspect ratio. 227 | 228 | Args: 229 | height: an int32 scalar tensor indicating the current height. 230 | width: an int32 scalar tensor indicating the current width. 231 | smallest_side: A python integer or scalar `Tensor` indicating the size of 232 | the smallest side after resize. 233 | 234 | Returns: 235 | new_height: an int32 scalar tensor indicating the new height. 236 | new_width: and int32 scalar tensor indicating the new width. 237 | """ 238 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 239 | 240 | height = tf.to_float(height) 241 | width = tf.to_float(width) 242 | smallest_side = tf.to_float(smallest_side) 243 | 244 | scale = tf.cond( 245 | tf.greater(height, width), lambda: smallest_side / width, 246 | lambda: smallest_side / height) 247 | new_height = tf.to_int32(height * scale) 248 | new_width = tf.to_int32(width * scale) 249 | return new_height, new_width 250 | 251 | 252 | def _aspect_preserving_resize(image, smallest_side): 253 | """Resize images preserving the original aspect ratio. 254 | 255 | Args: 256 | image: A 3-D image `Tensor`. 257 | smallest_side: A python integer or scalar `Tensor` indicating the size of 258 | the smallest side after resize. 259 | 260 | Returns: 261 | resized_image: A 3-D tensor containing the resized image. 262 | """ 263 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 264 | 265 | shape = tf.shape(image) 266 | height = shape[0] 267 | width = shape[1] 268 | new_height, new_width = _smallest_size_at_least(height, width, smallest_side) 269 | image = tf.expand_dims(image, 0) 270 | resized_image = tf.image.resize_bilinear( 271 | image, [new_height, new_width], align_corners=False) 272 | resized_image = tf.squeeze(resized_image) 273 | resized_image.set_shape([None, None, 3]) 274 | return resized_image 275 | 276 | 277 | def preprocess_for_train(image, 278 | output_height, 279 | output_width, 280 | resize_side_min=_RESIZE_SIDE_MIN, 281 | resize_side_max=_RESIZE_SIDE_MAX): 282 | """Preprocesses the given image for training. 283 | 284 | Note that the actual resizing scale is sampled from 285 | [`resize_size_min`, `resize_size_max`]. 286 | 287 | Args: 288 | image: A `Tensor` representing an image of arbitrary size. 289 | output_height: The height of the image after preprocessing. 290 | output_width: The width of the image after preprocessing. 291 | resize_side_min: The lower bound for the smallest side of the image for 292 | aspect-preserving resizing. 293 | resize_side_max: The upper bound for the smallest side of the image for 294 | aspect-preserving resizing. 295 | 296 | Returns: 297 | A preprocessed image. 298 | """ 299 | resize_side = tf.random_uniform( 300 | [], minval=resize_side_min, maxval=resize_side_max + 1, dtype=tf.int32) 301 | 302 | image = _aspect_preserving_resize(image, resize_side) 303 | image = _random_crop([image], output_height, output_width)[0] 304 | image.set_shape([output_height, output_width, 3]) 305 | image = tf.to_float(image) 306 | image = tf.image.random_flip_left_right(image) 307 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 308 | 309 | 310 | def preprocess_for_eval(image, output_height, output_width, resize_side): 311 | """Preprocesses the given image for evaluation. 312 | 313 | Args: 314 | image: A `Tensor` representing an image of arbitrary size. 315 | output_height: The height of the image after preprocessing. 316 | output_width: The width of the image after preprocessing. 317 | resize_side: The smallest side of the image for aspect-preserving resizing. 318 | 319 | Returns: 320 | A preprocessed image. 321 | """ 322 | image = _aspect_preserving_resize(image, resize_side) 323 | image = _central_crop([image], output_height, output_width)[0] 324 | image.set_shape([output_height, output_width, 3]) 325 | image = tf.to_float(image) 326 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 327 | 328 | 329 | def preprocess_image(image, 330 | output_height, 331 | output_width, 332 | is_training=False, 333 | resize_side_min=_RESIZE_SIDE_MIN, 334 | resize_side_max=_RESIZE_SIDE_MAX): 335 | """Preprocesses the given image. 336 | 337 | Args: 338 | image: A `Tensor` representing an image of arbitrary size. 339 | output_height: The height of the image after preprocessing. 340 | output_width: The width of the image after preprocessing. 341 | is_training: `True` if we're preprocessing the image for training and 342 | `False` otherwise. 343 | resize_side_min: The lower bound for the smallest side of the image for 344 | aspect-preserving resizing. If `is_training` is `False`, then this value 345 | is used for rescaling. 346 | resize_side_max: The upper bound for the smallest side of the image for 347 | aspect-preserving resizing. If `is_training` is `False`, this value is 348 | ignored. Otherwise, the resize side is sampled from 349 | [resize_size_min, resize_size_max]. 350 | 351 | Returns: 352 | A preprocessed image. 353 | """ 354 | if is_training: 355 | return preprocess_for_train(image, output_height, output_width, 356 | resize_side_min, resize_side_max) 357 | else: 358 | return preprocess_for_eval(image, output_height, output_width, 359 | resize_side_min) 360 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #Run evaluation. 4 | python eval.py \ 5 | --checkpoint_path=${TRAIN_DIR} \ 6 | --eval_dir=${TRAIN_DIR} \ 7 | --dataset_name=${DATASET} \ 8 | --dataset_dir=${DATASET_DIR} \ 9 | --model_name=${MODEL} \ 10 | --preprocessing_name=lenet \ 11 | --dataset_split_name=validation 12 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Run training. 3 | python train.py \ 4 | --train_dir=${TRAIN_DIR} \ 5 | --dataset_name=${DATASET} \ 6 | --dataset_dir=${DATASET_DIR} \ 7 | --model_name=${MODEL} \ 8 | --save_summaries_secs=60 \ 9 | --save_interval_secs=60 \ 10 | --dataset_split_name=train \ 11 | --preprocessing_name=lenet \ 12 | --max_number_of_steps=10000 \ 13 | --batch_size=32 \ 14 | --log_every_n_steps=100 \ 15 | --optimizer=adam \ 16 | --train_image_size=240 17 | -------------------------------------------------------------------------------- /scripts/unet-pascal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL=fcn_32 3 | DATASET=pascal 4 | BASE_DIR=/mnt/data 5 | TRAIN_DIR=${BASE_DIR}/cache/${DATASET}_${MODEL} 6 | DATASET_DIR=${BASE_DIR}/datasets/${DATASET} 7 | 8 | # Download the dataset 9 | python prepare_data.py \ 10 | --dataset_name=${DATASET} \ 11 | --dataset_dir=${DATASET_DIR} 12 | 13 | # Run training. 14 | python train.py \ 15 | --train_dir=${TRAIN_DIR} \ 16 | --dataset_name=${DATASET} \ 17 | --dataset_dir=${DATASET_DIR} \ 18 | --model_name=${MODEL} \ 19 | --save_summaries_secs=60 \ 20 | --save_interval_secs=60 \ 21 | --dataset_split_name=train \ 22 | --max_number_of_steps=10000 \ 23 | --batch_size=32 \ 24 | --log_every_n_steps=100 \ 25 | --optimizer=adam -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Generic training script that trains a model using a given dataset.""" 2 | 3 | from __future__ import absolute_import, division, print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python.ops import control_flow_ops 7 | 8 | from datasets import dataset_factory 9 | from deployment import model_deploy 10 | from nets import nets_factory 11 | from preprocessing import preprocessing_factory 12 | 13 | slim = tf.contrib.slim 14 | flags = tf.app.flags 15 | 16 | flags.DEFINE_string('master', '', 'address of the TensorFlow master to use') 17 | flags.DEFINE_string('train_dir', '/tmp/tfmodel/', 'checkpoints and event logs') 18 | flags.DEFINE_integer('num_clones', 1, 'model clones to deploy') 19 | flags.DEFINE_boolean('clone_on_cpu', False, 'Use CPUs to deploy clones') 20 | flags.DEFINE_integer('worker_replicas', 1, 'worker replicas') 21 | flags.DEFINE_integer('num_ps_tasks', 0, 'param servers. If 0, handle locally') 22 | flags.DEFINE_integer('num_readers', 4, 'parallel dataset readers') 23 | flags.DEFINE_integer('num_preprocessing_threads', 4, 'batch data threads') 24 | flags.DEFINE_integer('log_every_n_steps', 10, 'how often logs are print') 25 | flags.DEFINE_integer('save_summaries_secs', 600, 'summaries saved every x sec') 26 | flags.DEFINE_integer('save_interval_secs', 600, 'model saved every x sec') 27 | flags.DEFINE_integer('task', 0, 'Task id of the replica running the training') 28 | 29 | # Optimization Flags 30 | flags.DEFINE_float('weight_decay', 0.00004, 'weight decay on the model weights') 31 | flags.DEFINE_string('optimizer', 'rmsprop', '"adadelta", "adagrad", "adam",' 32 | '"ftrl", "momentum", "sgd" or "rmsprop"') 33 | flags.DEFINE_float('adadelta_rho', 0.95, 'decay rate for adadelta') 34 | flags.DEFINE_float('adagrad_initial_accumulator_value', 0.1, 'initial AdaGrad') 35 | flags.DEFINE_float('adam_beta1', 0.9, 'exp. decay for 1st moment estimates') 36 | flags.DEFINE_float('adam_beta2', 0.999, 'exp. decay for 2nd moment estimates') 37 | flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for optimizer') 38 | flags.DEFINE_float('ftrl_learning_rate_power', -0.5, 'learning rate power') 39 | flags.DEFINE_float('ftrl_initial_accumulator_value', 0.1, 'initital FTRL') 40 | flags.DEFINE_float('ftrl_l1', 0.0, 'FTRL l1 regularization strength') 41 | flags.DEFINE_float('ftrl_l2', 0.0, 'FTRL l2 regularization strength') 42 | flags.DEFINE_float('momentum', 0.9, 'MomentumOptimizer and RMSPropOptimizer') 43 | flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum') 44 | flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp') 45 | 46 | # Learning rate Flags 47 | flags.DEFINE_string('learning_rate_decay_type', 'polynomial', 48 | 'exponential/fixed/polynomial') 49 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate') 50 | flags.DEFINE_float('end_learning_rate', 0.0001, 'min end LR polynomial decay') 51 | flags.DEFINE_float('label_smoothing', 0.0, 'amount of label smoothing') 52 | flags.DEFINE_float('learning_rate_decay_factor', 0.94, 'Learning rate decay') 53 | flags.DEFINE_float('num_epochs_per_decay', 2.0, 'epochs when LR decays') 54 | flags.DEFINE_bool('sync_replicas', False, 'synchronize the replicas?') 55 | flags.DEFINE_integer('replicas_to_aggregate', 1, 'gradients before updating') 56 | flags.DEFINE_float('moving_average_decay', None, 'If None,not used') 57 | 58 | # Dataset Flags 59 | flags.DEFINE_string('dataset_name', 'imagenet', 'dataset to load') 60 | flags.DEFINE_string('dataset_split_name', 'train', 'name of train/test split') 61 | flags.DEFINE_string('dataset_dir', None, 'where dataset files are stored') 62 | flags.DEFINE_integer('labels_offset', 0, 'Labels offset; used in VGG/ResNet') 63 | flags.DEFINE_string('model_name', 'inception_v3', 'architecture to train') 64 | flags.DEFINE_string('preprocessing_name', None, 'If `None`, model_name is used') 65 | flags.DEFINE_integer('batch_size', 32, 'samples in each batch') 66 | flags.DEFINE_integer('train_image_size', None, 'Train image size') 67 | flags.DEFINE_integer('max_number_of_steps', None, 'maximum training steps') 68 | 69 | # Fine-Tuning Flags 70 | flags.DEFINE_string('checkpoint_path', None, 'path to a checkpoint to finetune') 71 | flags.DEFINE_string( 72 | 'checkpoint_exclude_scopes', None, 73 | 'Comma-separated list of scopes of variables to exclude when restoring ' 74 | 'from a checkpoint') 75 | flags.DEFINE_string( 76 | 'trainable_scopes', None, 77 | 'Comma-separated list of scopes to filter the set of variables to train' 78 | 'By default, None would train all the variables') 79 | flags.DEFINE_boolean( 80 | 'ignore_missing_vars', False, 81 | 'When restoring a checkpoint would ignore missing variables') 82 | 83 | FLAGS = flags.FLAGS 84 | 85 | 86 | def _configure_learning_rate(num_samples_per_epoch, global_step): 87 | """Configures the learning rate. 88 | 89 | Args: 90 | num_samples_per_epoch: The samples in each epoch of training. 91 | global_step: The global_step tensor. 92 | 93 | Returns: 94 | A `Tensor` representing the learning rate. 95 | 96 | Raises: 97 | ValueError: if 98 | """ 99 | decay_steps = int(num_samples_per_epoch / FLAGS.batch_size * 100 | FLAGS.num_epochs_per_decay) 101 | if FLAGS.sync_replicas: 102 | decay_steps /= FLAGS.replicas_to_aggregate 103 | 104 | if FLAGS.learning_rate_decay_type == 'exponential': 105 | return tf.train.exponential_decay( 106 | FLAGS.learning_rate, 107 | global_step, 108 | decay_steps, 109 | FLAGS.learning_rate_decay_factor, 110 | staircase=True, 111 | name='exponential_decay_learning_rate') 112 | elif FLAGS.learning_rate_decay_type == 'fixed': 113 | return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate') 114 | elif FLAGS.learning_rate_decay_type == 'polynomial': 115 | return tf.train.polynomial_decay( 116 | FLAGS.learning_rate, 117 | global_step, 118 | decay_steps, 119 | FLAGS.end_learning_rate, 120 | power=1.0, 121 | cycle=False, 122 | name='polynomial_decay_learning_rate') 123 | else: 124 | raise ValueError('learning_rate_decay_type [%s]', 125 | FLAGS.learning_rate_decay_type) 126 | 127 | 128 | def _configure_optimizer(learning_rate): 129 | """Configures the optimizer used for training. 130 | 131 | Args: learning_rate: A scalar or `Tensor` learning rate. 132 | 133 | Returns: An instance of an optimizer. 134 | 135 | Raises: 136 | ValueError: if FLAGS.optimizer is not recognized. 137 | """ 138 | if FLAGS.optimizer == 'adadelta': 139 | optimizer = tf.train.AdadeltaOptimizer( 140 | learning_rate, rho=FLAGS.adadelta_rho, epsilon=FLAGS.opt_epsilon) 141 | elif FLAGS.optimizer == 'adagrad': 142 | optimizer = tf.train.AdagradOptimizer( 143 | learning_rate, 144 | initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value) 145 | elif FLAGS.optimizer == 'adam': 146 | optimizer = tf.train.AdamOptimizer( 147 | learning_rate, 148 | beta1=FLAGS.adam_beta1, 149 | beta2=FLAGS.adam_beta2, 150 | epsilon=FLAGS.opt_epsilon) 151 | elif FLAGS.optimizer == 'ftrl': 152 | optimizer = tf.train.FtrlOptimizer( 153 | learning_rate, 154 | learning_rate_power=FLAGS.ftrl_learning_rate_power, 155 | initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value, 156 | l1_regularization_strength=FLAGS.ftrl_l1, 157 | l2_regularization_strength=FLAGS.ftrl_l2) 158 | elif FLAGS.optimizer == 'momentum': 159 | optimizer = tf.train.MomentumOptimizer( 160 | learning_rate, momentum=FLAGS.momentum, name='Momentum') 161 | elif FLAGS.optimizer == 'rmsprop': 162 | optimizer = tf.train.RMSPropOptimizer( 163 | learning_rate, 164 | decay=FLAGS.rmsprop_decay, 165 | momentum=FLAGS.rmsprop_momentum, 166 | epsilon=FLAGS.opt_epsilon) 167 | elif FLAGS.optimizer == 'sgd': 168 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 169 | else: 170 | raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer) 171 | return optimizer 172 | 173 | 174 | def _add_variables_summaries(learning_rate): 175 | summaries = [] 176 | for variable in slim.get_model_variables(): 177 | summaries.append(tf.summary.histogram(variable.op.name, variable)) 178 | summaries.append(tf.summary.scalar('training_lr', learning_rate)) 179 | return summaries 180 | 181 | 182 | def _get_init_fn(): 183 | """Returns a function run by the chief worker to warm-start the training. 184 | 185 | init_fn is only run when initializing the model on first global step. 186 | 187 | Returns: 188 | An init function run by the supervisor. 189 | """ 190 | if FLAGS.checkpoint_path is None: 191 | return None 192 | 193 | # Warn the user if a checkpoint exists in the train_dir. Then we'll be 194 | # ignoring the checkpoint anyway. 195 | if tf.train.latest_checkpoint(FLAGS.train_dir): 196 | tf.logging.info( 197 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s' % 198 | FLAGS.train_dir) 199 | return None 200 | 201 | exclusions = [] 202 | if FLAGS.checkpoint_exclude_scopes: 203 | exclusions = [ 204 | scope.strip() for scope in FLAGS.checkpoint_exclude_scopes.split(',') 205 | ] 206 | 207 | # TODO(sguada) variables.filter_variables() 208 | variables_to_restore = [] 209 | for var in slim.get_model_variables(): 210 | excluded = False 211 | for exclusion in exclusions: 212 | if var.op.name.startswith(exclusion): 213 | excluded = True 214 | break 215 | if not excluded: 216 | variables_to_restore.append(var) 217 | 218 | if tf.gfile.IsDirectory(FLAGS.checkpoint_path): 219 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 220 | else: 221 | checkpoint_path = FLAGS.checkpoint_path 222 | 223 | tf.logging.info('Fine-tuning from %s' % checkpoint_path) 224 | 225 | return slim.assign_from_checkpoint_fn( 226 | checkpoint_path, 227 | variables_to_restore, 228 | ignore_missing_vars=FLAGS.ignore_missing_vars) 229 | 230 | 231 | def _get_variables_to_train(): 232 | """Returns a list of variables to train. 233 | 234 | Returns: 235 | A list of variables to train by the optimizer. 236 | """ 237 | if FLAGS.trainable_scopes is None: 238 | return tf.trainable_variables() 239 | else: 240 | scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')] 241 | 242 | variables_to_train = [] 243 | for scope in scopes: 244 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 245 | variables_to_train.extend(variables) 246 | return variables_to_train 247 | 248 | 249 | def main(_): 250 | if not FLAGS.dataset_dir: 251 | raise ValueError('You must supply the dataset directory with --dataset_dir') 252 | 253 | tf.logging.set_verbosity(tf.logging.INFO) 254 | with tf.Graph().as_default(): 255 | # Config model_deploy 256 | deploy_config = model_deploy.DeploymentConfig( 257 | num_clones=FLAGS.num_clones, 258 | clone_on_cpu=FLAGS.clone_on_cpu, 259 | replica_id=FLAGS.task, 260 | num_replicas=FLAGS.worker_replicas, 261 | num_ps_tasks=FLAGS.num_ps_tasks) 262 | 263 | # Create global_step 264 | with tf.device(deploy_config.variables_device()): 265 | global_step = slim.create_global_step() 266 | 267 | # Select the dataset # 268 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 269 | FLAGS.dataset_split_name, 270 | FLAGS.dataset_dir) 271 | 272 | # Select the network 273 | network_fn = nets_factory.get_network_fn( 274 | FLAGS.model_name, 275 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 276 | weight_decay=FLAGS.weight_decay, 277 | is_training=True) 278 | 279 | # Select the preprocessing function 280 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 281 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 282 | preprocessing_name, is_training=True) 283 | 284 | # Create a dataset provider that loads data from the dataset # 285 | 286 | with tf.device(deploy_config.inputs_device()): 287 | provider = slim.dataset_data_provider.DatasetDataProvider( 288 | dataset, 289 | num_readers=FLAGS.num_readers, 290 | common_queue_capacity=20 * FLAGS.batch_size, 291 | common_queue_min=10 * FLAGS.batch_size) 292 | [image, label] = provider.get(['image', 'label']) 293 | label = image 294 | # todo(bdd) : imagewise labels 295 | # label -= FLAGS.labels_offset 296 | train_image_size = FLAGS.train_image_size or network_fn.default_image_size 297 | image, label = image_preprocessing_fn(image, label, train_image_size, 298 | train_image_size) 299 | images, labels = tf.train.batch( 300 | [image, label], 301 | batch_size=FLAGS.batch_size, 302 | num_threads=FLAGS.num_preprocessing_threads, 303 | capacity=5 * FLAGS.batch_size) 304 | 305 | #labels = slim.one_hot_encoding(labels, dataset.num_classes) 306 | batch_queue = slim.prefetch_queue.prefetch_queue( 307 | [images, labels], capacity=2 * deploy_config.num_clones) 308 | 309 | # Define the model # 310 | def clone_fn(batch_queue): 311 | """Allows data parallelism by creating multiple clones of network_fn.""" 312 | images, labels = batch_queue.dequeue() 313 | logits, end_points = network_fn(images) 314 | print("=" * 40 + ">images<" + "=" * 40) 315 | print(images) 316 | print("=" * 40 + ">labels<" + "=" * 40) 317 | print(labels) 318 | print("=" * 40 + ">logits<" + "=" * 40) 319 | print(logits) 320 | tf.contrib.losses.softmax_cross_entropy( 321 | logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) 322 | return end_points 323 | 324 | # Gather initial summaries. 325 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 326 | 327 | clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) 328 | clone_scope = deploy_config.clone_scope(0) 329 | # Gather update_ops from the first clone. These contain, for example, 330 | # the updates for batch_norm variables created by network_fn. 331 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone_scope) 332 | 333 | # Add summaries for end_points. 334 | end_points = clones[0].outputs 335 | for end_point in end_points: 336 | x = end_points[end_point] 337 | summaries.add(tf.summary.histogram('activations_' + end_point, x)) 338 | summaries.add( 339 | tf.summary.scalar('sparsity_' + end_point, tf.nn.zero_fraction(x))) 340 | 341 | # Add summaries for losses. 342 | for loss in tf.get_collection(tf.GraphKeys.LOSSES, clone_scope): 343 | summaries.add(tf.summary.scalar('losses_%s' % loss.op.name, loss)) 344 | 345 | # Add summaries for variables. 346 | for variable in slim.get_model_variables(): 347 | summaries.add(tf.summary.histogram(variable.op.name, variable)) 348 | 349 | # Configure the moving averages 350 | if FLAGS.moving_average_decay: 351 | moving_average_variables = slim.get_model_variables() 352 | variable_averages = tf.train.ExponentialMovingAverage( 353 | FLAGS.moving_average_decay, global_step) 354 | else: 355 | moving_average_variables, variable_averages = None, None 356 | 357 | # Configure the optimization procedure. # 358 | with tf.device(deploy_config.optimizer_device()): 359 | learning_rate = _configure_learning_rate(dataset.num_samples, global_step) 360 | optimizer = _configure_optimizer(learning_rate) 361 | summaries.add(tf.summary.scalar('learning_rate', learning_rate)) 362 | 363 | if FLAGS.sync_replicas: 364 | # If sync_replicas is enabled, the averaging will be done in the chief 365 | # queue runner. 366 | optimizer = tf.train.SyncReplicasOptimizer( 367 | opt=optimizer, 368 | replicas_to_aggregate=FLAGS.replicas_to_aggregate, 369 | variable_averages=variable_averages, 370 | variables_to_average=moving_average_variables, 371 | replica_id=tf.constant( 372 | FLAGS.task, tf.int32, shape=()), 373 | total_num_replicas=FLAGS.worker_replicas) 374 | elif FLAGS.moving_average_decay: 375 | # Update ops executed locally by trainer. 376 | update_ops.append(variable_averages.apply(moving_average_variables)) 377 | 378 | # Variables to train. 379 | variables_to_train = _get_variables_to_train() 380 | 381 | # and returns a train_tensor and summary_op 382 | total_loss, cl_gradients = model_deploy.optimize_clones( 383 | clones, optimizer, var_list=variables_to_train) 384 | # Add total_loss to summary. 385 | summaries.add(tf.summary.scalar('total_loss', total_loss)) 386 | 387 | # Create gradient updates. 388 | grad_updates = optimizer.apply_gradients( 389 | cl_gradients, global_step=global_step) 390 | update_ops.append(grad_updates) 391 | 392 | update_op = tf.group(*update_ops) 393 | train_tensor = control_flow_ops.with_dependencies([update_op], total_loss) 394 | 395 | # Add the summaries from the first clone. These contain the summaries 396 | # created by model_fn and either optimize_clones() or 397 | # _gather_clone_loss(). 398 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, clone_scope)) 399 | 400 | # Merge all summaries together. 401 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 402 | 403 | # Kicks off the training. 404 | slim.learning.train( 405 | train_tensor, 406 | logdir=FLAGS.train_dir, 407 | master=FLAGS.master, 408 | is_chief=(FLAGS.task == 0), 409 | init_fn=_get_init_fn(), 410 | summary_op=summary_op, 411 | number_of_steps=FLAGS.max_number_of_steps, 412 | log_every_n_steps=FLAGS.log_every_n_steps, 413 | save_summaries_secs=FLAGS.save_summaries_secs, 414 | save_interval_secs=FLAGS.save_interval_secs, 415 | sync_optimizer=optimizer if FLAGS.sync_replicas else None) 416 | 417 | 418 | if __name__ == '__main__': 419 | tf.app.run() 420 | --------------------------------------------------------------------------------