├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── client.py ├── createpickles.py ├── models ├── .gitignore ├── __init__.py ├── modelcreator.py └── modelserver.py ├── pickles ├── my_model_architecture.json ├── my_model_weights.h5 ├── scalar_x.pickle └── scalar_y.pickle ├── requirements.txt ├── server.py └── settings.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ### 0.4.0 [20200904] 2 | #### Changes: 3 | - Yearly dependency upgrades, `check requirements.txt` file 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ankur Srivastava 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Keras-rest-server: A simple rest implementation for loading and serving keras models 2 | ------------------ 3 | ## About: 4 | This repository contains a very simple server implemented in flask which loads a 5 | a simple neural network model trained using Keras from its saved-weights and 6 | model. 7 | 8 | In this example a very simple case of XOR is considered. 9 | ## Getting started: 10 | --- 11 | 1. Install Anaconda: 12 | ``` 13 | https://docs.continuum.io/anaconda/install 14 | ``` 15 | 16 | 2. Clone this repository 17 | ``` 18 | git clone https://github.com/ansrivas/keras-rest-server.git 19 | cd keras-rest-server 20 | ``` 21 | 22 | 3. Create a new environment ( Change python=2 or python=3) and activate it: 23 | ``` 24 | conda create --name keras-server -y python=2 25 | source activate keras-server 26 | ``` 27 | 28 | 4. Install all the dependencies: 29 | ``` 30 | conda env update -n keras-server --file requirements.txt 31 | ``` 32 | 33 | 5. To remove the environment run: 34 | ``` 35 | conda remove -n keras-server --all -y 36 | ``` 37 | 38 | ### Usage 39 | ------------------ 40 | 41 | ### Run to generate pickle files: 42 | ``` 43 | python createpickles.py 44 | ``` 45 | 46 | ### Run the server (defaults to http://localhost:7171) 47 | ``` 48 | python server.py 49 | ``` 50 | 51 | ### Send a post request to this server to test your model 52 | ``` 53 | python client.py 54 | ``` 55 | -------------------------------------------------------------------------------- /client.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Client to send inputs for predictions.""" 4 | 5 | import requests 6 | 7 | 8 | def get_predictions(X_input): 9 | """Get predictions from a rest backend for your input.""" 10 | print("Requesting prediction for XOR with {0}".format(X_input)) 11 | r = requests.post("http://localhost:7171/predict", json={'X_input': X_input}) 12 | print(r.status_code, r.reason) 13 | resp = r.json() 14 | prediction = resp['pred_val'][0] 15 | print("XOR of input: {0} is {1} ".format(X_input, prediction)) 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | X_inputs = [[1., 1.], [1., 0.]] 21 | 22 | for x_input in X_inputs: 23 | get_predictions(x_input) 24 | -------------------------------------------------------------------------------- /createpickles.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Default module to train a xor classifier and write weights to disk.""" 4 | 5 | from keras.models import Sequential 6 | from keras.layers.core import Dense, Activation 7 | import keras.optimizers as kop 8 | import numpy as np 9 | import os 10 | from sklearn.preprocessing import StandardScaler 11 | try: 12 | import cPickle as pickle 13 | except Exception as ex: 14 | import pickle 15 | 16 | 17 | def check_dir_exists(dirname='./pickles'): 18 | """Check if given dirname exists This will contain all the pickle files.""" 19 | if not os.path.exists(dirname): 20 | print("Directory to store pickes does not exist. Creating one now: ./pickles") 21 | os.mkdir(dirname) 22 | 23 | 24 | def save_x_y_scalar(X_train, Y_train): 25 | """Use a normalization method on your current dataset and save the coefficients. 26 | 27 | Args: 28 | X_train: Input X_train 29 | Y_train: Lables Y_train 30 | Returns: 31 | Normalized X_train,Y_train ( currently using StandardScaler from scikit-learn) 32 | """ 33 | scalar_x = StandardScaler() 34 | X_train = scalar_x.fit_transform(X_train) 35 | 36 | scalar_y = StandardScaler() 37 | Y_train = scalar_y.fit_transform(Y_train) 38 | 39 | print('dumping StandardScaler objects ..') 40 | pickle.dump(scalar_y, 41 | open('pickles/scalar_y.pickle', "wb"), 42 | protocol=pickle.HIGHEST_PROTOCOL) 43 | pickle.dump(scalar_x, 44 | open('pickles/scalar_x.pickle', "wb"), 45 | protocol=pickle.HIGHEST_PROTOCOL) 46 | return X_train, Y_train 47 | 48 | 49 | def create_model(X_train, Y_train): 50 | """create_model will create a very simple neural net model and save the weights in a predefined directory. 51 | 52 | Args: 53 | X_train: Input X_train 54 | Y_train: Lables Y_train 55 | """ 56 | xin = X_train.shape[1] 57 | 58 | model = Sequential() 59 | model.add(Dense(units=4, input_shape=(xin, ))) 60 | model.add(Activation('tanh')) 61 | model.add(Dense(4)) 62 | model.add(Activation('linear')) 63 | model.add(Dense(1)) 64 | 65 | rms = kop.RMSprop() 66 | 67 | print('compiling now..') 68 | model.compile(loss='mse', optimizer=rms) 69 | 70 | model.fit(X_train, Y_train, epochs=1000, batch_size=1, verbose=2) 71 | score = model.evaluate(X_train, Y_train, batch_size=1) 72 | print("Evaluation results:", score) 73 | open('pickles/my_model_architecture.json', 'w').write(model.to_json()) 74 | 75 | print("Saving weights in: ./pickles/my_model_weights.h5") 76 | model.save_weights('pickles/my_model_weights.h5') 77 | 78 | 79 | if __name__ == '__main__': 80 | X_train = np.array([[1., 1.], [1., 0], [0, 1.], [0, 0]]) 81 | Y_train = np.array([[0.], [1.], [1.], [0.]]) 82 | 83 | check_dir_exists(dirname='./pickles') 84 | X_train, Y_train = save_x_y_scalar(X_train, Y_train) 85 | create_model(X_train, Y_train) 86 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | # Scala-IDE specific 2 | .scala_dependencies 3 | .worksheet 4 | 5 | # IntelliJ specific 6 | .idea 7 | /bin/ 8 | 9 | *.class 10 | *.log 11 | *.out 12 | # sbt specific 13 | .cache 14 | .history 15 | .lib/ 16 | dist/* 17 | target/ 18 | lib_managed/ 19 | src_managed/ 20 | project/boot/ 21 | project/plugins/project/ 22 | 23 | # Scala-IDE specific 24 | .scala_dependencies 25 | .worksheet 26 | # Eclipse 27 | .project 28 | .pydevproject 29 | .pyc 30 | bin/ 31 | tmp/ 32 | .target 33 | .factorypath 34 | .classpath 35 | .settings/ 36 | *.tmp 37 | *.bak 38 | *.swp 39 | *~.nib 40 | .loadpath 41 | .cache-main 42 | # IntelliJ specific 43 | .idea 44 | /bin/ 45 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ansrivas/keras-rest-server/dde45af92bb956cb8ead439acc6037d307119c5e/models/__init__.py -------------------------------------------------------------------------------- /models/modelcreator.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Base module to create the model from its normalization coefficients.""" 4 | 5 | from keras.models import model_from_json 6 | 7 | try: 8 | import cPickle as pickle 9 | except Exception as ex: 10 | import pickle 11 | 12 | 13 | class ModelOperations(object): 14 | """ModelOperations class deals with saving/loading the model weights from Keras models.""" 15 | 16 | def __init__(self): 17 | """Initialize ModelOperations class.""" 18 | pass 19 | 20 | def load_model(self, json_path, weights_path): 21 | """Load model from a given json path and weights path.""" 22 | try: 23 | 24 | model = model_from_json(open(json_path).read()) 25 | model.load_weights(weights_path) 26 | return model 27 | except Exception as ex: 28 | raise Exception('Failed to load model/weights') 29 | 30 | def load_normalizer(self, sk_normalized): 31 | """load_normalizer will loads the sklearn.preprocessing.StandardScaler object. 32 | 33 | This had been used to normalize the original dataset. 34 | """ 35 | try: 36 | f = open(sk_normalized, 'rb') 37 | scalar = pickle.load(f) 38 | f.close() 39 | return scalar 40 | except Exception as ex: 41 | raise Exception('Failed to load normalizer') 42 | 43 | def save_model(self, model, json_path, weights_path): 44 | """Helper wrapper over savemodels and saveweights to help keras dump the weights and configuration.""" 45 | json_string = model.to_json() 46 | with open(json_path, 'w') as f: 47 | f.write(json_string) 48 | model.save_weights(weights_path) 49 | 50 | 51 | class Predictor(object): 52 | """Predictor class does the heavy lifting of compiling Keras model, normalizing inputs, denormalizing outputs.""" 53 | 54 | def __init__(self, json_path, weights_path, normalized_x, normalized_y, **kwargs): 55 | """Initialize Predictor class.""" 56 | modoperations = ModelOperations() 57 | self.model = modoperations.load_model(json_path, weights_path) 58 | self.scalar_x = modoperations.load_normalizer(normalized_x) 59 | self.scalar_y = modoperations.load_normalizer(normalized_y) 60 | 61 | def compile_model(self, loss, optimizer, **kwargs): 62 | """Similar to Keras compile function, expects atleast losstype and optimizer.""" 63 | self.model.compile(loss=loss, optimizer=optimizer, **kwargs) 64 | 65 | def _normalize_input(self, X_input): 66 | """Normalize the input object to be predicted according to the scalar used during the training process.""" 67 | X_input = self.scalar_x.transform(X_input) 68 | return X_input 69 | 70 | def _denormalize_prediction(self, x_pred): 71 | """De-normalize the x_pred to actual value as per dataset.""" 72 | value = self.scalar_y.inverse_transform(x_pred) 73 | return value 74 | 75 | def predict(self, X_input): 76 | """Make predictions, given some input data. 77 | 78 | This normalizes the predictions based on the real normalization 79 | parameters and then generates a prediction 80 | 81 | Args: 82 | X_input: Input vector to for prediction 83 | """ 84 | x_normed = self._normalize_input(X_input=X_input) 85 | x_pred = self.model.predict(x_normed) 86 | prediction = self._denormalize_prediction(x_pred) 87 | return prediction 88 | -------------------------------------------------------------------------------- /models/modelserver.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Model server with some hardcoded paths.""" 4 | 5 | from flask.views import MethodView 6 | 7 | from flask import Flask, request, jsonify 8 | from .modelcreator import Predictor 9 | from gevent.pywsgi import WSGIServer 10 | import numpy as np 11 | 12 | app = Flask(__name__) 13 | predictor = None 14 | 15 | 16 | class ModelLoader(MethodView): 17 | """ModelLoader class initialzes the model params and waits for a post request to server predictions.""" 18 | 19 | def __init__(self): 20 | """Initialize ModelLoader class.""" 21 | pass 22 | 23 | def post(self): 24 | """Accept a post request to serve predictions.""" 25 | content = request.get_json() 26 | X_input = content['X_input'] 27 | if not isinstance(X_input, np.ndarray): 28 | X_in = np.reshape(np.array(X_input), newshape=(1, 2)) 29 | pred_val = predictor.predict(X_input=X_in) 30 | pred_val = pred_val.tolist() 31 | return jsonify({'pred_val': pred_val}) 32 | 33 | 34 | def initialize_models(json_path, weights_path, normalized_x, normalized_y): 35 | """Initialize models and use this in Flask server.""" 36 | global predictor 37 | predictor = Predictor(json_path, weights_path, normalized_x, normalized_y) 38 | predictor.compile_model(loss='mse', optimizer='rmsprop') 39 | 40 | 41 | def run(host='0.0.0.0', port=7171): 42 | """Run a WSGI server using gevent.""" 43 | app.add_url_rule('/predict', view_func=ModelLoader.as_view('predict')) 44 | print('running server http://{0}'.format(host + ':' + str(port))) 45 | WSGIServer((host, port), app).serve_forever() 46 | -------------------------------------------------------------------------------- /pickles/my_model_architecture.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": [null, 2], "dtype": "float32", "sparse": false, "ragged": false, "name": "dense_input"}}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "batch_input_shape": [null, 2], "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Activation", "config": {"name": "activation", "trainable": true, "dtype": "float32", "activation": "tanh"}}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 4, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Activation", "config": {"name": "activation_1", "trainable": true, "dtype": "float32", "activation": "linear"}}, {"class_name": "Dense", "config": {"name": "dense_2", "trainable": true, "dtype": "float32", "units": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.4.0", "backend": "tensorflow"} -------------------------------------------------------------------------------- /pickles/my_model_weights.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ansrivas/keras-rest-server/dde45af92bb956cb8ead439acc6037d307119c5e/pickles/my_model_weights.h5 -------------------------------------------------------------------------------- /pickles/scalar_x.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ansrivas/keras-rest-server/dde45af92bb956cb8ead439acc6037d307119c5e/pickles/scalar_x.pickle -------------------------------------------------------------------------------- /pickles/scalar_y.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ansrivas/keras-rest-server/dde45af92bb956cb8ead439acc6037d307119c5e/pickles/scalar_y.pickle -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.3.0 2 | 3 | Keras==2.4.3 4 | gevent==20.6.2 5 | scikit-learn==0.23.2 6 | Flask==1.1.2 7 | Flask-Cors==3.0.9 8 | requests==2.24.0 9 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """A very simple flask server to serve models.""" 4 | 5 | from models import modelserver 6 | import settings 7 | import sys 8 | sys.path.append("./models") 9 | 10 | modelserver.initialize_models(json_path=settings.path_model_json, 11 | weights_path=settings.path_model_weight, 12 | normalized_x=settings.path_x_normalizer, 13 | normalized_y=settings.path_y_normalizer) 14 | modelserver.run() 15 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Settings module.""" 4 | 5 | from os.path import dirname, abspath 6 | __dir__ = dirname(abspath(__file__)) 7 | 8 | path_model_json = 'pickles/my_model_architecture.json' 9 | path_model_weight = 'pickles/my_model_weights.h5' 10 | path_y_normalizer = 'pickles/scalar_y.pickle' 11 | path_x_normalizer = 'pickles/scalar_x.pickle' 12 | --------------------------------------------------------------------------------