├── .gitattributes ├── save ├── checkpoint ├── test.ckpt.meta ├── test.ckpt.index └── test.ckpt.data-00000-of-00001 ├── serve └── test │ └── 1 │ ├── saved_model.pb │ └── variables │ ├── variables.index │ └── variables.data-00000-of-00001 ├── README.md ├── .idea ├── modules.xml ├── misc.xml ├── serving.iml └── workspace.xml ├── client.py ├── LICENSE ├── serve.py ├── .gitignore └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /save/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "test.ckpt" 2 | all_model_checkpoint_paths: "test.ckpt" 3 | -------------------------------------------------------------------------------- /save/test.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrancescoSaverioZuppichini/TensorFlow-Serving-Example/HEAD/save/test.ckpt.meta -------------------------------------------------------------------------------- /save/test.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrancescoSaverioZuppichini/TensorFlow-Serving-Example/HEAD/save/test.ckpt.index -------------------------------------------------------------------------------- /serve/test/1/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrancescoSaverioZuppichini/TensorFlow-Serving-Example/HEAD/serve/test/1/saved_model.pb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Serving Example 2 | 3 | Code example for my medium article https://towardsdatascience.com/deploy-tensorflow-models-9813b5a705d5 4 | -------------------------------------------------------------------------------- /save/test.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrancescoSaverioZuppichini/TensorFlow-Serving-Example/HEAD/save/test.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /serve/test/1/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrancescoSaverioZuppichini/TensorFlow-Serving-Example/HEAD/serve/test/1/variables/variables.index -------------------------------------------------------------------------------- /serve/test/1/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FrancescoSaverioZuppichini/TensorFlow-Serving-Example/HEAD/serve/test/1/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ApexVCS 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/serving.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from predict_client.prod_client import ProdClient 3 | from flask import Flask 4 | from flask import request 5 | from flask import jsonify 6 | 7 | HOST = 'localhost:9000' 8 | MODEL_NAME = 'test' 9 | MODEL_VERSION = 1 10 | 11 | app = Flask(__name__) 12 | client = ProdClient(HOST, MODEL_NAME, MODEL_VERSION) 13 | 14 | def convert_data(raw_data): 15 | return np.array(raw_data, dtype=np.float32) 16 | 17 | def get_prediction_from_model(data): 18 | req_data = [{'in_tensor_name': 'inputs', 'in_tensor_dtype': 'DT_FLOAT', 'data': data}] 19 | 20 | prediction = client.predict(req_data, request_timeout=10) 21 | 22 | return prediction 23 | 24 | 25 | @app.route("/prediction", methods=['POST']) 26 | def get_prediction(): 27 | req_data = request.get_json() 28 | raw_data = req_data['data'] 29 | 30 | data = convert_data(raw_data) 31 | prediction = get_prediction_from_model(data) 32 | 33 | # ndarray cannot be converted to JSON 34 | return jsonify({ 'predictions': prediction['outputs'].tolist() }) 35 | 36 | if __name__ == '__main__': 37 | app.run(host='localhost',port=3000) 38 | 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 FrancescoSaverioZuppichini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /serve.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | SAVE_PATH = './save' 5 | MODEL_NAME = 'test' 6 | VERSION = 1 7 | SERVE_PATH = './serve/{}/{}'.format(MODEL_NAME, VERSION) 8 | 9 | checkpoint = tf.train.latest_checkpoint(SAVE_PATH) 10 | 11 | tf.reset_default_graph() 12 | 13 | with tf.Session() as sess: 14 | # import the saved graph 15 | saver = tf.train.import_meta_graph(checkpoint + '.meta') 16 | # get the graph for this session 17 | graph = tf.get_default_graph() 18 | sess.run(tf.global_variables_initializer()) 19 | # get the tensors that we need 20 | inputs = graph.get_tensor_by_name('inputs:0') 21 | predictions = graph.get_tensor_by_name('prediction/Sigmoid:0') 22 | # create tensors info 23 | model_input = tf.saved_model.utils.build_tensor_info(inputs) 24 | model_output = tf.saved_model.utils.build_tensor_info(predictions) 25 | # build signature definition 26 | signature_definition = tf.saved_model.signature_def_utils.build_signature_def( 27 | inputs={'inputs': model_input}, 28 | outputs={'outputs': model_output}, 29 | method_name= tf.saved_model.signature_constants.PREDICT_METHOD_NAME) 30 | 31 | builder = tf.saved_model.builder.SavedModelBuilder(SERVE_PATH) 32 | 33 | builder.add_meta_graph_and_variables( 34 | sess, [tf.saved_model.tag_constants.SERVING], 35 | signature_def_map={ 36 | tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 37 | signature_definition 38 | }) 39 | # Save the model so we can serve it with a model server :) 40 | builder.save() 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .idea/ 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # Environments 83 | .env 84 | .venv 85 | env/ 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | .spyproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # mkdocs documentation 97 | /site 98 | 99 | # mypy 100 | .mypy_cache/ 101 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os, sys 4 | 5 | DATA_SIZE = 100 6 | SAVE_PATH = './save' 7 | EPOCHS = 1000 8 | LEARNING_RATE = 0.01 9 | MODEL_NAME = 'test' 10 | 11 | if not os.path.exists(SAVE_PATH): 12 | os.mkdir(SAVE_PATH) 13 | 14 | data = (np.random.rand(DATA_SIZE, 2), np.random.rand(DATA_SIZE, 1)) 15 | test = (np.random.rand(DATA_SIZE // 8, 2), np.random.rand(DATA_SIZE // 8, 1)) 16 | 17 | tf.reset_default_graph() 18 | 19 | x = tf.placeholder(tf.float32, shape=[None, 2], name='inputs') 20 | y = tf.placeholder(tf.float32, shape=[None, 1], name='targets') 21 | 22 | net = tf.layers.dense(x, 16, activation=tf.nn.relu) 23 | net = tf.layers.dense(net, 16, activation=tf.nn.relu) 24 | pred = tf.layers.dense(net, 1, activation=tf.nn.sigmoid, name='prediction') 25 | 26 | loss = tf.reduce_mean(tf.squared_difference(y, pred), name='loss') 27 | train_step = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) 28 | 29 | checkpoint = tf.train.latest_checkpoint(SAVE_PATH) 30 | should_train = checkpoint == None 31 | 32 | with tf.Session() as sess: 33 | sess.run(tf.global_variables_initializer()) 34 | if should_train: 35 | print("Training") 36 | saver = tf.train.Saver() 37 | for epoch in range(EPOCHS): 38 | _, curr_loss = sess.run([train_step, loss], feed_dict={x: data[0], y: data[1]}) 39 | print('EPOCH = {}, LOSS = {:0.4f}'.format(epoch, curr_loss)) 40 | path = saver.save(sess, SAVE_PATH + '/' + MODEL_NAME + '.ckpt') 41 | print("saved at {}".format(path)) 42 | else: 43 | print("Restoring") 44 | graph = tf.get_default_graph() 45 | saver = tf.train.import_meta_graph(checkpoint + '.meta') 46 | saver.restore(sess, checkpoint) 47 | 48 | loss = graph.get_tensor_by_name('loss:0') 49 | 50 | test_loss = sess.run(loss, feed_dict={'inputs:0': test[0], 'targets:0': test[1]}) 51 | print(sess.run(pred, feed_dict={'inputs:0': np.random.rand(10,2)})) 52 | print("TEST LOSS = {:0.4f}".format(test_loss)) 53 | 54 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 63 | 64 | 71 | 72 | 73 | 74 | 75 | true 76 | DEFINITION_ORDER 77 | 78 | 79 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 |