├── tests ├── __init__.py ├── iris_train_test.py └── iris_predict_test.py ├── docs ├── source │ ├── _static │ │ └── .gitkeep │ ├── ml_model_abc.rst │ ├── index.rst │ ├── iris_train.rst │ ├── iris_predict.rst │ └── conf.py ├── Makefile └── make.bat ├── iris_model ├── model_files │ └── svc_model.pickle ├── __init__.py ├── iris_train.py └── iris_predict.py ├── README.md ├── requirements.txt ├── Makefile ├── LICENSE ├── setup.py ├── .gitignore ├── ml_model_abc.py └── blog_post └── post.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_static/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /iris_model/model_files/svc_model.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schmidtbri/ml-model-abc-improvements/HEAD/iris_model/model_files/svc_model.pickle -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml-model-abc-improvements 2 | Code showing how to add metadata, documentation, and versioning to an ML model base class. 3 | 4 | The code in this repository goes along with [this blog post](https://towardsdatascience.com/improving-the-mlmodel-base-class-eded137629bd). 5 | -------------------------------------------------------------------------------- /docs/source/ml_model_abc.rst: -------------------------------------------------------------------------------- 1 | ML Model Abstract Base Class 2 | ============================ 3 | 4 | .. autoclass:: ml_model_abc.MLModel 5 | :members: 6 | 7 | .. automethod:: __init__ 8 | 9 | .. autoclass:: ml_model_abc.MLModelException 10 | :members: 11 | 12 | .. autoclass:: ml_model_abc.MLModelSchemaValidationException 13 | :members: -------------------------------------------------------------------------------- /iris_model/__init__.py: -------------------------------------------------------------------------------- 1 | __version_info__ = (0, 1, 0) 2 | __version__ = ".".join([str(n) for n in __version_info__]) 3 | 4 | # a display name for the model 5 | __display_name__ = "Iris Model" 6 | 7 | # returning the package name as the qualified name for the model 8 | __qualified_name__ = __name__.split(".")[0] 9 | 10 | # a description of the model 11 | __description__ = "A machine learning model for predicting the species of a flower based on its measurements." 12 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to mlmodel-base-class-improvements's documentation! 2 | =========================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Iris Model: 7 | 8 | iris_predict.rst 9 | iris_train.rst 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: MLModel Base Class: 14 | 15 | ml_model_abc.rst 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/source/iris_train.rst: -------------------------------------------------------------------------------- 1 | Iris Model Train Module 2 | ======================= 3 | 4 | Iris Training Code CLI Documentation 5 | ************************************ 6 | This documentation shows how to train the Iris model through the CLI interface. 7 | 8 | .. argparse:: 9 | :module: iris_model.iris_train 10 | :func: argument_parser 11 | :prog: iris_train 12 | 13 | 14 | Iris Training Code Documentation 15 | ******************************** 16 | This documentation shows how to train the Iris model by importing the code directly. 17 | 18 | .. automodule:: iris_model.iris_train 19 | :members: 20 | -------------------------------------------------------------------------------- /docs/source/iris_predict.rst: -------------------------------------------------------------------------------- 1 | Iris Model Predict Module 2 | ========================= 3 | 4 | Iris Model Input Schema 5 | *********************** 6 | 7 | This section describes the input data structure for the predict() method. 8 | 9 | .. jsonschema:: ../build/input_schema.json 10 | 11 | Iris Model Output Schema 12 | ************************ 13 | 14 | This section describes the output data structure of the predict() method. 15 | 16 | .. jsonschema:: ../build/output_schema.json 17 | 18 | Iris Model Prediction Code 19 | ************************** 20 | 21 | .. autoclass:: iris_model.iris_predict.IrisModel 22 | :members: 23 | 24 | .. automethod:: __init__ -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | 21 | view-docs: 22 | open ./build/html/index.html -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | atomicwrites==1.3.0 3 | attrs==19.1.0 4 | Babel==2.6.0 5 | certifi==2019.3.9 6 | chardet==3.0.4 7 | contextlib2==0.5.5 8 | docutils==0.14 9 | ghp-import==0.5.5 10 | idna==2.8 11 | imagesize==1.1.0 12 | Jinja2==2.11.3 13 | joblib==0.13.2 14 | jsonpointer==2.0 15 | MarkupSafe==1.1.1 16 | more-itertools==7.0.0 17 | numpy==1.16.3 18 | packaging==19.0 19 | pluggy==0.11.0 20 | py==1.10.0 21 | Pygments==2.7.4 22 | pyparsing==2.4.0 23 | pytest==4.5.0 24 | pytz==2019.1 25 | PyYAML==5.4 26 | requests==2.21.0 27 | schema==0.7.0 28 | scikit-learn==0.21.0 29 | scipy==1.2.1 30 | six==1.12.0 31 | snowballstemmer==1.2.1 32 | Sphinx==2.0.1 33 | sphinx-argparse==0.2.5 34 | sphinx-jsonschema==1.8 35 | sphinx-rtd-theme==0.4.3 36 | sphinxcontrib-applehelp==1.0.1 37 | sphinxcontrib-devhelp==1.0.1 38 | sphinxcontrib-htmlhelp==1.0.2 39 | sphinxcontrib-jsmath==1.0.1 40 | sphinxcontrib-qthelp==1.0.2 41 | sphinxcontrib-serializinghtml==1.1.3 42 | urllib3==1.24.3 43 | wcwidth==0.1.7 44 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | TEST_PATH=./ 2 | 3 | .DEFAULT_GOAL := help 4 | 5 | .PHONY: help 6 | 7 | help: 8 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 9 | 10 | clean-pyc: ## Remove python artifacts. 11 | find . -name '*.pyc' -exec rm -f {} + 12 | find . -name '*.pyo' -exec rm -f {} + 13 | find . -name '*~' -exec rm -f {} + 14 | 15 | clean-build: ## Remove build artifacts. 16 | rm -fr build/ 17 | rm -fr dist/ 18 | rm -fr *.egg-info 19 | 20 | test: clean-pyc ## Run unit test suite. 21 | py.test --verbose --color=yes $(TEST_PATH) 22 | 23 | clean-docs: ## Delete all files in the docs html build directory 24 | mkdir -p docs/docs/build 25 | rm -rf docs/docs/build 26 | 27 | html-docs: clean-docs ## Build the html documentation 28 | sphinx-build -b html docs/source docs/build/html 29 | 30 | view-docs: ## Open a web browser pointed at the documentation 31 | open docs/build/html/index.html 32 | 33 | gh-pages: ## import docs to gh-pages branch and push to origin 34 | ghp-import docs/build/html -p -n -m "Autogenerated documentation" -r origin -b gh-pages 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 schmidtbri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from iris_model import __version__ 2 | 3 | from setuptools import setup 4 | from os import path 5 | 6 | # Get the long description from the README file 7 | with open(path.join(path.abspath(path.dirname(__file__)), 'README.md'), encoding='utf-8') as f: 8 | long_description = f.read() 9 | 10 | setup( 11 | name='iris_model', 12 | version=__version__, 13 | description='A simple ML model example project.', 14 | long_description=long_description, 15 | long_description_content_type='text/markdown', 16 | url='https://github.com/schmidtbri/ml-model-abc-improvements', 17 | author='Brian Schmidt', 18 | author_email='6666331+schmidtbri@users.noreply.github.com', 19 | py_modules=["ml_model_abc"], 20 | packages=["iris_model"], 21 | install_requires=[ 22 | 'scikit-learn==0.21.0', 23 | 'schema==0.7.0' 24 | ], 25 | package_data={'iris_model': [ 26 | 'model_files/svc_model.pickle' 27 | ]}, 28 | include_package_data=True, 29 | entry_points={ 30 | 'console_scripts': [ 31 | 'iris_train=iris_model.iris_train:main', 32 | ] 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | sys.path.insert(0, os.path.abspath('../../')) 6 | 7 | # this code exports the input and output schema info to a json schema file 8 | # the json schema files are then used to auto generate documentation 9 | from iris_model import __version__ 10 | from iris_model.iris_predict import IrisModel 11 | 12 | input_json_schema_string = json.dumps(IrisModel.input_schema.json_schema("https://example.com/input-schema.json")) 13 | output_json_schema_string = json.dumps(IrisModel.output_schema.json_schema("https://example.com/output-schema.json")) 14 | 15 | # create the build directory if it doesn't exist 16 | docs_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17 | if not os.path.exists(os.path.join(docs_path, "build")): 18 | os.makedirs(os.path.join(docs_path, "build")) 19 | 20 | with open(os.path.abspath('../build/input_schema.json'), "w") as file: 21 | file.write(input_json_schema_string) 22 | 23 | with open(os.path.abspath('../build/output_schema.json'), "w") as file: 24 | file.write(output_json_schema_string) 25 | 26 | # -- Project information ----------------------------------------------------- 27 | project = 'ML Model Base Class Improvements' 28 | copyright = '2019, Brian Schmidt' 29 | author = 'Brian Schmidt' 30 | 31 | # The full version, including alpha/beta/rc tags 32 | release = __version__ 33 | 34 | 35 | # -- General configuration --------------------------------------------------- 36 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx-jsonschema', 'sphinxarg.ext'] 37 | 38 | templates_path = ['_templates'] 39 | 40 | exclude_patterns = [] 41 | 42 | # -- Options for HTML output ------------------------------------------------- 43 | html_theme = 'sphinx_rtd_theme' 44 | 45 | html_static_path = ['_static'] 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm stuff 107 | .idea/ 108 | 109 | .DS_Store 110 | -------------------------------------------------------------------------------- /tests/iris_train_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | from argparse import ArgumentParser 5 | 6 | # this adds the project root to the PYTHONPATH if its not already there, it makes it easier to run the unit tests 7 | if os.path.dirname(os.path.dirname(os.path.abspath(__file__))) not in sys.path: 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | 10 | from iris_model.iris_train import train, argument_parser 11 | 12 | 13 | class IrisModelTrainTest(unittest.TestCase): 14 | def test1(self): 15 | """ testing the train() function with no hyperparameters """ 16 | # arrange, act 17 | exception_raised = False 18 | try: 19 | result = train() 20 | except Exception as e: 21 | exception_raised = True 22 | 23 | # assert 24 | self.assertFalse(exception_raised) 25 | 26 | def test2(self): 27 | """ testing the train() function with hyperparameter c only""" 28 | # arrange, act 29 | exception_raised = False 30 | try: 31 | result = train(c=100.0) 32 | except Exception as e: 33 | exception_raised = True 34 | 35 | # assert 36 | self.assertFalse(exception_raised) 37 | 38 | def test3(self): 39 | """ testing the train() function with hyperparameter gamma only """ 40 | # arrange, act 41 | exception_raised = False 42 | try: 43 | result = train(gamma=0.001) 44 | except Exception as e: 45 | exception_raised = True 46 | 47 | # assert 48 | self.assertFalse(exception_raised) 49 | 50 | def test4(self): 51 | """ testing the train() function with all hyperparameters """ 52 | # arrange, act 53 | exception_raised = False 54 | try: 55 | result = train(gamma=0.001, c=100.0) 56 | except Exception as e: 57 | exception_raised = True 58 | 59 | # assert 60 | self.assertFalse(exception_raised) 61 | 62 | def test5(self): 63 | """ testing the argument parser function """ 64 | # arrange, act 65 | result = argument_parser() 66 | 67 | # assert 68 | self.assertTrue(type(result) is ArgumentParser) 69 | 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /iris_model/iris_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pickle 4 | import argparse 5 | import traceback 6 | from sklearn import datasets 7 | from sklearn import svm 8 | 9 | 10 | def train(gamma=0.001, c=100.0): 11 | """ Function to train, serialize, and save the Iris model. 12 | 13 | :param gamma: gamma parameter used by the SVM fit method. 14 | :type gamma: float 15 | :param c: c parameter used by the SVM fit method. 16 | :type c: float 17 | :rtype: None -- function will save trained model to the iris_model/model_files directory 18 | 19 | .. note:: 20 | This code is from: https://scikit-learn.org/stable/tutorial/basic/tutorial.html 21 | 22 | """ 23 | # loading the Iris dataset 24 | iris = datasets.load_iris() 25 | 26 | # instantiating an SVM model from scikit-learn 27 | svm_model = svm.SVC(gamma=gamma, C=c) 28 | 29 | # fitting the model 30 | svm_model.fit(iris.data[:-1], iris.target[:-1]) 31 | 32 | # serializing the model and saving it to the /model_files folder 33 | dir_path = os.path.dirname(os.path.realpath(__file__)) 34 | file = open(os.path.join(dir_path, "model_files", "svc_model.pickle"), 'wb') 35 | pickle.dump(svm_model, file) 36 | file.close() 37 | 38 | 39 | def argument_parser(): 40 | """ This method creates and returns the argument parser used for parsing the cli arguments. 41 | 42 | .. note:: 43 | This function is only used to auto generate the CLI documentation. 44 | 45 | """ 46 | parser = argparse.ArgumentParser(description='Command to train the Iris model.') 47 | parser.add_argument('-gamma', action="store", dest="gamma", type=float, 48 | help='Gamma value used to train the SVM model.') 49 | parser.add_argument('-c', action="store", dest="c", type=float, help='C value used to train the SVM model.') 50 | return parser 51 | 52 | 53 | def main(): 54 | """ Entry point for the cli interface """ 55 | parser = argument_parser() 56 | results = parser.parse_args() 57 | 58 | try: 59 | # we need these if else statements to handle hyperparameters that are not provided when the cli is called 60 | if results.gamma is None and results.c is None: 61 | train() 62 | elif results.gamma is not None and results.c is None: 63 | train(gamma=results.gamma) 64 | elif results.gamma is None and results.c is not None: 65 | train(c=results.c) 66 | else: 67 | train(gamma=results.gamma, c=results.c) 68 | except Exception as e: 69 | # printing the error to the screen 70 | traceback.print_exc() 71 | # returning error code 72 | sys.exit(os.EX_SOFTWARE) 73 | 74 | # returning code 0 75 | sys.exit(os.EX_OK) 76 | -------------------------------------------------------------------------------- /ml_model_abc.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class MLModel(ABC): 5 | """ An abstract base class for ML model prediction code """ 6 | @property 7 | @abstractmethod 8 | def display_name(self): 9 | """ This abstract property returns a display name for the model. 10 | 11 | .. note:: 12 | This is a name for the model that looks good in user interfaces. 13 | 14 | """ 15 | raise NotImplementedError() 16 | 17 | @property 18 | @abstractmethod 19 | def qualified_name(self): 20 | """ This abstract property returns the qualified name of the model. 21 | 22 | .. note:: 23 | A qualified name is an unambiguous identifier for the model. It should be possible to embed it in an URL. 24 | 25 | """ 26 | raise NotImplementedError() 27 | 28 | @property 29 | @abstractmethod 30 | def description(self): 31 | """ This abstract property returns a description of the model. """ 32 | raise NotImplementedError() 33 | 34 | @property 35 | @abstractmethod 36 | def major_version(self): 37 | """ This abstract property returns the model's major version as a string. """ 38 | raise NotImplementedError() 39 | 40 | @property 41 | @abstractmethod 42 | def minor_version(self): 43 | """ This abstract property returns the model's minor version as a string. """ 44 | raise NotImplementedError() 45 | 46 | @property 47 | @abstractmethod 48 | def input_schema(self): 49 | """ This abstract property returns the schema that is accepted by the predict() method. """ 50 | raise NotImplementedError() 51 | 52 | @property 53 | @abstractmethod 54 | def output_schema(self): 55 | """ This abstract property returns the schema that is returned by the predict() method. """ 56 | raise NotImplementedError() 57 | 58 | @abstractmethod 59 | def __init__(self): 60 | """ This method holds any deserialization and initialization code for the model. """ 61 | raise NotImplementedError() 62 | 63 | @abstractmethod 64 | def predict(self, data): 65 | """ Method to make a prediction with the model. 66 | 67 | :param data: data used by the model for making a prediction 68 | :type data: object -- can be any python type 69 | :rtype: python object -- can be any python type 70 | 71 | """ 72 | raise NotImplementedError() 73 | 74 | 75 | class MLModelException(Exception): 76 | """ Exception type used to raise exceptions within MLModel derived classes """ 77 | def __init__(self, *args): 78 | Exception.__init__(self, *args) 79 | 80 | 81 | class MLModelSchemaValidationException(MLModelException): 82 | """ Exception type used to raise schema validation exceptions within MLModel derived classes """ 83 | def __init__(self, *args): 84 | MLModelException.__init__(self, *args) 85 | -------------------------------------------------------------------------------- /iris_model/iris_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from schema import Schema 4 | from numpy import array 5 | 6 | from ml_model_abc import MLModel, MLModelSchemaValidationException 7 | from iris_model import __version_info__, __display_name__, __qualified_name__, __description__ 8 | 9 | 10 | class IrisModel(MLModel): 11 | """ A demonstration of how to use the MLModel base class """ 12 | # accessing the package metadata 13 | display_name = __display_name__ 14 | qualified_name = __qualified_name__ 15 | description = __description__ 16 | major_version = __version_info__[0] 17 | minor_version = __version_info__[1] 18 | 19 | # stating the input schema of the model as a Schema object 20 | input_schema = Schema({'sepal_length': float, 21 | 'sepal_width': float, 22 | 'petal_length': float, 23 | 'petal_width': float}) 24 | 25 | # stating the output schema of the model as a Schema object 26 | output_schema = Schema({'species': str}) 27 | 28 | def __init__(self): 29 | """ Class constructor that loads and deserializes the iris model parameters. 30 | 31 | .. note:: 32 | The trained model parameters are loaded from the "model_files" directory. 33 | 34 | """ 35 | dir_path = os.path.dirname(os.path.realpath(__file__)) 36 | file = open(os.path.join(dir_path, "model_files", "svc_model.pickle"), 'rb') 37 | self._svm_model = pickle.load(file) 38 | file.close() 39 | 40 | def predict(self, data): 41 | """ Method to make a prediction with the Iris model. 42 | 43 | :param data: Data for making a prediction with the Iris model. Object must meet requirements of the input schema. 44 | :type data: dict 45 | :rtype: dict -- The result of the prediction, the output object will meet the requirements of the output schema. 46 | 47 | """ 48 | try: 49 | self.input_schema.validate(data) 50 | except Exception as e: 51 | raise MLModelSchemaValidationException("Failed to validate input data: {}".format(str(e))) 52 | 53 | # converting the incoming dictionary into a numpy array that can be accepted by the scikit-learn model 54 | X = array([data["sepal_length"], data["sepal_width"], data["petal_length"], data["petal_width"]]).reshape(1, -1) 55 | 56 | # making the prediction and extracting the result from the array 57 | y_hat = int(self._svm_model.predict(X)[0]) 58 | 59 | # converting the prediction into a string that will match the output schema of the model 60 | # this list will map the output of the scikit-learn model to the output string expected by the schema 61 | targets = ['setosa', 'versicolor', 'virginica'] 62 | 63 | # this hides the actual output of the model, which is just a number, it will ensure that any user of the 64 | # model receives output that is easily interpretable, in this case the output will be the species name 65 | species = targets[y_hat] 66 | 67 | return {"species": species} 68 | -------------------------------------------------------------------------------- /tests/iris_predict_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | from sklearn import svm 5 | from schema import SchemaError 6 | 7 | # this adds the project root to the PYTHONPATH if its not already there, it makes it easier to run the unit tests 8 | if os.path.dirname(os.path.dirname(os.path.abspath(__file__))) not in sys.path: 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | 11 | from iris_model.iris_predict import IrisModel 12 | from ml_model_abc import MLModelSchemaValidationException 13 | 14 | class IrisModelPredictTests(unittest.TestCase): 15 | def test1(self): 16 | """ testing the __init__() method """ 17 | # arrange, act 18 | model = IrisModel() 19 | 20 | # assert 21 | self.assertTrue(type(model._svm_model) is svm.SVC) 22 | 23 | def test2(self): 24 | """ testing the input schema with wrong data """ 25 | # arrange 26 | data = {'name': 'Sue', 'age': '28', 'gender': 'Squid'} 27 | 28 | # act 29 | exception_raised = False 30 | try: 31 | validated_data = IrisModel.input_schema.validate(data) 32 | except SchemaError as e: 33 | exception_raised = True 34 | 35 | # assert 36 | self.assertTrue(exception_raised) 37 | 38 | def test3(self): 39 | """ testing the input schema with correct data """ 40 | # arrange 41 | data = {'sepal_length': 1.0, 42 | 'sepal_width': 1.0, 43 | 'petal_length': 1.0, 44 | 'petal_width': 1.0} 45 | 46 | # act 47 | exception_raised = False 48 | try: 49 | validated_data = IrisModel.input_schema.validate(data) 50 | except SchemaError as e: 51 | exception_raised = True 52 | 53 | # assert 54 | self.assertFalse(exception_raised) 55 | 56 | def test4(self): 57 | """ testing the output schema with incorrect data """ 58 | # arrange 59 | data = {'species': 1.0} 60 | 61 | # act 62 | exception_raised = False 63 | try: 64 | validated_data = IrisModel.output_schema.validate(data) 65 | except SchemaError as e: 66 | exception_raised = True 67 | 68 | # assert 69 | self.assertTrue(exception_raised) 70 | 71 | def test5(self): 72 | """ testing the output schema with correct data """ 73 | # arrange 74 | data = {'species': 'setosa'} 75 | 76 | # act 77 | exception_raised = False 78 | try: 79 | validated_data = IrisModel.output_schema.validate(data) 80 | except SchemaError as e: 81 | exception_raised = True 82 | 83 | # assert 84 | self.assertFalse(exception_raised) 85 | 86 | def test6(self): 87 | """ testing the predict() method throws schema exception when given bad data """ 88 | # arrange 89 | model = IrisModel() 90 | 91 | # act 92 | exception_raised = False 93 | try: 94 | prediction = model.predict({'name': 'Sue', 'age': '28', 'gender': 'Squid'}) 95 | except MLModelSchemaValidationException as e: 96 | exception_raised = True 97 | 98 | # assert 99 | self.assertTrue(exception_raised) 100 | 101 | def test7(self): 102 | """ testing the predict() method with good data""" 103 | # arrange 104 | model = IrisModel() 105 | 106 | # act 107 | prediction = model.predict(data={'sepal_length': 1.0, 108 | 'sepal_width': 1.0, 109 | 'petal_length': 1.0, 110 | 'petal_width': 1.0}) 111 | 112 | exception_raised = False 113 | try: 114 | IrisModel.output_schema.validate(prediction) 115 | except MLModelSchemaValidationException as e: 116 | exception_raised = True 117 | 118 | # assert 119 | self.assertFalse(exception_raised) 120 | self.assertTrue(type(prediction) is dict) 121 | self.assertTrue(prediction["species"] == 'setosa') 122 | self.assertFalse(exception_raised) 123 | 124 | def test8(self): 125 | """ testing JSON schema generation """ 126 | # arrange 127 | model = IrisModel() 128 | 129 | # act 130 | json_schema = model.output_schema.json_schema("https://example.com/my-schema.json") 131 | 132 | # assert 133 | print(json_schema) 134 | self.assertTrue(type(json_schema) is dict) 135 | 136 | def test9(self): 137 | """ testing the properties of the model object """ 138 | # arrange, act 139 | model = IrisModel() 140 | 141 | # assert 142 | self.assertTrue(model.display_name == "Iris Model") 143 | self.assertTrue(model.qualified_name == "iris_model") 144 | self.assertTrue(model.description == "A machine learning model for predicting the species of a flower based on its measurements.") 145 | self.assertTrue(model.major_version == 0) 146 | self.assertTrue(model.minor_version == 1) 147 | 148 | 149 | if __name__ == '__main__': 150 | unittest.main() 151 | -------------------------------------------------------------------------------- /blog_post/post.md: -------------------------------------------------------------------------------- 1 | Title: Improving the MLModel Base Class 2 | Date: 2019-06-12 09:21 3 | Category: Blog 4 | Slug: improving-the-mlmodel-base-class 5 | Authors: Brian Schmidt 6 | Summary: In the previous blog post in this series I showed an object oriented design for a base class that does Machine Learning model prediction. The design of the base class was intentionally very simple so that I could show a simple example of how to use the base class with a scikit-learn model. I showed an easy way to publish schema metadata about the model inputs and outputs, and how to write model deserialization code so that it is hidden from the users of the model. I also showed how to hide the implementation details of the model by translating the user's input to the model's input so that the user of the model doesn't have to know how to use pandas or numpy. In this blog post I will continue to make improvements to the MLModel class and the example that I used in the previous post. 7 | 8 | This blog post continues with the ideas developed in the previous post 9 | in this series. 10 | 11 | All of the code shown in this post can be found in [this Github repository](https://github.com/schmidtbri/ml-model-abc-improvements). 12 | 13 | In the previous blog post in this series I showed an object oriented 14 | design for a base class that does Machine Learning model prediction. The 15 | design of the base class was intentionally very simple so that I could 16 | show a simple example of how to use the base class with a scikit-learn 17 | model. I showed an easy way to publish schema metadata about the model 18 | inputs and outputs, and how to write model deserialization code so that 19 | it is hidden from the users of the model. I also showed how to hide the 20 | implementation details of the model by translating the user's input to 21 | the model's input so that the user of the model doesn't have to know how 22 | to use pandas or numpy. In this blog post I will continue to make 23 | improvements to the MLModel class and the example that I used in the 24 | previous post. 25 | 26 | In this blog post I will make the iris example code from the previous 27 | post into a full python package with many features that will make the 28 | iris model easier to install and use from other python packages. I will 29 | also continue to improve the MLModel base class. In general, I want to 30 | show how to make ML code easier to install and use. 31 | 32 | When I was doing research for this blog post I found a great [blog post](https://towardsdatascience.com/building-package-for-machine-learning-project-in-python-3fc16f541693) 33 | by [Mateusz Bednarski](https://towardsdatascience.com/@mbednarski) 34 | showing how to build machine learning models as python packages. There 35 | are some similarities between what I will show here and that blog post, 36 | however, this post focuses more on the deployment of ML models into 37 | production systems, whereas Mateusz'z post focuses on packaging the 38 | training code. 39 | 40 | This blog post assumes that you have some experience with Python. I will 41 | be referencing resources for learning the tools that I will be using in 42 | the blog post. 43 | 44 | ## Making the Iris Model into a Python Package 45 | 46 | Another improvement that we can make to the example code is to make it 47 | into a full-fledged Python package. This makes it easier to use and 48 | install in other projects. The goal here is to treat ML models as just 49 | another python package, this makes it possible to leverage all of the 50 | tools that Python has for packaging and reusing code. A good guide for 51 | structuring python packages can be found 52 | [here](https://python-packaging.readthedocs.io/en/latest/#). 53 | 54 | An common pattern that can be seen in ML code is that it is almost 55 | always hard to use and deploy. This is something that teams that do 56 | machine learning know very well, since the code written by a Data 57 | Scientist almost always needs to be rewritten by a software engineer 58 | before it is possible to deploy it into production systems. Luckily, we 59 | have a lot of tools to make the transition from experimental model to 60 | production model a smoother process. In this section I will show a few 61 | simple steps that will make the example model from [the last blog 62 | post]({filename}/articles/a-simple-ml-model-base-class/post.md) 63 | into an installable Python package. To accomplish this, we will add 64 | version information to the package, add a command line interface to the 65 | training script, add Sphinx documentation, and add a setup.py file to 66 | the project. As an additional touch, we will automate the documentation 67 | process for the interface of the ML model. 68 | 69 | First of all we need to reorganize the code in the project a little bit: 70 | 71 | ``` 72 | - project_root 73 | - docs (a folder, package documentation will goes in here) 74 | - iris_model (a folder, iris package code will goes in here) 75 | - model_files (a folder, the model files go in here) 76 | - __init__.py (this file is for the python package) 77 | - iris_predict.py (the prediction code goes here) 78 | - iris_train.py ( the training script goes here) 79 | - tests (a folder, unit tests for iris_model package go here) 80 | - ml_model_abc.py (the MLModel base class goes here) 81 | - requirements.txt 82 | - setup.py (the package installation script goes here) 83 | ``` 84 | 85 | A lot of this code is shared with the [previous blog 86 | post]({filename}/articles/a-simple-ml-model-base-class/post.md), 87 | but it is reorganized here to make it possible to have an ML model that 88 | is can be installed as a Python package. 89 | 90 | ## Adding Package Versioning 91 | 92 | Python packages are usually versioned using [semantic 93 | versioning](https://semver.org/). Software packages that 94 | use semantic versioning must declare a public API. This is complicated 95 | when we want to do versioning of ML models because we have two APIs: the 96 | API for making model predictions and the API for training the model. We 97 | can deal with this complexity by tying the different components of the 98 | semantic version of the package to the prediction API and the training 99 | API of the package. 100 | 101 | I chose to version the prediction API of the model using the major and 102 | minor version components of the semantic versioning standard. The 103 | reasoning for this is that a lot of users are affected by changes in the 104 | prediction API, but not as many users are affected by changes in the 105 | training API. This is because ML models are usually used by many people 106 | but trained by a few experts. The patch number of the version can be 107 | used to version changes to the training API. 108 | 109 | As an example, whenever the ML model prediction API changes in a 110 | backward-incompatible way the major version number will go up, and 111 | whenever it changes in a backwards-compatible way the minor version will 112 | go up. This approach ensures that any user of the ML model package will 113 | know how changes in the prediction API will affect them when they 114 | install the package. A simple way to understand when to increase the 115 | major or minor version numbers is to do so when the input and output 116 | schemas of the model change. Lastly, any changes to the model training 117 | API will cause the patch version number to go up. 118 | 119 | A [common approach](https://packaging.python.org/guides/single-sourcing-package-version/) 120 | for storing version information in a python package is to put a 121 | "\_\_verison\_\_" property into the \_\_init\_\_.py module in the root 122 | of the package: 123 | 124 | ```python 125 | __version_info__ = (0, 1, 0) 126 | __version__ = ".".join([str(n) for n in __version_info__]) 127 | ``` 128 | 129 | The code above can be found 130 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/iris_model/__init__.py#L1-L2). 131 | 132 | I like to think of an ML model as a software component like any other, 133 | the only difference being that an ML model is statistically significant. 134 | Of course, being statistically significant adds a lot of complexity, but 135 | at the end of the day ML models are just code that can be managed just 136 | like any other piece of code. In this section we can see how to take a 137 | step in that direction by attaching version information to the IrisModel 138 | package. 139 | 140 | Although semantic versioning is not designed to be used for versioning 141 | models, we can apply it here to version model code and gloss over the 142 | more complicated aspects of ML models. For example, we can't use 143 | semantic versioning to version model parameters since they are not part 144 | of the codebase and don't have an API. This is a problem that I will 145 | tackle in another blog post. 146 | 147 | ## Adding a CLI interface to the Training Script 148 | 149 | When building ML models, the training code is often written in jupyter 150 | notebooks, while there are ways to automate the training process with 151 | notebooks it's a lot easier to do it through the command line. To do 152 | this we will add a simple command line interface to the Iris model 153 | training script. We will create the interface using the 154 | [argparse](https://docs.python.org/3/library/argparse.html) 155 | package and then create a function that calls the train() function when 156 | the iris\_train.py script is called from the command line. 157 | 158 | To create the argparse ArgumentParse object we create a dedicated 159 | function (the reason for this will be explained below): 160 | 161 | ```python 162 | def argument_parser(): 163 | parser = argparse.ArgumentParser(description='Command to train the Iris model.') 164 | parser.add_argument('-gamma', action="store", dest="gamma", type=float, help='Gamma value used to train the SVM model.') 165 | parser.add_argument('-c', action="store", dest="c", type=float, help='C value used to train the SVM model.') 166 | return parser 167 | ``` 168 | 169 | The code above can be found 170 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/iris_model/iris_train.py#L39-L50). 171 | 172 | To call the train() function from the command line, I created a new 173 | function called main(). The function gets a parser object, parses the 174 | incoming parameters, and calls the train() function: 175 | 176 | ```python 177 | def main(): 178 | parser = argument_parser() 179 | results = parser.parse_args() 180 | try: 181 | if results.gamma is None and results.c is None: 182 | train() 183 | elif results.gamma is not None and results.c is None: 184 | train(gamma=results.gamma) 185 | elif results.gamma is None and results.c is not None: 186 | train(c=results.c) 187 | else: 188 | train(gamma=results.gamma, c=results.c) 189 | except Exception as e: 190 | traceback.print_exc() 191 | sys.exit(os.EX_SOFTWARE) 192 | sys.exit(os.EX_OK) 193 | ``` 194 | 195 | The code above can be found 196 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/iris_model/iris_train.py#L53-L75). 197 | 198 | The reason for adding the main() function to wrap the train() function 199 | is so that the main() function can be registered as an entry point when 200 | the iris\_model package is installed. The main() function also handles 201 | parsing the command line arguments, calls the train() function, handles 202 | exceptions and returns the success or error code to the operating system 203 | when the training process is done. Another benefit of this approach is 204 | that train() function can still be imported into other code and called 205 | as a function, but now it also has a CLI interface. 206 | 207 | ## Adding Sphinx Documentation 208 | 209 | One of the great parts of working in the Python ecosystem is the Sphinx 210 | package, which is used for creating documentation from source files. 211 | There are a lot of 212 | [great](https://pythonhosted.org/an_example_pypi_project/sphinx.html) 213 | [guides](https://www.sphinx-doc.org/en/1.5/tutorial.html) 214 | for documenting your package using Sphinx, so I won't go through it 215 | again here. For this blog post, I followed these guides to create a 216 | simple documentation page and [hosted it on Github 217 | pages](https://schmidtbri.github.io/ml-model-abc-improvements/). 218 | Adding documentation is a simple process and it is done by almost all 219 | Python packages that have more than a few users. After putting together 220 | the basic documentation, I followed a few simple extra steps to fully 221 | automate the creation of the documentation for the model. 222 | 223 | First of all, I added documentation strings to all classes and methods 224 | in the iris\_model package where it made sense. [Here is an 225 | example](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/iris_model/iris_predict.py#L41-L47) 226 | of how I documented the predict() method using the docstring in the .py 227 | file. The docstring is formatted so that it can be automatically built 228 | by the sphinx autodoc extension. This extension makes it easy to extract 229 | docstrings from python packages and modules and build documentation. A 230 | good guide for using the autodoc extension can be found 231 | [here](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html). 232 | 233 | However, one problem with using the MLModel base class for writing code 234 | is that the predict() method of a class that inherits from MLModel only 235 | accepts a single parameter called "data" as input. This makes it hard to 236 | document the input schema of the model through autodoc since the data 237 | structure accepted by the model for prediction can't be easily described 238 | in the docstring. The same problem happens when we try to document the 239 | return type of the predict() method. Luckily, we can automatically 240 | extract the JSON Schema representation of the input and output schemas 241 | of the model. In order to leverage this, I used the 242 | [sphinx-jsonschema](https://sphinx-jsonschema.readthedocs.io/en/latest/) 243 | extension to automatically add the schema information to a documentation 244 | page. The process for adding it is simple, I just had to add this code 245 | to an .rst file: 246 | 247 | ``` 248 | .. jsonschema:: ../build/input_schema.json 249 | ``` 250 | 251 | The code above can be found 252 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/docs/source/iris_predict.rst). 253 | 254 | The only problem is that the input and output json schema strings are 255 | not saved to disk for the jsonschema extension to access, but are 256 | available from an instance IrisModel class. To fix this, I added [this 257 | code](https://github.com/schmidtbri/ml-model-abc-improvements/blob/00d5e558f9af7571d824d597107412ed86681e8b/docs/source/conf.py#L29-L39) 258 | to the conf.py file that creates the Sphinx documentation. The code 259 | instantiates an IrisModel object, extracts the JSON Schema strings, and 260 | saves it to a location that can then be read by the Sphinx documentation 261 | generator. The documentation that is generated can be seen 262 | [here](https://schmidtbri.github.io/ml-model-abc-improvements/iris_predict.html#iris-model-input-schema) 263 | and 264 | [here](https://schmidtbri.github.io/ml-model-abc-improvements//iris_predict.html#iris-model-output-schema). 265 | 266 | Since we are using the argparse library for creating the CLI interface 267 | for the training script, we can use the 268 | [sphinxarg.ext](https://sphinx-argparse.readthedocs.io/en/stable/index.html) 269 | Sphinx extension to automatically generate the documentation. This was 270 | as easy as adding this code to the .rst file that describes the training 271 | script: 272 | 273 | ``` 274 | .. argparse:: 275 | :module: iris_model.iris_train 276 | :func: argument_parser 277 | :prog: iris_train 278 | ``` 279 | 280 | The code above can be found 281 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/docs/source/iris_train.rst). 282 | 283 | The sphixarg.ext extension then goes to the iris\_train module and calls 284 | the argument parser function, which returns an instance of a 285 | ArgumentParser object, which is then used to generate the documentation. 286 | The results can be seen in the documentation 287 | [here](https://schmidtbri.github.io/ml-model-abc-improvements//iris_train.html#iris-training-code-cli-documentation). 288 | 289 | This section shows how it is possible to write the code of an ML model 290 | in such a way that the documentation can be created automatically. 291 | Exposing the input and output schemas of the model as JSON schema 292 | strings makes it possible for a Data Scientist to communicate the 293 | requirements of the model clearly to the end user of the model. At the 294 | same time, by exposing the hyperparameters of the training script as 295 | command line options, its becomes possible to automatically document the 296 | training process. By writing the ML model code in a certain way, it 297 | makes it possible for any changes to the code to be documented 298 | automatically whenever the documentation is generated. 299 | 300 | ## Adding a setup.py File 301 | 302 | Now that we have the ML model code structured as a Python package, 303 | versioned, and documented, we'll add a 304 | [setup.py](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/setup.py) 305 | file to the project folder. The setup.py file is used by the setuptools 306 | package to install python packages and makes the ML model easily 307 | installable in a virtual environment. A great guide for writing the 308 | setup.py file for your package can be found 309 | [here](https://github.com/kennethreitz/setup.py). 310 | 311 | In the iris\_model package setup.py file, most of the fields are very 312 | easy to understand and they are better explained in other guides. In 313 | this blog post, I'll focus on the sections of the setup.py file that had 314 | to be specifically modified for the ML model package. First of all, we 315 | want to point at the folder that contains the iris\_model package, we 316 | can do this with this line in the setup.py file: 317 | 318 | ```python 319 | packages=["iris_model"], 320 | ``` 321 | 322 | The code above can be found 323 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/setup.py#L20). 324 | 325 | Next, we need to make sure that the ml\_model\_abc.py Python module is 326 | installed along with the iris\_model package. In the future, it would be 327 | better to take this code and put it into another Python package that the 328 | iris\_model package would depend on, but for now we just need this line 329 | of code: 330 | 331 | 332 | ```python 333 | py_modules=["ml_model_abc"], 334 | ``` 335 | 336 | The code above can be found [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/setup.py#L19). 337 | 338 | Next, we take care of the model parameters. The ML model requires that 339 | the model parameters be available for loading at prediction time, the 340 | setup.py file can handle this by adding this line of code: 341 | 342 | ```python 343 | package_data={'iris_model': ['model_files/svc_model.pickle']}, 344 | include_package_data=True, 345 | ``` 346 | 347 | The code above can be found 348 | [here.](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/setup.py#L25-L28) 349 | 350 | This ensures that when the package is installed into an environment, the 351 | model parameters will be copied along with the model\_files folder. 352 | 353 | Next, we have to register the iris\_train.py script as an entry point. 354 | This makes it possible to run the training script from the command line 355 | inside of an environment where the iris\_model package is installed: 356 | 357 | ```python 358 | entry_points={ 'console_scripts': ['iris_train=iris_model.iris_train:main',] 359 | ``` 360 | 361 | The code above can be found [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/setup.py#L29-L32). 362 | 363 | Once we have all of this in the setup.py file, we can try to do a pip 364 | install on a new virtual environment. We will install the package 365 | directly from the git repository to keep things simple. The shell 366 | commands to do this are these: 367 | 368 | ```bash 369 | mkdir example 370 | cd example 371 | 372 | # creating a virtual environment 373 | python3 -m venv venv 374 | 375 | #activating the virtual environment, on a mac computer 376 | source venv/bin/activate 377 | 378 | # installing the iris_model package from the github repository 379 | pip install git+https://github.com/schmidtbri/ml-model-abc-improvements#egg=iris_model 380 | ``` 381 | 382 | Now we can test the installation by starting an interactive Python 383 | interpreter and executing this Python code: 384 | 385 | ```python 386 | >>> from iris\_model.iris\_predict import IrisModel 387 | >>> model = IrisModel() 388 | >>> model 389 | 390 | >>> model.input_schema 391 | Schema({'sepal_length': , 392 | 'sepal_width': , 393 | 'petal_length': , 394 | 'petal_width': }) 395 | >>> model.output_schema 396 | Schema({'species': }) 397 | ``` 398 | 399 | Next, we can test the CLI interface for the training code by executing 400 | the command line in the command line: 401 | 402 | ```bash 403 | iris_train -c=10.0 -gamma=0.01 404 | ``` 405 | 406 | This section showed how to install the iris\_model Python package using 407 | common Python packaging tools, and how to use and retrain the model in 408 | different Python environment. 409 | 410 | ## Model Metadata in the MLModel Base Class 411 | 412 | In the [previous blog 413 | post](https://towardsdatascience.com/a-simple-ml-model-base-class-ab40e2febf13) 414 | we showed an MLModel base class with two required abstract properties: 415 | "input\_schema" and "output\_schema". These two properties were required 416 | to be provided by any class that derived from the MLModel base class and 417 | were used to publish schema metadata about the input and output data of 418 | the model. In order to keep things simple, I chose not to expose more 419 | metadata through class properties, however there are several other 420 | pieces of metadata that would be useful to expose to the outside world. 421 | For example: 422 | 423 | - display_name, a property that returns a display name for the model 424 | - qualified_name, a property that returns the qualified name of the model, a qualified name is an unambiguous identifier for the model 425 | - description, a property that returns a description of the model 426 | - major_version, this property returns the model's major version as a string 427 | - minor_version, this property returns the model's minor version as a string 428 | 429 | These properties are exposed as object properties and can be accessed 430 | the same way as the input\_schema and output\_schema properties. The new 431 | code for the MLModel base class now looks like this: 432 | 433 | ```python 434 | class MLModel(ABC): 435 | @property 436 | @abstractmethod 437 | def display_name(self): 438 | raise NotImplementedError() 439 | 440 | @property 441 | @abstractmethod 442 | def qualified_name(self): 443 | raise NotImplementedError() 444 | 445 | @property 446 | @abstractmethod 447 | def description(self): 448 | raise NotImplementedError() 449 | 450 | @property 451 | @abstractmethod 452 | def major_version(self): 453 | raise NotImplementedError() 454 | 455 | @property 456 | @abstractmethod 457 | def minor_version(self): 458 | raise NotImplementedError() 459 | 460 | @property 461 | @abstractmethod 462 | def input_schema(self): 463 | raise NotImplementedError() 464 | 465 | @property 466 | @abstractmethod 467 | def output_schema(self): 468 | raise NotImplementedError() 469 | 470 | @abstractmethod 471 | def __init__(self): 472 | raise NotImplementedError() 473 | 474 | @abstractmethod 475 | def predict(self, data): 476 | self.input_schema.validate(data) 477 | ``` 478 | 479 | The code above can be found [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/ml_model_abc.py#L4-L74). 480 | 481 | The new MLModel base class looks exactly like the previous 482 | implementation, but now also requires the properties described above to 483 | be published as instance properties. 484 | 485 | This metadata is added in the \_\_init\_\_.py file of the iris\_model 486 | package, since it is applicable to the whole package: 487 | 488 | ```python 489 | # a display name for the model 490 | __display_name__ = "Iris Model" 491 | 492 | # returning the package name as the qualified name for the model 493 | __qualified_name__ = __name__.split(".")[0] 494 | 495 | # a description of the model 496 | __description__ = "A machine learning model for predicting the species of a flower based on its measurements." 497 | ``` 498 | 499 | The code above can be found 500 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/iris_model/__init__.py#L4-L11). 501 | 502 | In order to show how a class that derives from the MLModel base class 503 | can publish these properties, we can modify the Iris model example used 504 | in the [previous blog 505 | post](https://towardsdatascience.com/a-simple-ml-model-base-class-ab40e2febf13). 506 | The Iris model class now looks like this: 507 | 508 | ```python 509 | from ml_model_abc import MLModel 510 | from iris_model import __version_info__, __display_name__, __qualified_name__, __description__ 511 | 512 | class IrisModel(MLModel): 513 | # accessing the package metadata 514 | display_name = __display_name__ 515 | qualified_name = __qualified_name__ 516 | description = __description__ 517 | major_version = __version_info__[0] 518 | minor_version = __version_info__[1] 519 | 520 | # stating the input schema of the model as a Schema object 521 | input_schema = Schema({'sepal_length': float, 522 | 'sepal_width': float, 523 | 'petal_length': float, 524 | 'petal_width': float}) 525 | 526 | # stating the output schema of the model as a Schema object 527 | output_schema = Schema({'species': str}) 528 | 529 | def __init__(self): 530 | dir_path = os.path.dirname(os.path.realpath(__file__)) 531 | file = open(os.path.join(dir_path, 532 | "model_files", 533 | "svc_model.pickle"), 'rb') 534 | self._svm_model = pickle.load(file) 535 | file.close() 536 | 537 | def predict(self, data): 538 | super().predict(data=data) 539 | X = array([data["sepal_length"], 540 | data["sepal_width"], 541 | data["petal_length"], 542 | data["petal_width"]]).reshape(1, -1) 543 | 544 | y_hat = int(self._svm_model.predict(X)[0]) 545 | targets = ['setosa', 'versicolor', 'virginica'] 546 | species = targets[y_hat] 547 | return {"species": species} 548 | ``` 549 | 550 | The code above can be found 551 | [here](https://github.com/schmidtbri/ml-model-abc-improvements/blob/master/iris_model/iris_predict.py#L1-L65). 552 | 553 | The display name, qualified name, and description properties are set as 554 | string class properties in the IrisModel class, and they are accessed 555 | from the \_\_init\_\_ module. The major and minor version properties are 556 | extracted from the \_\_version\_info\_\_ property. 557 | 558 | There can be some situations in which a single Python package will hold 559 | more than one MLModel derived class. In that case the display name, 560 | qualified name, and description metadata would be set individually 561 | within the MLModel derived class itself instead of accessing it from the 562 | package-wide metadata stored in the \_\_init\_\_ module. 563 | 564 | The class properties are now easily accessible from the model object, to 565 | show this we can instantiate the object and access the properties: 566 | 567 | 568 | ```python 569 | >>> from iris_model.iris_predict import IrisModel 570 | >>> iris_model = IrisModel() 571 | >>> iris_model.qualified_name 572 | 'iris\_model' 573 | >>> iris_model.display_name 574 | 'Iris Model' 575 | ``` 576 | 577 | These new metadata properties can now be used to introspect information 578 | about the model more easily, this also makes it possible to more easily 579 | manage many MLModel model objects in the same python process. 580 | 581 | ## Future Improvements 582 | 583 | In this blog post we showed how to do versioning of an ML model using 584 | standard conventions of python packages, however the model parameters of 585 | the Iris model also need to be versioned over time and metadata about 586 | them also needs to be kept. This is a problem that I will tackle in a 587 | future blog post. 588 | 589 | Another problem that we did not tackle in this blog post is how to have 590 | a more complex API for ML models. For example, the Iris model is only 591 | allowed to have one predict() method, this makes it impossible to do 592 | more complex operations with the Iris model. In a future blog post I 593 | will show how to modify the ML model base class to allow this. 594 | --------------------------------------------------------------------------------