├── .pylintrc ├── Makefile ├── README.md ├── export.py ├── model.py ├── predict.py ├── serve.py └── train.py /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | #init-hook= 9 | 10 | # Profiled execution. 11 | profile=no 12 | 13 | # Add files or directories to the blacklist. They should be base names, not 14 | # paths. 15 | ignore=CVS 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=yes 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | 25 | [MESSAGES CONTROL] 26 | 27 | # Enable the message, report, category or checker with the given id(s). You can 28 | # either give multiple identifier separated by comma (,) or put this option 29 | # multiple time. See also the "--disable" option for examples. 30 | enable=indexing-exception,old-raise-syntax 31 | 32 | # Disable the message, report, category or checker with the given id(s). You 33 | # can either give multiple identifiers separated by comma (,) or put this 34 | # option multiple times (only on the command line, not in the configuration 35 | # file where it should appear only once).You can also use "--disable=all" to 36 | # disable everything first and then reenable specific checks. For example, if 37 | # you want to run only the similarities checker, you can use "--disable=all 38 | # --enable=similarities". If you want to run only the classes checker, but have 39 | # no Warning level messages displayed, use"--disable=all --enable=classes 40 | # --disable=W" 41 | disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager 42 | 43 | 44 | # Set the cache size for astng objects. 45 | cache-size=500 46 | 47 | 48 | [REPORTS] 49 | 50 | # Set the output format. Available formats are text, parseable, colorized, msvs 51 | # (visual studio) and html. You can also give a reporter class, eg 52 | # mypackage.mymodule.MyReporterClass. 53 | output-format=text 54 | 55 | # Put messages in a separate file for each module / package specified on the 56 | # command line instead of printing them on stdout. Reports (if any) will be 57 | # written in a file name "pylint_global.[txt|html]". 58 | files-output=no 59 | 60 | # Tells whether to display a full report or only the messages 61 | reports=no 62 | 63 | # Python expression which should return a note less than 10 (10 is the highest 64 | # note). You have access to the variables errors warning, statement which 65 | # respectively contain the number of errors / warnings messages and the total 66 | # number of statements analyzed. This is used by the global evaluation report 67 | # (RP0004). 68 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 69 | 70 | # Add a comment according to your evaluation note. This is used by the global 71 | # evaluation report (RP0004). 72 | comment=no 73 | 74 | # Template used to display messages. This is a python new-style format string 75 | # used to format the message information. See doc for all details 76 | #msg-template= 77 | 78 | 79 | [TYPECHECK] 80 | 81 | # Tells whether missing members accessed in mixin class should be ignored. A 82 | # mixin class is detected if its name ends with "mixin" (case insensitive). 83 | ignore-mixin-members=yes 84 | 85 | # List of classes names for which member attributes should not be checked 86 | # (useful for classes with attributes dynamically set). 87 | ignored-classes=SQLObject 88 | 89 | # When zope mode is activated, add a predefined set of Zope acquired attributes 90 | # to generated-members. 91 | zope=no 92 | 93 | # List of members which are set dynamically and missed by pylint inference 94 | # system, and so shouldn't trigger E0201 when accessed. Python regular 95 | # expressions are accepted. 96 | generated-members=REQUEST,acl_users,aq_parent 97 | 98 | # List of decorators that create context managers from functions, such as 99 | # contextlib.contextmanager. 100 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 101 | 102 | 103 | [VARIABLES] 104 | 105 | # Tells whether we should check for unused import in __init__ files. 106 | init-import=no 107 | 108 | # A regular expression matching the beginning of the name of dummy variables 109 | # (i.e. not used). 110 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 111 | 112 | # List of additional names supposed to be defined in builtins. Remember that 113 | # you should avoid to define new builtins when possible. 114 | additional-builtins= 115 | 116 | 117 | [BASIC] 118 | 119 | # Required attributes for module, separated by a comma 120 | required-attributes= 121 | 122 | # List of builtins function names that should not be used, separated by a comma 123 | bad-functions=apply,input,reduce 124 | 125 | 126 | # Disable the report(s) with the given id(s). 127 | # All non-Google reports are disabled by default. 128 | disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 129 | 130 | # Regular expression which should only match correct module names 131 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 132 | 133 | # Regular expression which should only match correct module level names 134 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 135 | 136 | # Regular expression which should only match correct class names 137 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 138 | 139 | # Regular expression which should only match correct function names 140 | function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 141 | 142 | # Regular expression which should only match correct method names 143 | method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 144 | 145 | # Regular expression which should only match correct instance attribute names 146 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 147 | 148 | # Regular expression which should only match correct argument names 149 | argument-rgx=^[a-z][a-z0-9_]*$ 150 | 151 | # Regular expression which should only match correct variable names 152 | variable-rgx=^[a-z][a-z0-9_]*$ 153 | 154 | # Regular expression which should only match correct attribute names in class 155 | # bodies 156 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 157 | 158 | # Regular expression which should only match correct list comprehension / 159 | # generator expression variable names 160 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 161 | 162 | # Good variable names which should always be accepted, separated by a comma 163 | good-names=main,_ 164 | 165 | # Bad variable names which should always be refused, separated by a comma 166 | bad-names= 167 | 168 | # Regular expression which should only match function or class names that do 169 | # not require a docstring. 170 | no-docstring-rgx=(__.*__|main) 171 | 172 | # Minimum line length for functions/classes that require docstrings, shorter 173 | # ones are exempt. 174 | docstring-min-length=10 175 | 176 | 177 | [FORMAT] 178 | 179 | # Maximum number of characters on a single line. 180 | max-line-length=100 181 | 182 | # Regexp for a line that is allowed to be longer than the limit. 183 | ignore-long-lines=(?x) 184 | (^\s*(import|from)\s 185 | |\$Id:\s\/\/depot\/.+#\d+\s\$ 186 | |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') 187 | |^\s*\#\ LINT\.ThenChange 188 | |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ 189 | |pylint 190 | |""" 191 | |\# 192 | |lambda 193 | |(https?|ftp):) 194 | 195 | # Allow the body of an if to be on the same line as the test if there is no 196 | # else. 197 | single-line-if-stmt=y 198 | 199 | # List of optional constructs for which whitespace checking is disabled 200 | no-space-check= 201 | 202 | # Maximum number of lines in a module 203 | max-module-lines=99999 204 | 205 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 206 | # tab). 207 | indent-string=' ' 208 | 209 | 210 | [SIMILARITIES] 211 | 212 | # Minimum lines number of a similarity. 213 | min-similarity-lines=4 214 | 215 | # Ignore comments when computing similarities. 216 | ignore-comments=yes 217 | 218 | # Ignore docstrings when computing similarities. 219 | ignore-docstrings=yes 220 | 221 | # Ignore imports when computing similarities. 222 | ignore-imports=no 223 | 224 | 225 | [MISCELLANEOUS] 226 | 227 | # List of note tags to take in consideration, separated by a comma. 228 | notes= 229 | 230 | 231 | [IMPORTS] 232 | 233 | # Deprecated modules which should not be used, separated by a comma 234 | deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets 235 | 236 | # Create a graph of every (i.e. internal and external) dependencies in the 237 | # given file (report RP0402 must not be disabled) 238 | import-graph= 239 | 240 | # Create a graph of external dependencies in the given file (report RP0402 must 241 | # not be disabled) 242 | ext-import-graph= 243 | 244 | # Create a graph of internal dependencies in the given file (report RP0402 must 245 | # not be disabled) 246 | int-import-graph= 247 | 248 | 249 | [CLASSES] 250 | 251 | # List of interface methods to ignore, separated by a comma. This is used for 252 | # instance to not check methods defines in Zope's Interface base class. 253 | ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by 254 | 255 | # List of method names used to declare (i.e. assign) instance attributes. 256 | defining-attr-methods=__init__,__new__,setUp 257 | 258 | # List of valid names for the first argument in a class method. 259 | valid-classmethod-first-arg=cls,class_ 260 | 261 | # List of valid names for the first argument in a metaclass class method. 262 | valid-metaclass-classmethod-first-arg=mcs 263 | 264 | 265 | [DESIGN] 266 | 267 | # Maximum number of arguments for function / method 268 | max-args=5 269 | 270 | # Argument names that match this expression will be ignored. Default to name 271 | # with leading underscore 272 | ignored-argument-names=_.* 273 | 274 | # Maximum number of locals for function / method body 275 | max-locals=15 276 | 277 | # Maximum number of return / yield for function / method body 278 | max-returns=6 279 | 280 | # Maximum number of branch for function / method body 281 | max-branches=12 282 | 283 | # Maximum number of statements in function / method body 284 | max-statements=50 285 | 286 | # Maximum number of parents for a class (see R0901). 287 | max-parents=7 288 | 289 | # Maximum number of attributes for a class (see R0902). 290 | max-attributes=7 291 | 292 | # Minimum number of public methods for a class (see R0903). 293 | min-public-methods=2 294 | 295 | # Maximum number of public methods for a class (see R0904). 296 | max-public-methods=20 297 | 298 | 299 | [EXCEPTIONS] 300 | 301 | # Exceptions that will emit a warning when being caught. Defaults to 302 | # "Exception" 303 | overgeneral-exceptions=Exception,StandardError,BaseException 304 | 305 | 306 | [AST] 307 | 308 | # Maximum line length for lambdas 309 | short-func-length=1 310 | 311 | # List of module members that should be marked as deprecated. 312 | # All of the string functions are listed in 4.1.4 Deprecated string functions 313 | # in the Python 2.4 docs. 314 | deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc 315 | 316 | 317 | [DOCSTRING] 318 | 319 | # List of exceptions that do not need to be mentioned in the Raises section of 320 | # a docstring. 321 | ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError 322 | 323 | 324 | 325 | [TOKENS] 326 | 327 | # Number of spaces of indent required when the last token on the preceding line 328 | # is an open (, [, or {. 329 | indent-after-paren=4 330 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | run: 2 | python train.py 3 | python export.py 4 | python predict.py 5 | python serve.py 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Estimator Basics 2 | 3 | Train, predict, export and reload a `tf.estimator` for inference on a dummy example. 4 | 5 | [Read the blog post](https://guillaumegenthial.github.io/serving-tensorflow-estimator.html) 6 | 7 | 8 | 9 | 10 | ## Quickstart 11 | 12 | ``` 13 | make run 14 | ``` 15 | 16 | ## Details 17 | 18 | - `model.py` defines the `model_fn` 19 | - `train.py` trains an Estimator using the `model_fn` 20 | - `export.py` exports the Estimator as a `saved_model` 21 | - `predict.py` reloads an Estimator and uses it for prediction 22 | - `serve.py` reloads the inference graph from the `saved_model` format and uses it for prediction 23 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | """Export estimator as a saved_model""" 2 | 3 | __author__ = "Guillaume Genthial" 4 | 5 | import tensorflow as tf 6 | 7 | from model import model_fn 8 | 9 | 10 | def serving_input_receiver_fn(): 11 | """Serving input_fn that builds features from placeholders 12 | 13 | Returns 14 | ------- 15 | tf.estimator.export.ServingInputReceiver 16 | """ 17 | number = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='number') 18 | receiver_tensors = {'number': number} 19 | features = tf.tile(number, multiples=[1, 2]) 20 | return tf.estimator.export.ServingInputReceiver(features, receiver_tensors) 21 | 22 | 23 | if __name__ == '__main__': 24 | estimator = tf.estimator.Estimator(model_fn, 'model', params={}) 25 | estimator.export_saved_model('saved_model', serving_input_receiver_fn) 26 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """Dummy model_fn""" 2 | 3 | __author__ = "Guillaume Genthial" 4 | 5 | 6 | import tensorflow as tf 7 | 8 | 9 | def model_fn(features, labels, mode, params): 10 | # pylint: disable=unused-argument 11 | """Dummy model_fn""" 12 | if isinstance(features, dict): # For serving 13 | features = features['feature'] 14 | 15 | predictions = tf.layers.dense(features, 1) 16 | 17 | if mode == tf.estimator.ModeKeys.PREDICT: 18 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 19 | else: 20 | loss = tf.nn.l2_loss(predictions - labels) 21 | if mode == tf.estimator.ModeKeys.EVAL: 22 | return tf.estimator.EstimatorSpec( 23 | mode, loss=loss) 24 | 25 | elif mode == tf.estimator.ModeKeys.TRAIN: 26 | train_op = tf.train.AdamOptimizer(learning_rate=0.5).minimize( 27 | loss, global_step=tf.train.get_global_step()) 28 | return tf.estimator.EstimatorSpec( 29 | mode, loss=loss, train_op=train_op) 30 | else: 31 | raise NotImplementedError() 32 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """Predict using estimator.predict""" 2 | 3 | __author__ = "Guillaume Genthial" 4 | 5 | 6 | import functools 7 | from pathlib import Path 8 | import logging 9 | import sys 10 | import time 11 | 12 | import tensorflow as tf 13 | 14 | from model import model_fn 15 | 16 | 17 | def example_input_fn(number): 18 | """Dummy input_fn""" 19 | dataset = tf.data.Dataset.from_generator( 20 | lambda: ([number, number] for _ in range(1)), 21 | output_types=tf.float32, output_shapes=(2,)) 22 | iterator = dataset.batch(1).make_one_shot_iterator() 23 | next_element = iterator.get_next() 24 | return next_element, None 25 | 26 | 27 | def my_service(): 28 | """Some service yielding numbers""" 29 | start, end = 100, 110 30 | for number in range(start, end): 31 | yield number 32 | 33 | 34 | if __name__ == '__main__': 35 | # Logging 36 | Path('model').mkdir(exist_ok=True) 37 | tf.logging.set_verbosity(logging.INFO) 38 | handlers = [ 39 | logging.FileHandler('model/predict.log'), 40 | logging.StreamHandler(sys.stdout) 41 | ] 42 | logging.getLogger('tensorflow').handlers = handlers 43 | 44 | # Instantiate estimator 45 | estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir='model', 46 | params={}) 47 | 48 | # Predict using the estimator 49 | tic = time.time() 50 | for nb in my_service(): 51 | example_inpf = functools.partial(example_input_fn, nb) 52 | for pred in estimator.predict(example_inpf): 53 | # print((pred - 2*nb)**2) 54 | pass 55 | 56 | toc = time.time() 57 | print('Average time in predict.py: {}s'.format((toc - tic) / 10)) 58 | -------------------------------------------------------------------------------- /serve.py: -------------------------------------------------------------------------------- 1 | """Reload and serve a saved model""" 2 | 3 | __author__ = "Guillaume Genthial" 4 | 5 | from pathlib import Path 6 | import time 7 | 8 | from tensorflow.contrib import predictor 9 | 10 | 11 | def my_service(): 12 | """Some service yielding numbers""" 13 | start, end = 100, 110 14 | for number in range(start, end): 15 | yield number 16 | 17 | 18 | if __name__ == '__main__': 19 | export_dir = 'saved_model' 20 | subdirs = [x for x in Path(export_dir).iterdir() 21 | if x.is_dir() and 'temp' not in str(x)] 22 | latest = str(sorted(subdirs)[-1]) 23 | predict_fn = predictor.from_saved_model(latest) 24 | tic = time.time() 25 | for nb in my_service(): 26 | pred = predict_fn({'number': [[nb]]})['output'] 27 | # print((pred - 2*nb)**2) 28 | toc = time.time() 29 | print('Average time in serve.py: {}s'.format((toc - tic) / 10)) 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Exporting a tf.estimator for prediction""" 2 | 3 | __author__ = "Guillaume Genthial" 4 | 5 | from pathlib import Path 6 | import logging 7 | import sys 8 | 9 | import tensorflow as tf 10 | 11 | from model import model_fn 12 | 13 | 14 | def train_generator_fn(): 15 | for number in range(100): 16 | yield [number, number], [2 * number] 17 | 18 | 19 | def train_input_fn(): 20 | shapes, types = (2, 1), (tf.float32, tf.float32) 21 | dataset = tf.data.Dataset.from_generator( 22 | train_generator_fn, output_types=types, output_shapes=shapes) 23 | dataset = dataset.batch(20).repeat(200) 24 | return dataset 25 | 26 | 27 | if __name__ == '__main__': 28 | # Logging 29 | Path('model').mkdir(exist_ok=True) 30 | tf.logging.set_verbosity(logging.INFO) 31 | handlers = [ 32 | logging.FileHandler('model/train.log'), 33 | logging.StreamHandler(sys.stdout) 34 | ] 35 | logging.getLogger('tensorflow').handlers = handlers 36 | 37 | # Train estimator 38 | estimator = tf.estimator.Estimator(model_fn, 'model', params={}) 39 | estimator.train(train_input_fn) 40 | --------------------------------------------------------------------------------