├── requirements.txt ├── model_files └── svc_model.pickle ├── README.md ├── iris_train.py ├── ml_model_abc.py ├── tests ├── ml_model_abc_tests.py └── iris_model_tests.py ├── LICENSE ├── iris_predict.py ├── .gitignore └── blog_post └── post.md /requirements.txt: -------------------------------------------------------------------------------- 1 | contextlib2==0.5.5 2 | numpy==1.16.2 3 | schema==0.7.0 4 | scikit-learn==0.20.3 5 | scipy==1.2.1 6 | -------------------------------------------------------------------------------- /model_files/svc_model.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schmidtbri/simple-ml-model-abc/HEAD/model_files/svc_model.pickle -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # simple-ml-model-abc 2 | Code demonstrating a simple Machine Learning model abstract base class and its uses. 3 | 4 | This code goes along with this [blog post](https://towardsdatascience.com/a-simple-ml-model-base-class-ab40e2febf13). 5 | -------------------------------------------------------------------------------- /iris_train.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets 2 | from sklearn import svm 3 | import pickle 4 | import os 5 | 6 | 7 | def train(): 8 | """ This code is from: https://scikit-learn.org/stable/tutorial/basic/tutorial.html """ 9 | iris = datasets.load_iris() 10 | 11 | svm_model = svm.SVC(gamma=0.001, C=100.0) 12 | 13 | svm_model.fit(iris.data[:-1], iris.target[:-1]) 14 | 15 | dir_path = os.path.dirname(os.path.realpath(__file__)) 16 | file = open(os.path.join(dir_path, "model_files", "svc_model.pickle"), 'wb') 17 | pickle.dump(svm_model, file) 18 | file.close() 19 | 20 | 21 | if __name__ == "__main__": 22 | train() 23 | -------------------------------------------------------------------------------- /ml_model_abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class MLModel(ABC): 5 | """ An abstract base class for ML model prediction code """ 6 | @property 7 | @abstractmethod 8 | def input_schema(self): 9 | raise NotImplementedError() 10 | 11 | @property 12 | @abstractmethod 13 | def output_schema(self): 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def __init__(self): 18 | raise NotImplementedError() 19 | 20 | @abstractmethod 21 | def predict(self, data): 22 | self.input_schema.validate(data) 23 | 24 | 25 | class MLModelException(Exception): 26 | """ Exception type used to raise exceptions within MLModel derived classes """ 27 | def __init__(self,*args,**kwargs): 28 | Exception.__init__(self, *args, **kwargs) 29 | -------------------------------------------------------------------------------- /tests/ml_model_abc_tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | 5 | # this adds the project root to the PYTHONPATH if its not already there, it makes it easier to run the unit tests 6 | if os.path.dirname(os.path.dirname(os.path.abspath(__file__))) not in sys.path: 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | 9 | from ml_model_abc import MLModelException 10 | 11 | 12 | class TestMLModel(unittest.TestCase): 13 | def test1(self): 14 | """ testing the __init__() method """ 15 | # arrange, act 16 | exception_raised = False 17 | exception = None 18 | try: 19 | raise MLModelException("Testing raising MLModelException.") 20 | except MLModelException as e: 21 | exception_raised = True 22 | exception = e 23 | 24 | # assert 25 | self.assertTrue(type(exception) is MLModelException) 26 | self.assertTrue(exception_raised) 27 | 28 | 29 | if __name__ == '__main__': 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 schmidtbri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /iris_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from schema import Schema, Or 4 | from numpy import array 5 | 6 | from ml_model_abc import MLModel 7 | 8 | 9 | class IrisSVCModel(MLModel): 10 | """ A demonstration of how to use """ 11 | input_schema = Schema({'sepal_length': float, 12 | 'sepal_width': float, 13 | 'petal_length': float, 14 | 'petal_width': float}) 15 | 16 | # the output of the model will be one of three strings 17 | output_schema = Schema({'species': Or("setosa", "versicolor", "virginica")}) 18 | 19 | def __init__(self): 20 | dir_path = os.path.dirname(os.path.realpath(__file__)) 21 | file = open(os.path.join(dir_path, "model_files", "svc_model.pickle"), 'rb') 22 | self._svm_model = pickle.load(file) 23 | file.close() 24 | 25 | def predict(self, data): 26 | # calling the super method to validate against the input_schema 27 | super().predict(data=data) 28 | 29 | # converting the incoming dictionary into a numpy array that can be accepted by the scikit-learn model 30 | X = array([data["sepal_length"], data["sepal_width"], data["petal_length"], data["petal_width"]]).reshape(1, -1) 31 | 32 | # making the prediction and extracting the result from the array 33 | y_hat = int(self._svm_model.predict(X)[0]) 34 | 35 | #converting the prediction into a string that will match the output schema of the model 36 | # this list will map the output of the scikit-learn model to the output string expected by the schema 37 | targets = ['setosa', 'versicolor', 'virginica'] 38 | species = targets[y_hat] 39 | 40 | return {"species": species} 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | /.idea 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | -------------------------------------------------------------------------------- /tests/iris_model_tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | from sklearn import svm 5 | from schema import SchemaError 6 | import json 7 | 8 | # this adds the project root to the PYTHONPATH if its not already there, it makes it easier to run the unit tests 9 | if os.path.dirname(os.path.dirname(os.path.abspath(__file__))) not in sys.path: 10 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 11 | 12 | from iris_predict import IrisSVCModel 13 | 14 | 15 | class TestIrisSVCModel(unittest.TestCase): 16 | def test1(self): 17 | """ testing the __init__() method """ 18 | # arrange, act 19 | model = IrisSVCModel() 20 | 21 | # assert 22 | self.assertTrue(type(model._svm_model) is svm.SVC) 23 | 24 | def test2(self): 25 | """ testing the input schema with wrong data """ 26 | # arrange 27 | data = {'name': 'Sue', 'age': '28', 'gender': 'Squid'} 28 | 29 | # act 30 | exception_raised = False 31 | try: 32 | validated_data = IrisSVCModel.input_schema.validate(data) 33 | except SchemaError as e: 34 | exception_raised = True 35 | 36 | # assert 37 | self.assertTrue(exception_raised) 38 | 39 | def test3(self): 40 | """ testing the input schema with correct data """ 41 | # arrange 42 | data = {'sepal_length': 1.0, 43 | 'sepal_width': 1.0, 44 | 'petal_length': 1.0, 45 | 'petal_width': 1.0} 46 | 47 | # act 48 | exception_raised = False 49 | try: 50 | validated_data = IrisSVCModel.input_schema.validate(data) 51 | except SchemaError as e: 52 | exception_raised = True 53 | 54 | # assert 55 | self.assertFalse(exception_raised) 56 | 57 | def test4(self): 58 | """ testing the output schema with incorrect data """ 59 | # arrange 60 | data = {'species': 1.0} 61 | 62 | # act 63 | exception_raised = False 64 | try: 65 | validated_data = IrisSVCModel.output_schema.validate(data) 66 | except SchemaError as e: 67 | exception_raised = True 68 | 69 | # assert 70 | self.assertTrue(exception_raised) 71 | 72 | def test5(self): 73 | """ testing the output schema with correct data """ 74 | # arrange 75 | data = {'species': 'setosa'} 76 | 77 | # act 78 | exception_raised = False 79 | try: 80 | validated_data = IrisSVCModel.output_schema.validate(data) 81 | except SchemaError as e: 82 | exception_raised = True 83 | 84 | # assert 85 | self.assertFalse(exception_raised) 86 | 87 | def test6(self): 88 | """ testing the predict() method throws schems exception when given bad data """ 89 | # arrange 90 | model = IrisSVCModel() 91 | 92 | # act 93 | exception_raised = False 94 | try: 95 | prediction = model.predict({'name': 'Sue', 'age': '28', 'gender': 'Squid'}) 96 | except SchemaError as e: 97 | exception_raised = True 98 | 99 | # assert 100 | self.assertTrue(exception_raised) 101 | 102 | def test7(self): 103 | """ testing the predict() method with good data""" 104 | # arrange 105 | model = IrisSVCModel() 106 | 107 | # act 108 | prediction = model.predict(data={'sepal_length': 1.0, 109 | 'sepal_width': 1.0, 110 | 'petal_length': 1.0, 111 | 'petal_width': 1.0}) 112 | 113 | exception_raised = False 114 | try: 115 | IrisSVCModel.output_schema.validate(prediction) 116 | except SchemaError as e: 117 | exception_raised = True 118 | 119 | # assert 120 | self.assertFalse(exception_raised) 121 | self.assertTrue(type(prediction) is dict) 122 | self.assertTrue(prediction["species"] == 'setosa') 123 | self.assertFalse(exception_raised) 124 | 125 | def test8(self): 126 | """ testing JSON schema generation """ 127 | # arrange 128 | model = IrisSVCModel() 129 | 130 | # act 131 | json_schema = json.dumps(model.input_schema.json_schema("https://example.com/my-schema.json")) 132 | 133 | # assert 134 | print(json_schema) 135 | self.assertTrue(type(json_schema) is str) 136 | 137 | 138 | if __name__ == '__main__': 139 | unittest.main() 140 | -------------------------------------------------------------------------------- /blog_post/post.md: -------------------------------------------------------------------------------- 1 | Title: A Simple ML Model Base Class 2 | Date: 2019-04-02 09:20 3 | Category: Blog 4 | Slug: a-simple-ml-model-base-class 5 | Authors: Brian Schmidt 6 | Summary: When creating software it is often useful to write abstract classes to help define different interfaces that classes can implement and inherit from. By creating a base class, a standard can be defined that simplifies the design of the whole system and clarifies every decision moving forward. 7 | 8 | When creating software it is often useful to write abstract classes to 9 | help define different interfaces that classes can implement and inherit 10 | from. By creating a base class, a standard can be defined that 11 | simplifies the design of the whole system and clarifies every decision 12 | moving forward. 13 | 14 | The integration of ML models with other software components is often 15 | complicated and can benefit greatly from using an Object Oriented 16 | approach. Recently, I've been seeing this problem solved in many 17 | different ways, so I decided to try to implement my own solution. 18 | 19 | In this post I will describe a simple implementation of a base class for 20 | Machine Learning Models. This post will focus on making predictions with 21 | ML models, and integrating ML models with other software components. 22 | Training code will not be shown to keep the code simple. The code in 23 | this post will be written in Python, if you aren't familiar with 24 | abstract base classes in Python, 25 | [here](https://www.python-course.eu/python3_abstract_classes.php) 26 | is a good place to learn. 27 | 28 | ## Scikit-learn's Approach to Base Classes 29 | 30 | The most well known ML software package in python is scikit-learn, and 31 | it provides a set of abstract base classes in the [base.py 32 | module](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/base.py). 33 | The scikit-learn API is a great place to learn about machine learning 34 | software engineering in general, but in this case we want to focus on 35 | it's approach to base classes for making predictions with ML models. 36 | 37 | Scikit-learn defines an abstract base class called Estimator which is 38 | meant to be the base class for any class that is able to learn from a 39 | data set, a class that derives from Estimator must implement a "fit" 40 | method. Scikit-learn also defines a Predictor base class that is meant 41 | to be the base class for any class that is able to infer from learned 42 | parameters when presented with new data, a class that derives from 43 | Predictor must implement a "predict" method. These two bases classes are 44 | some of the most commonly used abstractions in the Scikit-learn package. 45 | By defining these base classes, the Scikit-learn project provides a 46 | strong base for coding ML algorithms. 47 | 48 | These two interfaces are broad enough to take us far, but what about 49 | serialization and deserialization? ML models need to be loaded from 50 | storage before they can be used. On this front scikit-learn is mostly 51 | silent, and no standard interface for hiding the details of model 52 | serialization and deserialization is provided. Also, what if we need to 53 | publish schema information about the input and output data that a model 54 | needs for scoring? Scikit-learn does not provide a way to do this 55 | either, since it uses numpy arrays for input and output. 56 | 57 | Because of these factors, using Scikit-learn's API is not necessarily 58 | the best way to integrate ML models with other software components. 59 | Integrating a Scikit-learn model with other software components by using 60 | the Scikit-learn API exposes internal details about how the model is 61 | serialized and how information is passed into the model. For example, if 62 | a Data Scientist hands over a scikit-learn model in a pickled file along 63 | with some code, a software engineer would have to be familiar with how 64 | to deserialize the model object and how to structure a Numpy array in 65 | such a way that it will be accepted by the model's predict() method. The 66 | best way to solve this problem is to hide these implementation details 67 | behind an interface. 68 | 69 | In summary, to simplify the use of ML models within production systems, 70 | it would be useful to solve a couple of issues: 71 | 72 | - How to consistently and transparently send data to the model 73 | 74 | - How to load serialized model assets when instantiating a model 75 | 76 | - How to publish input and output data schema information 77 | 78 | ## Some Solutions 79 | 80 | Over the last few years, a few big tech companies have been developing 81 | proprietary in-house machine learning infrastructure and software. Some 82 | of these companies sell access to their ML platform and others have 83 | published details about their approach to ML infrastructure. Also, there 84 | have been a few open source projects that seek to simplify the 85 | deployment of ML models to production systems. In this section I will 86 | describe some solutions that have emerged recently for the problems 87 | described above. 88 | 89 | ### AWS Sagemaker 90 | 91 | AWS Sagemaker is a platform for training and deploying ML models within 92 | the AWS ecosystem. The platform has several ready-made ML algorithms 93 | that can be leveraged without writing a lot of code. However, a way to 94 | deploy custom ML code to the platform is provided. To deploy a 95 | prediction endpoint on top of the Sagemaker service, a Python Flask 96 | application with a "/ping" and "/invocations" endpoints must be created 97 | and deployed within a Docker container. 98 | 99 | In the Sagemaker example published 100 | [here](https://github.com/awslabs/amazon-sagemaker-examples/blob/35941a33425b3a441275abc7243eb1f959a584e4/advanced_functionality/scikit_bring_your_own/container/decision_trees/predictor.py#L24-L43), 101 | we can see the recommended way to run the model prediction code within 102 | the Flask application. In the example, the scikit-learn model object is 103 | deserialized and saved as a class property, and the model is then 104 | accessed by the "predict" method. This implementation does not provide a 105 | way to publish schema metadata about the model and does not enforce any 106 | specific implementation on the model code. The AWS Sagemaker library 107 | does not provide a base class to help write the model code. 108 | 109 | ### Facebook 110 | 111 | Facebook published a blog post about their ML systems 112 | [here](https://code.fb.com/ml-applications/introducing-fblearner-flow-facebook-s-ai-backbone/). 113 | The FBLearner Flow system is made up of workflows and operators. A 114 | workflow is a single unit of work with a specific set of inputs and 115 | outputs, a workflow is made up of operators which do simple operations 116 | on data. The blog post shows how to train a Decision Tree model on the 117 | iris data set. The blog post does not provide many implementation 118 | details about their internal Python packages. An interesting part of the 119 | approach taken is the fact that schema metadata is attached to every 120 | workflow created, ensuring type safety at runtime. There are not details 121 | about loading and storing model assets. Facebook's FBFlow Python package 122 | does not use base classes that developers can inherit from to write 123 | code, but uses function annotations to attach metadata to ML model code. 124 | 125 | ### Uber 126 | 127 | Uber published a blog post about their approach to custom ML models 128 | [here](https://eng.uber.com/michelangelo-pyml/). Uber's 129 | PyML package is used to deploy ML models that are not natively supported 130 | by Uber's Michelangelo ML platform, which is described 131 | [here](https://eng.uber.com/michelangelo/). The PyML 132 | package does not specify how to write model training code, but does 133 | provide a base class for writing ML model prediction code. The base 134 | class is called DataFrameModel. The interface is very simple, it only 135 | has two methods: the \_\_init\_\_() method, and the predict() method. 136 | The model assets are required to be deserialized in the class 137 | constructor and all prediction code is in the predict method of the 138 | class. 139 | 140 | The DataFrameModel interface requires the use of Pandas dataframes or 141 | tensors when giving data to the model for prediction. This is a design 142 | decision can backfire because there is no way to tell the user of the 143 | model how to structure the input data to the model. However, the use of 144 | the \_\_init\_\_() method for loading model assets helps to hide the 145 | complexity of the model from the user. Also, by using base classes that 146 | must be inherited from in order to deploy code to the production 147 | systems, certain requirements can be more easily checked. 148 | 149 | ### Seldon Core 150 | 151 | Seldon Core is an open source project for hosting ML models. It supports 152 | custom Python models, as described 153 | [here](https://docs.seldon.io/projects/seldon-core/en/latest/python/python_component.html). 154 | The model code is required to be in a Python class with an 155 | \_\_init\_\_() method and a predict() method, it follows Uber's design 156 | closely but does not use an abstract base class to enforce the 157 | interface. Another difference is that Seldon allows the model class to 158 | return results in several different ways, and not just in Pandas 159 | dataframes. Seldon also allows the model class to return column name 160 | metadata for the model inputs, but no type metadata. 161 | 162 | A Simple ML Model Base Class 163 | ============================ 164 | 165 | NOTE: All of the code shown in this section can be found in [this 166 | Github 167 | repository](https://github.com/schmidtbri/simple-ml-model-abc). 168 | 169 | In this section I will present a simple abstract base class that 170 | combines the strengths of the approaches shown above into one abstract 171 | base class for ML models. I will also explain the reasoning behind the 172 | design. 173 | 174 | Here is the code for the abstract base class: 175 | 176 | ```python 177 | class MLModel(ABC): 178 | """ An abstract base class for ML model prediction code """ 179 | @property 180 | @abstractmethod 181 | def input_schema(self): 182 | raise NotImplementedError() 183 | 184 | @property 185 | @abstractmethod 186 | def output_schema(self): 187 | raise NotImplementedError() 188 | 189 | @abstractmethod 190 | def __init__(self): 191 | raise NotImplementedError() 192 | 193 | @abstractmethod 194 | def predict(self, data): 195 | self.input_schema.validate(data) 196 | ``` 197 | 198 | The code looks very similar to Uber's and Seldon Core's approach. The 199 | model file deserialization code is still expected to be implemented in 200 | the \_\_init\_\_() method, and the prediction code is still expected to 201 | be in the predict() method. Any model that needs to be used by other 202 | software packages is expected to derive from the MLModel abstract base 203 | class and implement these two methods. 204 | 205 | However, there are some differences. The input to the predict method is 206 | not expected to be of any particular type, it can be any Python type as 207 | long as the input data is packaged into a single input parameter called 208 | "data". This is different from Seldon Core's and Uber's approach which 209 | required Numpy arrays and Pandas arrays. 210 | 211 | Another difference is that the base class shown above requires the model 212 | creator to attach schema metadata to their implementation. The base 213 | class has two extra properties that are not present in the Seldon Core 214 | and Uber implementations: the "input\_schema" and "output\_schema" 215 | properties are meant to publish the schema of the data that the model 216 | will accept in the predict method and the shema of the model that the 217 | model will output from the predict method. To do this, I will use the 218 | python schema package, but there are many options for writing and 219 | enforcing schema, for example the marshmallow-schema and schematics 220 | python packages. 221 | 222 | We also need to define a way for a model creator to raise exceptions. 223 | For this we can write a simple custom Exception: 224 | 225 | ```python 226 | class MLModelException(Exception): 227 | """ Exception type for use within MLModel derived classes """ 228 | def __init__(self, *args, **kwargs): 229 | Exception.__init__(self, *args, **kwargs) 230 | ``` 231 | 232 | Using the Base Class 233 | ==================== 234 | 235 | This blog post deals purely with the ML code that will be used for 236 | predicting in production and not with the model training code. However, 237 | we still need to have a model to work with. Here's a simple scikit-learn 238 | model training script: 239 | 240 | ```python 241 | iris = datasets.load_iris() 242 | svm_model = svm.SVC(gamma=0.001, C=100.0) 243 | svm_model.fit(iris.data[:-1], iris.target[:-1]) 244 | 245 | dir_path = os.path.dirname(os.path.realpath(__file__)) 246 | file = open(os.path.join(dirpath, "model_files", "svc_model.pickle"), 'wb') 247 | pickle.dump(svm_model, file) 248 | file.close() 249 | ``` 250 | 251 | Now that we have a trained model, we can write the class that will inherit from MLModel and make predictions: 252 | 253 | ```python 254 | class IrisSVCModel(MLModel): 255 | """ A demonstration of how to use """ 256 | input_schema = Schema({'sepal_length': float, 257 | 'sepal_width': float, 258 | 'petal_length': float, 259 | 'petal_width': float}) 260 | 261 | # the output of the model will be one of three strings 262 | output_schema = Schema({'species': Or("setosa", 263 | "versicolor", 264 | "virginica")}) 265 | 266 | def __init__(self): 267 | dir_path = os.path.dirname(os.path.realpath(__file__)) 268 | file = open(os.path.join(dir_path, "model_files", "svc_model.pickle"), 'rb') 269 | self._svm_model = pickle.load(file) 270 | file.close() 271 | 272 | def predict(self, data): 273 | # calling the super method to validate against the 274 | # input_schema 275 | super().predict(data=data) 276 | 277 | # converting the incoming dictionary into a numpy array 278 | # that can be accepted by the scikit-learn model 279 | X = array([data["sepal_length"], 280 | data["sepal_width"], 281 | data["petal_length"], 282 | data["petal_width"]]).reshape(1, -1) 283 | 284 | # making the prediction 285 | y_hat = int(self._svm_model.predict(X)[0]) 286 | 287 | # converting the prediction into a string that will match 288 | # the output schema of the model, this list will map the 289 | # output of the scikit-learn model to the string expected by 290 | # the output schema 291 | targets = ['setosa', 'versicolor', 'virginica'] 292 | species = targets[y_hat] 293 | 294 | return {"species": species} 295 | ``` 296 | 297 | One useful thing about using the schema package for building the input 298 | and output schemas of the model is that it supports exporting the schema 299 | in the JSON schema format: 300 | 301 | ```python 302 | >>> model = IrisSVCModel() 303 | >>> print(json.dumps(model.input_schema.json_schema("https://example.com/my-schema.json"))) 304 | {"type": "object", "properties": {"sepal_length": {"type": "number"}, "sepal_width": {"type": "number"}, 305 | ... 306 | ... 307 | ``` 308 | 309 | ## Conclusion 310 | 311 | In this post I showed a few different approaches to deploying ML model 312 | code to production systems. I also showed an implementation of a Python 313 | base class that brings together the best features of the different 314 | approaches discussed. In conclusion I will discuss some of the benefits 315 | of the approach I sketched out above. 316 | 317 | The MLModel base class has very few dependencies. it does not require 318 | the model creator to use Pandas, numpy, or any other Python package to 319 | transfer data to the model. This also means that it does not force the 320 | user of the model to know any internal implementation details about the 321 | model. On the other hand, Uber's solution requires that the user of the 322 | model know how to work with Pandas dataframes. However, if the model 323 | creator still wishes to accept numpy arrays or Pandas dataframes to 324 | their model, the MLModel base class shown above still allows this. 325 | 326 | By using python dictionaries for model input and output, the model is 327 | easier to use. There is no need to understand how to use numpy arrays or 328 | Pandas dataframes, remember the order of the columns, or know how the 329 | output columns areencoded in order to use the model. 330 | 331 | By stating the input and output schemas of a model programmatically, it 332 | is possible to compare different model's schemas through automated 333 | tools. This can be useful when tracking model changes across many 334 | different versions of a model. Facebook's approach allows schema 335 | metadata to be attached to ML models, but no other approach discussed 336 | above does this. 337 | 338 | By hiding the deserialization code behind the \_\_init\_\_() method, the 339 | deserialization technique or the storage location of model files can be 340 | changed without affecting the code that uses the model. In the same way, 341 | I can replace the code in the predict() method without affecting the 342 | user of the model, as long as the input and output schemas remain the 343 | same. This is the benefit of using Object Oriented Programming to hide 344 | implementation details from users of your code. 345 | 346 | There are some other improvements that can be added to the MLModel base 347 | class shown in this post, but these will be shown in a later blog post. 348 | --------------------------------------------------------------------------------