├── .github └── ISSUE_TEMPLATE.md ├── .gitignore ├── .travis.yml ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── LICENSE ├── MANIFEST.in ├── README.rst ├── ml2rt ├── __init__.py ├── exporter.py ├── importer.py ├── utils.py └── version.py ├── requirements_dev.txt ├── setup.cfg ├── setup.py ├── tests ├── test_mlut.py └── testdata │ └── script.txt └── tox.ini /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * MLutil version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | .idea -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Config file for automatic testing at travis-ci.org 2 | 3 | language: python 4 | python: 5 | - 3.7 6 | - 3.6 7 | 8 | # Command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors 9 | install: pip install -U tox-travis 10 | 11 | # Command to run tests, e.g. python setup.py test 12 | script: tox 13 | 14 | 15 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Sherin Thomas 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/hhsecond/ml2rt/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 30 | wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. Anything tagged with "enhancement" 36 | and "help wanted" is open to whoever wants to implement it. 37 | 38 | Write Documentation 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | ml2rt could always use more documentation, whether as part of the 42 | official ml2rt docs, in docstrings, or even on the web in blog posts, 43 | articles, and such. 44 | 45 | Submit Feedback 46 | ~~~~~~~~~~~~~~~ 47 | 48 | The best way to send feedback is to file an issue at https://github.com/hhsecond/ml2rt/issues. 49 | 50 | If you are proposing a feature: 51 | 52 | * Explain in detail how it would work. 53 | * Keep the scope as narrow as possible, to make it easier to implement. 54 | * Remember that this is a volunteer-driven project, and that contributions 55 | are welcome :) 56 | 57 | Get Started! 58 | ------------ 59 | 60 | Ready to contribute? Here's how to set up `ml2rt` for local development. 61 | 62 | 1. Fork the `ml2rt` repo on GitHub. 63 | 2. Clone your fork locally:: 64 | 65 | $ git clone git@github.com:your_name_here/ml2rt.git 66 | 67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 68 | 69 | $ mkvirtualenv ml2rt 70 | $ cd ml2rt/ 71 | $ python setup.py develop 72 | 73 | 4. Create a branch for local development:: 74 | 75 | $ git checkout -b name-of-your-bugfix-or-feature 76 | 77 | Now you can make your changes locally. 78 | 79 | 5. When you're done making changes, check that your changes pass flake8 and the 80 | tests, including testing other Python versions with tox:: 81 | 82 | $ flake8 ml2rt tests 83 | $ python setup.py test or py.test 84 | $ tox 85 | 86 | To get flake8 and tox, just pip install them into your virtualenv. 87 | 88 | 6. Commit your changes and push your branch to GitHub:: 89 | 90 | $ git add . 91 | $ git commit -m "Your detailed description of your changes." 92 | $ git push origin name-of-your-bugfix-or-feature 93 | 94 | 7. Submit a pull request through the GitHub website. 95 | 96 | Pull Request Guidelines 97 | ----------------------- 98 | 99 | Before you submit a pull request, check that it meets these guidelines: 100 | 101 | 1. The pull request should include tests. 102 | 2. If the pull request adds functionality, the docs should be updated. Put 103 | your new functionality into a function with a docstring, and add the 104 | feature to the list in README.rst. 105 | 3. The pull request should work for Python 2.7, 3.4, 3.5 and 3.6, and for PyPy. Check 106 | https://travis-ci.org/hhsecond/ml2rt/pull_requests 107 | and make sure that the tests pass for all supported Python versions. 108 | 109 | Tips 110 | ---- 111 | 112 | To run a subset of tests:: 113 | 114 | $ py.test tests.test_ml2rt 115 | 116 | 117 | Deploying 118 | --------- 119 | 120 | A reminder for the maintainers on how to deploy. 121 | Make sure all your changes are committed (including an entry in HISTORY.rst). 122 | Then run:: 123 | 124 | $ bumpversion patch # possible: major / minor / patch 125 | $ git push 126 | $ git push --tags 127 | 128 | Travis will then deploy to PyPI if tests pass. 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache Software License 2.0 2 | 3 | Copyright (c) 2019, [tensor]werk 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include LICENSE 4 | include README.rst 5 | 6 | recursive-include tests * 7 | recursive-exclude * __pycache__ 8 | recursive-exclude * *.py[co] 9 | 10 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 11 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ===================================================== 2 | ml2rt - Utilities for taking ML to different runtimes 3 | ===================================================== 4 | 5 | 6 | Machine learning utilities for model conversion, serialization, loading etc 7 | 8 | 9 | * Free software: Apache Software License 2.0 10 | 11 | Installation 12 | ------------ 13 | 14 | :: 15 | 16 | pip install ml2rt 17 | 18 | 19 | Documentation 20 | ------------- 21 | 22 | ml2rt provides some convenient functions to convert, save & load machine learning models. It currently supports Tensorflow, PyTorch, Sklearn, Spark and ONNX but frameworks like xgboost, coreml are on the way. 23 | 24 | Saving Tensorflow model 25 | *********************** 26 | 27 | .. code-block:: python 28 | 29 | import tensorflow as tf 30 | from ml2rt import save_tensorflow 31 | # train your model here 32 | sess = tf.Session() 33 | save_tensorflow(sess, path, output=['output']) 34 | 35 | Saving PyTorch model 36 | ******************** 37 | 38 | .. code-block:: python 39 | 40 | # it has to be a torchscript graph made by tracing / scripting 41 | from ml2rt import save_torch 42 | save_torch(torch_script_graph, path) 43 | 44 | Saving ONNX model 45 | ***************** 46 | 47 | .. code-block:: python 48 | 49 | from ml2rt import save_onnx 50 | save_onnx(onnx_model, path) 51 | 52 | Saving sklearn model 53 | ******************** 54 | 55 | .. code-block:: python 56 | 57 | from ml2rt import save_sklearn 58 | prototype = np.array(some_shape, dtype=some_dtype) # Equivalent to the input of the model 59 | save_sklearn(sklearn_model, path, prototype=prototype) 60 | 61 | # or 62 | 63 | # some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str object 64 | save_sklearn(sklearn_model, path, shape=some_shape, dtype=some_dtype) 65 | 66 | # or 67 | 68 | # some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str object 69 | inital_types = utils.guess_onnx_tensortype(shape=shape, dtype=dtype) 70 | save_sklearn(sklearn_model, path, initial_types=initial_types) 71 | 72 | Saving sparkml model 73 | ******************** 74 | 75 | .. code-block:: python 76 | 77 | from ml2rt import save_sparkml 78 | prototype = np.array(some_shape, dtype=some_dtype) # Equivalent to the input of the model 79 | save_sparkml(spark_model, path, prototype=prototype) 80 | 81 | # or 82 | 83 | # some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str object 84 | save_sparkml(spark_model, path, shape=some_shape, dtype=some_dtype) 85 | 86 | # or 87 | 88 | # some_shape has to be a tuple and some_dtype has to be a np.dtype, np.dtype.type or str object 89 | inital_types = utils.guess_onnx_tensortype(shape=shape, dtype=dtype) 90 | save_sparkml(spark_model, path, initial_types=initial_types) 91 | 92 | Sklearn and sparkml models will be converted to ONNX first and then save to the disk. These models can be executed using ONNXRuntime, RedisAI etc. ONNX conversion needs to know the type of the input nodes and hence we have to pass shape & dtype or a prototype from where the utility can infer the shape & dtype or an initial_type object which is understood by the conversion utility. Frameworks like sparkml allows users to have heterogeneous inputs with more than one type. In such cases, use `guess_onnx_tensortypes` and create more than one initial_types which can be passed to save function as a list 93 | 94 | 95 | Loading model & script 96 | ********************** 97 | Loading function can load both single file models like freezed tensorflow model or torchscript model or onnx model as well as SavedModel from tensorflow 98 | 99 | .. code-block:: python 100 | 101 | model = ml2rt.load_model(path) 102 | 103 | script = ml2rt.load_script(script) 104 | -------------------------------------------------------------------------------- /ml2rt/__init__.py: -------------------------------------------------------------------------------- 1 | from .exporter import save_tensorflow, save_onnx, save_sklearn, save_sparkml, save_torch 2 | from .importer import load_model, load_script 3 | 4 | 5 | __author__ = """[tensor]werk""" 6 | __email__ = 'sherin@tensorwerk.com' 7 | -------------------------------------------------------------------------------- /ml2rt/exporter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Collection 4 | 5 | from . import utils 6 | 7 | 8 | def save_tensorflow(sess, path: str, output: Collection[str]): 9 | """ 10 | Serialize a tensorflow session object to disk using TF utilities. 11 | :param sess: Tensorflow session object. 12 | :param path: Path to which the object will be serialized 13 | :param output: List of output nodes, required for TF sess to serialize 14 | """ 15 | 16 | # TODO: TF 1.14+ has issue with __spec__ 17 | if not utils.is_installed('tensorflow'): 18 | raise RuntimeError('Please install Tensorflow to use this feature.') 19 | import tensorflow as tf 20 | graph_def = sess.graph_def 21 | 22 | # clearing device information 23 | for node in graph_def.node: 24 | node.device = "" 25 | frozen = tf.graph_util.convert_variables_to_constants( 26 | sess, graph_def, output) 27 | directory = os.path.dirname(path) 28 | file = os.path.basename(path) 29 | tf.io.write_graph(frozen, directory, file, as_text=False) 30 | 31 | 32 | def save_torch(graph, path: str): 33 | """ 34 | Serialize a torchscript object to disk using PyTorch utilities. 35 | :param graph: torchscript object 36 | :param path: Path to which the object will be serialized 37 | """ 38 | if not utils.is_installed('torch'): 39 | raise RuntimeError('Please install PyTorch to use this feature.') 40 | import torch 41 | # TODO how to handle the cpu/gpu 42 | if graph.training is True: 43 | warnings.warn( 44 | 'Graph is in training mode. Converting to evaluation mode') 45 | graph.eval() 46 | torch.jit.save(graph, path) 47 | 48 | 49 | def save_onnx(graph, path: str): 50 | """ 51 | Serialize an ONNX object to disk. 52 | :param graph: ONNX graph object 53 | :param path: Path to which the object will be serialized 54 | """ 55 | with open(path, 'wb') as f: 56 | f.write(graph.SerializeToString()) 57 | 58 | 59 | def save_sklearn(model, path: str, initial_types=None, prototype=None, shape=None, dtype=None): 60 | """ 61 | Convert a scikit-learn model to onnx first and then save it to disk using `save_onnx`. 62 | We use onnxmltool to do the conversion from scikit-learn to ONNX and currently not all the 63 | scikit-learn models are supported by onnxmltools. A list of supported models can be found 64 | in the documentation. 65 | :param model: Scikit-learn model 66 | :param path: Path to which the object will be serialized 67 | :param initial_types: a python list. Each element is a tuple of a variable name and a type 68 | defined in onnxconverter_common.data_types. If initial type is empty, we'll guess the 69 | required information from prototype or infer it by using shape and dtype. 70 | :param prototype: A numpy array that gives shape and type information. This is ignored if 71 | initial_types is not None 72 | :param shape: Shape of the input to the model. Ignored if initial_types or prototype is not None 73 | :param dtype: redisai.DType object which represents the type of the input to the model. 74 | Ignored if initial_types or prototype is not None 75 | """ 76 | if not utils.is_installed(['onnxmltools', 'skl2onnx', 'pandas']): 77 | raise RuntimeError('Please install onnxmltools, skl2onnx & pandas to use this feature.') 78 | from onnxmltools import convert_sklearn 79 | if initial_types is None: 80 | initial_types = [utils.guess_onnx_tensortype(prototype, shape, dtype)] 81 | if not isinstance(initial_types, list): 82 | raise TypeError(( 83 | "`initial_types` has to be a list. " 84 | "If you have only one initial_type, put that into a list")) 85 | serialized = convert_sklearn(model, initial_types=initial_types) 86 | save_onnx(serialized, path) 87 | 88 | 89 | def save_sparkml( 90 | model, path, initial_types=None, prototype=None, 91 | shape=None, dtype=None, spark_session=None): 92 | """ 93 | Convert a spark model to onnx first and then save it to disk using `save_onnx`. 94 | We use onnxmltool to do the conversion from spark to ONNX and currently not all the 95 | spark models are supported by onnxmltools. A list of supported models can be found 96 | in the documentation. 97 | :param model: PySpark model object 98 | :param path: Path to which the object will be serialized 99 | :param initial_types: a python list. Each element is a tuple of a variable name and a type 100 | defined in onnxconverter_common.data_types. If initial type is empty, we'll guess the 101 | required information from prototype or infer it by using shape and dtype. 102 | :param prototype: A numpy array that gives shape and type information. This is ignored if 103 | initial_types is not None 104 | :param shape: Shape of the input to the model. Ignored if initial_types or prototype is not None 105 | :param dtype: redisai.DType object which represents the type of the input to the model. 106 | Ignored if initial_types or prototype is not None 107 | """ 108 | if not utils.is_installed(['onnxmltools', 'pyspark']): 109 | raise RuntimeError('Please install onnxmltools & pyspark to use this feature.') 110 | from onnxmltools import convert_sparkml 111 | if initial_types is None: 112 | initial_types = [utils.guess_onnx_tensortype(prototype, shape, dtype)] 113 | if not isinstance(initial_types, list): 114 | raise TypeError(( 115 | "`initial_types` has to be a list. " 116 | "If you have only one initial_type, put that into a list")) 117 | # TODO: test issue with passing different datatype for numerical values 118 | # known issue: https://github.com/onnx/onnxmltools/tree/master/onnxmltools/convert/sparkml 119 | serialized = convert_sparkml(model, initial_types=initial_types, spark_session=spark_session) 120 | save_onnx(serialized, path) 121 | -------------------------------------------------------------------------------- /ml2rt/importer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from collections import namedtuple 4 | 5 | from .exporter import save_tensorflow 6 | 7 | model_container = namedtuple('model_container', field_names=['data', 'inputs', 'outputs']) 8 | 9 | 10 | def load_model(path: str, tags=None, signature=None): 11 | """ 12 | Return the binary data. If the input path is a directory of SavedModel from 13 | tensorflow, it converts the SavedModel to freezed model protbuf and then read 14 | it as binary. It also returns the input and output lists along with the binary 15 | model data in case of SavedModel. 16 | 17 | :param path: File path from where the native model or the rai models are saved 18 | :param tags: Tags for reading from SavedModel 19 | :param signature: SignatureDef for reading from SavedModel 20 | """ 21 | path = Path(path) 22 | if path.is_dir(): # Expecting TF SavedModel 23 | import tensorflow as tf 24 | if tf.__version__ > '1.15.9': 25 | raise RuntimeError("Current tensorflow version must be 1.x (preferably 1.15)" 26 | "even if the model is built with 2.x. If this that doesn't" 27 | "work, follow the steps mentioned in this guide which uses tracing - " 28 | "https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/" 29 | "\nBe warned that creating graph by using tracing might not give you" 30 | "expected result if your graph is relying on dynamic ops internally") 31 | if tags is None: 32 | tags = ['serve'] 33 | if signature is None: 34 | signature = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 35 | graph = tf.Graph() 36 | sess = tf.Session(graph=graph) 37 | with graph.as_default(): 38 | try: 39 | model = tf.saved_model.loader.load(sess=sess, tags=tags, export_dir=str(path)) 40 | except Exception as e: 41 | raise RuntimeError("We could not load the provided SavedModel. It is probably" 42 | "caused by the tensorflow version difference (You might have" 43 | "used a different version of tensorflow for saving the model)." 44 | "Above stacktrace must have more information") from e 45 | # TODO: try with multiple input/output 46 | inputs = [] 47 | for val in model.signature_def[signature].inputs.values(): 48 | inputs.append(val.name.split(':')[0]) 49 | outputs = [] 50 | for val in model.signature_def[signature].outputs.values(): 51 | outputs.append(val.name.split(':')[0]) 52 | tmp_path = Path('model.pb') 53 | save_tensorflow(sess, str(tmp_path), outputs) 54 | with open(tmp_path, 'rb') as f: 55 | data = f.read() 56 | tmp_path.unlink() 57 | return model_container(data, inputs, outputs) 58 | else: 59 | with open(path, 'rb') as f: 60 | return f.read() 61 | 62 | 63 | def load_script(path: str): 64 | """ 65 | Load script is a convinient method that just reads the content from the file 66 | and returns it, as of now. But eventually can do validations using PyTorch's 67 | scirpt compile utility and clean up the input files for user etc 68 | """ 69 | with open(path, 'rb') as f: 70 | return f.read() 71 | -------------------------------------------------------------------------------- /ml2rt/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def _get_tensortype(shape, dtype): 5 | # TODO: remove this once this is added to onnxconverter_common or move this there 6 | from onnxconverter_common import data_types as onnx_dtypes 7 | 8 | if dtype == 'float32': 9 | onnx_dtype = onnx_dtypes.FloatTensorType(shape) 10 | elif dtype == 'float64': 11 | onnx_dtype = onnx_dtypes.DoubleTensorType(shape) 12 | elif dtype in ('int64', 'uint64'): 13 | onnx_dtype = onnx_dtypes.Int64TensorType(shape) 14 | elif dtype in ('int32', 'uint32'): # noqa 15 | onnx_dtype = onnx_dtypes.Int32TensorType(shape) 16 | else: 17 | raise NotImplementedError(f"'{dtype.name}' is not supported by ONNXRuntime. ") 18 | return onnx_dtype 19 | 20 | 21 | def guess_onnx_tensortype(prototype=None, shape=None, dtype=None, node_name='features'): 22 | # TODO: perhaps move this to onnxconverter_common 23 | import numpy as np 24 | 25 | if prototype is not None: 26 | if hasattr(prototype, 'shape') and hasattr(prototype, 'dtype'): 27 | shape = prototype.shape 28 | dtype = prototype.dtype.name 29 | else: 30 | raise TypeError("`prototype` has to be a valid `numpy.ndarray` of shape of your input") 31 | else: 32 | if not all([shape, dtype]): 33 | raise RuntimeError( 34 | "Did you forget to pass `prototype` or (`shape` & `dtype`)") 35 | try: 36 | dtype = np.dtype(dtype).name 37 | except TypeError: 38 | raise TypeError( 39 | '`dtype` not understood. ' 40 | 'It has to be a valid `np.dtype` or `np.dtype.type` object ' 41 | 'or an `str` that represents a valid numpy data type') 42 | if not isinstance(shape, tuple) or isinstance(shape, list): 43 | raise RuntimeError("Inferred `shape` attribute is not a tuple / list") 44 | return node_name, _get_tensortype(shape, dtype) 45 | 46 | 47 | def is_installed(packages): 48 | if not isinstance(packages, list): 49 | packages = [packages] 50 | for p in packages: 51 | if importlib.util.find_spec(p) is None: 52 | return False 53 | return True 54 | 55 | -------------------------------------------------------------------------------- /ml2rt/version.py: -------------------------------------------------------------------------------- 1 | # Store the version here so: 2 | # 1) we don't load dependencies by storing it in __init__.py 3 | # 2) we can import it in setup.py for the same reason 4 | # 3) we can import it into your module module 5 | __version__ = '0.2.0' 6 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pip==18.1 2 | bumpversion==0.5.3 3 | wheel==0.32.1 4 | watchdog==0.9.0 5 | flake8==3.5.0 6 | tox==3.5.2 7 | coverage==4.5.1 8 | Sphinx==1.8.1 9 | twine==1.12.1 10 | 11 | pytest==3.8.2 12 | pytest-runner==4.2 13 | numpy==1.16.4 14 | onnx==1.5.0 15 | pyspark==2.4.3 16 | sklearn 17 | tensorflow==1.15.2 18 | torch==1.1.0 19 | 20 | 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:ml2rt/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | 20 | [aliases] 21 | # Define setup.py command aliases here 22 | test = pytest 23 | 24 | [tool:pytest] 25 | collect_ignore = ['setup.py'] 26 | 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.rst') as readme_file: 4 | readme = readme_file.read() 5 | 6 | exec(open('ml2rt/version.py').read()) 7 | 8 | requirements = [] 9 | extras_require = { 10 | 'tensorflow': ['tensorflow'], 11 | 'pytorch': ['torch'], 12 | 'sklearn': ['sklearn', 'skl2onnx', 'pandas', 'onnxmltools', 'onnxconverter_common', 'numpy'], 13 | 'sparkml': ['pyspark', 'onnxmltools', 'onnxconverter_common', 'numpy'], 14 | 'onnx': ['onnx'], 15 | 'all': [] 16 | } 17 | 18 | for packages in extras_require.values(): 19 | extras_require['all'].extend(packages) 20 | 21 | setup_requirements = ['pytest-runner', ] 22 | 23 | test_requirements = ['pytest', ] 24 | 25 | setup( 26 | author="Sherin Thomas", 27 | author_email='sherin@tensorwerk.com', 28 | classifiers=[ 29 | 'Development Status :: 2 - Pre-Alpha', 30 | 'Intended Audience :: Developers', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Natural Language :: English', 33 | 'Programming Language :: Python :: 3', 34 | 'Programming Language :: Python :: 3 :: Only', 35 | 'Programming Language :: Python :: 3.6', 36 | 'Programming Language :: Python :: 3.7', 37 | ], 38 | description="Machine learning utilities for model conversion, serialization, loading etc", 39 | install_requires=requirements, 40 | extras_require=extras_require, 41 | license="Apache Software License 2.0", 42 | long_description=readme + '\n\n', 43 | include_package_data=True, 44 | keywords='ml2rt', 45 | name='ml2rt', 46 | packages=find_packages(include=['ml2rt']), 47 | setup_requires=setup_requirements, 48 | test_suite='tests', 49 | tests_require=test_requirements, 50 | url='https://github.com/hhsecond/ml2rt', 51 | version=__version__, # comes from ml2rt/version.py 52 | zip_safe=False, 53 | ) 54 | -------------------------------------------------------------------------------- /tests/test_mlut.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import sys 4 | 5 | from ml2rt import load_model, load_script 6 | from ml2rt import ( 7 | save_tensorflow, save_torch, save_onnx, save_sklearn, save_sparkml) 8 | from ml2rt import utils 9 | import tensorflow as tf 10 | import torch 11 | from sklearn import linear_model, datasets 12 | import onnx 13 | import numpy as np 14 | from pyspark.sql import SparkSession 15 | from pyspark.ml.linalg import Vectors 16 | from pyspark.ml.regression import LinearRegression 17 | import pyspark 18 | 19 | 20 | def get_tf_graph(): 21 | x = tf.placeholder(tf.float32, name='input') 22 | W = tf.Variable(5., name='W') 23 | b = tf.Variable(3., name='b') 24 | y = x * W + b 25 | y = tf.identity(y, name='output') 26 | 27 | 28 | class MyModule(torch.jit.ScriptModule): 29 | def __init__(self): 30 | super(MyModule, self).__init__() 31 | 32 | @torch.jit.script_method 33 | def forward(self, a, b): 34 | return a + b 35 | 36 | 37 | def get_sklearn_model_and_prototype(): 38 | model = linear_model.LinearRegression() 39 | boston = datasets.load_boston() 40 | X, y = boston.data, boston.target 41 | model.fit(X, y) 42 | return model, X[0].reshape(1, -1).astype(np.float32) 43 | 44 | 45 | def get_onnx_model(): 46 | torch_model = torch.nn.ReLU() 47 | # maybe there exists, but couldn't find a way to pass 48 | # the onnx model without writing to disk 49 | torch.onnx.export(torch_model, torch.rand(1, 1), 'model.onnx') 50 | onnx_model = onnx.load('model.onnx') 51 | os.remove('model.onnx') 52 | return onnx_model 53 | 54 | 55 | def get_spark_model_and_prototype(): 56 | executable = sys.executable 57 | os.environ["SPARK_HOME"] = pyspark.__path__[0] 58 | os.environ["PYSPARK_PYTHON"] = executable 59 | os.environ["PYSPARK_DRIVER_PYTHON"] = executable 60 | spark = SparkSession.builder.appName("redisai_test").getOrCreate() 61 | # label is input + 1 62 | data = spark.createDataFrame([ 63 | (2.0, Vectors.dense(1.0)), 64 | (3.0, Vectors.dense(2.0)), 65 | (4.0, Vectors.dense(3.0)), 66 | (5.0, Vectors.dense(4.0)), 67 | (6.0, Vectors.dense(5.0)), 68 | (7.0, Vectors.dense(6.0)) 69 | ], ["label", "features"]) 70 | lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") 71 | model = lr.fit(data) 72 | prototype = np.array([[1.0]], dtype=np.float32) 73 | return model, prototype 74 | 75 | 76 | class TestModel: 77 | # TODO: Detailed tests 78 | 79 | def test_TFGraph(self): 80 | get_tf_graph() 81 | init = tf.global_variables_initializer() 82 | sess = tf.Session() 83 | sess.run(init) 84 | path = f'{time.time()}.pb' 85 | save_tensorflow(sess, path, output=['output']) 86 | assert os.path.exists(path) 87 | os.remove(path) 88 | 89 | def test_PyTorchGraph(self): 90 | torch_graph = MyModule() 91 | path = f'{time.time()}.pt' 92 | save_torch(torch_graph, path) 93 | assert os.path.exists(path) 94 | os.remove(path) 95 | 96 | def test_ScriptLoad(self): 97 | dirname = os.path.dirname(__file__) 98 | path = f'{dirname}/testdata/script.txt' 99 | load_script(path) 100 | 101 | def test_ONNXGraph(self): 102 | onnx_model = get_onnx_model() 103 | path = f'{time.time()}.onnx' 104 | save_onnx(onnx_model, path) 105 | assert os.path.exists(path) 106 | load_model(path) 107 | os.remove(path) 108 | 109 | def test_SKLearnGraph(self): 110 | sklearn_model, prototype = get_sklearn_model_and_prototype() 111 | 112 | # saving with prototype 113 | path = f'{time.time()}.onnx' 114 | save_sklearn(sklearn_model, path, prototype=prototype) 115 | assert os.path.exists(path) 116 | load_model(path) 117 | os.remove(path) 118 | 119 | # saving with shape and dtype 120 | shape = prototype.shape 121 | if prototype.dtype == np.float32: 122 | dtype = prototype.dtype 123 | else: 124 | raise RuntimeError("Test is not configured to run with another type") 125 | path = f'{time.time()}.onnx' 126 | save_sklearn(sklearn_model, path, shape=shape, dtype=dtype) 127 | assert os.path.exists(path) 128 | load_model(path) 129 | os.remove(path) 130 | 131 | # saving with initial_types 132 | inital_types = utils.guess_onnx_tensortype(shape=shape, dtype=dtype) 133 | path = f'{time.time()}.onnx' 134 | save_sklearn(sklearn_model, path, initial_types=[inital_types]) 135 | assert os.path.exists(path) 136 | load_model(path) 137 | os.remove(path) 138 | 139 | def test_SparkMLGraph(self): 140 | spark_model, prototype = get_spark_model_and_prototype() 141 | 142 | # saving with prototype 143 | path = f'{time.time()}.onnx' 144 | save_sparkml(spark_model, path, prototype=prototype) 145 | load_model(path) 146 | assert os.path.exists(path) 147 | os.remove(path) 148 | 149 | # saving with shape and dtype 150 | shape = prototype.shape 151 | if prototype.dtype == np.float32: 152 | dtype = prototype.dtype 153 | else: 154 | raise RuntimeError("Test is not configured to run with another type") 155 | path = f'{time.time()}.onnx' 156 | save_sparkml(spark_model, path, shape=shape, dtype=dtype) 157 | assert os.path.exists(path) 158 | load_model(path) 159 | os.remove(path) 160 | 161 | # saving with initial_types 162 | inital_types = utils.guess_onnx_tensortype(shape=shape, dtype=dtype) 163 | path = f'{time.time()}.onnx' 164 | save_sparkml(spark_model, path, initial_types=[inital_types]) 165 | assert os.path.exists(path) 166 | load_model(path) 167 | os.remove(path) 168 | -------------------------------------------------------------------------------- /tests/testdata/script.txt: -------------------------------------------------------------------------------- 1 | def bar(a, b): 2 | return a + b 3 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py36, py37, flake8 3 | 4 | [travis] 5 | python = 6 | 3.6: py36 7 | 3.7: py37 8 | 9 | [testenv:flake8] 10 | basepython = python 11 | deps = flake8 12 | commands = flake8 ml2rt 13 | 14 | [testenv] 15 | setenv = 16 | PYTHONPATH = {toxinidir} 17 | deps = 18 | -r{toxinidir}/requirements_dev.txt 19 | ; If you want to make tox run the tests with the same versions, create a 20 | ; requirements.txt with the pinned versions and uncomment the following line: 21 | ; -r{toxinidir}/requirements.txt 22 | commands = 23 | pip install -U pip 24 | py.test --basetemp={envtmpdir} 25 | 26 | 27 | --------------------------------------------------------------------------------