├── MANIFEST.in ├── src └── tensorflow_serving_python │ ├── __init__.py │ ├── protos │ ├── __init__.py │ ├── prediction_service.proto │ ├── model.proto │ ├── resource_handle.proto │ ├── predict.proto │ ├── tensor_shape.proto │ ├── types.proto │ ├── tensor.proto │ └── wrappers.proto │ ├── proto_util.py │ └── client.py ├── requirements.txt ├── .gitignore ├── examples └── request.py ├── setup.py └── README.md /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | grpcio 3 | grpcio-tools 4 | tensorflow 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv 2 | build 3 | dist 4 | *_pb2.py 5 | *.egg-info 6 | .idea 7 | *.pyc 8 | .eggs 9 | .DS_Store 10 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/proto_util.py: -------------------------------------------------------------------------------- 1 | def copy_message(src, dst): 2 | """ 3 | Copy the contents of a src proto message to a destination proto message via string serialization 4 | :param src: Source proto 5 | :param dst: Destination proto 6 | :return: 7 | """ 8 | dst.ParseFromString(src.SerializeToString()) 9 | return dst 10 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/prediction_service.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow.serving; 4 | option cc_enable_arenas = true; 5 | 6 | import "tensorflow_serving_python/protos/predict.proto"; 7 | 8 | // PredictionService provides access to machine-learned models loaded by 9 | // model_servers. 10 | service PredictionService { 11 | // Predict -- provides access to loaded TensorFlow model. 12 | rpc Predict(PredictRequest) returns (PredictResponse); 13 | } 14 | -------------------------------------------------------------------------------- /examples/request.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | 4 | from tensorflow_serving_python.client import TFClient 5 | 6 | pp = pprint.PrettyPrinter(indent=2) 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser(description='RPC Test.') 10 | 11 | parser.add_argument('--host', required=True, type=str, help='Hostname to query') 12 | parser.add_argument('--port', required=True, type=str, help='Port to query') 13 | parser.add_argument('--image', required=True, type=str, help='Image to send (JPG format)') 14 | args = parser.parse_args() 15 | 16 | data = open(args.image, "rb").read() 17 | client = TFClient(args.host, args.port) 18 | pp.pprint(client.make_prediction(data, timeout=10)) 19 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/model.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow.serving; 4 | option cc_enable_arenas = true; 5 | 6 | import "tensorflow_serving_python/protos/wrappers.proto"; 7 | 8 | // Metadata for an inference request such as the model name and version. 9 | message ModelSpec { 10 | // Required servable name. 11 | string name = 1; 12 | 13 | // Optional version. If unspecified, will use the latest (numerical) version. 14 | // Typically not needed unless coordinating across multiple models that were 15 | // co-trained and/or have inter-dependencies on the versions used at inference 16 | // time. 17 | google.protobuf.Int64Value version = 2; 18 | 19 | // A named signature to evaluate. If unspecified, the default signature will 20 | // be used. Note that only MultiInference will initially support this. 21 | string signature_name = 3; 22 | } 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools.command.install import install 3 | 4 | import os 5 | 6 | 7 | class BuildPackageProtos(install): 8 | def run(self): 9 | install.run(self) 10 | from grpc.tools import command 11 | command.build_package_protos(self.distribution.package_dir['']) 12 | 13 | 14 | setup( 15 | name='tensorflow_serving_python', 16 | version='0.1', 17 | description='Python client for tensorflow serving', 18 | author="Sebastian Schlecht", 19 | license="MIT", 20 | packages=['tensorflow_serving_python', 'tensorflow_serving_python.protos'], 21 | package_dir={'': 'src'}, 22 | setup_requires=['cython'], 23 | install_requires=[ 24 | 'grpcio', 'grpcio-tools', 25 | 'tensorflow', 26 | ], 27 | cmdclass={ 28 | 'install': BuildPackageProtos, 29 | 'develop': BuildPackageProtos, 30 | } 31 | ) 32 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/resource_handle.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "ResourceHandleProto"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // Protocol buffer representing a handle to a tensorflow resource. Handles are 10 | // not valid across executions, but can be serialized back and forth from within 11 | // a single run. 12 | message ResourceHandle { 13 | // Unique name for the device containing the resource. 14 | string device = 1; 15 | 16 | // Container in which this resource is placed. 17 | string container = 2; 18 | 19 | // Unique name of this resource. 20 | string name = 3; 21 | 22 | // Hash code for the type of the resource. Is only valid in the same device 23 | // and in the same execution. 24 | uint64 hash_code = 4; 25 | 26 | // For debug-only, the name of the type pointed to by this handle, if 27 | // available. 28 | string maybe_type_name = 5; 29 | }; 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-serving-python # 2 | Minimal Python client to communicate with TensorFlow Serving. 3 | 4 | ### About ### 5 | Python client to communicate with the TensorFlow RPC API. Similar to the examples in the documentation of TensorFlow Serving, 6 | we use gRPC as the RPC client. 7 | 8 | ### TODOs ### 9 | - Link TensorFlow Serving as a sub-module here. To compile protobufs properly during build. 10 | - When distributing, get rid of tf.contrib functions to remove TensorFlow as a dependency. 11 | 12 | ### Installation ### 13 | Check out the source e.g. via ```git clone https://github.com/sebastian-schlecht/tensorflow-serving-python.git``` and run 14 | ```python setup.py install```. 15 | 16 | This package depends on the following pip modules, so you might want to install those first: 17 | - Cython 18 | - grpcio 19 | - grpcio-tools 20 | - tensorflow 21 | 22 | 23 | Alternatively, we offer an unofficial python wheel to be installed with: 24 | ```pip install http://www.mealomi.com/storage/tensorflow_serving_python/tensorflow_serving_python-0.1-py2-none-any.whl``` 25 | 26 | *Note* 27 | When running on Linux, make sure pip is upgraded via ```pip install --upgrade pip``` such that TensorFlow is installed correctly. 28 | 29 | ### Examples ### 30 | Examples can be found in the examples folder. 31 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/predict.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow.serving; 4 | option cc_enable_arenas = true; 5 | 6 | import "tensorflow_serving_python/protos/tensor.proto"; 7 | import "tensorflow_serving_python/protos/model.proto"; 8 | 9 | // PredictRequest specifies which TensorFlow model to run, as well as 10 | // how inputs are mapped to tensors and how outputs are filtered before 11 | // returning to user. 12 | message PredictRequest { 13 | // Model Specification. 14 | ModelSpec model_spec = 1; 15 | 16 | // Input tensors. 17 | // Names of input tensor are alias names. The mapping from aliases to real 18 | // input tensor names is expected to be stored as named generic signature 19 | // under the key "inputs" in the model export. 20 | // Each alias listed in a generic signature named "inputs" should be provided 21 | // exactly once in order to run the prediction. 22 | map inputs = 2; 23 | 24 | // Output filter. 25 | // Names specified are alias names. The mapping from aliases to real output 26 | // tensor names is expected to be stored as named generic signature under 27 | // the key "outputs" in the model export. 28 | // Only tensors specified here will be run/fetched and returned, with the 29 | // exception that when none is specified, all tensors specified in the 30 | // named signature will be run/fetched and returned. 31 | repeated string output_filter = 3; 32 | } 33 | 34 | // Response for PredictRequest on successful run. 35 | message PredictResponse { 36 | // Output tensors. 37 | map outputs = 1; 38 | } 39 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/tensor_shape.proto: -------------------------------------------------------------------------------- 1 | // Protocol buffer representing the shape of tensors. 2 | 3 | syntax = "proto3"; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorShapeProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | package tensorflow; 10 | 11 | // Dimensions of a tensor. 12 | message TensorShapeProto { 13 | // One dimension of the tensor. 14 | message Dim { 15 | // Size of the tensor in that dimension. 16 | // This value must be >= -1, but values of -1 are reserved for "unknown" 17 | // shapes (values of -1 mean "unknown" dimension). Certain wrappers 18 | // that work with TensorShapeProto may fail at runtime when deserializing 19 | // a TensorShapeProto containing a dim value of -1. 20 | int64 size = 1; 21 | 22 | // Optional name of the tensor dimension. 23 | string name = 2; 24 | }; 25 | 26 | // Dimensions of the tensor, such as {"input", 30}, {"output", 40} 27 | // for a 30 x 40 2D tensor. If an entry has size -1, this 28 | // corresponds to a dimension of unknown size. The names are 29 | // optional. 30 | // 31 | // The order of entries in "dim" matters: It indicates the layout of the 32 | // values in the tensor in-memory representation. 33 | // 34 | // The first entry in "dim" is the outermost dimension used to layout the 35 | // values, the last entry is the innermost dimension. This matches the 36 | // in-memory layout of RowMajor Eigen tensors. 37 | // 38 | // If "dim.size()" > 0, "unknown_rank" must be false. 39 | repeated Dim dim = 2; 40 | 41 | // If true, the number of dimensions in the shape is unknown. 42 | // 43 | // If true, "dim.size()" must be 0. 44 | bool unknown_rank = 3; 45 | }; 46 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/types.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TypesProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | // LINT.IfChange 10 | enum DataType { 11 | // Not a legal value for DataType. Used to indicate a DataType field 12 | // has not been set. 13 | DT_INVALID = 0; 14 | 15 | // Data types that all computation devices are expected to be 16 | // capable to support. 17 | DT_FLOAT = 1; 18 | DT_DOUBLE = 2; 19 | DT_INT32 = 3; 20 | DT_UINT8 = 4; 21 | DT_INT16 = 5; 22 | DT_INT8 = 6; 23 | DT_STRING = 7; 24 | DT_COMPLEX64 = 8; // Single-precision complex 25 | DT_INT64 = 9; 26 | DT_BOOL = 10; 27 | DT_QINT8 = 11; // Quantized int8 28 | DT_QUINT8 = 12; // Quantized uint8 29 | DT_QINT32 = 13; // Quantized int32 30 | DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. 31 | DT_QINT16 = 15; // Quantized int16 32 | DT_QUINT16 = 16; // Quantized uint16 33 | DT_UINT16 = 17; 34 | DT_COMPLEX128 = 18; // Double-precision complex 35 | DT_HALF = 19; 36 | DT_RESOURCE = 20; 37 | 38 | // TODO(josh11b): DT_GENERIC_PROTO = ??; 39 | // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? 40 | 41 | // Do not use! These are only for parameters. Every enum above 42 | // should have a corresponding value below (verified by types_test). 43 | DT_FLOAT_REF = 101; 44 | DT_DOUBLE_REF = 102; 45 | DT_INT32_REF = 103; 46 | DT_UINT8_REF = 104; 47 | DT_INT16_REF = 105; 48 | DT_INT8_REF = 106; 49 | DT_STRING_REF = 107; 50 | DT_COMPLEX64_REF = 108; 51 | DT_INT64_REF = 109; 52 | DT_BOOL_REF = 110; 53 | DT_QINT8_REF = 111; 54 | DT_QUINT8_REF = 112; 55 | DT_QINT32_REF = 113; 56 | DT_BFLOAT16_REF = 114; 57 | DT_QINT16_REF = 115; 58 | DT_QUINT16_REF = 116; 59 | DT_UINT16_REF = 117; 60 | DT_COMPLEX128_REF = 118; 61 | DT_HALF_REF = 119; 62 | DT_RESOURCE_REF = 120; 63 | } 64 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/client.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from grpc.beta import implementations 3 | 4 | from tensorflow_serving_python.protos import prediction_service_pb2, predict_pb2 5 | from tensorflow_serving_python.proto_util import copy_message 6 | 7 | 8 | class TFClient(object): 9 | """ 10 | TFClient class to use for RPC calls 11 | """ 12 | def __init__(self, host, port): 13 | """ 14 | Setup stuff 15 | :param host: Server hostname 16 | :param port: Server port 17 | :return: 18 | """ 19 | self.host = host 20 | self.port = port 21 | 22 | # Setup channel 23 | self.channel = implementations.insecure_channel(self.host, int(self.port)) 24 | self.stub = prediction_service_pb2.beta_create_PredictionService_stub(self.channel) 25 | 26 | def execute(self, request, timeout=10.0): 27 | """ 28 | Execture the RPC request 29 | :param request: Request proto 30 | :param timeout: Timeout in seconds to wait for more batches to pile up 31 | :return: Prediction result 32 | """ 33 | return self.stub.Predict(request, timeout) 34 | 35 | def make_prediction(self, data, name='inception', timeout=10., convert_to_dict=True): 36 | """ 37 | Make a prediction on a buffer full of image data (tested .jpg as of now) 38 | :param data: Data buffer 39 | :param name: Name of the model_spec to use 40 | :param timeout: Timeout in seconds to wait for more batches to pile up 41 | :return: Prediction result 42 | """ 43 | request = predict_pb2.PredictRequest() 44 | request.model_spec.name = name 45 | proto = tf.contrib.util.make_tensor_proto(data, shape=[1]) 46 | 47 | # TODO dst.CopyFrom(src) fails here because we compile custom protocolbuffers 48 | # TODO Proper compiling would speed up the next line by a factor of 10 49 | copy_message(proto, request.inputs['images']) 50 | response = self.execute(request, timeout=timeout) 51 | 52 | if not convert_to_dict: 53 | return response 54 | 55 | # Convert to friendly python object 56 | results_dict = {} 57 | for key in response.outputs: 58 | tensor_proto = response.outputs[key] 59 | nd_array = tf.contrib.util.make_ndarray(tensor_proto) 60 | results_dict[key] = nd_array 61 | 62 | return results_dict -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package tensorflow; 4 | option cc_enable_arenas = true; 5 | option java_outer_classname = "TensorProtos"; 6 | option java_multiple_files = true; 7 | option java_package = "org.tensorflow.framework"; 8 | 9 | import "tensorflow_serving_python/protos/resource_handle.proto"; 10 | import "tensorflow_serving_python/protos/tensor_shape.proto"; 11 | import "tensorflow_serving_python/protos/types.proto"; 12 | 13 | // Protocol buffer representing a tensor. 14 | message TensorProto { 15 | DataType dtype = 1; 16 | 17 | // Shape of the tensor. TODO(touts): sort out the 0-rank issues. 18 | TensorShapeProto tensor_shape = 2; 19 | 20 | // Only one of the representations below is set, one of "tensor_contents" and 21 | // the "xxx_val" attributes. We are not using oneof because as oneofs cannot 22 | // contain repeated fields it would require another extra set of messages. 23 | 24 | // Version number. 25 | // 26 | // In version 0, if the "repeated xxx" representations contain only one 27 | // element, that element is repeated to fill the shape. This makes it easy 28 | // to represent a constant Tensor with a single value. 29 | int32 version_number = 3; 30 | 31 | // Serialized content from Tensor::AsProtoTensorContent(). This representation 32 | // can be used for all tensor types. 33 | bytes tensor_content = 4; 34 | 35 | // Type specific representations that make it easy to create tensor protos in 36 | // all languages. Only the representation corresponding to "dtype" can 37 | // be set. The values hold the flattened representation of the tensor in 38 | // row major order. 39 | 40 | // DT_HALF. Note that since protobuf has no int16 type, we'll have some 41 | // pointless zero padding for each value here. 42 | repeated int32 half_val = 13 [packed = true]; 43 | 44 | // DT_FLOAT. 45 | repeated float float_val = 5 [packed = true]; 46 | 47 | // DT_DOUBLE. 48 | repeated double double_val = 6 [packed = true]; 49 | 50 | // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. 51 | repeated int32 int_val = 7 [packed = true]; 52 | 53 | // DT_STRING 54 | repeated bytes string_val = 8; 55 | 56 | // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real 57 | // and imaginary parts of i-th single precision complex. 58 | repeated float scomplex_val = 9 [packed = true]; 59 | 60 | // DT_INT64 61 | repeated int64 int64_val = 10 [packed = true]; 62 | 63 | // DT_BOOL 64 | repeated bool bool_val = 11 [packed = true]; 65 | 66 | // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real 67 | // and imaginary parts of i-th double precision complex. 68 | repeated double dcomplex_val = 12 [packed = true]; 69 | 70 | // DT_RESOURCE 71 | repeated ResourceHandle resource_handle_val = 14; 72 | }; 73 | -------------------------------------------------------------------------------- /src/tensorflow_serving_python/protos/wrappers.proto: -------------------------------------------------------------------------------- 1 | // Protocol Buffers - Google's data interchange format 2 | // Copyright 2008 Google Inc. All rights reserved. 3 | // https://developers.google.com/protocol-buffers/ 4 | // 5 | // Redistribution and use in source and binary forms, with or without 6 | // modification, are permitted provided that the following conditions are 7 | // met: 8 | // 9 | // * Redistributions of source code must retain the above copyright 10 | // notice, this list of conditions and the following disclaimer. 11 | // * Redistributions in binary form must reproduce the above 12 | // copyright notice, this list of conditions and the following disclaimer 13 | // in the documentation and/or other materials provided with the 14 | // distribution. 15 | // * Neither the name of Google Inc. nor the names of its 16 | // contributors may be used to endorse or promote products derived from 17 | // this software without specific prior written permission. 18 | // 19 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | // Wrappers for primitive (non-message) types. These types are useful 32 | // for embedding primitives in the `google.protobuf.Any` type and for places 33 | // where we need to distinguish between the absence of a primitive 34 | // typed field and its default value. 35 | 36 | syntax = "proto3"; 37 | 38 | package google.protobuf; 39 | 40 | option csharp_namespace = "Google.Protobuf.WellKnownTypes"; 41 | option cc_enable_arenas = true; 42 | option go_package = "github.com/golang/protobuf/ptypes/wrappers"; 43 | option java_package = "com.google.protobuf"; 44 | option java_outer_classname = "WrappersProto"; 45 | option java_multiple_files = true; 46 | option objc_class_prefix = "GPB"; 47 | 48 | // Wrapper message for `double`. 49 | // 50 | // The JSON representation for `DoubleValue` is JSON number. 51 | message DoubleValue { 52 | // The double value. 53 | double value = 1; 54 | } 55 | 56 | // Wrapper message for `float`. 57 | // 58 | // The JSON representation for `FloatValue` is JSON number. 59 | message FloatValue { 60 | // The float value. 61 | float value = 1; 62 | } 63 | 64 | // Wrapper message for `int64`. 65 | // 66 | // The JSON representation for `Int64Value` is JSON string. 67 | message Int64Value { 68 | // The int64 value. 69 | int64 value = 1; 70 | } 71 | 72 | // Wrapper message for `uint64`. 73 | // 74 | // The JSON representation for `UInt64Value` is JSON string. 75 | message UInt64Value { 76 | // The uint64 value. 77 | uint64 value = 1; 78 | } 79 | 80 | // Wrapper message for `int32`. 81 | // 82 | // The JSON representation for `Int32Value` is JSON number. 83 | message Int32Value { 84 | // The int32 value. 85 | int32 value = 1; 86 | } 87 | 88 | // Wrapper message for `uint32`. 89 | // 90 | // The JSON representation for `UInt32Value` is JSON number. 91 | message UInt32Value { 92 | // The uint32 value. 93 | uint32 value = 1; 94 | } 95 | 96 | // Wrapper message for `bool`. 97 | // 98 | // The JSON representation for `BoolValue` is JSON `true` and `false`. 99 | message BoolValue { 100 | // The bool value. 101 | bool value = 1; 102 | } 103 | 104 | // Wrapper message for `string`. 105 | // 106 | // The JSON representation for `StringValue` is JSON string. 107 | message StringValue { 108 | // The string value. 109 | string value = 1; 110 | } 111 | 112 | // Wrapper message for `bytes`. 113 | // 114 | // The JSON representation for `BytesValue` is JSON string. 115 | message BytesValue { 116 | // The bytes value. 117 | bytes value = 1; 118 | } 119 | --------------------------------------------------------------------------------