├── .codacy.yml ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── _config.yml ├── coverage.sh ├── examples ├── README.md ├── img │ ├── airplane.jpg │ └── cat.jpg ├── keras_boston_neural_net.py ├── keras_imagenet_resnet50.py ├── pytorch_imagenet_resnet50.py ├── sklearn_boston_linear_regression.py └── sklearn_iris_logistic_regression.py ├── requirements.txt ├── serveit ├── __init__.py ├── config.py ├── log_utils.py ├── server.py └── utils.py ├── setup.cfg ├── setup.py └── tests ├── SuccessKid.jpg ├── __init__.py ├── keras ├── __init__.py └── test_server.py ├── pytorch ├── __init__.py └── test_server.py ├── sklearn ├── __init__.py └── test_server.py ├── test_callback_utils.py ├── test_serialization_utils.py └── test_server.py /.codacy.yml: -------------------------------------------------------------------------------- 1 | --- 2 | exclude_paths: 3 | - tests/** 4 | - examples/** 5 | - *.py 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # VIM files 2 | *.swp 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | - "3.6" 5 | install: 6 | - pip install -r requirements.txt 7 | script: 8 | - python -m pytest tests --ignore tests/keras --ignore tests/pytorch --ignore tests/test_callback_utils.py 9 | after_script: ./coverage.sh 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2010-2018 Google, Inc. http://angularjs.org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ServeIt 2 | [![Build Status](https://travis-ci.org/rtlee9/serveit.svg?branch=master)](https://travis-ci.org/rtlee9/serveit) 3 | [![Codacy Grade Badge](https://api.codacy.com/project/badge/Grade/2af32a3840d5441e815f3956659b091f)](https://www.codacy.com/app/ryantlee9/serveit) 4 | [![Codacy Coverage Badge](https://api.codacy.com/project/badge/Coverage/2af32a3840d5441e815f3956659b091f)](https://www.codacy.com/app/ryantlee9/serveit) 5 | [![PyPI version](https://badge.fury.io/py/ServeIt.svg)](https://badge.fury.io/py/ServeIt) 6 | 7 | ServeIt lets you serve model predictions and supplementary information from a RESTful API using your favorite Python ML library in as little as one line of code: 8 | 9 | ```python 10 | from serveit.server import ModelServer 11 | from sklearn.linear_model import LogisticRegression 12 | from sklearn.datasets import load_iris 13 | 14 | # fit logistic regression on Iris data 15 | clf = LogisticRegression() 16 | data = load_iris() 17 | clf.fit(data.data, data.target) 18 | 19 | # initialize server with a model and start serving predictions 20 | ModelServer(clf, clf.predict).serve() 21 | ``` 22 | 23 | Your new API is now accepting `POST` requests at `localhost:5000/predictions`! Please see the [examples](examples) directory for detailed examples across domains (e.g., regression, image classification), including live examples. 24 | 25 | #### Features 26 | Current ServeIt features include: 27 | 28 | 1. Model inference serving via RESTful API endpoint 29 | 1. Extensible library for inference-time data loading, preprocessing, input validation, and postprocessing 30 | 1. Supplementary information endpoint creation 31 | 1. Automatic JSON serialization of responses 32 | 1. Configurable request and response logging (work in progress) 33 | 34 | #### Supported libraries 35 | The following libraries are currently supported: 36 | * Scikit-Learn 37 | * Keras 38 | * PyTorch 39 | 40 | ## Installation: Python 2.7 and Python 3.6 41 | Installation is easy with pip: `pip install serveit` 42 | 43 | ## Building 44 | You can build locally with: `python setup.py` 45 | 46 | ## License 47 | [MIT](LICENSE.md) 48 | 49 | Please consider buying me a coffee if you like my work: 50 | 51 | Buy Me A Coffee 52 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /coverage.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | coverage run --source serveit -a -m tests.test_utils 3 | coverage run --source serveit -a -m tests.sklearn.test_server 4 | coverage xml --omit serveit/config.py 5 | python-codacy-coverage -r coverage.xml 6 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # ServeIt examples 2 | 3 | ## Basic example: Iris predictions with Scikit-learn 4 | 5 | Let's train and deploy a logistic regression model to classify irises. We'll start by fitting a model: 6 | ```python 7 | from sklearn.datasets import load_iris 8 | from sklearn.linear_model import LogisticRegression 9 | 10 | # fit a model on the Iris dataset 11 | data = load_iris() 12 | clf = LogisticRegression() 13 | clf.fit(data.data, data.target) 14 | ``` 15 | Now we can serve our trained model: 16 | ```python 17 | from serveit.server import ModelServer 18 | 19 | # initialize server 20 | server = ModelServer(clf, clf.predict) 21 | 22 | # optional: add informational endpoints 23 | server.create_info_endpoint('features', data.feature_names) 24 | server.create_info_endpoint('target_labels', data.target_names.tolist()) 25 | 26 | # start serving predictions from API 27 | server.serve() 28 | ``` 29 | 30 | Behold: 31 | 32 | ```bash 33 | curl -XPOST 'localhost:5000/predictions'\ 34 | -H "Content-Type: application/json"\ 35 | -d "[[5.6, 2.9, 3.6, 1.3], [4.4, 2.9, 1.4, 0.2], [5.5, 2.4, 3.8, 1.1], [5.0, 3.4, 1.5, 0.2], [5.7, 2.5, 5.0, 2.0]]" 36 | # [1, 0, 1, 0, 2] 37 | 38 | curl -XGET 'localhost:5000/info/model' 39 | # {"penalty": "l2", "tol": 0.0001, "C": 1.0, "classes_": [0, 1, 2], "coef_": [[0.4150, 1.4613, -2.2621, -1.0291], ...], ...} 40 | 41 | curl -XGET 'localhost:5000/info/features' 42 | # ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"] 43 | 44 | curl -XGET 'localhost:5000/info/target_labels' 45 | # ["setosa", "versicolor", "virginica"] 46 | ``` 47 | 48 | ## Advanced example: image classification with Keras 49 | 50 | ServeIt accepts optional pre/postprocessing callback methods, making it easy start serving more complex models. Let's deploy a pre-trained Keras model to a new API endpoint so that we can classify images on the fly. We'll start by loading a ResNet50 model pre-trained on ImageNet: 51 | 52 | ```python 53 | from keras.applications.resnet50 import ResNet50 54 | 55 | # load Resnet50 model pretrained on ImageNet 56 | model = ResNet50(weights='imagenet') 57 | ``` 58 | 59 | Next we define methods for loading and preprocessing an image from a URL... 60 | ```python 61 | from keras.preprocessing import image 62 | from keras.applications.resnet50 import preprocess_input 63 | from flask import request 64 | import requests 65 | from serveit.utils import make_serializable, get_bytes_to_image_callback 66 | 67 | # define a loader callback for the API to fetch the relevant data and 68 | # preprocessor callbacks to map to a format expected by the model 69 | def loader(): 70 | """Load image from URL, and preprocess for Resnet.""" 71 | url = request.args.get('url') # read image URL as a request URL param 72 | response = requests.get(url) # make request to static image file 73 | return response.content 74 | 75 | # get a bytes-to-image callback, resizing the image to 224x224 for ImageNet 76 | bytes_to_image = get_bytes_to_image_callback(image_dims=(224, 224)) 77 | 78 | # create a list of different preprocessors to chain multiple steps 79 | preprocessor = [bytes_to_image, preprocess_input] 80 | ``` 81 | 82 | ... and import a decoder for postprocessing the model predictions for the API response: 83 | ```python 84 | from keras.applications.resnet50 import decode_predictions 85 | ``` 86 | 87 | And now we're ready to start serving our image classifier: 88 | ```python 89 | from serveit.server import ModelServer 90 | 91 | # deploy model to a ModelServer 92 | server = ModelServer( 93 | model, 94 | model.predict, 95 | data_loader=loader, 96 | preprocessor=preprocessor, 97 | postprocessor=decode_predictions, 98 | ) 99 | 100 | # start serving 101 | server.serve() 102 | ``` 103 | 104 | Behold: 105 | ![cat picture](img/cat.jpg) 106 | ```bash 107 | curl -XPOST 'localhost:5000/predictions?url=https://cdn.pixabay.com/photo/2017/11/14/13/06/kitty-2948404_640.jpg' 108 | # [[["n02123159", "tiger_cat", 0.598746120929718], ["n02127052", "lynx", 0.32807421684265137], ["n02123045", "tabby", 0.042475175112485886]]] 109 | ``` 110 | 111 | ![plane picture](img/airplane.jpg) 112 | ```bash 113 | curl -XPOST 'localhost:5000/predictions?url=https://cdn.pixabay.com/photo/2012/06/28/08/26/plane-50893_640.jpg' 114 | # [[["n02690373", "airliner", 0.5599709749221802], ["n04592741", "wing", 0.286420077085495], ["n04552348", "warplane", 0.14331381022930145]]] 115 | ``` 116 | 117 | You can interact with a live DenseNet121 demo server at `https://imagenet-keras.ryanlee.site/predictions` (source code and sample requests [here](https://github.com/rtlee9/serveit-demo-imagenet-keras/)). 118 | 119 | ## Advanced example: serving with gunicorn 120 | If you have a preference for a specific WSGI HTTP server, you can easily retrieve the underlying app from the server to serve separately. Once you've initialized the ModelServer class, fetch the underlying app in the global scope of a Python script like so: 121 | 122 | ```python 123 | # main.py 124 | app = server.get_app() 125 | ``` 126 | 127 | Now all you have to do in your shell (or Procfile) is: 128 | ```bash 129 | # shell 130 | gunicorn main:app 131 | 132 | # Procfile 133 | web: gunicorn main:app 134 | ``` 135 | 136 | [View all examples](https://github.com/rtlee9/serveit/tree/master/examples) 137 | -------------------------------------------------------------------------------- /examples/img/airplane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/examples/img/airplane.jpg -------------------------------------------------------------------------------- /examples/img/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/examples/img/cat.jpg -------------------------------------------------------------------------------- /examples/keras_boston_neural_net.py: -------------------------------------------------------------------------------- 1 | """Sample ServeIt prediction server.""" 2 | from sklearn.datasets import load_boston 3 | from serveit.server import ModelServer 4 | from keras.models import Sequential 5 | from keras.layers import Dense 6 | 7 | 8 | def get_model(input_dim): 9 | """Create and compile simple model.""" 10 | model = Sequential() 11 | model.add(Dense(100, input_dim=input_dim, activation='sigmoid')) 12 | model.add(Dense(1)) 13 | model.compile(loss='mean_squared_error', optimizer='SGD') 14 | return model 15 | 16 | # fit a model on the Boston housing dataset 17 | data = load_boston() 18 | model = get_model(data.data.shape[1]) 19 | model.fit(data.data, data.target) 20 | 21 | 22 | def validator(input_data): 23 | """Simple model input validator. 24 | 25 | Validator ensures the input data array is 26 | - two dimensional 27 | - has the correct number of features. 28 | """ 29 | global data 30 | # check num dims 31 | if input_data.ndim != 2: 32 | return False, 'Data should have two dimensions.' 33 | # check number of columns 34 | if input_data.shape[1] != data.data.shape[1]: 35 | reason = '{} features required, {} features provided'.format( 36 | data.data.shape[1], input_data.shape[1]) 37 | return False, reason 38 | # validation passed 39 | return True, None 40 | 41 | # deploy model to a ModelServer 42 | server = ModelServer(model, model.predict, validator) 43 | 44 | # add informational endpoints 45 | server.create_info_endpoint('features', data.feature_names) 46 | 47 | # start API 48 | server.serve() 49 | -------------------------------------------------------------------------------- /examples/keras_imagenet_resnet50.py: -------------------------------------------------------------------------------- 1 | """Serve Keras ResNet50 model trained on ImageNet. 2 | 3 | Prediction endpoint, served at `/predictions` takes a URL pointing to an image 4 | and returns a list of class probabilities. 5 | """ 6 | from serveit.server import ModelServer 7 | from serveit.utils import get_bytes_to_image_callback 8 | 9 | from keras.applications.resnet50 import ResNet50 10 | from keras.applications.resnet50 import decode_predictions 11 | from keras.applications.resnet50 import preprocess_input 12 | 13 | from flask import request 14 | import requests 15 | 16 | # load Resnet50 model pretrained on ImageNet 17 | model = ResNet50(weights='imagenet') 18 | 19 | 20 | # define a loader callback for the API to fetch the relevant data and 21 | # convert to a format expected by the prediction function 22 | def loader(): 23 | """Load image from URL, and preprocess for Resnet.""" 24 | url = request.args.get('url') # read image URL as a request URL param 25 | response = requests.get(url) # make request to static image file 26 | return response.content 27 | 28 | # get a bytes-to-image callback, resizing the image to 224x224 for ImageNet 29 | bytes_to_image = get_bytes_to_image_callback(image_dims=(224, 224)) 30 | 31 | # deploy model to a ModelServer 32 | server = ModelServer( 33 | model, 34 | model.predict, 35 | data_loader=loader, 36 | preprocessor=[bytes_to_image, preprocess_input], 37 | postprocessor=decode_predictions, 38 | ) 39 | 40 | # start API 41 | server.serve() 42 | -------------------------------------------------------------------------------- /examples/pytorch_imagenet_resnet50.py: -------------------------------------------------------------------------------- 1 | """Serve PyTorch ResNet50 model trained on ImageNet. 2 | 3 | Prediction endpoint, served at `/predictions` takes a URL pointing to an image 4 | and returns a list of class probabilities. 5 | """ 6 | from serveit.server import ModelServer 7 | from serveit.utils import get_bytes_to_image_callback 8 | 9 | import torchvision.models as models 10 | import torchvision.transforms as transforms 11 | import torch 12 | 13 | from flask import request 14 | import requests 15 | 16 | # URL for ImageNet labels in JSON 17 | LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json' 18 | 19 | # parse labels into lookup 20 | labels = { 21 | int(key): value for (key, value) 22 | in requests.get(LABELS_URL).json().items() 23 | } 24 | 25 | # load Resnet50 model pretrained on ImageNet 26 | model = models.resnet50(pretrained=True) 27 | model.eval() 28 | 29 | 30 | # define a loader callback for the API to fetch the relevant data and 31 | # convert to a format expected by the prediction function 32 | def loader(): 33 | """Load image from URL, and preprocess for Resnet.""" 34 | url = request.args.get('url') # read image URL as a request URL param 35 | response = requests.get(url) # make request to static image file 36 | return response.content 37 | 38 | # define preprocessing callback chain 39 | preprocessor = [ 40 | get_bytes_to_image_callback(image_dims=(224, 224)), # convert bytes to image of size 224 x 224 41 | lambda img: torch.from_numpy(img.swapaxes(3, 1).swapaxes(2, 3).copy()) / 255, # convert to tensor, rescale 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalize pixel intensities 43 | torch.autograd.Variable, # convert to PyTorch Variable 44 | ] 45 | 46 | 47 | # define a postprocessor callback for the API to transform the model predictions 48 | def postprocessor(prediction): 49 | """Map prediction tensor to labels.""" 50 | prediction = prediction.data.numpy()[0] 51 | top_predictions = prediction.argsort()[-3:][::-1] 52 | return [labels[prediction] for prediction in top_predictions] 53 | 54 | # deploy model to a ModelServer 55 | server = ModelServer( 56 | model, 57 | model, 58 | data_loader=loader, 59 | preprocessor=preprocessor, 60 | postprocessor=postprocessor, 61 | to_numpy=False 62 | ) 63 | 64 | # start API 65 | server.serve() 66 | -------------------------------------------------------------------------------- /examples/sklearn_boston_linear_regression.py: -------------------------------------------------------------------------------- 1 | """Sample ServeIt Scikit-Learn server.""" 2 | from sklearn.datasets import load_boston 3 | from sklearn.linear_model import LinearRegression 4 | from serveit.server import ModelServer 5 | 6 | # fit a model on the Boston housing dataset 7 | data = load_boston() 8 | reg = LinearRegression() 9 | reg.fit(data.data, data.target) 10 | 11 | 12 | def validator(input_data): 13 | """Simple model input validator. 14 | 15 | Validator ensures the input data array is 16 | - two dimensional 17 | - has the correct number of features. 18 | """ 19 | global data 20 | # check num dims 21 | if input_data.ndim != 2: 22 | return False, 'Data should have two dimensions.' 23 | # check number of columns 24 | if input_data.shape[1] != data.data.shape[1]: 25 | reason = '{} features required, {} features provided'.format( 26 | data.data.shape[1], input_data.shape[1]) 27 | return False, reason 28 | # validation passed 29 | return True, None 30 | 31 | # deploy model to a SkLearnServer 32 | server = ModelServer(reg, reg.predict, validator) 33 | 34 | # add informational endpoints 35 | server.create_info_endpoint('features', data.feature_names) 36 | 37 | # start API 38 | server.serve() 39 | -------------------------------------------------------------------------------- /examples/sklearn_iris_logistic_regression.py: -------------------------------------------------------------------------------- 1 | """Sample ServeIt Scikit-Learn server.""" 2 | from sklearn.datasets import load_iris 3 | from sklearn.linear_model import LogisticRegression 4 | from serveit.server import ModelServer 5 | 6 | # fit a model on the Iris dataset 7 | data = load_iris() 8 | clf = LogisticRegression() 9 | clf.fit(data.data, data.target) 10 | 11 | 12 | def validator(input_data): 13 | """Simple model input validator. 14 | 15 | Validator ensures the input data array is 16 | - two dimensional 17 | - has the correct number of features. 18 | """ 19 | global data 20 | # check num dims 21 | if input_data.ndim != 2: 22 | return False, 'Data should have two dimensions.' 23 | # check number of columns 24 | if input_data.shape[1] != data.data.shape[1]: 25 | reason = '{} features required, {} features provided'.format( 26 | data.data.shape[1], input_data.shape[1]) 27 | return False, reason 28 | # validation passed 29 | return True, None 30 | 31 | # deploy model to a SkLearnServer 32 | server = ModelServer(clf, clf.predict, validator) 33 | 34 | # add informational endpoints 35 | server.create_info_endpoint('features', data.feature_names) 36 | server.create_info_endpoint('target_labels', data.target_names.tolist()) 37 | 38 | # start API 39 | server.serve() 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==0.12.3 2 | scikit-learn==0.19.1 3 | scipy==1.0.0 4 | numpy==1.13.3 5 | Flask-RESTful==0.3.6 6 | gunicorn==19.7.1 7 | tensorflow==1.5.0 8 | Keras==2.1.4 9 | meinheld==0.6.1 10 | codacy-coverage==1.3.10 11 | coverage==4.5.1 12 | h5py==2.7.1 13 | Pillow==5.0.0 14 | -------------------------------------------------------------------------------- /serveit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/serveit/__init__.py -------------------------------------------------------------------------------- /serveit/config.py: -------------------------------------------------------------------------------- 1 | """Static and environment variables.""" 2 | from os import getenv 3 | 4 | WSGI_HOST = getenv('WSGI_HOST', '127.0.0.1') 5 | WSGI_PORT = int(getenv('WSGI_PORT', 5000)) 6 | -------------------------------------------------------------------------------- /serveit/log_utils.py: -------------------------------------------------------------------------------- 1 | """Logger setup.""" 2 | import logging 3 | from os import getenv 4 | 5 | logging.basicConfig(level=logging.WARNING) 6 | 7 | 8 | def get_logger(name): 9 | """Get a logger with the specified name.""" 10 | logger = logging.getLogger(name) 11 | logger.setLevel(getenv('LOGLEVEL', 'INFO')) 12 | return logger 13 | -------------------------------------------------------------------------------- /serveit/server.py: -------------------------------------------------------------------------------- 1 | """Base class for serving predictions.""" 2 | from flask import Flask, jsonify 3 | from flask_restful import Resource, Api 4 | import numpy as np 5 | 6 | from .utils import make_serializable, json_numpy_loader 7 | from .log_utils import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def exception_log_and_respond(exception, logger, message, status_code): 13 | """Log an error and send jsonified respond.""" 14 | logger.error(message, exc_info=True) 15 | return make_response( 16 | message, 17 | status_code, 18 | dict(exception_type=type(exception).__name__, exception_message=str(exception)), 19 | ) 20 | 21 | 22 | def make_response(message, status_code, details=None): 23 | """Make a jsonified response with specified message and status code.""" 24 | response_body = dict(message=message) 25 | if details: 26 | response_body['details'] = details 27 | response = jsonify(response_body) 28 | response.status_code = status_code 29 | return response 30 | 31 | 32 | class ModelServer(object): 33 | """Easy deploy class.""" 34 | 35 | def __init__( 36 | self, 37 | model, 38 | predict, 39 | input_validation=lambda data: (True, None), 40 | data_loader=json_numpy_loader, 41 | preprocessor=lambda x: x, 42 | postprocessor=make_serializable, 43 | to_numpy=True): 44 | """Initialize class with prediction function. 45 | 46 | Arguments: 47 | - predict (fn): function that takes a numpy array of features as input, 48 | and returns a prediction of targets 49 | - input_validation (fn): takes a numpy array as input; 50 | returns True if validation passes and False otherwise 51 | - data_loader (fn): reads flask request and returns data preprocessed to be 52 | used in the `predict` method 53 | - postprocessor (fn): transforms the predictions from the `predict` method 54 | """ 55 | self.model = model 56 | self.predict = predict 57 | self.data_loader = data_loader 58 | self.preprocessor = preprocessor 59 | self.postprocessor = postprocessor 60 | self.app = Flask('{}_{}'.format(self.__class__.__name__, type(predict).__name__)) 61 | self.api = Api(self.app, catch_all_404s=True) 62 | self._create_prediction_endpoint( 63 | data_loader=data_loader, 64 | input_validation=input_validation, 65 | preprocessor=preprocessor, 66 | postprocessor=postprocessor, 67 | to_numpy=to_numpy, 68 | ) 69 | logger.info('Model predictions registered to endpoint /predictions (available via POST)') 70 | self.app.logger.setLevel(logger.level) # TODO: separate configuration for API loglevel 71 | self._create_model_info_endpoint() 72 | 73 | def __repr__(self): 74 | """String representation.""" 75 | return ''.format(type(self.predict).__name__) 76 | 77 | def _create_prediction_endpoint( 78 | self, 79 | to_numpy=True, 80 | data_loader=json_numpy_loader, 81 | preprocessor=lambda x: x, 82 | input_validation=lambda data: (True, None), 83 | postprocessor=lambda x: x, 84 | make_serializable_post=True): 85 | """Create an endpoint to serve predictions. 86 | 87 | Arguments: 88 | - input_validation (fn): takes a numpy array as input; 89 | returns True if validation passes and False otherwise 90 | - data_loader (fn): reads flask request and returns data preprocessed to be 91 | used in the `predict` method 92 | - postprocessor (fn): transforms the predictions from the `predict` method 93 | """ 94 | # copy instance variables to local scope for resource class 95 | predict = self.predict 96 | logger = self.app.logger 97 | 98 | # create restful resource 99 | class Predictions(Resource): 100 | @staticmethod 101 | def post(): 102 | # read data from API request 103 | try: 104 | data = data_loader() 105 | except Exception as e: 106 | return exception_log_and_respond(e, logger, 'Unable to fetch data', 400) 107 | 108 | try: 109 | if hasattr(preprocessor, '__iter__'): 110 | for preprocessor_step in preprocessor: 111 | data = preprocessor_step(data) 112 | else: 113 | data = preprocessor(data) # preprocess data 114 | data = np.array(data) if to_numpy else data # convert to numpy 115 | except Exception as e: 116 | return exception_log_and_respond(e, logger, 'Could not preprocess data', 400) 117 | 118 | # sanity check using user defined callback (default is no check) 119 | validation_pass, validation_reason = input_validation(data) 120 | if not validation_pass: 121 | # if validation fails, log the reason code, log the data, and send a 400 response 122 | validation_message = 'Input validation failed with reason: {}'.format(validation_reason) 123 | logger.error(validation_message) 124 | logger.debug('Data: {}'.format(data)) 125 | return make_response(validation_message, 400) 126 | 127 | try: 128 | prediction = predict(data) 129 | except Exception as e: 130 | # log exception and return the message in a 500 response 131 | logger.debug('Data: {}'.format(data)) 132 | return exception_log_and_respond(e, logger, 'Unable to make prediction', 500) 133 | logger.debug(prediction) 134 | try: 135 | # preprocess data 136 | if hasattr(postprocessor, '__iter__'): 137 | for postprocessor_step in postprocessor: 138 | prediction = postprocessor_step(prediction) 139 | else: 140 | prediction = postprocessor(prediction) 141 | 142 | # cast to serializable types 143 | if make_serializable_post: 144 | return make_serializable(prediction) 145 | else: 146 | return prediction 147 | 148 | except Exception as e: 149 | return exception_log_and_respond(e, logger, 'Postprocessing failed', 500) 150 | 151 | # map resource to endpoint 152 | self.api.add_resource(Predictions, '/predictions') 153 | 154 | def create_info_endpoint(self, name, data): 155 | """Create an endpoint to serve info GET requests.""" 156 | # make sure data is serializable 157 | data = make_serializable(data) 158 | 159 | # create generic restful resource to serve static JSON data 160 | class InfoBase(Resource): 161 | @staticmethod 162 | def get(): 163 | return data 164 | 165 | def info_factory(name): 166 | """Return an Info derivative resource.""" 167 | class NewClass(InfoBase): 168 | pass 169 | NewClass.__name__ = "{}_{}".format(name, InfoBase.__name__) 170 | return NewClass 171 | 172 | path = '/info/{}'.format(name) 173 | self.api.add_resource(info_factory(name), path) 174 | logger.info('Regestered informational resource to {} (available via GET)'.format(path)) 175 | logger.debug('Endpoint {} will now serve the following static data:\n{}'.format(path, data)) 176 | 177 | def _create_model_info_endpoint(self, path='/info/model'): 178 | """Create an endpoint to serve info GET requests.""" 179 | model = self.model 180 | 181 | # parse model details 182 | model_details = {} 183 | for key, value in model.__dict__.items(): 184 | model_details[key] = make_serializable(value) 185 | 186 | # create generic restful resource to serve model information as JSON 187 | class ModelInfo(Resource): 188 | @staticmethod 189 | def get(): 190 | return model_details 191 | 192 | self.api.add_resource(ModelInfo, path) 193 | self.app.logger.info('Regestered informational resource to {} (available via GET)'.format(path)) 194 | self.app.logger.debug('Endpoint {} will now serve the following static data:\n{}'.format(path, model_details)) 195 | 196 | def serve(self, host='127.0.0.1', port=5000): 197 | """Serve predictions as an API endpoint.""" 198 | from meinheld import server, middleware 199 | # self.app.run(host=host, port=port) 200 | server.listen((host, port)) 201 | server.run(middleware.WebSocketMiddleware(self.app)) 202 | 203 | def get_app(self): 204 | """Return the underlying Flask app.""" 205 | return self.app 206 | -------------------------------------------------------------------------------- /serveit/utils.py: -------------------------------------------------------------------------------- 1 | """Utility methods.""" 2 | import json 3 | from flask import request 4 | 5 | from .log_utils import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | 10 | def make_serializable(data): 11 | """Ensure data is serializable.""" 12 | if is_serializable(data): 13 | return data 14 | 15 | # if numpy array convert to list 16 | try: 17 | return data.tolist() 18 | except AttributeError: 19 | pass 20 | except Exception as e: 21 | logger.debug('{} exception ({}): {}'.format(type(e).__name__, e, data)) 22 | 23 | # try serializing each child element 24 | if isinstance(data, dict): 25 | return {key: make_serializable(value) for key, value in data.items()} 26 | try: 27 | return [make_serializable(element) for element in data] 28 | except TypeError: # not iterable 29 | pass 30 | except Exception: 31 | logger.debug('Could not serialize {}; converting to string'.format(data)) 32 | 33 | # last resort: convert to string 34 | return str(data) 35 | 36 | 37 | def is_serializable(data): 38 | """Check if data is serializable.""" 39 | try: 40 | json.dumps(data) 41 | return True 42 | except TypeError: 43 | return False 44 | 45 | 46 | def json_numpy_loader(): 47 | """Load data from JSON request and convert to numpy array.""" 48 | data = request.get_json() 49 | logger.debug('Received JSON data of length {:,}'.format(len(data))) 50 | return data 51 | 52 | 53 | def get_bytes_to_image_callback(image_dims=(224, 224)): 54 | """Return a callback to process image bytes for ImageNet.""" 55 | from keras.preprocessing import image 56 | import numpy as np 57 | from PIL import Image 58 | from io import BytesIO 59 | 60 | def preprocess_image_bytes(data_bytes): 61 | """Process image bytes for ImageNet.""" 62 | try: 63 | img = Image.open(BytesIO(data_bytes)) # open image 64 | except OSError as e: 65 | raise ValueError('Please provide a raw image') 66 | img = img.resize(image_dims, Image.ANTIALIAS) # model requires 224x224 pixels 67 | x = image.img_to_array(img) # convert image to numpy array 68 | x = np.expand_dims(x, axis=0) # model expects dim 0 to be iterable across images 69 | return x 70 | return preprocess_image_bytes 71 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | # This flag says to generate wheels that support both Python 2 and Python 3 | # 3. If your code will not run unchanged on both Python 2 and 3, you will 4 | # need to generate separate wheels for each Python version that you 5 | # support. 6 | universal=1 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup module.""" 2 | from setuptools import setup, find_packages 3 | from codecs import open 4 | from os import path 5 | import logging 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | logger = logging.getLogger(__name__) 9 | 10 | # Get the long description from the README file 11 | try: 12 | # try converting readme markdown formatting to rst (supported by pypi) 13 | import pypandoc 14 | long_description = pypandoc.convert('README.md', 'rst') 15 | except(IOError, ImportError) as e: 16 | logger.error('{} error: {}'.format(type(e).__name__, e)) 17 | long_description = open('README.md').read() 18 | 19 | 20 | # Arguments marked as "Required" below must be included for upload to PyPI. 21 | # Fields marked as "Optional" may be commented out. 22 | 23 | setup( 24 | # This is the name of your project. The first time you publish this 25 | # package, this name will be registered for you. It will determine how 26 | # users can install this project, e.g.: 27 | # 28 | # $ pip install sampleproject 29 | # 30 | # And where it will live on PyPI: https://pypi.org/project/sampleproject/ 31 | # 32 | # There are some restrictions on what makes a valid project name 33 | # specification here: 34 | # https://packaging.python.org/specifications/core-metadata/#name 35 | name='ServeIt', # Required 36 | 37 | # Versions should comply with PEP 440: 38 | # https://www.python.org/dev/peps/pep-0440/ 39 | # 40 | # For a discussion on single-sourcing the version across setup.py and the 41 | # project code, see 42 | # https://packaging.python.org/en/latest/single_source_version.html 43 | version='0.0.9', # Required 44 | 45 | # This is a one-line description or tagline of what your project does. This 46 | # corresponds to the "Summary" metadata field: 47 | # https://packaging.python.org/specifications/core-metadata/#summary 48 | description='Machine learning prediction serving', # Required 49 | 50 | # This is an optional longer description of your project that represents 51 | # the body of text which users will see when they visit PyPI. 52 | # 53 | # Often, this is the same as your README, so you can just read it in from 54 | # that file directly (as we have already done above) 55 | # 56 | # This field corresponds to the "Description" metadata field: 57 | # https://packaging.python.org/specifications/core-metadata/#description-optional 58 | long_description=long_description, # Optional 59 | 60 | # This should be a valid link to your project's main homepage. 61 | # 62 | # This field corresponds to the "Home-Page" metadata field: 63 | # https://packaging.python.org/specifications/core-metadata/#home-page-optional 64 | url='https://github.com/rtlee9/serveit', # Optional 65 | 66 | # This should be your name or the name of the organization which owns the 67 | # project. 68 | author='Ryan Lee', # Optional 69 | 70 | # This should be a valid email address corresponding to the author listed 71 | # above. 72 | author_email='ryantlee9@gmail.com', # Optional 73 | 74 | # Classifiers help users find your project by categorizing it. 75 | # 76 | # For a list of valid classifiers, see 77 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 78 | classifiers=[ # Optional 79 | # How mature is this project? Common values are 80 | # 3 - Alpha 81 | # 4 - Beta 82 | # 5 - Production/Stable 83 | 'Development Status :: 3 - Alpha', 84 | 85 | # Indicate who your project is intended for 86 | 'Intended Audience :: Developers', 87 | 'Topic :: Software Development :: Build Tools', 88 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 89 | 'Topic :: Internet :: WWW/HTTP :: WSGI :: Server', 90 | 91 | # Pick your license as you wish 92 | 'License :: OSI Approved :: MIT License', 93 | 94 | # Specify the Python versions you support here. In particular, ensure 95 | # that you indicate whether you support Python 2, Python 3 or both. 96 | # 'Programming Language :: Python :: 2', 97 | 'Programming Language :: Python :: 2.7', 98 | # 'Programming Language :: Python :: 3', 99 | # 'Programming Language :: Python :: 3.4', 100 | # 'Programming Language :: Python :: 3.5', 101 | 'Programming Language :: Python :: 3.6', 102 | 'Programming Language :: Python :: 3.7', 103 | ], 104 | 105 | # This field adds keywords for your project which will appear on the 106 | # project page. What does your project relate to? 107 | # 108 | # Note that this is a string of words separated by whitespace, not a list. 109 | keywords='machine learning model deployment serving API RESTful', # Optional 110 | 111 | # You can just specify package directories manually here if your project is 112 | # simple. Or you can use find_packages(). 113 | # 114 | # Alternatively, if you just want to distribute a single Python file, use 115 | # the `py_modules` argument instead as follows, which will expect a file 116 | # called `my_module.py` to exist: 117 | # 118 | # py_modules=["my_module"], 119 | # 120 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required 121 | 122 | # This field lists other packages that your project depends on to run. 123 | # Any package you put here will be installed by pip when your project is 124 | # installed, so they must be valid existing projects. 125 | # 126 | # For an analysis of "install_requires" vs pip's requirements files see: 127 | # https://packaging.python.org/en/latest/requirements.html 128 | install_requires=['flask', 'flask-restful', 'meinheld'], # Optional 129 | 130 | # List additional groups of dependencies here (e.g. development 131 | # dependencies). Users will be able to install these using the "extras" 132 | # syntax, for example: 133 | # 134 | # $ pip install sampleproject[dev] 135 | # 136 | # Similar to `install_requires` above, these must be valid existing 137 | # projects. 138 | extras_require={ # Optional 139 | 'dev': ['check-manifest'], 140 | 'test': ['coverage'], 141 | }, 142 | 143 | # To provide executable scripts, use entry points in preference to the 144 | # "scripts" keyword. Entry points provide cross-platform support and allow 145 | # `pip` to create the appropriate form of executable for the target 146 | # platform. 147 | # 148 | # For example, the following would provide a command called `sample` which 149 | # executes the function `main` from this package when invoked: 150 | entry_points={ # Optional 151 | }, 152 | ) 153 | -------------------------------------------------------------------------------- /tests/SuccessKid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/SuccessKid.jpg -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/__init__.py -------------------------------------------------------------------------------- /tests/keras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/keras/__init__.py -------------------------------------------------------------------------------- /tests/keras/test_server.py: -------------------------------------------------------------------------------- 1 | """Test ModelServer with Keras models.""" 2 | import unittest 3 | from sklearn.datasets import load_boston 4 | 5 | from tests.test_server import ModelServerTest 6 | 7 | 8 | class BostonKerasNNTest(unittest.TestCase, ModelServerTest): 9 | """Test ModelServer with Keras nerual net fitted on housing data.""" 10 | 11 | def setUp(self): 12 | """Unittest set up.""" 13 | data = load_boston() 14 | self.model = self.get_model(data.data.shape[1]) 15 | super(BostonKerasNNTest, self)._setup(self.model, self.model.fit, data) 16 | 17 | @staticmethod 18 | def get_model(input_dim): 19 | """Create and compile simple model.""" 20 | from keras.models import Sequential 21 | from keras.layers import Dense 22 | model = Sequential() 23 | model.add(Dense(100, input_dim=input_dim, activation='sigmoid')) 24 | model.add(Dense(1)) 25 | model.compile(loss='mean_squared_error', optimizer='SGD') 26 | return model 27 | 28 | 29 | if __name__ == '__main__': 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /tests/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/pytorch/__init__.py -------------------------------------------------------------------------------- /tests/pytorch/test_server.py: -------------------------------------------------------------------------------- 1 | """Test ModelServer with Keras models.""" 2 | import unittest 3 | from sklearn.datasets import load_boston 4 | import torch 5 | from torch.autograd import Variable 6 | from torch import optim 7 | import numpy as np 8 | 9 | from tests.test_server import ModelServerTest 10 | 11 | 12 | class BostonPytorchNNTest(unittest.TestCase, ModelServerTest): 13 | """Test ModelServer with Keras nerual net fitted on housing data.""" 14 | 15 | def setUp(self): 16 | """Unittest set up.""" 17 | data = load_boston() 18 | self.model = self.get_model(data.data.shape[1], 1) 19 | 20 | # define preprocessing callback chain 21 | preprocessor = [ 22 | np.array, # convert to numpy array 23 | torch.from_numpy, # convert to tensor 24 | lambda tensor: tensor.type(torch.FloatTensor), # convert to FloatTensor 25 | torch.autograd.Variable, # convert to PyTorch Variable 26 | ] 27 | 28 | super(BostonPytorchNNTest, self)._setup( 29 | self.model, self.train, data, preprocessor=preprocessor, predict=self.model, to_numpy=False, postprocessor=lambda variable: variable.data.numpy()) 30 | 31 | def train(self, x, y): 32 | loss = torch.nn.MSELoss(size_average=True) 33 | optimizer = optim.SGD(self.model.parameters(), lr=1e-8) 34 | x = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=False) 35 | y = Variable(torch.from_numpy(y).type(torch.FloatTensor), requires_grad=False) 36 | 37 | for i in range(100): 38 | # Reset gradient 39 | optimizer.zero_grad() 40 | 41 | # Forward 42 | fx = self.model.forward(x) 43 | output = loss.forward(fx, y) 44 | 45 | # Backward 46 | output.backward() 47 | 48 | # Update parameters 49 | optimizer.step() 50 | 51 | return output.data[0] 52 | 53 | @staticmethod 54 | def get_model(input_dim, output_dim): 55 | """Create and compile simple model.""" 56 | model = torch.nn.Sequential() 57 | model.add_module("linear", torch.nn.Linear(input_dim, output_dim, bias=False)) 58 | return model 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /tests/sklearn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/sklearn/__init__.py -------------------------------------------------------------------------------- /tests/sklearn/test_server.py: -------------------------------------------------------------------------------- 1 | """Test ModelServer with Scikit-Learn models.""" 2 | import unittest 3 | from sklearn.datasets import load_iris, load_boston 4 | 5 | from tests.test_server import ModelServerTest 6 | 7 | 8 | class IrisLogisticRegressionTest(unittest.TestCase, ModelServerTest): 9 | """Test ModelServer with LogisticRegression fitted on iris data.""" 10 | 11 | def setUp(self): 12 | """Unittest set up.""" 13 | from sklearn.linear_model import LogisticRegression 14 | self.model = LogisticRegression() 15 | super(IrisLogisticRegressionTest, self)._setup(self.model, self.model.fit, load_iris()) 16 | 17 | 18 | class IrisSvcTest(unittest.TestCase, ModelServerTest): 19 | """Test ModelServer with SVC fitted on iris data.""" 20 | 21 | def setUp(self): 22 | """Unittest set up.""" 23 | from sklearn.svm import SVC 24 | self.model = SVC() 25 | super(IrisSvcTest, self)._setup(self.model, self.model.fit, load_iris()) 26 | 27 | 28 | class IrisRfcTest(unittest.TestCase, ModelServerTest): 29 | """Test ModelServer with RandomForestClassifier fitted on iris data.""" 30 | 31 | def setUp(self): 32 | """Unittest set up.""" 33 | from sklearn.ensemble import RandomForestClassifier 34 | self.model = RandomForestClassifier() 35 | super(IrisRfcTest, self)._setup(self.model, self.model.fit, load_iris()) 36 | 37 | 38 | class BostonLinearRegressionTest(unittest.TestCase, ModelServerTest): 39 | """Test ModelServer with LogisticRegression fitted on housing data.""" 40 | 41 | def setUp(self): 42 | """Unittest set up.""" 43 | from sklearn.linear_model import LinearRegression 44 | self.model = LinearRegression() 45 | super(BostonLinearRegressionTest, self)._setup(self.model, self.model.fit, load_boston()) 46 | 47 | 48 | class BostonSvrTest(unittest.TestCase, ModelServerTest): 49 | """Test ModelServer with SVR fitted on housing data.""" 50 | 51 | def setUp(self): 52 | """Unittest set up.""" 53 | from sklearn.svm import SVR 54 | self.model = SVR() 55 | super(BostonSvrTest, self)._setup(self.model, self.model.fit, load_boston()) 56 | 57 | 58 | class BostonRfrTest(unittest.TestCase, ModelServerTest): 59 | """Test ModelServer with LogisticRegression fitted on housing data.""" 60 | 61 | def setUp(self): 62 | """Unittest set up.""" 63 | from sklearn.ensemble import RandomForestRegressor 64 | self.model = RandomForestRegressor() 65 | super(BostonRfrTest, self)._setup(self.model, self.model.fit, load_boston()) 66 | 67 | if __name__ == '__main__': 68 | unittest.main() 69 | -------------------------------------------------------------------------------- /tests/test_callback_utils.py: -------------------------------------------------------------------------------- 1 | """Test utility methods.""" 2 | import unittest 3 | 4 | from serveit.utils import get_bytes_to_image_callback 5 | 6 | 7 | class CallbackTest(unittest.TestCase): 8 | """Test utility callbacks.""" 9 | 10 | def _test_get_bytes_to_image_callback(self, image_dims): 11 | """Convert image bytes to an image for ImageNet.""" 12 | with open('tests/SuccessKid.jpg', 'rb') as f: 13 | image_bytes = f.read() 14 | bytes_to_image_callback = get_bytes_to_image_callback(image_dims=image_dims) 15 | image = bytes_to_image_callback(image_bytes) 16 | self.assertEqual((1, *image_dims, 3), image.shape) 17 | 18 | def test_get_bytes_to_image_callback_224_224(self): 19 | """Convert image bytes to 224x224 image for ImageNet.""" 20 | self._test_get_bytes_to_image_callback((224, 224)) 21 | 22 | def test_get_bytes_to_image_callback_128_128(self): 23 | """Convert image bytes to 128x128 image for ImageNet.""" 24 | self._test_get_bytes_to_image_callback((128, 128)) 25 | 26 | if __name__ == '__main__': 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_serialization_utils.py: -------------------------------------------------------------------------------- 1 | """Test utility methods.""" 2 | import unittest 3 | import numpy as np 4 | 5 | from serveit.utils import is_serializable, make_serializable 6 | 7 | 8 | class SeralizationTest(unittest.TestCase): 9 | """Test serialization.""" 10 | 11 | def setUp(self): 12 | """Unittest setup.""" 13 | self.serializable_data = [ 14 | [1, 2, 3], 15 | ['a', 'b', 'c'], 16 | {'a': 1, 'b': 'c'}, 17 | {'a': [1, 2, 3], 'b': 'c'}, 18 | ] 19 | 20 | class DummyClass(object): 21 | def __repr__(self): 22 | return 'XpBheIxCcm' 23 | 24 | self.dummy_class = DummyClass() 25 | self.unserializable_data = [ 26 | np.array([1, 2, 3]), 27 | np.array(['a', 'b', 'c']), 28 | self.dummy_class, 29 | ] 30 | 31 | def test_is_serializable(self): 32 | """Test is_serializable against example data.""" 33 | for data in self.serializable_data: 34 | self.assertTrue(is_serializable(data)) 35 | for data in self.unserializable_data: 36 | self.assertFalse(is_serializable(data)) 37 | 38 | def test_make_serializable_data_serializable(self): 39 | """make_serializable should return the same object if serializable.""" 40 | for data in self.serializable_data: 41 | self.assertEqual(make_serializable(data), data) 42 | 43 | def test_make_serializable_numpy_data(self): 44 | """make_serializable should cast numpy array to list.""" 45 | self.assertEqual(make_serializable(np.array([1, 2, 3])), [1, 2, 3]) 46 | self.assertEqual(make_serializable(np.array(['a', 'b', 'c'])), ['a', 'b', 'c']) 47 | 48 | def test_make_serializable_object_data(self): 49 | """make_serializable should return an objects __repr__ if no `tolist` method.""" 50 | self.assertEqual(make_serializable(self.dummy_class), 'XpBheIxCcm') 51 | 52 | 53 | if __name__ == '__main__': 54 | unittest.main() 55 | -------------------------------------------------------------------------------- /tests/test_server.py: -------------------------------------------------------------------------------- 1 | """Base ModelServer test class.""" 2 | import json 3 | import numpy as np 4 | 5 | from serveit.server import ModelServer 6 | 7 | 8 | class ModelServerTest(object): 9 | """Base class to test the prediction server. 10 | 11 | ModelServerTest should be inherited by a class that has a `model` attribute, 12 | and calls `ModelServerTest._setup()` after instantiation. That class should 13 | also inherit from `unittest.TestCase` to ensure tests are executed. 14 | """ 15 | 16 | def _setup(self, model, fit, data, predict=None, **kwargs): 17 | """Set up method to be called before each unit test. 18 | 19 | Arguments: 20 | - fit (callable): model training method; must accept args (data, target) 21 | """ 22 | self.data = data 23 | fit(self.data.data, self.data.target) 24 | self.predict = predict or self.model.predict 25 | self.server_kwargs = kwargs 26 | self.server = ModelServer(self.model, self.predict, **kwargs) 27 | self.app = self.server.app.test_client() 28 | 29 | @staticmethod 30 | def _prediction_post(app, data): 31 | """Make a POST request to `app` with JSON body `data`.""" 32 | return app.post( 33 | '/predictions', 34 | headers={'Content-Type': 'application/json'}, 35 | data=json.dumps(data), 36 | ) 37 | 38 | def _get_sample_data(self, n=100): 39 | """Return a sample of size n of self.data.""" 40 | sample_idx = np.random.randint(self.data.data.shape[0], size=n) 41 | return self.data.data[sample_idx, :] 42 | 43 | def test_404_media(self): 44 | """Make sure API serves 404 response with JSON.""" 45 | response = self.app.get('/fake-endpoint') 46 | self.assertEqual(response.status_code, 404) 47 | response_data_raw = response.get_data() 48 | self.assertIsNotNone(response_data_raw) 49 | response_data = json.loads(response_data_raw) 50 | self.assertGreater(len(response_data), 0) 51 | 52 | def test_features_info_none(self): 53 | """Verify 404 response if '/info/features' endpoint not yet created.""" 54 | response = self.app.get('/info/features') 55 | self.assertEqual(response.status_code, 404) 56 | 57 | def test_features_info(self): 58 | """Test features info endpoint.""" 59 | self.server.create_info_endpoint('features', self.data.feature_names) 60 | app = self.server.app.test_client() 61 | response = app.get('/info/features') 62 | response_data = json.loads(response.get_data()) 63 | self.assertEqual(len(response_data), self.data.data.shape[1]) 64 | try: 65 | self.assertCountEqual(response_data, self.data.feature_names) 66 | except AttributeError: # Python 2 67 | self.assertItemsEqual(response_data, self.data.feature_names) 68 | 69 | def test_target_labels_info_none(self): 70 | """Verify 404 response if '/info/target_labels' endpoint not yet created.""" 71 | response = self.app.get('/info/target_labels') 72 | self.assertEqual(response.status_code, 404) 73 | 74 | def test_target_labels_info(self): 75 | """Test target labels info endpoint.""" 76 | if not hasattr(self.data, 'target_names'): 77 | return 78 | self.server.create_info_endpoint('target_labels', self.data.target_names.tolist()) 79 | app = self.server.app.test_client() 80 | response = app.get('/info/target_labels') 81 | response_data = json.loads(response.get_data()) 82 | self.assertEqual(len(response_data), self.data.target_names.shape[0]) 83 | try: 84 | self.assertCountEqual(response_data, self.data.target_names) 85 | except AttributeError: # Python 2 86 | self.assertItemsEqual(response_data, self.data.target_names) 87 | 88 | def test_predictions(self): 89 | """Test predictions endpoint.""" 90 | sample_data = self._get_sample_data() 91 | response = self._prediction_post(self.app, sample_data.tolist()) 92 | response_data = json.loads(response.get_data()) 93 | self.assertEqual(len(response_data), len(sample_data)) 94 | if self.data.target.ndim > 1: 95 | # for multiclass each prediction should be one of the training labels 96 | for prediction in response_data: 97 | self.assertIn(prediction, self.data.target) 98 | else: 99 | # the average regression prediction for a sample of data should be similar 100 | # to the population mean 101 | # TODO: remove variance from this test (i.e., no chance of false negative) 102 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1 103 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1) 104 | 105 | def test_input_validation(self): 106 | """Add simple input validator and make sure it triggers.""" 107 | # model input validator 108 | def feature_count_check(data): 109 | try: 110 | # convert PyTorch variables to numpy arrays 111 | data = data.data.numpy() 112 | except: 113 | pass 114 | # check num dims 115 | if data.ndim != 2: 116 | return False, 'Data should have two dimensions.' 117 | # check number of columns 118 | if data.shape[1] != self.data.data.shape[1]: 119 | reason = '{} features required, {} features provided'.format( 120 | data.shape[1], self.data.data.shape[1]) 121 | return False, reason 122 | # validation passed 123 | return True, None 124 | 125 | # set up test server 126 | server = ModelServer(self.model, self.predict, feature_count_check, **self.server_kwargs) 127 | app = server.app.test_client() 128 | 129 | # generate sample data 130 | sample_data = self._get_sample_data() 131 | 132 | # post good data, verify 200 response 133 | response = self._prediction_post(app, sample_data.tolist()) 134 | self.assertEqual(response.status_code, 200) 135 | 136 | # post bad data (drop a single column), verify 400 response 137 | response = self._prediction_post(app, sample_data[:, :-1].tolist()) 138 | self.assertEqual(response.status_code, 400) 139 | response_data = json.loads(response.get_data()) 140 | expected_reason = '{} features required, {} features provided'.format( 141 | self.data.data.shape[1] - 1, self.data.data.shape[1]) 142 | self.assertIn(expected_reason, response_data['message']) 143 | 144 | def test_model_info(self): 145 | """Test model info endpoint.""" 146 | response = self.app.get('/info/model') 147 | response_data = json.loads(response.get_data()) 148 | self.assertGreater(len(response_data), 3) # TODO: expand test scope 149 | 150 | def test_data_loader(self): 151 | """Test model prediction with a custom data loader callback.""" 152 | # TODO: test alternative request method (e.g., URL params) 153 | # define custom data loader 154 | def read_json_from_dict(): 155 | from flask import request 156 | # read data as the value of the 'data' key 157 | data = request.get_json() 158 | return np.array(data['data']) 159 | 160 | # create test client 161 | server = ModelServer(self.model, self.predict, data_loader=read_json_from_dict, **self.server_kwargs) 162 | app = server.app.test_client() 163 | 164 | # generate sample data, and wrap in dict keyed by 'data' 165 | sample_data = self._get_sample_data() 166 | data_dict = dict(data=sample_data.tolist()) 167 | 168 | response = self._prediction_post(app, data_dict) 169 | response_data = json.loads(response.get_data()) 170 | self.assertEqual(len(response_data), len(sample_data)) 171 | if self.data.target.ndim > 1: 172 | # for multiclass each prediction should be one of the training labels 173 | for prediction in response_data: 174 | self.assertIn(prediction, self.data.target) 175 | else: 176 | # the average regression prediction for a sample of data should be similar 177 | # to the population mean 178 | # TODO: remove variance from this test (i.e., no chance of false negative) 179 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1 180 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1) 181 | 182 | def _update_kwargs_item(self, item, key_name, position='first'): 183 | """Prepend a method to the existing preprocessing chain, add to self's kwargs and return.""" 184 | kwargs = self.server_kwargs 185 | if key_name in self.server_kwargs: 186 | existing_items = kwargs[key_name] 187 | if not isinstance(existing_items, (list, tuple)): 188 | existing_items = [existing_items] 189 | else: 190 | existing_items = [] 191 | if position == 'first': 192 | kwargs[key_name] = [item] + existing_items 193 | if position == 'last': 194 | kwargs[key_name] = existing_items + [item] 195 | return kwargs 196 | 197 | def test_preprocessing(self): 198 | """Test predictions endpoint with custom preprocessing callback.""" 199 | # create test client with postprocessor that unraps data from a dict as the value of the 'data' key 200 | kwargs = self._update_kwargs_item(lambda d: d['data'], 'preprocessor') 201 | server = ModelServer(self.model, self.predict, **kwargs) 202 | app = server.app.test_client() 203 | 204 | # generate sample data, and wrap in dict keyed by 'data' 205 | sample_data = self._get_sample_data() 206 | data_dict = dict(data=sample_data.tolist()) 207 | 208 | response = self._prediction_post(app, data_dict) 209 | response_data = json.loads(response.get_data()) 210 | self.assertEqual(len(response_data), len(sample_data)) 211 | if self.data.target.ndim > 1: 212 | # for multiclass each prediction should be one of the training labels 213 | for prediction in response_data: 214 | self.assertIn(prediction, self.data.target) 215 | else: 216 | # the average regression prediction for a sample of data should be similar 217 | # to the population mean 218 | # TODO: remove variance from this test (i.e., no chance of false negative) 219 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1 220 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1) 221 | 222 | def test_preprocessing_list(self): 223 | """Test predictions endpoint with chained preprocessing callbacks.""" 224 | # create test client with postprocessor that unraps data from a dict as the value of the 'data' key 225 | kwargs = self._update_kwargs_item(lambda d: d['data'], 'preprocessor') 226 | kwargs['preprocessor'] = [lambda d: d['data2']] + kwargs['preprocessor'] 227 | server = ModelServer( 228 | self.model, 229 | self.predict, 230 | **kwargs 231 | ) 232 | app = server.app.test_client() 233 | 234 | # generate sample data, and wrap in dict keyed by 'data' 235 | sample_data = self._get_sample_data() 236 | data_dict = dict(data2=dict(data=sample_data.tolist())) 237 | 238 | response = self._prediction_post(app, data_dict) 239 | response_data = json.loads(response.get_data()) 240 | self.assertEqual(len(response_data), len(sample_data)) 241 | if self.data.target.ndim > 1: 242 | # for multiclass each prediction should be one of the training labels 243 | for prediction in response_data: 244 | self.assertIn(prediction, self.data.target) 245 | else: 246 | # the average regression prediction for a sample of data should be similar 247 | # to the population mean 248 | # TODO: remove variance from this test (i.e., no chance of false negative) 249 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1 250 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1) 251 | 252 | def test_postprocessing(self): 253 | """Test predictions endpoint with custom postprocessing callback.""" 254 | # create test client with postprocessor that wraps predictions in a dictionary 255 | kwargs = self._update_kwargs_item(lambda x: dict(prediction=x.tolist()), 'postprocessor', 'last') 256 | server = ModelServer(self.model, self.predict, **kwargs) 257 | app = server.app.test_client() 258 | 259 | # generate sample data 260 | sample_data = self._get_sample_data() 261 | 262 | response = self._prediction_post(app, sample_data.tolist()) 263 | response_data = json.loads(response.get_data())['prediction'] # predictions are nested under 'prediction' key 264 | self.assertEqual(len(response_data), len(sample_data)) 265 | if self.data.target.ndim > 1: 266 | # for multiclass each prediction should be one of the training labels 267 | for prediction in response_data: 268 | self.assertIn(prediction, self.data.target) 269 | else: 270 | # the average regression prediction for a sample of data should be similar 271 | # to the population mean 272 | # TODO: remove variance from this test (i.e., no chance of false negative) 273 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1 274 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1) 275 | 276 | def test_get_app(self): 277 | """Make sure get_app method returns the same app.""" 278 | self.assertEqual(self.server.get_app(), self.server.app) 279 | 280 | def test_400_no_content_type(self): 281 | """Check 400 response if no Content-Type header specified.""" 282 | response = self.app.post( 283 | '/predictions', 284 | ) 285 | self.assertEqual(response.status_code, 400) 286 | response_body = json.loads(response.get_data()) 287 | self.assertEqual(response_body['message'], 'Unable to fetch data') 288 | self.assertGreaterEqual(len(response_body['details']), 2) 289 | --------------------------------------------------------------------------------