├── .codacy.yml
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── _config.yml
├── coverage.sh
├── examples
├── README.md
├── img
│ ├── airplane.jpg
│ └── cat.jpg
├── keras_boston_neural_net.py
├── keras_imagenet_resnet50.py
├── pytorch_imagenet_resnet50.py
├── sklearn_boston_linear_regression.py
└── sklearn_iris_logistic_regression.py
├── requirements.txt
├── serveit
├── __init__.py
├── config.py
├── log_utils.py
├── server.py
└── utils.py
├── setup.cfg
├── setup.py
└── tests
├── SuccessKid.jpg
├── __init__.py
├── keras
├── __init__.py
└── test_server.py
├── pytorch
├── __init__.py
└── test_server.py
├── sklearn
├── __init__.py
└── test_server.py
├── test_callback_utils.py
├── test_serialization_utils.py
└── test_server.py
/.codacy.yml:
--------------------------------------------------------------------------------
1 | ---
2 | exclude_paths:
3 | - tests/**
4 | - examples/**
5 | - *.py
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # VIM files
2 | *.swp
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "2.7"
4 | - "3.6"
5 | install:
6 | - pip install -r requirements.txt
7 | script:
8 | - python -m pytest tests --ignore tests/keras --ignore tests/pytorch --ignore tests/test_callback_utils.py
9 | after_script: ./coverage.sh
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2010-2018 Google, Inc. http://angularjs.org
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ServeIt
2 | [](https://travis-ci.org/rtlee9/serveit)
3 | [](https://www.codacy.com/app/ryantlee9/serveit)
4 | [](https://www.codacy.com/app/ryantlee9/serveit)
5 | [](https://badge.fury.io/py/ServeIt)
6 |
7 | ServeIt lets you serve model predictions and supplementary information from a RESTful API using your favorite Python ML library in as little as one line of code:
8 |
9 | ```python
10 | from serveit.server import ModelServer
11 | from sklearn.linear_model import LogisticRegression
12 | from sklearn.datasets import load_iris
13 |
14 | # fit logistic regression on Iris data
15 | clf = LogisticRegression()
16 | data = load_iris()
17 | clf.fit(data.data, data.target)
18 |
19 | # initialize server with a model and start serving predictions
20 | ModelServer(clf, clf.predict).serve()
21 | ```
22 |
23 | Your new API is now accepting `POST` requests at `localhost:5000/predictions`! Please see the [examples](examples) directory for detailed examples across domains (e.g., regression, image classification), including live examples.
24 |
25 | #### Features
26 | Current ServeIt features include:
27 |
28 | 1. Model inference serving via RESTful API endpoint
29 | 1. Extensible library for inference-time data loading, preprocessing, input validation, and postprocessing
30 | 1. Supplementary information endpoint creation
31 | 1. Automatic JSON serialization of responses
32 | 1. Configurable request and response logging (work in progress)
33 |
34 | #### Supported libraries
35 | The following libraries are currently supported:
36 | * Scikit-Learn
37 | * Keras
38 | * PyTorch
39 |
40 | ## Installation: Python 2.7 and Python 3.6
41 | Installation is easy with pip: `pip install serveit`
42 |
43 | ## Building
44 | You can build locally with: `python setup.py`
45 |
46 | ## License
47 | [MIT](LICENSE.md)
48 |
49 | Please consider buying me a coffee if you like my work:
50 |
51 |
52 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/coverage.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | coverage run --source serveit -a -m tests.test_utils
3 | coverage run --source serveit -a -m tests.sklearn.test_server
4 | coverage xml --omit serveit/config.py
5 | python-codacy-coverage -r coverage.xml
6 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # ServeIt examples
2 |
3 | ## Basic example: Iris predictions with Scikit-learn
4 |
5 | Let's train and deploy a logistic regression model to classify irises. We'll start by fitting a model:
6 | ```python
7 | from sklearn.datasets import load_iris
8 | from sklearn.linear_model import LogisticRegression
9 |
10 | # fit a model on the Iris dataset
11 | data = load_iris()
12 | clf = LogisticRegression()
13 | clf.fit(data.data, data.target)
14 | ```
15 | Now we can serve our trained model:
16 | ```python
17 | from serveit.server import ModelServer
18 |
19 | # initialize server
20 | server = ModelServer(clf, clf.predict)
21 |
22 | # optional: add informational endpoints
23 | server.create_info_endpoint('features', data.feature_names)
24 | server.create_info_endpoint('target_labels', data.target_names.tolist())
25 |
26 | # start serving predictions from API
27 | server.serve()
28 | ```
29 |
30 | Behold:
31 |
32 | ```bash
33 | curl -XPOST 'localhost:5000/predictions'\
34 | -H "Content-Type: application/json"\
35 | -d "[[5.6, 2.9, 3.6, 1.3], [4.4, 2.9, 1.4, 0.2], [5.5, 2.4, 3.8, 1.1], [5.0, 3.4, 1.5, 0.2], [5.7, 2.5, 5.0, 2.0]]"
36 | # [1, 0, 1, 0, 2]
37 |
38 | curl -XGET 'localhost:5000/info/model'
39 | # {"penalty": "l2", "tol": 0.0001, "C": 1.0, "classes_": [0, 1, 2], "coef_": [[0.4150, 1.4613, -2.2621, -1.0291], ...], ...}
40 |
41 | curl -XGET 'localhost:5000/info/features'
42 | # ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
43 |
44 | curl -XGET 'localhost:5000/info/target_labels'
45 | # ["setosa", "versicolor", "virginica"]
46 | ```
47 |
48 | ## Advanced example: image classification with Keras
49 |
50 | ServeIt accepts optional pre/postprocessing callback methods, making it easy start serving more complex models. Let's deploy a pre-trained Keras model to a new API endpoint so that we can classify images on the fly. We'll start by loading a ResNet50 model pre-trained on ImageNet:
51 |
52 | ```python
53 | from keras.applications.resnet50 import ResNet50
54 |
55 | # load Resnet50 model pretrained on ImageNet
56 | model = ResNet50(weights='imagenet')
57 | ```
58 |
59 | Next we define methods for loading and preprocessing an image from a URL...
60 | ```python
61 | from keras.preprocessing import image
62 | from keras.applications.resnet50 import preprocess_input
63 | from flask import request
64 | import requests
65 | from serveit.utils import make_serializable, get_bytes_to_image_callback
66 |
67 | # define a loader callback for the API to fetch the relevant data and
68 | # preprocessor callbacks to map to a format expected by the model
69 | def loader():
70 | """Load image from URL, and preprocess for Resnet."""
71 | url = request.args.get('url') # read image URL as a request URL param
72 | response = requests.get(url) # make request to static image file
73 | return response.content
74 |
75 | # get a bytes-to-image callback, resizing the image to 224x224 for ImageNet
76 | bytes_to_image = get_bytes_to_image_callback(image_dims=(224, 224))
77 |
78 | # create a list of different preprocessors to chain multiple steps
79 | preprocessor = [bytes_to_image, preprocess_input]
80 | ```
81 |
82 | ... and import a decoder for postprocessing the model predictions for the API response:
83 | ```python
84 | from keras.applications.resnet50 import decode_predictions
85 | ```
86 |
87 | And now we're ready to start serving our image classifier:
88 | ```python
89 | from serveit.server import ModelServer
90 |
91 | # deploy model to a ModelServer
92 | server = ModelServer(
93 | model,
94 | model.predict,
95 | data_loader=loader,
96 | preprocessor=preprocessor,
97 | postprocessor=decode_predictions,
98 | )
99 |
100 | # start serving
101 | server.serve()
102 | ```
103 |
104 | Behold:
105 | 
106 | ```bash
107 | curl -XPOST 'localhost:5000/predictions?url=https://cdn.pixabay.com/photo/2017/11/14/13/06/kitty-2948404_640.jpg'
108 | # [[["n02123159", "tiger_cat", 0.598746120929718], ["n02127052", "lynx", 0.32807421684265137], ["n02123045", "tabby", 0.042475175112485886]]]
109 | ```
110 |
111 | 
112 | ```bash
113 | curl -XPOST 'localhost:5000/predictions?url=https://cdn.pixabay.com/photo/2012/06/28/08/26/plane-50893_640.jpg'
114 | # [[["n02690373", "airliner", 0.5599709749221802], ["n04592741", "wing", 0.286420077085495], ["n04552348", "warplane", 0.14331381022930145]]]
115 | ```
116 |
117 | You can interact with a live DenseNet121 demo server at `https://imagenet-keras.ryanlee.site/predictions` (source code and sample requests [here](https://github.com/rtlee9/serveit-demo-imagenet-keras/)).
118 |
119 | ## Advanced example: serving with gunicorn
120 | If you have a preference for a specific WSGI HTTP server, you can easily retrieve the underlying app from the server to serve separately. Once you've initialized the ModelServer class, fetch the underlying app in the global scope of a Python script like so:
121 |
122 | ```python
123 | # main.py
124 | app = server.get_app()
125 | ```
126 |
127 | Now all you have to do in your shell (or Procfile) is:
128 | ```bash
129 | # shell
130 | gunicorn main:app
131 |
132 | # Procfile
133 | web: gunicorn main:app
134 | ```
135 |
136 | [View all examples](https://github.com/rtlee9/serveit/tree/master/examples)
137 |
--------------------------------------------------------------------------------
/examples/img/airplane.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/examples/img/airplane.jpg
--------------------------------------------------------------------------------
/examples/img/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/examples/img/cat.jpg
--------------------------------------------------------------------------------
/examples/keras_boston_neural_net.py:
--------------------------------------------------------------------------------
1 | """Sample ServeIt prediction server."""
2 | from sklearn.datasets import load_boston
3 | from serveit.server import ModelServer
4 | from keras.models import Sequential
5 | from keras.layers import Dense
6 |
7 |
8 | def get_model(input_dim):
9 | """Create and compile simple model."""
10 | model = Sequential()
11 | model.add(Dense(100, input_dim=input_dim, activation='sigmoid'))
12 | model.add(Dense(1))
13 | model.compile(loss='mean_squared_error', optimizer='SGD')
14 | return model
15 |
16 | # fit a model on the Boston housing dataset
17 | data = load_boston()
18 | model = get_model(data.data.shape[1])
19 | model.fit(data.data, data.target)
20 |
21 |
22 | def validator(input_data):
23 | """Simple model input validator.
24 |
25 | Validator ensures the input data array is
26 | - two dimensional
27 | - has the correct number of features.
28 | """
29 | global data
30 | # check num dims
31 | if input_data.ndim != 2:
32 | return False, 'Data should have two dimensions.'
33 | # check number of columns
34 | if input_data.shape[1] != data.data.shape[1]:
35 | reason = '{} features required, {} features provided'.format(
36 | data.data.shape[1], input_data.shape[1])
37 | return False, reason
38 | # validation passed
39 | return True, None
40 |
41 | # deploy model to a ModelServer
42 | server = ModelServer(model, model.predict, validator)
43 |
44 | # add informational endpoints
45 | server.create_info_endpoint('features', data.feature_names)
46 |
47 | # start API
48 | server.serve()
49 |
--------------------------------------------------------------------------------
/examples/keras_imagenet_resnet50.py:
--------------------------------------------------------------------------------
1 | """Serve Keras ResNet50 model trained on ImageNet.
2 |
3 | Prediction endpoint, served at `/predictions` takes a URL pointing to an image
4 | and returns a list of class probabilities.
5 | """
6 | from serveit.server import ModelServer
7 | from serveit.utils import get_bytes_to_image_callback
8 |
9 | from keras.applications.resnet50 import ResNet50
10 | from keras.applications.resnet50 import decode_predictions
11 | from keras.applications.resnet50 import preprocess_input
12 |
13 | from flask import request
14 | import requests
15 |
16 | # load Resnet50 model pretrained on ImageNet
17 | model = ResNet50(weights='imagenet')
18 |
19 |
20 | # define a loader callback for the API to fetch the relevant data and
21 | # convert to a format expected by the prediction function
22 | def loader():
23 | """Load image from URL, and preprocess for Resnet."""
24 | url = request.args.get('url') # read image URL as a request URL param
25 | response = requests.get(url) # make request to static image file
26 | return response.content
27 |
28 | # get a bytes-to-image callback, resizing the image to 224x224 for ImageNet
29 | bytes_to_image = get_bytes_to_image_callback(image_dims=(224, 224))
30 |
31 | # deploy model to a ModelServer
32 | server = ModelServer(
33 | model,
34 | model.predict,
35 | data_loader=loader,
36 | preprocessor=[bytes_to_image, preprocess_input],
37 | postprocessor=decode_predictions,
38 | )
39 |
40 | # start API
41 | server.serve()
42 |
--------------------------------------------------------------------------------
/examples/pytorch_imagenet_resnet50.py:
--------------------------------------------------------------------------------
1 | """Serve PyTorch ResNet50 model trained on ImageNet.
2 |
3 | Prediction endpoint, served at `/predictions` takes a URL pointing to an image
4 | and returns a list of class probabilities.
5 | """
6 | from serveit.server import ModelServer
7 | from serveit.utils import get_bytes_to_image_callback
8 |
9 | import torchvision.models as models
10 | import torchvision.transforms as transforms
11 | import torch
12 |
13 | from flask import request
14 | import requests
15 |
16 | # URL for ImageNet labels in JSON
17 | LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
18 |
19 | # parse labels into lookup
20 | labels = {
21 | int(key): value for (key, value)
22 | in requests.get(LABELS_URL).json().items()
23 | }
24 |
25 | # load Resnet50 model pretrained on ImageNet
26 | model = models.resnet50(pretrained=True)
27 | model.eval()
28 |
29 |
30 | # define a loader callback for the API to fetch the relevant data and
31 | # convert to a format expected by the prediction function
32 | def loader():
33 | """Load image from URL, and preprocess for Resnet."""
34 | url = request.args.get('url') # read image URL as a request URL param
35 | response = requests.get(url) # make request to static image file
36 | return response.content
37 |
38 | # define preprocessing callback chain
39 | preprocessor = [
40 | get_bytes_to_image_callback(image_dims=(224, 224)), # convert bytes to image of size 224 x 224
41 | lambda img: torch.from_numpy(img.swapaxes(3, 1).swapaxes(2, 3).copy()) / 255, # convert to tensor, rescale
42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # normalize pixel intensities
43 | torch.autograd.Variable, # convert to PyTorch Variable
44 | ]
45 |
46 |
47 | # define a postprocessor callback for the API to transform the model predictions
48 | def postprocessor(prediction):
49 | """Map prediction tensor to labels."""
50 | prediction = prediction.data.numpy()[0]
51 | top_predictions = prediction.argsort()[-3:][::-1]
52 | return [labels[prediction] for prediction in top_predictions]
53 |
54 | # deploy model to a ModelServer
55 | server = ModelServer(
56 | model,
57 | model,
58 | data_loader=loader,
59 | preprocessor=preprocessor,
60 | postprocessor=postprocessor,
61 | to_numpy=False
62 | )
63 |
64 | # start API
65 | server.serve()
66 |
--------------------------------------------------------------------------------
/examples/sklearn_boston_linear_regression.py:
--------------------------------------------------------------------------------
1 | """Sample ServeIt Scikit-Learn server."""
2 | from sklearn.datasets import load_boston
3 | from sklearn.linear_model import LinearRegression
4 | from serveit.server import ModelServer
5 |
6 | # fit a model on the Boston housing dataset
7 | data = load_boston()
8 | reg = LinearRegression()
9 | reg.fit(data.data, data.target)
10 |
11 |
12 | def validator(input_data):
13 | """Simple model input validator.
14 |
15 | Validator ensures the input data array is
16 | - two dimensional
17 | - has the correct number of features.
18 | """
19 | global data
20 | # check num dims
21 | if input_data.ndim != 2:
22 | return False, 'Data should have two dimensions.'
23 | # check number of columns
24 | if input_data.shape[1] != data.data.shape[1]:
25 | reason = '{} features required, {} features provided'.format(
26 | data.data.shape[1], input_data.shape[1])
27 | return False, reason
28 | # validation passed
29 | return True, None
30 |
31 | # deploy model to a SkLearnServer
32 | server = ModelServer(reg, reg.predict, validator)
33 |
34 | # add informational endpoints
35 | server.create_info_endpoint('features', data.feature_names)
36 |
37 | # start API
38 | server.serve()
39 |
--------------------------------------------------------------------------------
/examples/sklearn_iris_logistic_regression.py:
--------------------------------------------------------------------------------
1 | """Sample ServeIt Scikit-Learn server."""
2 | from sklearn.datasets import load_iris
3 | from sklearn.linear_model import LogisticRegression
4 | from serveit.server import ModelServer
5 |
6 | # fit a model on the Iris dataset
7 | data = load_iris()
8 | clf = LogisticRegression()
9 | clf.fit(data.data, data.target)
10 |
11 |
12 | def validator(input_data):
13 | """Simple model input validator.
14 |
15 | Validator ensures the input data array is
16 | - two dimensional
17 | - has the correct number of features.
18 | """
19 | global data
20 | # check num dims
21 | if input_data.ndim != 2:
22 | return False, 'Data should have two dimensions.'
23 | # check number of columns
24 | if input_data.shape[1] != data.data.shape[1]:
25 | reason = '{} features required, {} features provided'.format(
26 | data.data.shape[1], input_data.shape[1])
27 | return False, reason
28 | # validation passed
29 | return True, None
30 |
31 | # deploy model to a SkLearnServer
32 | server = ModelServer(clf, clf.predict, validator)
33 |
34 | # add informational endpoints
35 | server.create_info_endpoint('features', data.feature_names)
36 | server.create_info_endpoint('target_labels', data.target_names.tolist())
37 |
38 | # start API
39 | server.serve()
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Flask==0.12.3
2 | scikit-learn==0.19.1
3 | scipy==1.0.0
4 | numpy==1.13.3
5 | Flask-RESTful==0.3.6
6 | gunicorn==19.7.1
7 | tensorflow==1.5.0
8 | Keras==2.1.4
9 | meinheld==0.6.1
10 | codacy-coverage==1.3.10
11 | coverage==4.5.1
12 | h5py==2.7.1
13 | Pillow==5.0.0
14 |
--------------------------------------------------------------------------------
/serveit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/serveit/__init__.py
--------------------------------------------------------------------------------
/serveit/config.py:
--------------------------------------------------------------------------------
1 | """Static and environment variables."""
2 | from os import getenv
3 |
4 | WSGI_HOST = getenv('WSGI_HOST', '127.0.0.1')
5 | WSGI_PORT = int(getenv('WSGI_PORT', 5000))
6 |
--------------------------------------------------------------------------------
/serveit/log_utils.py:
--------------------------------------------------------------------------------
1 | """Logger setup."""
2 | import logging
3 | from os import getenv
4 |
5 | logging.basicConfig(level=logging.WARNING)
6 |
7 |
8 | def get_logger(name):
9 | """Get a logger with the specified name."""
10 | logger = logging.getLogger(name)
11 | logger.setLevel(getenv('LOGLEVEL', 'INFO'))
12 | return logger
13 |
--------------------------------------------------------------------------------
/serveit/server.py:
--------------------------------------------------------------------------------
1 | """Base class for serving predictions."""
2 | from flask import Flask, jsonify
3 | from flask_restful import Resource, Api
4 | import numpy as np
5 |
6 | from .utils import make_serializable, json_numpy_loader
7 | from .log_utils import get_logger
8 |
9 | logger = get_logger(__name__)
10 |
11 |
12 | def exception_log_and_respond(exception, logger, message, status_code):
13 | """Log an error and send jsonified respond."""
14 | logger.error(message, exc_info=True)
15 | return make_response(
16 | message,
17 | status_code,
18 | dict(exception_type=type(exception).__name__, exception_message=str(exception)),
19 | )
20 |
21 |
22 | def make_response(message, status_code, details=None):
23 | """Make a jsonified response with specified message and status code."""
24 | response_body = dict(message=message)
25 | if details:
26 | response_body['details'] = details
27 | response = jsonify(response_body)
28 | response.status_code = status_code
29 | return response
30 |
31 |
32 | class ModelServer(object):
33 | """Easy deploy class."""
34 |
35 | def __init__(
36 | self,
37 | model,
38 | predict,
39 | input_validation=lambda data: (True, None),
40 | data_loader=json_numpy_loader,
41 | preprocessor=lambda x: x,
42 | postprocessor=make_serializable,
43 | to_numpy=True):
44 | """Initialize class with prediction function.
45 |
46 | Arguments:
47 | - predict (fn): function that takes a numpy array of features as input,
48 | and returns a prediction of targets
49 | - input_validation (fn): takes a numpy array as input;
50 | returns True if validation passes and False otherwise
51 | - data_loader (fn): reads flask request and returns data preprocessed to be
52 | used in the `predict` method
53 | - postprocessor (fn): transforms the predictions from the `predict` method
54 | """
55 | self.model = model
56 | self.predict = predict
57 | self.data_loader = data_loader
58 | self.preprocessor = preprocessor
59 | self.postprocessor = postprocessor
60 | self.app = Flask('{}_{}'.format(self.__class__.__name__, type(predict).__name__))
61 | self.api = Api(self.app, catch_all_404s=True)
62 | self._create_prediction_endpoint(
63 | data_loader=data_loader,
64 | input_validation=input_validation,
65 | preprocessor=preprocessor,
66 | postprocessor=postprocessor,
67 | to_numpy=to_numpy,
68 | )
69 | logger.info('Model predictions registered to endpoint /predictions (available via POST)')
70 | self.app.logger.setLevel(logger.level) # TODO: separate configuration for API loglevel
71 | self._create_model_info_endpoint()
72 |
73 | def __repr__(self):
74 | """String representation."""
75 | return ''.format(type(self.predict).__name__)
76 |
77 | def _create_prediction_endpoint(
78 | self,
79 | to_numpy=True,
80 | data_loader=json_numpy_loader,
81 | preprocessor=lambda x: x,
82 | input_validation=lambda data: (True, None),
83 | postprocessor=lambda x: x,
84 | make_serializable_post=True):
85 | """Create an endpoint to serve predictions.
86 |
87 | Arguments:
88 | - input_validation (fn): takes a numpy array as input;
89 | returns True if validation passes and False otherwise
90 | - data_loader (fn): reads flask request and returns data preprocessed to be
91 | used in the `predict` method
92 | - postprocessor (fn): transforms the predictions from the `predict` method
93 | """
94 | # copy instance variables to local scope for resource class
95 | predict = self.predict
96 | logger = self.app.logger
97 |
98 | # create restful resource
99 | class Predictions(Resource):
100 | @staticmethod
101 | def post():
102 | # read data from API request
103 | try:
104 | data = data_loader()
105 | except Exception as e:
106 | return exception_log_and_respond(e, logger, 'Unable to fetch data', 400)
107 |
108 | try:
109 | if hasattr(preprocessor, '__iter__'):
110 | for preprocessor_step in preprocessor:
111 | data = preprocessor_step(data)
112 | else:
113 | data = preprocessor(data) # preprocess data
114 | data = np.array(data) if to_numpy else data # convert to numpy
115 | except Exception as e:
116 | return exception_log_and_respond(e, logger, 'Could not preprocess data', 400)
117 |
118 | # sanity check using user defined callback (default is no check)
119 | validation_pass, validation_reason = input_validation(data)
120 | if not validation_pass:
121 | # if validation fails, log the reason code, log the data, and send a 400 response
122 | validation_message = 'Input validation failed with reason: {}'.format(validation_reason)
123 | logger.error(validation_message)
124 | logger.debug('Data: {}'.format(data))
125 | return make_response(validation_message, 400)
126 |
127 | try:
128 | prediction = predict(data)
129 | except Exception as e:
130 | # log exception and return the message in a 500 response
131 | logger.debug('Data: {}'.format(data))
132 | return exception_log_and_respond(e, logger, 'Unable to make prediction', 500)
133 | logger.debug(prediction)
134 | try:
135 | # preprocess data
136 | if hasattr(postprocessor, '__iter__'):
137 | for postprocessor_step in postprocessor:
138 | prediction = postprocessor_step(prediction)
139 | else:
140 | prediction = postprocessor(prediction)
141 |
142 | # cast to serializable types
143 | if make_serializable_post:
144 | return make_serializable(prediction)
145 | else:
146 | return prediction
147 |
148 | except Exception as e:
149 | return exception_log_and_respond(e, logger, 'Postprocessing failed', 500)
150 |
151 | # map resource to endpoint
152 | self.api.add_resource(Predictions, '/predictions')
153 |
154 | def create_info_endpoint(self, name, data):
155 | """Create an endpoint to serve info GET requests."""
156 | # make sure data is serializable
157 | data = make_serializable(data)
158 |
159 | # create generic restful resource to serve static JSON data
160 | class InfoBase(Resource):
161 | @staticmethod
162 | def get():
163 | return data
164 |
165 | def info_factory(name):
166 | """Return an Info derivative resource."""
167 | class NewClass(InfoBase):
168 | pass
169 | NewClass.__name__ = "{}_{}".format(name, InfoBase.__name__)
170 | return NewClass
171 |
172 | path = '/info/{}'.format(name)
173 | self.api.add_resource(info_factory(name), path)
174 | logger.info('Regestered informational resource to {} (available via GET)'.format(path))
175 | logger.debug('Endpoint {} will now serve the following static data:\n{}'.format(path, data))
176 |
177 | def _create_model_info_endpoint(self, path='/info/model'):
178 | """Create an endpoint to serve info GET requests."""
179 | model = self.model
180 |
181 | # parse model details
182 | model_details = {}
183 | for key, value in model.__dict__.items():
184 | model_details[key] = make_serializable(value)
185 |
186 | # create generic restful resource to serve model information as JSON
187 | class ModelInfo(Resource):
188 | @staticmethod
189 | def get():
190 | return model_details
191 |
192 | self.api.add_resource(ModelInfo, path)
193 | self.app.logger.info('Regestered informational resource to {} (available via GET)'.format(path))
194 | self.app.logger.debug('Endpoint {} will now serve the following static data:\n{}'.format(path, model_details))
195 |
196 | def serve(self, host='127.0.0.1', port=5000):
197 | """Serve predictions as an API endpoint."""
198 | from meinheld import server, middleware
199 | # self.app.run(host=host, port=port)
200 | server.listen((host, port))
201 | server.run(middleware.WebSocketMiddleware(self.app))
202 |
203 | def get_app(self):
204 | """Return the underlying Flask app."""
205 | return self.app
206 |
--------------------------------------------------------------------------------
/serveit/utils.py:
--------------------------------------------------------------------------------
1 | """Utility methods."""
2 | import json
3 | from flask import request
4 |
5 | from .log_utils import get_logger
6 |
7 | logger = get_logger(__name__)
8 |
9 |
10 | def make_serializable(data):
11 | """Ensure data is serializable."""
12 | if is_serializable(data):
13 | return data
14 |
15 | # if numpy array convert to list
16 | try:
17 | return data.tolist()
18 | except AttributeError:
19 | pass
20 | except Exception as e:
21 | logger.debug('{} exception ({}): {}'.format(type(e).__name__, e, data))
22 |
23 | # try serializing each child element
24 | if isinstance(data, dict):
25 | return {key: make_serializable(value) for key, value in data.items()}
26 | try:
27 | return [make_serializable(element) for element in data]
28 | except TypeError: # not iterable
29 | pass
30 | except Exception:
31 | logger.debug('Could not serialize {}; converting to string'.format(data))
32 |
33 | # last resort: convert to string
34 | return str(data)
35 |
36 |
37 | def is_serializable(data):
38 | """Check if data is serializable."""
39 | try:
40 | json.dumps(data)
41 | return True
42 | except TypeError:
43 | return False
44 |
45 |
46 | def json_numpy_loader():
47 | """Load data from JSON request and convert to numpy array."""
48 | data = request.get_json()
49 | logger.debug('Received JSON data of length {:,}'.format(len(data)))
50 | return data
51 |
52 |
53 | def get_bytes_to_image_callback(image_dims=(224, 224)):
54 | """Return a callback to process image bytes for ImageNet."""
55 | from keras.preprocessing import image
56 | import numpy as np
57 | from PIL import Image
58 | from io import BytesIO
59 |
60 | def preprocess_image_bytes(data_bytes):
61 | """Process image bytes for ImageNet."""
62 | try:
63 | img = Image.open(BytesIO(data_bytes)) # open image
64 | except OSError as e:
65 | raise ValueError('Please provide a raw image')
66 | img = img.resize(image_dims, Image.ANTIALIAS) # model requires 224x224 pixels
67 | x = image.img_to_array(img) # convert image to numpy array
68 | x = np.expand_dims(x, axis=0) # model expects dim 0 to be iterable across images
69 | return x
70 | return preprocess_image_bytes
71 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | # This flag says to generate wheels that support both Python 2 and Python
3 | # 3. If your code will not run unchanged on both Python 2 and 3, you will
4 | # need to generate separate wheels for each Python version that you
5 | # support.
6 | universal=1
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup module."""
2 | from setuptools import setup, find_packages
3 | from codecs import open
4 | from os import path
5 | import logging
6 |
7 | here = path.abspath(path.dirname(__file__))
8 | logger = logging.getLogger(__name__)
9 |
10 | # Get the long description from the README file
11 | try:
12 | # try converting readme markdown formatting to rst (supported by pypi)
13 | import pypandoc
14 | long_description = pypandoc.convert('README.md', 'rst')
15 | except(IOError, ImportError) as e:
16 | logger.error('{} error: {}'.format(type(e).__name__, e))
17 | long_description = open('README.md').read()
18 |
19 |
20 | # Arguments marked as "Required" below must be included for upload to PyPI.
21 | # Fields marked as "Optional" may be commented out.
22 |
23 | setup(
24 | # This is the name of your project. The first time you publish this
25 | # package, this name will be registered for you. It will determine how
26 | # users can install this project, e.g.:
27 | #
28 | # $ pip install sampleproject
29 | #
30 | # And where it will live on PyPI: https://pypi.org/project/sampleproject/
31 | #
32 | # There are some restrictions on what makes a valid project name
33 | # specification here:
34 | # https://packaging.python.org/specifications/core-metadata/#name
35 | name='ServeIt', # Required
36 |
37 | # Versions should comply with PEP 440:
38 | # https://www.python.org/dev/peps/pep-0440/
39 | #
40 | # For a discussion on single-sourcing the version across setup.py and the
41 | # project code, see
42 | # https://packaging.python.org/en/latest/single_source_version.html
43 | version='0.0.9', # Required
44 |
45 | # This is a one-line description or tagline of what your project does. This
46 | # corresponds to the "Summary" metadata field:
47 | # https://packaging.python.org/specifications/core-metadata/#summary
48 | description='Machine learning prediction serving', # Required
49 |
50 | # This is an optional longer description of your project that represents
51 | # the body of text which users will see when they visit PyPI.
52 | #
53 | # Often, this is the same as your README, so you can just read it in from
54 | # that file directly (as we have already done above)
55 | #
56 | # This field corresponds to the "Description" metadata field:
57 | # https://packaging.python.org/specifications/core-metadata/#description-optional
58 | long_description=long_description, # Optional
59 |
60 | # This should be a valid link to your project's main homepage.
61 | #
62 | # This field corresponds to the "Home-Page" metadata field:
63 | # https://packaging.python.org/specifications/core-metadata/#home-page-optional
64 | url='https://github.com/rtlee9/serveit', # Optional
65 |
66 | # This should be your name or the name of the organization which owns the
67 | # project.
68 | author='Ryan Lee', # Optional
69 |
70 | # This should be a valid email address corresponding to the author listed
71 | # above.
72 | author_email='ryantlee9@gmail.com', # Optional
73 |
74 | # Classifiers help users find your project by categorizing it.
75 | #
76 | # For a list of valid classifiers, see
77 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers
78 | classifiers=[ # Optional
79 | # How mature is this project? Common values are
80 | # 3 - Alpha
81 | # 4 - Beta
82 | # 5 - Production/Stable
83 | 'Development Status :: 3 - Alpha',
84 |
85 | # Indicate who your project is intended for
86 | 'Intended Audience :: Developers',
87 | 'Topic :: Software Development :: Build Tools',
88 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
89 | 'Topic :: Internet :: WWW/HTTP :: WSGI :: Server',
90 |
91 | # Pick your license as you wish
92 | 'License :: OSI Approved :: MIT License',
93 |
94 | # Specify the Python versions you support here. In particular, ensure
95 | # that you indicate whether you support Python 2, Python 3 or both.
96 | # 'Programming Language :: Python :: 2',
97 | 'Programming Language :: Python :: 2.7',
98 | # 'Programming Language :: Python :: 3',
99 | # 'Programming Language :: Python :: 3.4',
100 | # 'Programming Language :: Python :: 3.5',
101 | 'Programming Language :: Python :: 3.6',
102 | 'Programming Language :: Python :: 3.7',
103 | ],
104 |
105 | # This field adds keywords for your project which will appear on the
106 | # project page. What does your project relate to?
107 | #
108 | # Note that this is a string of words separated by whitespace, not a list.
109 | keywords='machine learning model deployment serving API RESTful', # Optional
110 |
111 | # You can just specify package directories manually here if your project is
112 | # simple. Or you can use find_packages().
113 | #
114 | # Alternatively, if you just want to distribute a single Python file, use
115 | # the `py_modules` argument instead as follows, which will expect a file
116 | # called `my_module.py` to exist:
117 | #
118 | # py_modules=["my_module"],
119 | #
120 | packages=find_packages(exclude=['contrib', 'docs', 'tests']), # Required
121 |
122 | # This field lists other packages that your project depends on to run.
123 | # Any package you put here will be installed by pip when your project is
124 | # installed, so they must be valid existing projects.
125 | #
126 | # For an analysis of "install_requires" vs pip's requirements files see:
127 | # https://packaging.python.org/en/latest/requirements.html
128 | install_requires=['flask', 'flask-restful', 'meinheld'], # Optional
129 |
130 | # List additional groups of dependencies here (e.g. development
131 | # dependencies). Users will be able to install these using the "extras"
132 | # syntax, for example:
133 | #
134 | # $ pip install sampleproject[dev]
135 | #
136 | # Similar to `install_requires` above, these must be valid existing
137 | # projects.
138 | extras_require={ # Optional
139 | 'dev': ['check-manifest'],
140 | 'test': ['coverage'],
141 | },
142 |
143 | # To provide executable scripts, use entry points in preference to the
144 | # "scripts" keyword. Entry points provide cross-platform support and allow
145 | # `pip` to create the appropriate form of executable for the target
146 | # platform.
147 | #
148 | # For example, the following would provide a command called `sample` which
149 | # executes the function `main` from this package when invoked:
150 | entry_points={ # Optional
151 | },
152 | )
153 |
--------------------------------------------------------------------------------
/tests/SuccessKid.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/SuccessKid.jpg
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/__init__.py
--------------------------------------------------------------------------------
/tests/keras/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/keras/__init__.py
--------------------------------------------------------------------------------
/tests/keras/test_server.py:
--------------------------------------------------------------------------------
1 | """Test ModelServer with Keras models."""
2 | import unittest
3 | from sklearn.datasets import load_boston
4 |
5 | from tests.test_server import ModelServerTest
6 |
7 |
8 | class BostonKerasNNTest(unittest.TestCase, ModelServerTest):
9 | """Test ModelServer with Keras nerual net fitted on housing data."""
10 |
11 | def setUp(self):
12 | """Unittest set up."""
13 | data = load_boston()
14 | self.model = self.get_model(data.data.shape[1])
15 | super(BostonKerasNNTest, self)._setup(self.model, self.model.fit, data)
16 |
17 | @staticmethod
18 | def get_model(input_dim):
19 | """Create and compile simple model."""
20 | from keras.models import Sequential
21 | from keras.layers import Dense
22 | model = Sequential()
23 | model.add(Dense(100, input_dim=input_dim, activation='sigmoid'))
24 | model.add(Dense(1))
25 | model.compile(loss='mean_squared_error', optimizer='SGD')
26 | return model
27 |
28 |
29 | if __name__ == '__main__':
30 | unittest.main()
31 |
--------------------------------------------------------------------------------
/tests/pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/pytorch/__init__.py
--------------------------------------------------------------------------------
/tests/pytorch/test_server.py:
--------------------------------------------------------------------------------
1 | """Test ModelServer with Keras models."""
2 | import unittest
3 | from sklearn.datasets import load_boston
4 | import torch
5 | from torch.autograd import Variable
6 | from torch import optim
7 | import numpy as np
8 |
9 | from tests.test_server import ModelServerTest
10 |
11 |
12 | class BostonPytorchNNTest(unittest.TestCase, ModelServerTest):
13 | """Test ModelServer with Keras nerual net fitted on housing data."""
14 |
15 | def setUp(self):
16 | """Unittest set up."""
17 | data = load_boston()
18 | self.model = self.get_model(data.data.shape[1], 1)
19 |
20 | # define preprocessing callback chain
21 | preprocessor = [
22 | np.array, # convert to numpy array
23 | torch.from_numpy, # convert to tensor
24 | lambda tensor: tensor.type(torch.FloatTensor), # convert to FloatTensor
25 | torch.autograd.Variable, # convert to PyTorch Variable
26 | ]
27 |
28 | super(BostonPytorchNNTest, self)._setup(
29 | self.model, self.train, data, preprocessor=preprocessor, predict=self.model, to_numpy=False, postprocessor=lambda variable: variable.data.numpy())
30 |
31 | def train(self, x, y):
32 | loss = torch.nn.MSELoss(size_average=True)
33 | optimizer = optim.SGD(self.model.parameters(), lr=1e-8)
34 | x = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=False)
35 | y = Variable(torch.from_numpy(y).type(torch.FloatTensor), requires_grad=False)
36 |
37 | for i in range(100):
38 | # Reset gradient
39 | optimizer.zero_grad()
40 |
41 | # Forward
42 | fx = self.model.forward(x)
43 | output = loss.forward(fx, y)
44 |
45 | # Backward
46 | output.backward()
47 |
48 | # Update parameters
49 | optimizer.step()
50 |
51 | return output.data[0]
52 |
53 | @staticmethod
54 | def get_model(input_dim, output_dim):
55 | """Create and compile simple model."""
56 | model = torch.nn.Sequential()
57 | model.add_module("linear", torch.nn.Linear(input_dim, output_dim, bias=False))
58 | return model
59 |
60 |
61 | if __name__ == '__main__':
62 | unittest.main()
63 |
--------------------------------------------------------------------------------
/tests/sklearn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rtlee9/serveit/d97b5fbe56bec78d6c0193d6fd2ea2a0c1cbafdc/tests/sklearn/__init__.py
--------------------------------------------------------------------------------
/tests/sklearn/test_server.py:
--------------------------------------------------------------------------------
1 | """Test ModelServer with Scikit-Learn models."""
2 | import unittest
3 | from sklearn.datasets import load_iris, load_boston
4 |
5 | from tests.test_server import ModelServerTest
6 |
7 |
8 | class IrisLogisticRegressionTest(unittest.TestCase, ModelServerTest):
9 | """Test ModelServer with LogisticRegression fitted on iris data."""
10 |
11 | def setUp(self):
12 | """Unittest set up."""
13 | from sklearn.linear_model import LogisticRegression
14 | self.model = LogisticRegression()
15 | super(IrisLogisticRegressionTest, self)._setup(self.model, self.model.fit, load_iris())
16 |
17 |
18 | class IrisSvcTest(unittest.TestCase, ModelServerTest):
19 | """Test ModelServer with SVC fitted on iris data."""
20 |
21 | def setUp(self):
22 | """Unittest set up."""
23 | from sklearn.svm import SVC
24 | self.model = SVC()
25 | super(IrisSvcTest, self)._setup(self.model, self.model.fit, load_iris())
26 |
27 |
28 | class IrisRfcTest(unittest.TestCase, ModelServerTest):
29 | """Test ModelServer with RandomForestClassifier fitted on iris data."""
30 |
31 | def setUp(self):
32 | """Unittest set up."""
33 | from sklearn.ensemble import RandomForestClassifier
34 | self.model = RandomForestClassifier()
35 | super(IrisRfcTest, self)._setup(self.model, self.model.fit, load_iris())
36 |
37 |
38 | class BostonLinearRegressionTest(unittest.TestCase, ModelServerTest):
39 | """Test ModelServer with LogisticRegression fitted on housing data."""
40 |
41 | def setUp(self):
42 | """Unittest set up."""
43 | from sklearn.linear_model import LinearRegression
44 | self.model = LinearRegression()
45 | super(BostonLinearRegressionTest, self)._setup(self.model, self.model.fit, load_boston())
46 |
47 |
48 | class BostonSvrTest(unittest.TestCase, ModelServerTest):
49 | """Test ModelServer with SVR fitted on housing data."""
50 |
51 | def setUp(self):
52 | """Unittest set up."""
53 | from sklearn.svm import SVR
54 | self.model = SVR()
55 | super(BostonSvrTest, self)._setup(self.model, self.model.fit, load_boston())
56 |
57 |
58 | class BostonRfrTest(unittest.TestCase, ModelServerTest):
59 | """Test ModelServer with LogisticRegression fitted on housing data."""
60 |
61 | def setUp(self):
62 | """Unittest set up."""
63 | from sklearn.ensemble import RandomForestRegressor
64 | self.model = RandomForestRegressor()
65 | super(BostonRfrTest, self)._setup(self.model, self.model.fit, load_boston())
66 |
67 | if __name__ == '__main__':
68 | unittest.main()
69 |
--------------------------------------------------------------------------------
/tests/test_callback_utils.py:
--------------------------------------------------------------------------------
1 | """Test utility methods."""
2 | import unittest
3 |
4 | from serveit.utils import get_bytes_to_image_callback
5 |
6 |
7 | class CallbackTest(unittest.TestCase):
8 | """Test utility callbacks."""
9 |
10 | def _test_get_bytes_to_image_callback(self, image_dims):
11 | """Convert image bytes to an image for ImageNet."""
12 | with open('tests/SuccessKid.jpg', 'rb') as f:
13 | image_bytes = f.read()
14 | bytes_to_image_callback = get_bytes_to_image_callback(image_dims=image_dims)
15 | image = bytes_to_image_callback(image_bytes)
16 | self.assertEqual((1, *image_dims, 3), image.shape)
17 |
18 | def test_get_bytes_to_image_callback_224_224(self):
19 | """Convert image bytes to 224x224 image for ImageNet."""
20 | self._test_get_bytes_to_image_callback((224, 224))
21 |
22 | def test_get_bytes_to_image_callback_128_128(self):
23 | """Convert image bytes to 128x128 image for ImageNet."""
24 | self._test_get_bytes_to_image_callback((128, 128))
25 |
26 | if __name__ == '__main__':
27 | unittest.main()
28 |
--------------------------------------------------------------------------------
/tests/test_serialization_utils.py:
--------------------------------------------------------------------------------
1 | """Test utility methods."""
2 | import unittest
3 | import numpy as np
4 |
5 | from serveit.utils import is_serializable, make_serializable
6 |
7 |
8 | class SeralizationTest(unittest.TestCase):
9 | """Test serialization."""
10 |
11 | def setUp(self):
12 | """Unittest setup."""
13 | self.serializable_data = [
14 | [1, 2, 3],
15 | ['a', 'b', 'c'],
16 | {'a': 1, 'b': 'c'},
17 | {'a': [1, 2, 3], 'b': 'c'},
18 | ]
19 |
20 | class DummyClass(object):
21 | def __repr__(self):
22 | return 'XpBheIxCcm'
23 |
24 | self.dummy_class = DummyClass()
25 | self.unserializable_data = [
26 | np.array([1, 2, 3]),
27 | np.array(['a', 'b', 'c']),
28 | self.dummy_class,
29 | ]
30 |
31 | def test_is_serializable(self):
32 | """Test is_serializable against example data."""
33 | for data in self.serializable_data:
34 | self.assertTrue(is_serializable(data))
35 | for data in self.unserializable_data:
36 | self.assertFalse(is_serializable(data))
37 |
38 | def test_make_serializable_data_serializable(self):
39 | """make_serializable should return the same object if serializable."""
40 | for data in self.serializable_data:
41 | self.assertEqual(make_serializable(data), data)
42 |
43 | def test_make_serializable_numpy_data(self):
44 | """make_serializable should cast numpy array to list."""
45 | self.assertEqual(make_serializable(np.array([1, 2, 3])), [1, 2, 3])
46 | self.assertEqual(make_serializable(np.array(['a', 'b', 'c'])), ['a', 'b', 'c'])
47 |
48 | def test_make_serializable_object_data(self):
49 | """make_serializable should return an objects __repr__ if no `tolist` method."""
50 | self.assertEqual(make_serializable(self.dummy_class), 'XpBheIxCcm')
51 |
52 |
53 | if __name__ == '__main__':
54 | unittest.main()
55 |
--------------------------------------------------------------------------------
/tests/test_server.py:
--------------------------------------------------------------------------------
1 | """Base ModelServer test class."""
2 | import json
3 | import numpy as np
4 |
5 | from serveit.server import ModelServer
6 |
7 |
8 | class ModelServerTest(object):
9 | """Base class to test the prediction server.
10 |
11 | ModelServerTest should be inherited by a class that has a `model` attribute,
12 | and calls `ModelServerTest._setup()` after instantiation. That class should
13 | also inherit from `unittest.TestCase` to ensure tests are executed.
14 | """
15 |
16 | def _setup(self, model, fit, data, predict=None, **kwargs):
17 | """Set up method to be called before each unit test.
18 |
19 | Arguments:
20 | - fit (callable): model training method; must accept args (data, target)
21 | """
22 | self.data = data
23 | fit(self.data.data, self.data.target)
24 | self.predict = predict or self.model.predict
25 | self.server_kwargs = kwargs
26 | self.server = ModelServer(self.model, self.predict, **kwargs)
27 | self.app = self.server.app.test_client()
28 |
29 | @staticmethod
30 | def _prediction_post(app, data):
31 | """Make a POST request to `app` with JSON body `data`."""
32 | return app.post(
33 | '/predictions',
34 | headers={'Content-Type': 'application/json'},
35 | data=json.dumps(data),
36 | )
37 |
38 | def _get_sample_data(self, n=100):
39 | """Return a sample of size n of self.data."""
40 | sample_idx = np.random.randint(self.data.data.shape[0], size=n)
41 | return self.data.data[sample_idx, :]
42 |
43 | def test_404_media(self):
44 | """Make sure API serves 404 response with JSON."""
45 | response = self.app.get('/fake-endpoint')
46 | self.assertEqual(response.status_code, 404)
47 | response_data_raw = response.get_data()
48 | self.assertIsNotNone(response_data_raw)
49 | response_data = json.loads(response_data_raw)
50 | self.assertGreater(len(response_data), 0)
51 |
52 | def test_features_info_none(self):
53 | """Verify 404 response if '/info/features' endpoint not yet created."""
54 | response = self.app.get('/info/features')
55 | self.assertEqual(response.status_code, 404)
56 |
57 | def test_features_info(self):
58 | """Test features info endpoint."""
59 | self.server.create_info_endpoint('features', self.data.feature_names)
60 | app = self.server.app.test_client()
61 | response = app.get('/info/features')
62 | response_data = json.loads(response.get_data())
63 | self.assertEqual(len(response_data), self.data.data.shape[1])
64 | try:
65 | self.assertCountEqual(response_data, self.data.feature_names)
66 | except AttributeError: # Python 2
67 | self.assertItemsEqual(response_data, self.data.feature_names)
68 |
69 | def test_target_labels_info_none(self):
70 | """Verify 404 response if '/info/target_labels' endpoint not yet created."""
71 | response = self.app.get('/info/target_labels')
72 | self.assertEqual(response.status_code, 404)
73 |
74 | def test_target_labels_info(self):
75 | """Test target labels info endpoint."""
76 | if not hasattr(self.data, 'target_names'):
77 | return
78 | self.server.create_info_endpoint('target_labels', self.data.target_names.tolist())
79 | app = self.server.app.test_client()
80 | response = app.get('/info/target_labels')
81 | response_data = json.loads(response.get_data())
82 | self.assertEqual(len(response_data), self.data.target_names.shape[0])
83 | try:
84 | self.assertCountEqual(response_data, self.data.target_names)
85 | except AttributeError: # Python 2
86 | self.assertItemsEqual(response_data, self.data.target_names)
87 |
88 | def test_predictions(self):
89 | """Test predictions endpoint."""
90 | sample_data = self._get_sample_data()
91 | response = self._prediction_post(self.app, sample_data.tolist())
92 | response_data = json.loads(response.get_data())
93 | self.assertEqual(len(response_data), len(sample_data))
94 | if self.data.target.ndim > 1:
95 | # for multiclass each prediction should be one of the training labels
96 | for prediction in response_data:
97 | self.assertIn(prediction, self.data.target)
98 | else:
99 | # the average regression prediction for a sample of data should be similar
100 | # to the population mean
101 | # TODO: remove variance from this test (i.e., no chance of false negative)
102 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1
103 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1)
104 |
105 | def test_input_validation(self):
106 | """Add simple input validator and make sure it triggers."""
107 | # model input validator
108 | def feature_count_check(data):
109 | try:
110 | # convert PyTorch variables to numpy arrays
111 | data = data.data.numpy()
112 | except:
113 | pass
114 | # check num dims
115 | if data.ndim != 2:
116 | return False, 'Data should have two dimensions.'
117 | # check number of columns
118 | if data.shape[1] != self.data.data.shape[1]:
119 | reason = '{} features required, {} features provided'.format(
120 | data.shape[1], self.data.data.shape[1])
121 | return False, reason
122 | # validation passed
123 | return True, None
124 |
125 | # set up test server
126 | server = ModelServer(self.model, self.predict, feature_count_check, **self.server_kwargs)
127 | app = server.app.test_client()
128 |
129 | # generate sample data
130 | sample_data = self._get_sample_data()
131 |
132 | # post good data, verify 200 response
133 | response = self._prediction_post(app, sample_data.tolist())
134 | self.assertEqual(response.status_code, 200)
135 |
136 | # post bad data (drop a single column), verify 400 response
137 | response = self._prediction_post(app, sample_data[:, :-1].tolist())
138 | self.assertEqual(response.status_code, 400)
139 | response_data = json.loads(response.get_data())
140 | expected_reason = '{} features required, {} features provided'.format(
141 | self.data.data.shape[1] - 1, self.data.data.shape[1])
142 | self.assertIn(expected_reason, response_data['message'])
143 |
144 | def test_model_info(self):
145 | """Test model info endpoint."""
146 | response = self.app.get('/info/model')
147 | response_data = json.loads(response.get_data())
148 | self.assertGreater(len(response_data), 3) # TODO: expand test scope
149 |
150 | def test_data_loader(self):
151 | """Test model prediction with a custom data loader callback."""
152 | # TODO: test alternative request method (e.g., URL params)
153 | # define custom data loader
154 | def read_json_from_dict():
155 | from flask import request
156 | # read data as the value of the 'data' key
157 | data = request.get_json()
158 | return np.array(data['data'])
159 |
160 | # create test client
161 | server = ModelServer(self.model, self.predict, data_loader=read_json_from_dict, **self.server_kwargs)
162 | app = server.app.test_client()
163 |
164 | # generate sample data, and wrap in dict keyed by 'data'
165 | sample_data = self._get_sample_data()
166 | data_dict = dict(data=sample_data.tolist())
167 |
168 | response = self._prediction_post(app, data_dict)
169 | response_data = json.loads(response.get_data())
170 | self.assertEqual(len(response_data), len(sample_data))
171 | if self.data.target.ndim > 1:
172 | # for multiclass each prediction should be one of the training labels
173 | for prediction in response_data:
174 | self.assertIn(prediction, self.data.target)
175 | else:
176 | # the average regression prediction for a sample of data should be similar
177 | # to the population mean
178 | # TODO: remove variance from this test (i.e., no chance of false negative)
179 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1
180 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1)
181 |
182 | def _update_kwargs_item(self, item, key_name, position='first'):
183 | """Prepend a method to the existing preprocessing chain, add to self's kwargs and return."""
184 | kwargs = self.server_kwargs
185 | if key_name in self.server_kwargs:
186 | existing_items = kwargs[key_name]
187 | if not isinstance(existing_items, (list, tuple)):
188 | existing_items = [existing_items]
189 | else:
190 | existing_items = []
191 | if position == 'first':
192 | kwargs[key_name] = [item] + existing_items
193 | if position == 'last':
194 | kwargs[key_name] = existing_items + [item]
195 | return kwargs
196 |
197 | def test_preprocessing(self):
198 | """Test predictions endpoint with custom preprocessing callback."""
199 | # create test client with postprocessor that unraps data from a dict as the value of the 'data' key
200 | kwargs = self._update_kwargs_item(lambda d: d['data'], 'preprocessor')
201 | server = ModelServer(self.model, self.predict, **kwargs)
202 | app = server.app.test_client()
203 |
204 | # generate sample data, and wrap in dict keyed by 'data'
205 | sample_data = self._get_sample_data()
206 | data_dict = dict(data=sample_data.tolist())
207 |
208 | response = self._prediction_post(app, data_dict)
209 | response_data = json.loads(response.get_data())
210 | self.assertEqual(len(response_data), len(sample_data))
211 | if self.data.target.ndim > 1:
212 | # for multiclass each prediction should be one of the training labels
213 | for prediction in response_data:
214 | self.assertIn(prediction, self.data.target)
215 | else:
216 | # the average regression prediction for a sample of data should be similar
217 | # to the population mean
218 | # TODO: remove variance from this test (i.e., no chance of false negative)
219 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1
220 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1)
221 |
222 | def test_preprocessing_list(self):
223 | """Test predictions endpoint with chained preprocessing callbacks."""
224 | # create test client with postprocessor that unraps data from a dict as the value of the 'data' key
225 | kwargs = self._update_kwargs_item(lambda d: d['data'], 'preprocessor')
226 | kwargs['preprocessor'] = [lambda d: d['data2']] + kwargs['preprocessor']
227 | server = ModelServer(
228 | self.model,
229 | self.predict,
230 | **kwargs
231 | )
232 | app = server.app.test_client()
233 |
234 | # generate sample data, and wrap in dict keyed by 'data'
235 | sample_data = self._get_sample_data()
236 | data_dict = dict(data2=dict(data=sample_data.tolist()))
237 |
238 | response = self._prediction_post(app, data_dict)
239 | response_data = json.loads(response.get_data())
240 | self.assertEqual(len(response_data), len(sample_data))
241 | if self.data.target.ndim > 1:
242 | # for multiclass each prediction should be one of the training labels
243 | for prediction in response_data:
244 | self.assertIn(prediction, self.data.target)
245 | else:
246 | # the average regression prediction for a sample of data should be similar
247 | # to the population mean
248 | # TODO: remove variance from this test (i.e., no chance of false negative)
249 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1
250 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1)
251 |
252 | def test_postprocessing(self):
253 | """Test predictions endpoint with custom postprocessing callback."""
254 | # create test client with postprocessor that wraps predictions in a dictionary
255 | kwargs = self._update_kwargs_item(lambda x: dict(prediction=x.tolist()), 'postprocessor', 'last')
256 | server = ModelServer(self.model, self.predict, **kwargs)
257 | app = server.app.test_client()
258 |
259 | # generate sample data
260 | sample_data = self._get_sample_data()
261 |
262 | response = self._prediction_post(app, sample_data.tolist())
263 | response_data = json.loads(response.get_data())['prediction'] # predictions are nested under 'prediction' key
264 | self.assertEqual(len(response_data), len(sample_data))
265 | if self.data.target.ndim > 1:
266 | # for multiclass each prediction should be one of the training labels
267 | for prediction in response_data:
268 | self.assertIn(prediction, self.data.target)
269 | else:
270 | # the average regression prediction for a sample of data should be similar
271 | # to the population mean
272 | # TODO: remove variance from this test (i.e., no chance of false negative)
273 | pred_pct_diff = np.array(response_data).mean() / self.data.target.mean() - 1
274 | self.assertAlmostEqual(pred_pct_diff / 1e4, 0, places=1)
275 |
276 | def test_get_app(self):
277 | """Make sure get_app method returns the same app."""
278 | self.assertEqual(self.server.get_app(), self.server.app)
279 |
280 | def test_400_no_content_type(self):
281 | """Check 400 response if no Content-Type header specified."""
282 | response = self.app.post(
283 | '/predictions',
284 | )
285 | self.assertEqual(response.status_code, 400)
286 | response_body = json.loads(response.get_data())
287 | self.assertEqual(response_body['message'], 'Unable to fetch data')
288 | self.assertGreaterEqual(len(response_body['details']), 2)
289 |
--------------------------------------------------------------------------------