├── .flake8 ├── .github └── workflows │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── clearbox_wrapper ├── __init__.py ├── data_preparation │ ├── __init__.py │ └── data_preparation.py ├── exceptions │ ├── __init__.py │ └── exceptions.py ├── keras │ ├── __init__.py │ └── keras.py ├── model │ ├── __init__.py │ └── model.py ├── preprocessing │ ├── __init__.py │ └── preprocessing.py ├── pyfunc │ ├── __init__.py │ ├── model.py │ └── pyfunc.py ├── pytorch │ ├── __init__.py │ ├── pickle_module.py │ └── pytorch.py ├── schema │ ├── __init__.py │ ├── schema.py │ └── utils.py ├── signature │ ├── __init__.py │ └── signature.py ├── sklearn │ ├── __init__.py │ └── sklearn.py ├── utils │ ├── __init__.py │ ├── environment.py │ ├── file_utils.py │ └── model_utils.py ├── wrapper │ ├── __init__.py │ ├── model.py │ ├── utils.py │ └── wrapper.py └── xgboost │ ├── __init__.py │ └── xgboost.py ├── conftest.py ├── docs └── images │ ├── clearbox_ai_wrapper_no_preprocessing.png │ ├── clearbox_ai_wrapper_preprocessing.png │ └── clearbox_ai_wrapper_preprocessing_data_preparation.png ├── examples ├── 1_iris_sklearn │ ├── 1_Clearbox_Wrapper_Iris_Scikit.ipynb │ ├── iris_test_set.csv │ ├── iris_training_set.csv │ └── iris_wrapped_model_v0.3.10.zip ├── 2_loans_preprocessing_xgboost │ ├── 2_Clearbox_Wrapper_Loans_Xgboost.ipynb │ ├── loans_test_set.csv │ ├── loans_training_set.csv │ └── loans_xgboost_wrapped_model_v0.3.10.zip ├── 3_boston_preprocessing_pytorch │ ├── 3_Clearbox_Wrapper_Boston_Pytorch.ipynb │ ├── boston_test_set.csv │ ├── boston_training_set.csv │ └── boston_wrapped_model_v0.3.10.zip ├── 4_adult_data_cleaning_preprocessing_keras │ ├── 4_Clearbox_Wrapper_Adult_Keras.ipynb │ ├── adult_test.csv │ ├── adult_training.csv │ └── adult_wrapped_model_preparation_preprocessing_v0.3.10.zip └── 5_hospital_preprocessing_pytorch │ ├── 5_Clearbox_Wrapper_Hospital_Pytorch.ipynb │ ├── hospital_readmissions_test.csv │ ├── hospital_readmissions_training.csv │ └── hospital_wrapped_model_v0.3.10.zip ├── mypy.ini ├── noxfile.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── datasets ├── adult_test_50_rows.csv ├── adult_training_500_rows.csv ├── boston_housing.csv ├── iris_test_set_one_hot_y.csv └── iris_training_set_one_hot_y.csv ├── keras ├── test_keras_adult.py ├── test_keras_boston.py └── test_keras_iris.py ├── pytorch ├── test_pytorch_adult.py ├── test_pytorch_boston.py └── test_pytorch_iris.py ├── sklearn ├── test_sklearn_adult.py ├── test_sklearn_boston.py └── test_sklearn_iris.py └── xgboost ├── test_xgboost_adult.py ├── test_xgboost_boston.py └── test_xgboost_iris.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B, B9, BLK, C, E, F, I, W 3 | ignore = E203,E501,W503 4 | max-line-length = 88 5 | application-import-names = clearbox_wrapper, tests 6 | import-order-style = google 7 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | release: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: actions/setup-python@v1 11 | with: 12 | python-version: 3.8.8 13 | architecture: x64 14 | - run: pip install poetry==1.1.4 15 | - run: poetry build 16 | - run: poetry publish --username=__token__ --password=${{ secrets.PYPI_TOKEN }} 17 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Lint, Static Typing and Tests 2 | on: push 3 | jobs: 4 | tests: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - uses: actions/setup-python@v1 9 | with: 10 | python-version: 3.8.8 11 | architecture: x64 12 | - run: pip install nox==2020.8.22 13 | - run: pip install poetry==1.1.4 14 | - run: nox 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | **/__pycache__ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # VS Code 133 | .vscode 134 | 135 | # pyenv 136 | .python-version 137 | 138 | # mlflow 139 | mlruns 140 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - repo: local 8 | hooks: 9 | - id: black 10 | name: black 11 | entry: poetry run black 12 | language: system 13 | types: [python] 14 | - id: flake8 15 | name: flake8 16 | entry: poetry run flake8 17 | language: system 18 | types: [python] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Tests](https://github.com/Clearbox-AI/clearbox-wrapper/workflows/Tests/badge.svg)](https://github.com/Clearbox-AI/clearbox-wrapper/actions?workflow=Tests) 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/clearbox-wrapper.svg)](https://pypi.org/project/clearbox-wrapper/) 4 | 5 | # Clearbox AI Wrapper 6 | 7 | Clearbox AI Wrapper is a Python library to package and save a Machine Learning model built with common ML/DL frameworks. It is designed to wrap models trained on structured (tabular) data. It includes optional **preprocessing** and **data preparation** arguments which can be used to build ready-to-production pipelines. 8 | 9 | Passing the original training set on which the model is trained as `input_data` parameter, a **model signature** will be generated. A signature is a description of the final pipeline's inputs and outputs. The signature is stored in JSON format in the MLmodel file, together with other model metadata. 10 | 11 | ## Main Features 12 | 13 | The wrapper was born as a fork from [mlflow](https://github.com/mlflow/mlflow) and it's based on its [standard format](https://mlflow.org/docs/latest/models.html). It adds the possibility to package, together with the fitted model, preprocessing and data preparation functions in order to create a production-ready pipeline able to receive new data, preprocess them and makes predictions. The resulting wrapped model/pipeline is saved as a zipped folder. 14 | 15 | The library is designed to automatically detect the Python version, the model framework and its version adding this information to the requirements saved into the final folder. Additional dependencies (e.g. libraries used in preprocessing or data preparation) can also be added as a list parameter if necessary. 16 | 17 | The resulting wrapped folder can be loaded via the Wrapper and the model will be ready to take input through the `predict` or `predict_proba` (if present) method. 18 | 19 | **IMPORTANT**: Currently, it is necessary to load the wrapped model with the same Python version with which the model was saved. 20 | 21 | ## No preprocessing 22 | 23 | In the simplest case, the original dataset has already been preprocessed or it doesn't need any preprocessing. It contains only numerical values (ordinal features or one-hot encoded categorical features) and we can easily train a model on it. Then, we only need to save the model and it will be ready to receive new data to make predictions on it. 24 | 25 | ![](/docs/images/clearbox_ai_wrapper_no_preprocessing.png) 26 | 27 | The following lines show how to wrap and save a simple Scikit-Learn model without preprocessing or data preparation: 28 | 29 | ```python 30 | import clearbox_wrapper as cbw 31 | 32 | model = DecisionTreeClassifier(max_depth=4, random_state=42) 33 | model.fit(X_train, y_train) 34 | cbw.save_model('wrapped_model_path', model, input_data=X_train) 35 | ``` 36 | 37 | ## Preprocessing 38 | 39 | Typically, data are preprocessed before being fed into the model. It is almost always necessary to transform (e.g. scaling, binarizing,...) raw data values into a representation that is more suitable for the downstream model. Most kinds of ML models take only numeric data as input, so we must at least encode the non-numeric data, if any. 40 | 41 | Preprocessing is usually written and performed separately, before building and training the model. We fit some transformers, transform the whole dataset(s) and train the model on the processed data. If the model goes into production, we need to ship the preprocessing as well. New raw data must be processed on the same way the training dataset was. 42 | 43 | With Clearbox AI Wrapper it's possible to wrap and save the preprocessing along with the model so to have a pipeline Processing+Model ready to take raw data, pre-process them and make predictions. 44 | 45 | ![](/docs/images/clearbox_ai_wrapper_preprocessing.png) 46 | 47 | All the preprocessing code **must be wrapped in a single function** so it can be passed as the `preprocessing` parameter to the `save_model` method. You can use your own custom code for the preprocessing, just remember to wrap all of it in a single function, save it along with the model and add any extra dependencies. 48 | 49 | **IMPORTANT**: If the preprocessing includes any kind of fitting on the training dataset (e.g. Scikit Learn transformers), it must be performed outside the final preprocessing function to save. Fit the transformer(s) outside the function and put only the `transform` method inside it. Furthermore, if the entire preprocessing is performed with a single Scikit-Learn transformer, you can directly pass it (fitted) to the `save_model` method. 50 | 51 | ```python 52 | from sklearn.preprocessing import RobustScaler 53 | import xgboost as xgb 54 | 55 | import clearbox_wrapper as cbw 56 | 57 | 58 | x, y = dataset 59 | x_preprocessor = RobustScaler() 60 | x_preprocessed = x_preprocessor.fit_transform(x) 61 | 62 | model = xgb.XGBClassifier(use_label_encoder=False) 63 | fitted_model = model.fit(x_preprocessed, y) 64 | cbw.save_model('wrapped_model_path', 65 | fitted_model, 66 | preprocessing=x_preprocessor, 67 | input_data=X_train, 68 | additional_deps=["scikit-learn==0.23.2"]) 69 | ``` 70 | 71 | ## Data Preparation (advanced usage) 72 | 73 | For a complex task, a single-step preprocessing could be not enough. Raw data initially collected could be very noisy, contain useless columns or splitted into different dataframes/tables sources. A first data processing is usually performed even before considering any kind of model to feed the data in. The entire dataset is cleaned and the following additional processing and the model are built considering only the cleaned data. But this is not always the case. Sometimes, this situation still applies for data fed in real time to a model in production. 74 | 75 | We believe that a two-step data processing is required to deal with this situation. We refer to the first additional step by the term **Data Preparation**. With Clearbox AI Wrapper it's possible to wrap a data preparation step as well, in order to save a final Data Preparation + Preprocessing + Model pipeline ready to takes input. 76 | 77 | ![](/docs/images/clearbox_ai_wrapper_preprocessing_data_preparation.png) 78 | 79 | All the data preparation code **must be wrapped in a single function** so it can be passed as the `data_preparation` parameter to the `save_model` method. The same considerations wrote above for the preprocessing step still apply for data preparation. 80 | 81 | ```python 82 | import numpy as np 83 | from sklearn.preprocessing import MaxAbsScaler 84 | from tensorflow.keras.layers import Dense 85 | from tensorflow.keras.models import Sequential 86 | 87 | import clearbox_wrapper as cbw 88 | 89 | def preparation(x): 90 | data_prepared = np.delete(x, 0, axis=1) 91 | return data_prepared 92 | 93 | x_preprocessor = RobustScaler() 94 | 95 | x, y = dataset 96 | x_prepared = preparation(x) 97 | x_preprocessed = x_preprocessor.fit_transform(x_prepared) 98 | 99 | model = Sequential() 100 | model.add(Dense(8, input_dim=x_preprocessed.shape[1], activation="relu")) 101 | model.add(Dense(3, activation="softmax")) 102 | 103 | model.compile( 104 | optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] 105 | ) 106 | model.fit(x_preprocessed, y) 107 | 108 | cbw.save_model( 109 | 'wrapped_model_path', 110 | model, 111 | preprocessing=x_preprocessor, 112 | data_preparation=preparation, 113 | input_data=X_train, 114 | additional_deps=["scikit-learn==0.23.2", "numpy==1.18.0"] 115 | ) 116 | ``` 117 | 118 | ### Data Preparation vs. Preprocessing 119 | 120 | It is not always clear which are the differences between preprocessing and data preparation. It's not easy to understand where data preparation ends and preprocessing begins. There are no conditions that apply in any case, but in general you should build the data preparation step working only with the dataset, without considering the model your data will be fed into. Any kind of operation is allowed, but often preparing the raw data includes removing or normalizing some columns, replacing values, add a column based on other column values,... After this step, no matter what kind of transformation the data have been through, they should still be readable and understandable by a human user. 121 | 122 | The preprocessing step, on the contrary, should be considered closely tied with the downstream ML model and adapted to its particular "needs". Typically processed data by this second step are only numeric and non necessarily understandable by a human. 123 | 124 | ## Supported ML frameworks 125 | 126 | - Scikit-Learn 127 | - XGBoost 128 | - Keras 129 | - Pytorch 130 | 131 | ## Installation 132 | 133 | Install the latest relased version on the [Python Package Index (PyPI)](https://pypi.org/project/clearbox-wrapper/) with 134 | 135 | ```shell 136 | pip install clearbox-wrapper 137 | ``` 138 | 139 | ## Examples 140 | 141 | The following Jupyter notebooks provide examples of simle and complex cases: 142 | 143 | - [Scikit Learn Decision Tree on Iris Dataset](https://github.com/Clearbox-AI/clearbox-wrapper/blob/master/examples/1_iris_sklearn/1_Clearbox_Wrapper_Iris_Scikit.ipynb) (No preprocessing, No data preparation) 144 | - [XGBoost Model on Lending Club Loans Dataset](https://github.com/Clearbox-AI/clearbox-wrapper/blob/master/examples/2_loans_preprocessing_xgboost/2_Clearbox_Wrapper_Loans_Xgboost.ipynb) (Preprocessing, No data preparation) 145 | - [Pytorch Network on Boston Housing Dataset](https://github.com/Clearbox-AI/clearbox-wrapper/blob/master/examples/3_boston_preprocessing_pytorch/3_Clearbox_Wrapper_Boston_Pytorch.ipynb) (Preprocessing, No data preparation) 146 | - [Keras Network on UCI Adult Dataset](https://github.com/Clearbox-AI/clearbox-wrapper/blob/master/examples/4_adult_data_cleaning_preprocessing_keras/4_Clearbox_Wrapper_Adult_Keras.ipynb) (Preprocessing and data preparation) 147 | - [Pytorch Network on Diabetes Hospital Readmissions](https://github.com/Clearbox-AI/clearbox-wrapper/blob/master/examples/5_hospital_preprocessing_pytorch/5_Clearbox_Wrapper_Hospital_Pytorch.ipynb) (Preprocessing and data preparation) 148 | 149 | ## License 150 | 151 | [Apache License 2.0](https://github.com/Clearbox-AI/clearbox-wrapper/blob/master/LICENSE) 152 | -------------------------------------------------------------------------------- /clearbox_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.11" 2 | 3 | from .exceptions import ClearboxWrapperException 4 | from .model import Model 5 | from .wrapper import load_model, save_model 6 | 7 | 8 | __all__ = [ClearboxWrapperException, load_model, Model, save_model] 9 | -------------------------------------------------------------------------------- /clearbox_wrapper/data_preparation/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_preparation import ( 2 | create_and_save_data_preparation, 3 | DataPreparation, 4 | load_serialized_data_preparation, 5 | ) 6 | 7 | __all__ = [ 8 | create_and_save_data_preparation, 9 | load_serialized_data_preparation, 10 | DataPreparation, 11 | ] 12 | -------------------------------------------------------------------------------- /clearbox_wrapper/data_preparation/data_preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Union 3 | 4 | import cloudpickle 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from clearbox_wrapper.exceptions import ClearboxWrapperException 9 | 10 | DataPreparationInput = Union[pd.DataFrame, pd.Series, np.ndarray] 11 | DataPreparationOutput = Union[pd.DataFrame, pd.Series, np.ndarray] 12 | 13 | 14 | class DataPreparation(object): 15 | def __init__(self, data_preparation_function: Callable): 16 | """Create a DataPreparation instance. 17 | 18 | Parameters 19 | ---------- 20 | preprocessing_function : Callable 21 | A function to use as a preprocessor. You can use your own custom code for 22 | preprocessing, but it must be wrapped in a single function. 23 | 24 | NOTE: If the preprocessing includes any kind of fitting on the training dataset 25 | (e.g. Scikit Learn transformers), it must be performed outside the final 26 | preprocessing function to save. Fit the transformer(s) outside the function and 27 | put only the transform method inside it. Furthermore, if the entire preprocessing 28 | is performed with a single Scikit-Learn transformer, you can directly pass it 29 | (fitted) to this method. 30 | 31 | Raises 32 | ------ 33 | TypeError 34 | If preprocessing_function is not a function (Callable type) 35 | """ 36 | self.data_preparation = data_preparation_function 37 | 38 | def __repr__(self) -> str: 39 | return "Data Preparation: \n" " {}\n".format(repr(self.data_preparation)) 40 | 41 | @property 42 | def data_preparation(self) -> Callable: 43 | """Get the data preparation function. 44 | 45 | Returns 46 | ------- 47 | Callable 48 | The data preparation function. 49 | """ 50 | return self._data_preparation 51 | 52 | @data_preparation.setter 53 | def data_preparation(self, data_preparation_function: Callable) -> None: 54 | """Set the data preparation function. 55 | 56 | Parameters 57 | ---------- 58 | value : Callable 59 | The data preparation function. 60 | """ 61 | self._data_preparation = data_preparation_function 62 | 63 | def prepare_data(self, data: DataPreparationInput) -> DataPreparationOutput: 64 | """Prepare input data using the data preparation function. 65 | 66 | Parameters 67 | ---------- 68 | data : DataPreparationInput 69 | Input data to prepare. 70 | 71 | Returns 72 | ------- 73 | DataPreparationOutput 74 | Prepared data. 75 | """ 76 | prepared_data = ( 77 | self.data_preparation.transform(data) 78 | if hasattr(self.data_preparation, "transform") 79 | else self.data_preparation(data) 80 | ) 81 | return prepared_data 82 | 83 | def save(self, path: str) -> None: 84 | if os.path.exists(path): 85 | raise ClearboxWrapperException( 86 | "Data preparation path '{}' already exists".format(path) 87 | ) 88 | with open(path, "wb") as data_preparation_serialized_file: 89 | cloudpickle.dump(self, data_preparation_serialized_file) 90 | 91 | 92 | def create_and_save_data_preparation( 93 | data_preparation_function: Callable, path: str 94 | ) -> None: 95 | """Create, serialize and save a DataPreparation instance. 96 | 97 | Parameters 98 | ---------- 99 | data_preparation_function : Callable 100 | A function to use as data preparation. You can use your own custom code for 101 | data preparation, but it must be wrapped in a single function. 102 | 103 | NOTE: If the data preparation includes any kind of fitting on the training dataset 104 | (e.g. Scikit Learn transformers), it must be performed outside the final data 105 | preparation function to save. Fit the transformer(s) outside the function and put 106 | only the transform method inside it. Furthermore, if the entire data preparation 107 | is performed with a single Scikit-Learn transformer, you can directly pass it 108 | (fitted) to this method. 109 | path : str 110 | Local path to save the data preparation to. 111 | 112 | Raises 113 | ------ 114 | TypeError 115 | If data_preparation_function is not a function (Callable type) 116 | ClearboxWrapperException 117 | If data preparation path already exists. 118 | """ 119 | if not isinstance(data_preparation_function, Callable): 120 | raise TypeError( 121 | "data_preparation_function should be a Callable, got '{}'".format( 122 | type(data_preparation_function) 123 | ) 124 | ) 125 | if os.path.exists(path): 126 | raise ClearboxWrapperException( 127 | "Data preparation path '{}' already exists".format(path) 128 | ) 129 | 130 | data_preparation = DataPreparation(data_preparation_function) 131 | with open(path, "wb") as data_preparation_serialized_file: 132 | cloudpickle.dump(data_preparation, data_preparation_serialized_file) 133 | 134 | 135 | def load_serialized_data_preparation( 136 | serialized_data_preparation_path: str, 137 | ) -> DataPreparation: 138 | with open(serialized_data_preparation_path, "rb") as serialized_data_preparation: 139 | return cloudpickle.load(serialized_data_preparation) 140 | -------------------------------------------------------------------------------- /clearbox_wrapper/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | from .exceptions import ClearboxWrapperException 2 | 3 | __all__ = [ClearboxWrapperException] 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/exceptions/exceptions.py: -------------------------------------------------------------------------------- 1 | class ClearboxWrapperException(Exception): 2 | def __init(self, message): 3 | super().__init__(message) 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/keras/__init__.py: -------------------------------------------------------------------------------- 1 | from .keras import _load_clearbox, _load_pyfunc, save_keras_model 2 | 3 | __all__ = [_load_clearbox, _load_pyfunc, save_keras_model] 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MLMODEL_FILE_NAME, Model 2 | 3 | __all__ = [MLMODEL_FILE_NAME, Model] 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/model/model.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import json 3 | import os 4 | from typing import Dict, Optional 5 | 6 | import yaml 7 | 8 | from clearbox_wrapper.schema import Schema 9 | from clearbox_wrapper.signature import Signature 10 | 11 | 12 | MLMODEL_FILE_NAME = "MLmodel" 13 | 14 | 15 | class Model(object): 16 | """A ML Model that can support multiple model flavors.""" 17 | 18 | def __init__( 19 | self, 20 | timestamp: Optional[datetime] = None, 21 | flavors: Optional[Dict] = None, 22 | model_signature: Optional[Signature] = None, 23 | preprocessing_signature: Optional[Signature] = None, 24 | preparation_signature: Optional[Signature] = None, 25 | ) -> None: 26 | """Create a new Model object. 27 | 28 | Parameters 29 | ---------- 30 | timestamp : Optional[datetime], optional 31 | A timestamp of the model creation, by default None. 32 | If None, it will be use datetime.utcnow() 33 | flavors : Optional[Dict], optional 34 | Dictionary of flavors: a "flavor" is a convention 35 | representing the framework the model was created with 36 | and/or downstream tools that can "understand" the model, 37 | by default None. 38 | signature : Optional[Signature], optional 39 | Description of the model inputs as Signature oject, 40 | by default None 41 | """ 42 | self.timestamp = str(timestamp or datetime.utcnow()) 43 | self.flavors = flavors if flavors is not None else {} 44 | self.model_signature = model_signature 45 | self.preprocessing_signature = preprocessing_signature 46 | self.preparation_signature = preparation_signature 47 | 48 | def __eq__(self, other: "Model") -> bool: 49 | """Check if two models are equal. 50 | 51 | Parameters 52 | ---------- 53 | other : Model 54 | A Model object 55 | 56 | Returns 57 | ------- 58 | bool 59 | true, if the other Model is equal to self. 60 | """ 61 | if not isinstance(other, Model): 62 | return False 63 | return self.__dict__ == other.__dict__ 64 | 65 | def get_data_preparation_input_schema(self) -> Optional[Schema]: 66 | """Get the data preparation inputs schema." 67 | 68 | Returns 69 | ------- 70 | Schema 71 | Data preparation inputs schema: specification of types and column names. 72 | """ 73 | return ( 74 | self.preparation_signature.inputs 75 | if self.preparation_signature is not None 76 | else None 77 | ) 78 | 79 | def get_data_preparation_output_schema(self) -> Optional[Schema]: 80 | """Get the data preparation outputs schema." 81 | 82 | Returns 83 | ------- 84 | Schema 85 | Data preparation outputs schema: specification of types and column names. 86 | """ 87 | return ( 88 | self.preparation_signature.outputs 89 | if self.preparation_signature is not None 90 | else None 91 | ) 92 | 93 | def get_preprocessing_input_schema(self) -> Optional[Schema]: 94 | """Get the preprocessing inputs schema." 95 | 96 | Returns 97 | ------- 98 | Schema 99 | Preprocessing inputs schema: specification of types and column names. 100 | """ 101 | return ( 102 | self.preprocessing_signature.inputs 103 | if self.preprocessing_signature is not None 104 | else None 105 | ) 106 | 107 | def get_preprocessing_output_schema(self) -> Optional[Schema]: 108 | """Get the preprocessing outputs schema." 109 | 110 | Returns 111 | ------- 112 | Schema 113 | Preprocessing outputs schema: specification of types and column names. 114 | """ 115 | return ( 116 | self.preprocessing_signature.outputs 117 | if self.preprocessing_signature is not None 118 | else None 119 | ) 120 | 121 | def get_model_input_schema(self) -> Optional[Schema]: 122 | """Get the model inputs schema." 123 | 124 | Returns 125 | ------- 126 | Schema 127 | Model inputs schema: specification of types and column names. 128 | """ 129 | return self.model_signature.inputs if self.model_signature is not None else None 130 | 131 | def get_model_output_schema(self) -> Optional[Schema]: 132 | """Get the model outputs schema." 133 | 134 | Returns 135 | ------- 136 | Schema 137 | Model outputs schema: specification of types and column names. 138 | """ 139 | return ( 140 | self.model_signature.outputs if self.model_signature is not None else None 141 | ) 142 | 143 | def add_flavor(self, name: str, **params) -> "Model": 144 | """Add an entry for how to serve the model in a given format. 145 | 146 | Returns 147 | ------- 148 | Model 149 | self 150 | """ 151 | self.flavors[name] = params 152 | return self 153 | 154 | @property 155 | def preparation_signature(self) -> Optional[Signature]: 156 | """Get the data preparation signature associated with the model. 157 | 158 | Returns 159 | ------- 160 | Optional[Signature] 161 | Data preparation signature as a Signature instance. 162 | """ 163 | return self._preparation_signature 164 | 165 | @preparation_signature.setter 166 | def preparation_signature(self, value: Signature) -> None: 167 | """Set the data preparation signature of the model 168 | 169 | Parameters 170 | ---------- 171 | value : Signature 172 | Data preparation signature as a Signature instance. 173 | """ 174 | self._preparation_signature = value 175 | 176 | @property 177 | def preprocessing_signature(self) -> Optional[Signature]: 178 | """Get the preprocessing signature associated with the model. 179 | 180 | Returns 181 | ------- 182 | Optional[Signature] 183 | Preprocessing signature as a Signature instance. 184 | """ 185 | return self._preprocessing_signature 186 | 187 | @preprocessing_signature.setter 188 | def preprocessing_signature(self, value: Signature) -> None: 189 | """Set the preprocessing signature of the model 190 | 191 | Parameters 192 | ---------- 193 | value : Signature 194 | Preprocessing signature as a Signature instance. 195 | """ 196 | self._preprocessing_signature = value 197 | 198 | @property 199 | def model_signature(self) -> Optional[Signature]: 200 | """Get the model signature. 201 | 202 | Returns 203 | ------- 204 | Optional[Signature] 205 | The model signature: it defines the schema of a model inputs and outputs. 206 | Model inputs/outputs are described as a sequence of (optionally) named 207 | columns with type. 208 | """ 209 | return self._model_signature 210 | 211 | @model_signature.setter 212 | def model_signature(self, value: Signature) -> None: 213 | """Set the model signature 214 | 215 | Parameters 216 | ---------- 217 | value : Signature 218 | Model signature as a Signature object. 219 | """ 220 | self._model_signature = value 221 | 222 | def to_dict(self) -> Dict: 223 | """Get the model attributes as a dictionary. 224 | 225 | Returns 226 | ------- 227 | Dict 228 | Model attributes as dict 229 | """ 230 | model_dict = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} 231 | if self.model_signature is not None: 232 | model_dict["signature"] = self.model_signature.to_dict() 233 | if self.preprocessing_signature is not None: 234 | model_dict[ 235 | "preprocessing_signature" 236 | ] = self.preprocessing_signature.to_dict() 237 | if self.preparation_signature is not None: 238 | model_dict["preparation_signature"] = self.preparation_signature.to_dict() 239 | return model_dict 240 | 241 | def to_yaml(self, stream=None): 242 | """Serialize model object into a YAML stream. If stream is None, 243 | return the produced string instead. 244 | 245 | Parameters 246 | ---------- 247 | stream : f, optional 248 | YAML stream, by default None 249 | 250 | Returns 251 | ------- 252 | str 253 | YAML stream. If stream is None, return the produced string instead. 254 | """ 255 | return yaml.safe_dump(self.to_dict(), stream=stream, default_flow_style=False) 256 | 257 | def __str__(self) -> str: 258 | """Get model representation as string. 259 | 260 | Returns 261 | ------- 262 | str 263 | Model attributes as string. 264 | """ 265 | return self.to_yaml() 266 | 267 | def to_json(self) -> str: 268 | """Get model representation in JSON format. 269 | 270 | Returns 271 | ------- 272 | str 273 | Model attributes in JSON format. 274 | """ 275 | return json.dumps(self.to_dict()) 276 | 277 | def save(self, path: str) -> None: 278 | """Save model YAML representation to a file 279 | 280 | Parameters 281 | ---------- 282 | path : str 283 | Path of the file to save the YAML in. 284 | """ 285 | with open(path, "w") as out: 286 | self.to_yaml(out) 287 | 288 | @classmethod 289 | def load(cls, path: str) -> "Model": 290 | """Load a model from file. 291 | 292 | Parameters 293 | ---------- 294 | path : str 295 | Path of the file to load the Model from. 296 | 297 | Returns 298 | ------- 299 | Model 300 | Loaded Model. 301 | """ 302 | if os.path.isdir(path): 303 | path = os.path.join(path, MLMODEL_FILE_NAME) 304 | with open(path) as f: 305 | return cls.from_dict(yaml.safe_load(f.read())) 306 | 307 | @classmethod 308 | def from_dict(cls, model_dict: Dict) -> "Model": 309 | """Load a model from its YAML representation. 310 | 311 | Parameters 312 | ---------- 313 | model_dict : Dict 314 | Model dictionary representation. 315 | 316 | Returns 317 | ------- 318 | Model 319 | Loaded Model. 320 | """ 321 | 322 | if "signature" in model_dict and isinstance(model_dict["signature"], dict): 323 | model_dict = model_dict.copy() 324 | model_dict["model_signature"] = Signature.from_dict(model_dict["signature"]) 325 | del model_dict["signature"] 326 | 327 | if "preprocessing_signature" in model_dict and isinstance( 328 | model_dict["preprocessing_signature"], dict 329 | ): 330 | model_dict = model_dict.copy() 331 | model_dict["preprocessing_signature"] = Signature.from_dict( 332 | model_dict["preprocessing_signature"] 333 | ) 334 | 335 | if "preparation_signature" in model_dict and isinstance( 336 | model_dict["preparation_signature"], dict 337 | ): 338 | model_dict = model_dict.copy() 339 | model_dict["preparation_signature"] = Signature.from_dict( 340 | model_dict["preparation_signature"] 341 | ) 342 | 343 | return cls(**model_dict) 344 | -------------------------------------------------------------------------------- /clearbox_wrapper/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessing import ( 2 | create_and_save_preprocessing, 3 | load_serialized_preprocessing, 4 | Preprocessing, 5 | ) 6 | 7 | __all__ = [ 8 | create_and_save_preprocessing, 9 | load_serialized_preprocessing, 10 | Preprocessing, 11 | ] 12 | -------------------------------------------------------------------------------- /clearbox_wrapper/preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Union 3 | 4 | import cloudpickle 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from clearbox_wrapper.exceptions import ClearboxWrapperException 9 | 10 | 11 | PreprocessingInput = Union[pd.DataFrame, pd.Series, np.ndarray] 12 | PreprocessingOutput = Union[pd.DataFrame, pd.Series, np.ndarray] 13 | 14 | 15 | class Preprocessing(object): 16 | def __init__(self, preprocessing_function: Callable): 17 | """Create a Preprocessing instance. 18 | 19 | Parameters 20 | ---------- 21 | preprocessing_function : Callable 22 | A function to use as a preprocessor. You can use your own custom code for 23 | preprocessing, but it must be wrapped in a single function. 24 | 25 | NOTE: If the preprocessing includes any kind of fitting on the training dataset 26 | (e.g. Scikit Learn transformers), it must be performed outside the final 27 | preprocessing function to save. Fit the transformer(s) outside the function and 28 | put only the transform method inside it. Furthermore, if the entire preprocessing 29 | is performed with a single Scikit-Learn transformer, you can directly pass it 30 | (fitted) to this method. 31 | 32 | Raises 33 | ------ 34 | TypeError 35 | If preprocessing_function is not a function (Callable type) 36 | """ 37 | self.preprocessing = preprocessing_function 38 | 39 | def __repr__(self) -> str: 40 | return "Preprocesing: \n" " {}\n".format(repr(self.preprocessing)) 41 | 42 | @property 43 | def preprocessing_function(self) -> Callable: 44 | """Get the preprocessing function. 45 | 46 | Returns 47 | ------- 48 | Callable 49 | The preprocessing function. 50 | """ 51 | return self._preprocessing 52 | 53 | @preprocessing_function.setter 54 | def preprocessing_function(self, preprocessing_function: Callable) -> None: 55 | """Set the preprocessing function. 56 | 57 | Parameters 58 | ---------- 59 | value : Callable 60 | The preprocessing function. 61 | """ 62 | self._preprocessing = preprocessing_function 63 | 64 | def preprocess(self, data: PreprocessingInput) -> PreprocessingOutput: 65 | """Preprocess input data using the preprocessing function. 66 | 67 | Parameters 68 | ---------- 69 | data : PreprocessingInput 70 | Input data to preprocess. 71 | 72 | Returns 73 | ------- 74 | PreprocessingOutput 75 | Preprocessed data. 76 | """ 77 | preprocessed_data = ( 78 | self.preprocessing.transform(data) 79 | if hasattr(self.preprocessing, "transform") 80 | else self.preprocessing(data) 81 | ) 82 | return preprocessed_data 83 | 84 | def save(self, path: str) -> None: 85 | if os.path.exists(path): 86 | raise ClearboxWrapperException( 87 | "Preprocessing path '{}' already exists".format(path) 88 | ) 89 | with open(path, "wb") as preprocessing_serialized_file: 90 | cloudpickle.dump(self, preprocessing_serialized_file) 91 | 92 | 93 | def create_and_save_preprocessing(preprocessing_function: Callable, path: str) -> None: 94 | """Create, serialize and save a Preprocessing instance. 95 | 96 | Parameters 97 | ---------- 98 | preprocessing_function : Callable 99 | A function to use as a preprocessor. You can use your own custom code for 100 | preprocessing, but it must be wrapped in a single function. 101 | 102 | NOTE: If the preprocessing includes any kind of fitting on the training dataset 103 | (e.g. Scikit Learn transformers), it must be performed outside the final preprocessing 104 | function to save. Fit the transformer(s) outside the function and put only the transform 105 | method inside it. Furthermore, if the entire preprocessing is performed with a single 106 | Scikit-Learn transformer, you can directly pass it (fitted) to this method. 107 | path : str 108 | Local path to save the preprocessing to. 109 | 110 | Raises 111 | ------ 112 | TypeError 113 | If preprocessing_function is not a function (Callable type) 114 | ClearboxWrapperException 115 | If preprocessing path already exists. 116 | """ 117 | if not isinstance(preprocessing_function, Callable): 118 | raise TypeError( 119 | "preprocessing_function should be a Callable, got '{}'".format( 120 | type(preprocessing_function) 121 | ) 122 | ) 123 | if os.path.exists(path): 124 | raise ClearboxWrapperException( 125 | "Preprocessing path '{}' already exists".format(path) 126 | ) 127 | 128 | preprocessing = Preprocessing(preprocessing_function) 129 | with open(path, "wb") as preprocessing_serialized_file: 130 | cloudpickle.dump(preprocessing, preprocessing_serialized_file) 131 | 132 | 133 | def load_serialized_preprocessing(serialized_preprocessing_path: str) -> Preprocessing: 134 | with open(serialized_preprocessing_path, "rb") as serialized_preprocessing: 135 | return cloudpickle.load(serialized_preprocessing) 136 | -------------------------------------------------------------------------------- /clearbox_wrapper/pyfunc/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from clearbox_wrapper.model import Model 4 | from clearbox_wrapper.utils import PYTHON_VERSION 5 | 6 | FLAVOR_NAME = "python_function" 7 | MAIN = "loader_module" 8 | CODE = "code" 9 | DATA = "data" 10 | ENV = "env" 11 | PY_VERSION = "python_version" 12 | 13 | 14 | def add_pyfunc_flavor_to_model( 15 | model: Model, 16 | loader_module: str, 17 | data: str = None, 18 | code=None, 19 | env: str = None, 20 | **kwargs 21 | ) -> Model: 22 | """Add Pyfunc flavor to a model configuration. Caller can use this to create a valid 23 | Pyfunc model flavor out of an existing directory structure. A Pyfunc flavor will be 24 | added to the flavors list into the MLModel file: 25 | flavors: 26 | python_function: 27 | env: ... 28 | loader_module: ... 29 | model_path: ... 30 | python_version: ... 31 | 32 | Parameters 33 | ---------- 34 | model : Model 35 | Existing model. 36 | loader_module : str 37 | The module to be used to load the model (e.g. clearbox_wrapper.sklearn) 38 | data : str, optional 39 | Path to the model data, by default None. 40 | code : str, optional 41 | Path to the code dependencies, by default None. 42 | env : str, optional 43 | Path to the Conda environment, by default None. 44 | 45 | Returns 46 | ------- 47 | Model 48 | The Model with the new flavor added. 49 | """ 50 | parms = deepcopy(kwargs) 51 | parms[MAIN] = loader_module 52 | parms[PY_VERSION] = PYTHON_VERSION 53 | if code: 54 | parms[CODE] = code 55 | if data: 56 | parms[DATA] = data 57 | if env: 58 | parms[ENV] = env 59 | return model.add_flavor(FLAVOR_NAME, **parms) 60 | 61 | 62 | __all__ = [ 63 | add_pyfunc_flavor_to_model, 64 | FLAVOR_NAME, 65 | MAIN, 66 | PY_VERSION, 67 | CODE, 68 | DATA, 69 | ENV, 70 | ] 71 | -------------------------------------------------------------------------------- /clearbox_wrapper/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch import _load_clearbox, _load_pyfunc, save_pytorch_model 2 | 3 | __all__ = [_load_clearbox, _load_pyfunc, save_pytorch_model] 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/pytorch/pickle_module.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module imports contents from CloudPickle in a way that is compatible with the 3 | ``pickle_module`` parameter of PyTorch's model persistence function: ``torch.save`` 4 | (see https://github.com/pytorch/pytorch/blob/692898fe379c9092f5e380797c32305145cd06e1/torch/ 5 | serialization.py#L192). It is included as a distinct module from :mod:`mlflow.pytorch` to avoid 6 | polluting the namespace with wildcard imports. 7 | 8 | Calling ``torch.save(..., pickle_module=mlflow.pytorch.pickle_module)`` will persist PyTorch 9 | definitions using CloudPickle, leveraging improved pickling functionality such as the ability 10 | to capture class definitions in the "__main__" scope. 11 | 12 | TODO: Remove this module or make it an alias of CloudPickle when CloudPickle and PyTorch have 13 | compatible pickling APIs. 14 | """ 15 | 16 | from cloudpickle import * # noqa 17 | 18 | 19 | from cloudpickle import CloudPickler as Pickler # noqa 20 | 21 | 22 | from pickle import Unpickler # noqa 23 | -------------------------------------------------------------------------------- /clearbox_wrapper/pytorch/pytorch.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | import importlib 3 | import os 4 | import posixpath 5 | import shutil 6 | from typing import Any, Dict, Optional, Union 7 | 8 | import cloudpickle 9 | from loguru import logger 10 | import numpy as np 11 | import pandas as pd 12 | import yaml 13 | 14 | from clearbox_wrapper.exceptions import ClearboxWrapperException 15 | from clearbox_wrapper.model import MLMODEL_FILE_NAME, Model 16 | import clearbox_wrapper.pyfunc as pyfunc 17 | from clearbox_wrapper.pytorch import pickle_module as clearbox_pytorch_pickle_module 18 | from clearbox_wrapper.signature import Signature 19 | from clearbox_wrapper.utils import ( 20 | _copy_file_or_tree, 21 | _get_default_conda_env, 22 | _get_flavor_configuration, 23 | TempDir, 24 | ) 25 | from clearbox_wrapper.wrapper import add_clearbox_flavor_to_model 26 | 27 | 28 | FLAVOR_NAME = "pytorch" 29 | 30 | _SERIALIZED_TORCH_MODEL_FILE_NAME = "model.pth" 31 | _TORCH_STATE_DICT_FILE_NAME = "state_dict.pth" 32 | _PICKLE_MODULE_INFO_FILE_NAME = "pickle_module_info.txt" 33 | _EXTRA_FILES_KEY = "extra_files" 34 | _REQUIREMENTS_FILE_KEY = "requirements_file" 35 | 36 | 37 | def get_default_pytorch_conda_env() -> Dict: 38 | import torch 39 | import torchvision 40 | 41 | pip_deps = [ 42 | "cloudpickle=={}".format(cloudpickle.__version__), 43 | "pytorch=={}".format(torch.__version__), 44 | "torchvision=={}".format(torchvision.__version__), 45 | ] 46 | 47 | return _get_default_conda_env(additional_pip_deps=pip_deps) 48 | 49 | 50 | def save_pytorch_model( 51 | pytorch_model: Any, 52 | path: str, 53 | conda_env: Optional[Union[str, Dict]] = None, 54 | mlmodel: Optional[Model] = None, 55 | signature: Optional[Signature] = None, 56 | add_clearbox_flavor: bool = False, 57 | preprocessing_subpath: str = None, 58 | data_preparation_subpath: str = None, 59 | code_paths=None, 60 | pickle_module=None, 61 | requirements_file=None, 62 | extra_files=None, 63 | **kwargs 64 | ): 65 | import torch 66 | 67 | pickle_module = pickle_module or clearbox_pytorch_pickle_module 68 | if not isinstance(pytorch_model, torch.nn.Module): 69 | raise TypeError("Argument 'pytorch_model' should be a torch.nn.Module") 70 | if code_paths is not None: 71 | if not isinstance(code_paths, list): 72 | raise TypeError( 73 | "Argument code_paths should be a list, not {}".format(type(code_paths)) 74 | ) 75 | 76 | if os.path.exists(path): 77 | raise ClearboxWrapperException("Model path '{}' already exists".format(path)) 78 | os.makedirs(path) 79 | 80 | if mlmodel is None: 81 | mlmodel = Model() 82 | if signature is not None: 83 | mlmodel.signature = signature 84 | 85 | model_data_subpath = "data" 86 | model_data_path = os.path.join(path, model_data_subpath) 87 | os.makedirs(model_data_path) 88 | 89 | # Persist the pickle module name as a file in the model's `data` directory. This is 90 | # necessary because the `data` directory is the only available parameter to 91 | # `_load_pyfunc`, and it does not contain the MLmodel configuration; therefore, 92 | # it is not sufficient to place the module name in the MLmodel 93 | # 94 | # TODO: Stop persisting this information to the filesystem once we have a mechanism for 95 | # supplying the MLmodel configuration to `mlflow.pytorch._load_pyfunc` 96 | pickle_module_path = os.path.join(model_data_path, _PICKLE_MODULE_INFO_FILE_NAME) 97 | with open(pickle_module_path, "w") as f: 98 | f.write(pickle_module.__name__) 99 | 100 | # Save pytorch model 101 | model_path = os.path.join(model_data_path, _SERIALIZED_TORCH_MODEL_FILE_NAME) 102 | if isinstance(pytorch_model, torch.jit.ScriptModule): 103 | torch.jit.ScriptModule.save(pytorch_model, model_path) 104 | else: 105 | torch.save(pytorch_model, model_path, pickle_module=pickle_module, **kwargs) 106 | 107 | torchserve_artifacts_config = {} 108 | 109 | if requirements_file: 110 | if not isinstance(requirements_file, str): 111 | raise TypeError("Path to requirements file should be a string") 112 | 113 | with TempDir() as tmp_requirements_dir: 114 | 115 | rel_path = os.path.basename(requirements_file) 116 | torchserve_artifacts_config[_REQUIREMENTS_FILE_KEY] = {"path": rel_path} 117 | shutil.move(tmp_requirements_dir.path(rel_path), path) 118 | 119 | if extra_files: 120 | torchserve_artifacts_config[_EXTRA_FILES_KEY] = [] 121 | if not isinstance(extra_files, list): 122 | raise TypeError("Extra files argument should be a list") 123 | 124 | with TempDir() as tmp_extra_files_dir: 125 | for extra_file in extra_files: 126 | rel_path = posixpath.join( 127 | _EXTRA_FILES_KEY, 128 | os.path.basename(extra_file), 129 | ) 130 | torchserve_artifacts_config[_EXTRA_FILES_KEY].append({"path": rel_path}) 131 | shutil.move( 132 | tmp_extra_files_dir.path(), 133 | posixpath.join(path, _EXTRA_FILES_KEY), 134 | ) 135 | 136 | conda_env_subpath = "conda.yaml" 137 | if conda_env is None: 138 | conda_env = get_default_pytorch_conda_env() 139 | elif not isinstance(conda_env, dict): 140 | with open(conda_env, "r") as f: 141 | conda_env = yaml.safe_load(f) 142 | with open(os.path.join(path, conda_env_subpath), "w") as f: 143 | yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 144 | 145 | if code_paths is not None: 146 | code_dir_subpath = "code" 147 | for code_path in code_paths: 148 | _copy_file_or_tree(src=code_path, dst=path, dst_dir=code_dir_subpath) 149 | else: 150 | code_dir_subpath = None 151 | 152 | mlmodel.add_flavor( 153 | FLAVOR_NAME, 154 | model_data=model_data_subpath, 155 | pytorch_version=torch.__version__, 156 | **torchserve_artifacts_config, 157 | ) 158 | 159 | pyfunc.add_pyfunc_flavor_to_model( 160 | mlmodel, 161 | loader_module="clearbox_wrapper.pytorch", 162 | data=model_data_subpath, 163 | pickle_module_name=pickle_module.__name__, 164 | code=code_dir_subpath, 165 | env=conda_env_subpath, 166 | ) 167 | 168 | if add_clearbox_flavor: 169 | add_clearbox_flavor_to_model( 170 | mlmodel, 171 | loader_module="clearbox_wrapper.pytorch", 172 | data=model_data_subpath, 173 | pickle_module_name=pickle_module.__name__, 174 | code=code_dir_subpath, 175 | env=conda_env_subpath, 176 | preprocessing=preprocessing_subpath, 177 | data_preparation=data_preparation_subpath, 178 | ) 179 | 180 | mlmodel.save(os.path.join(path, MLMODEL_FILE_NAME)) 181 | 182 | 183 | def _load_model(path, **kwargs): 184 | """ 185 | :param path: The path to a serialized PyTorch model. 186 | :param kwargs: Additional kwargs to pass to the PyTorch ``torch.load`` function. 187 | """ 188 | import torch 189 | 190 | if os.path.isdir(path): 191 | # `path` is a directory containing a serialized PyTorch model and a text file containing 192 | # information about the pickle module that should be used by PyTorch to load it 193 | model_path = os.path.join(path, "model.pth") 194 | pickle_module_path = os.path.join(path, _PICKLE_MODULE_INFO_FILE_NAME) 195 | with open(pickle_module_path, "r") as f: 196 | pickle_module_name = f.read() 197 | if ( 198 | "pickle_module" in kwargs 199 | and kwargs["pickle_module"].__name__ != pickle_module_name 200 | ): 201 | logger.warning( 202 | "Attempting to load the PyTorch model with a pickle module, '%s', that does not" 203 | " match the pickle module that was used to save the model: '%s'.", 204 | kwargs["pickle_module"].__name__, 205 | pickle_module_name, 206 | ) 207 | else: 208 | try: 209 | kwargs["pickle_module"] = importlib.import_module(pickle_module_name) 210 | except ImportError as exc: 211 | raise ClearboxWrapperException( 212 | message=( 213 | "Failed to import the pickle module that was used to save the PyTorch" 214 | " model. Pickle module name: `{pickle_module_name}`".format( 215 | pickle_module_name=pickle_module_name 216 | ) 217 | ) 218 | ) from exc 219 | 220 | else: 221 | model_path = path 222 | 223 | if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"): 224 | return torch.load(model_path, **kwargs) 225 | else: 226 | try: 227 | # load the model as an eager model. 228 | return torch.load(model_path, **kwargs) 229 | except Exception: 230 | # If fails, assume the model as a scripted model 231 | return torch.jit.load(model_path) 232 | 233 | 234 | def load_model(model_path, **kwargs): 235 | """ 236 | Load a PyTorch model from a local file or a run. 237 | 238 | :param model_uri: The location, in URI format, of the MLflow model, for example: 239 | 240 | - ``/Users/me/path/to/local/model`` 241 | - ``relative/path/to/local/model`` 242 | - ``s3://my_bucket/path/to/model`` 243 | - ``runs://run-relative/path/to/model`` 244 | - ``models://`` 245 | - ``models://`` 246 | 247 | For more information about supported URI schemes, see 248 | `Referencing Artifacts `_. 250 | 251 | :param kwargs: kwargs to pass to ``torch.load`` method. 252 | :return: A PyTorch model. 253 | 254 | .. code-block:: python 255 | :caption: Example 256 | 257 | import torch 258 | import mlflow.pytorch 259 | 260 | # Class defined here 261 | class LinearNNModel(torch.nn.Module): 262 | ... 263 | 264 | # Initialize our model, criterion and optimizer 265 | ... 266 | 267 | # Training loop 268 | ... 269 | # Log the model 270 | with mlflow.start_run() as run: 271 | mlflow.pytorch.log_model(model, "model") 272 | 273 | # Inference after loading the logged model 274 | model_uri = "runs:/{}/model".format(run.info.run_id) 275 | loaded_model = mlflow.pytorch.load_model(model_uri) 276 | for x in [4.0, 6.0, 30.0]: 277 | X = torch.Tensor([[x]]) 278 | y_pred = loaded_model(X) 279 | print("predict X: {}, y_pred: {:.2f}".format(x, y_pred.data.item())) 280 | 281 | .. code-block:: text 282 | :caption: Output 283 | 284 | predict X: 4.0, y_pred: 7.57 285 | predict X: 6.0, y_pred: 11.64 286 | predict X: 30.0, y_pred: 60.48 287 | """ 288 | import torch 289 | 290 | try: 291 | pyfunc_conf = _get_flavor_configuration( 292 | model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME 293 | ) 294 | except ClearboxWrapperException: 295 | pyfunc_conf = {} 296 | code_subpath = pyfunc_conf.get(pyfunc.CODE) 297 | if code_subpath is not None: 298 | pyfunc._add_code_to_system_path( 299 | code_path=os.path.join(model_path, code_subpath) 300 | ) 301 | 302 | pytorch_conf = _get_flavor_configuration( 303 | model_path=model_path, flavor_name=FLAVOR_NAME 304 | ) 305 | if torch.__version__ != pytorch_conf["pytorch_version"]: 306 | logger.warning( 307 | "Stored model version '%s' does not match installed PyTorch version '%s'", 308 | pytorch_conf["pytorch_version"], 309 | torch.__version__, 310 | ) 311 | torch_model_artifacts_path = os.path.join(model_path, pytorch_conf["model_data"]) 312 | return _load_model(path=torch_model_artifacts_path, **kwargs) 313 | 314 | 315 | def _load_pyfunc(path, **kwargs): 316 | """ 317 | Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``. 318 | 319 | :param path: Local filesystem path to the MLflow Model with the ``pytorch`` flavor. 320 | """ 321 | return _PyTorchWrapper(_load_model(path, **kwargs)) 322 | 323 | 324 | def _load_clearbox(path, **kwargs): 325 | """ 326 | Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``. 327 | 328 | :param path: Local filesystem path to the MLflow Model with the ``pytorch`` flavor. 329 | """ 330 | return _PyTorchWrapper(_load_model(path, **kwargs)) 331 | 332 | 333 | class _PyTorchWrapper(object): 334 | """ 335 | Wrapper class that creates a predict function such that 336 | predict(data: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame) 337 | """ 338 | 339 | def __init__(self, pytorch_model): 340 | self.pytorch_model = pytorch_model 341 | 342 | def predict(self, data, device="cpu"): 343 | import torch 344 | 345 | if isinstance(data, pd.DataFrame): 346 | inp_data = data.values.astype(np.float32) 347 | elif isinstance(data, np.ndarray): 348 | inp_data = data 349 | elif isinstance(data, torch.Tensor): 350 | inp_data = data.detach().numpy() 351 | elif isinstance(data, (list, dict)): 352 | raise TypeError( 353 | "The PyTorch flavor does not support List or Dict input types. " 354 | "Please use a pandas.DataFrame or a numpy.ndarray" 355 | ) 356 | else: 357 | raise TypeError("Input data should be pandas.DataFrame or numpy.ndarray") 358 | 359 | self.pytorch_model.to(device) 360 | self.pytorch_model.eval() 361 | with torch.no_grad(): 362 | input_tensor = torch.from_numpy(inp_data).to(device) 363 | preds = self.pytorch_model(input_tensor.float()) 364 | if not isinstance(preds, torch.Tensor): 365 | raise TypeError( 366 | "Expected PyTorch model to output a single output tensor, " 367 | "but got output of type '{}'".format(type(preds)) 368 | ) 369 | """ if isinstance(data, pd.DataFrame): 370 | predicted = pd.DataFrame(preds.numpy()) 371 | predicted.index = data.index 372 | else: """ 373 | predicted = preds.numpy() 374 | return predicted 375 | -------------------------------------------------------------------------------- /clearbox_wrapper/schema/__init__.py: -------------------------------------------------------------------------------- 1 | from .schema import DataType, Schema 2 | from .utils import _infer_schema 3 | 4 | __all__ = [_infer_schema, DataType, Schema] 5 | -------------------------------------------------------------------------------- /clearbox_wrapper/schema/schema.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import json 3 | from typing import Any, Dict, List, Optional, Union 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from clearbox_wrapper.exceptions import ClearboxWrapperException 9 | 10 | 11 | def _pandas_string_type(): 12 | try: 13 | return pd.StringDtype() 14 | except AttributeError: 15 | return np.object 16 | 17 | 18 | class DataType(Enum): 19 | 20 | boolean = (1, np.dtype("bool"), "BooleanType") # Logical data (True, False) 21 | integer = (2, np.dtype("int32"), "IntegerType") # 32b signed integer numbers 22 | long = (3, np.dtype("int64"), "LongType") # 64b signed integer numbers 23 | float = (4, np.dtype("float32"), "FloatType") # 32b floating point numbers 24 | double = (5, np.dtype("float64"), "DoubleType") # 64b floating point numbers 25 | string = (6, np.dtype("str"), "StringType", _pandas_string_type()) # Text data 26 | binary = (7, np.dtype("bytes"), "BinaryType", np.object) # Sequence of raw bytes 27 | 28 | 29 | class ColumnSpec(object): 30 | def __init__( 31 | self, type: DataType, name: Optional[str] = None, has_nans: bool = False 32 | ): 33 | self._name = name 34 | self._has_nans = has_nans 35 | try: 36 | self._type = DataType[type] if isinstance(type, str) else type 37 | except KeyError: 38 | raise ClearboxWrapperException( 39 | "Unsupported type '{0}', expected instance of DataType or " 40 | "one of {1}".format(type, [t.name for t in DataType]) 41 | ) 42 | if not isinstance(self.type, DataType): 43 | raise TypeError( 44 | "Expected Datatype or str for the 'type' " 45 | "argument, but got {}".format(self.type.__class__) 46 | ) 47 | 48 | @property 49 | def type(self) -> DataType: 50 | """The column data type.""" 51 | return self._type 52 | 53 | @property 54 | def has_nans(self) -> bool: 55 | """Wheter the columns contains Na values.""" 56 | return self._has_nans 57 | 58 | @property 59 | def name(self) -> Optional[str]: 60 | """The column name or None if the columns is unnamed.""" 61 | return self._name 62 | 63 | def to_dict(self) -> Dict[str, Any]: 64 | if self.name is None: 65 | return {"type": self.type.name, "has_nans": self.has_nans} 66 | else: 67 | return { 68 | "name": self.name, 69 | "type": self.type.name, 70 | "has_nans": str(self.has_nans), 71 | } 72 | 73 | def __eq__(self, other) -> bool: 74 | names_eq = (self.name is None and other.name is None) or self.name == other.name 75 | return names_eq and self.type == other.type and self.has_nans == other.has_nans 76 | 77 | def __repr__(self) -> str: 78 | if self.name is None: 79 | return repr(self.type) 80 | else: 81 | return "{name}: {type}, Na values: {has_nans}".format( 82 | name=repr(self.name), type=repr(self.type), has_nans=repr(self.has_nans) 83 | ) 84 | 85 | 86 | class Schema(object): 87 | """ 88 | Specification of types and column names in a dataset. 89 | Schema is represented as a list of :py:class:`ColumnSpec`. 90 | The columns in a schema can be named, with unique non empty name for every column, 91 | or unnamed with implicit integer index defined by their list indices. 92 | Combination of named and unnamed columns is not allowed. 93 | """ 94 | 95 | def __init__(self, cols: List[ColumnSpec]): 96 | if not ( 97 | all(map(lambda x: x.name is None, cols)) 98 | or all(map(lambda x: x.name is not None, cols)) 99 | ): 100 | raise ClearboxWrapperException( 101 | "Creating Schema with a combination of named and unnamed columns " 102 | "is not allowed. Got column names {}".format([x.name for x in cols]) 103 | ) 104 | self._cols = cols 105 | 106 | @property 107 | def columns(self) -> List[ColumnSpec]: 108 | """The list of columns that defines this schema.""" 109 | return self._cols 110 | 111 | def column_names(self) -> List[Union[str, int]]: 112 | """Get list of column names or range of indices if the schema has no column names.""" 113 | return [x.name or i for i, x in enumerate(self.columns)] 114 | 115 | def has_column_names(self) -> bool: 116 | """ Return true iff this schema declares column names, false otherwise. """ 117 | return self.columns and self.columns[0].name is not None 118 | 119 | def column_types(self) -> List[DataType]: 120 | """ Get column types of the columns in the dataset.""" 121 | return [x.type for x in self._cols] 122 | 123 | def numpy_types(self) -> List[np.dtype]: 124 | """ Convenience shortcut to get the datatypes as numpy types.""" 125 | return [x.type.to_numpy() for x in self.columns] 126 | 127 | def pandas_types(self) -> List[np.dtype]: 128 | """ Convenience shortcut to get the datatypes as pandas types.""" 129 | return [x.type.to_pandas() for x in self.columns] 130 | 131 | def to_json(self) -> str: 132 | """Serialize into json string.""" 133 | return json.dumps([x.to_dict() for x in self.columns]) 134 | 135 | def to_dict(self) -> List[Dict[str, Any]]: 136 | """Serialize into a jsonable dictionary.""" 137 | return [x.to_dict() for x in self.columns] 138 | 139 | @classmethod 140 | def from_json(cls, json_str: str): 141 | """ Deserialize from a json string.""" 142 | return cls([ColumnSpec(**x) for x in json.loads(json_str)]) 143 | 144 | def __eq__(self, other) -> bool: 145 | if isinstance(other, Schema): 146 | return self.columns == other.columns 147 | else: 148 | return False 149 | 150 | def __repr__(self) -> str: 151 | return repr(self.columns) 152 | -------------------------------------------------------------------------------- /clearbox_wrapper/schema/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from clearbox_wrapper.exceptions import ClearboxWrapperException 7 | from .schema import ColumnSpec, DataType, Schema 8 | 9 | 10 | class TensorsNotSupportedException(ClearboxWrapperException): 11 | def __init__(self, msg): 12 | super().__init__( 13 | "Multidimensional arrays (aka tensors) are not supported. " "{}".format(msg) 14 | ) 15 | 16 | 17 | def _infer_schema(data: Any) -> Schema: 18 | """Infer a schema from data. 19 | 20 | The schema represents data as a sequence of (optionally) named columns with types. 21 | 22 | Parameters 23 | ---------- 24 | data : Any 25 | Valid data. It should be one of the following types: 26 | - pandas.DataFrame or pandas.Series 27 | - dictionary of { name -> numpy.ndarray} 28 | - numpy.ndarray 29 | The data types should be mappable to one of `clearbox.schema.DataType`. 30 | 31 | Returns 32 | ------- 33 | Schema 34 | Schema instance inferred from data. 35 | 36 | Raises 37 | ------ 38 | TypeError 39 | If type of data is not valid (pandas.DataFrame or pandas.Series, dictionary of 40 | { name -> numpy.ndarray}, numpy.ndarray) 41 | TensorsNotSupportedException 42 | If data are multidimensional (>2d) arrays (tensors). 43 | """ 44 | if hasattr(data, "toarray"): 45 | data = data.toarray() 46 | elif hasattr(data, "detach"): 47 | data = data.detach().numpy() 48 | if isinstance(data, dict): 49 | res = [] 50 | for col in data.keys(): 51 | ary = data[col] 52 | if not isinstance(ary, np.ndarray): 53 | raise TypeError("Data in the dictionary must be of type numpy.ndarray") 54 | dims = len(ary.shape) 55 | if dims == 1: 56 | res.append(ColumnSpec(type=_infer_numpy_array(ary), name=col)) 57 | else: 58 | raise TensorsNotSupportedException( 59 | "Data in the dictionary must be 1-dimensional, " 60 | "got shape {}".format(ary.shape) 61 | ) 62 | schema = Schema(res) 63 | elif isinstance(data, pd.Series): 64 | has_nans = data.isna().any() 65 | series_converted_to_numpy = data.dropna().values if has_nans else data.values 66 | schema = Schema( 67 | [ 68 | ColumnSpec( 69 | type=_infer_numpy_array(series_converted_to_numpy), 70 | has_nans=has_nans, 71 | ) 72 | ] 73 | ) 74 | elif isinstance(data, pd.DataFrame): 75 | columns_spec_list = [] 76 | for col in data.columns: 77 | has_nans = data[col].isna().any() 78 | col_converted_to_numpy = ( 79 | data[col].dropna().values if has_nans else data[col].values 80 | ) 81 | columns_spec_list.append( 82 | ColumnSpec( 83 | type=_infer_numpy_array(col_converted_to_numpy), 84 | name=col, 85 | has_nans=has_nans, 86 | ) 87 | ) 88 | schema = Schema(columns_spec_list) 89 | elif isinstance(data, np.ndarray): 90 | if len(data.shape) > 2: 91 | raise TensorsNotSupportedException( 92 | "Attempting to infer schema from numpy array with " 93 | "shape {}".format(data.shape) 94 | ) 95 | if data.dtype == np.object: 96 | data = pd.DataFrame(data).infer_objects() 97 | schema = Schema( 98 | [ 99 | ColumnSpec(type=_infer_numpy_array(data[col].values)) 100 | for col in data.columns 101 | ] 102 | ) 103 | elif len(data.shape) == 1: 104 | schema = Schema([ColumnSpec(type=_infer_numpy_dtype(data.dtype))]) 105 | elif len(data.shape) == 2: 106 | schema = Schema( 107 | [ColumnSpec(type=_infer_numpy_dtype(data.dtype))] * data.shape[1] 108 | ) 109 | else: 110 | raise TypeError( 111 | "Expected one of (pandas.DataFrame, numpy array, " 112 | "dictionary of (name -> numpy.ndarray)) " 113 | "but got '{}'".format(type(data)) 114 | ) 115 | return schema 116 | 117 | 118 | def _infer_numpy_dtype(dtype: np.dtype) -> DataType: 119 | """Infer DataType from numpy dtype. 120 | 121 | Parameters 122 | ---------- 123 | dtype : np.dtype 124 | Numpy dtype 125 | 126 | Returns 127 | ------- 128 | DataType 129 | Inferred DataType. 130 | 131 | Raises 132 | ------ 133 | TypeError 134 | If type of `dtype` is not numpy.dtype. 135 | Exception 136 | If `dtype.kind`=='O' 137 | ClearboxWrapperException 138 | If `dtype` is unsupported. 139 | """ 140 | if not isinstance(dtype, np.dtype): 141 | raise TypeError("Expected numpy.dtype, got '{}'.".format(type(dtype))) 142 | if dtype.kind == "b": 143 | return DataType.boolean 144 | elif dtype.kind == "i" or dtype.kind == "u": 145 | if dtype.itemsize < 4 or (dtype.kind == "i" and dtype.itemsize == 4): 146 | return DataType.integer 147 | elif dtype.itemsize < 8 or (dtype.kind == "i" and dtype.itemsize == 8): 148 | return DataType.long 149 | elif dtype.kind == "f": 150 | if dtype.itemsize <= 4: 151 | return DataType.float 152 | elif dtype.itemsize <= 8: 153 | return DataType.double 154 | 155 | elif dtype.kind == "U": 156 | return DataType.string 157 | elif dtype.kind == "S": 158 | return DataType.binary 159 | elif dtype.kind == "O": 160 | raise Exception( 161 | "Can not infer np.object without looking at the values, call " 162 | "_infer_numpy_array instead." 163 | ) 164 | raise ClearboxWrapperException( 165 | "Unsupported numpy data type '{0}', kind '{1}'".format(dtype, dtype.kind) 166 | ) 167 | 168 | 169 | def _infer_numpy_array(col: np.ndarray) -> DataType: 170 | """Infer DataType of a numpy array. 171 | 172 | Parameters 173 | ---------- 174 | col : np.ndarray 175 | Column representation as a numpy array. 176 | 177 | Returns 178 | ------- 179 | DataType 180 | Inferred datatype. 181 | 182 | Raises 183 | ------ 184 | TypeError 185 | If `col` is not a numpy array. 186 | ClearboxWrapperException 187 | If `col` is not a 1D array. 188 | """ 189 | if not isinstance(col, np.ndarray): 190 | raise TypeError("Expected numpy.ndarray, got '{}'.".format(type(col))) 191 | if len(col.shape) > 1: 192 | raise ClearboxWrapperException( 193 | "Expected 1d array, got array with shape {}".format(col.shape) 194 | ) 195 | 196 | class IsInstanceOrNone(object): 197 | def __init__(self, *args): 198 | self.classes = args 199 | self.seen_instances = 0 200 | 201 | def __call__(self, x): 202 | if x is None: 203 | return True 204 | elif any(map(lambda c: isinstance(x, c), self.classes)): 205 | self.seen_instances += 1 206 | return True 207 | else: 208 | return False 209 | 210 | if col.dtype.kind == "O": 211 | is_binary_test = IsInstanceOrNone(bytes, bytearray) 212 | if all(map(is_binary_test, col)) and is_binary_test.seen_instances > 0: 213 | return DataType.binary 214 | is_string_test = IsInstanceOrNone(str) 215 | if all(map(is_string_test, col)) and is_string_test.seen_instances > 0: 216 | return DataType.string 217 | # NB: bool is also instance of int => boolean test must precede integer test. 218 | is_boolean_test = IsInstanceOrNone(bool) 219 | if all(map(is_boolean_test, col)) and is_boolean_test.seen_instances > 0: 220 | return DataType.boolean 221 | is_long_test = IsInstanceOrNone(int) 222 | if all(map(is_long_test, col)) and is_long_test.seen_instances > 0: 223 | return DataType.long 224 | is_double_test = IsInstanceOrNone(float) 225 | if all(map(is_double_test, col)) and is_double_test.seen_instances > 0: 226 | return DataType.double 227 | else: 228 | raise ClearboxWrapperException( 229 | "Unable to map 'np.object' type to MLflow DataType. np.object can" 230 | "be mapped iff all values have identical data type which is one " 231 | "of (string, (bytes or byterray), int, float)." 232 | ) 233 | else: 234 | return _infer_numpy_dtype(col.dtype) 235 | -------------------------------------------------------------------------------- /clearbox_wrapper/signature/__init__.py: -------------------------------------------------------------------------------- 1 | from .signature import infer_signature, Signature 2 | 3 | __all__ = [infer_signature, Signature] 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/signature/signature.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from clearbox_wrapper.schema import _infer_schema, Schema 7 | 8 | InferableDataset = Union[pd.DataFrame, np.ndarray, Dict[str, np.ndarray]] 9 | 10 | 11 | class Signature(object): 12 | """Description of a Model, Preprocessing or Data Preparation inputs and outpus. 13 | 14 | Attributes 15 | ---------- 16 | inputs : clearbox_wrapper.schema.Schema 17 | Inputs schema as a sequence of (optionally) named columns with types. 18 | outputs : Optional[clearbox_wrapper.schema.Schema] 19 | Outputs schema as a sequence of (optionally) named columns with types. 20 | 21 | """ 22 | 23 | def __init__(self, inputs: Schema, outputs: Optional[Schema] = None) -> "Signature": 24 | """Create a new Signature instance given inputs and (optionally) outputs schemas. 25 | 26 | Parameters 27 | ---------- 28 | inputs : clearbox_wrapper.schema.Schema 29 | Inputs schema as a sequence of (optionally) named columns with types. 30 | outputs : Optional[clearbox_wrapper.schema.Schema] 31 | Outputs schema as a sequence of (optionally) named columns with types, 32 | by default None. 33 | 34 | 35 | Raises 36 | ------ 37 | TypeError 38 | If `inputs` is not type Schema or if `outputs` is not type Schema or None. 39 | """ 40 | if not isinstance(inputs, Schema): 41 | raise TypeError("Inputs must be type Schema, got '{}'".format(type(inputs))) 42 | if outputs is not None and not isinstance(outputs, Schema): 43 | raise TypeError( 44 | "Outputs must be either None or clearbox_wrapper.schema.Schema, " 45 | "got '{}'".format(type(inputs)) 46 | ) 47 | self.inputs = inputs 48 | self.outputs = outputs 49 | 50 | def to_dict(self) -> Dict[str, Any]: 51 | """Generate dictionary representation of the signature. 52 | 53 | Returns 54 | ------- 55 | Dict[str, Any] 56 | Signature dictionary {"inputs": inputs schema as JSON, 57 | "outputs": outputs schema as JSON} 58 | """ 59 | return { 60 | "inputs": self.inputs.to_json(), 61 | "outputs": self.outputs.to_json() if self.outputs is not None else None, 62 | } 63 | 64 | @classmethod 65 | def from_dict(cls, signature_dict: Dict[str, Any]) -> "Signature": 66 | """Create a Signature instance from a dictionary representation. 67 | 68 | Parameters 69 | ---------- 70 | signature_dict: Dict[str, Any] 71 | Signature dictionary {"inputs": inputs schema as JSON, 72 | "outputs": outputs schema as JSON} 73 | 74 | Returns 75 | ------- 76 | Signature 77 | A Signature instance. 78 | """ 79 | inputs = Schema.from_json(signature_dict["inputs"]) 80 | if "outputs" in signature_dict and signature_dict["outputs"] is not None: 81 | outputs = Schema.from_json(signature_dict["outputs"]) 82 | return cls(inputs, outputs) 83 | else: 84 | return cls(inputs) 85 | 86 | def __eq__(self, other: "Signature") -> bool: 87 | """Check if two Signature instances (self and other) are equal. 88 | 89 | Parameters 90 | ---------- 91 | other : Signature 92 | A Signature instance 93 | 94 | Returns 95 | ------- 96 | bool 97 | True if the two signatures are equal, False otherwise. 98 | """ 99 | return ( 100 | isinstance(other, Signature) 101 | and self.inputs == other.inputs 102 | and self.outputs == other.outputs 103 | ) 104 | 105 | def __repr__(self) -> str: 106 | """Generate string representation. 107 | 108 | Returns 109 | ------- 110 | str 111 | Signature string representation. 112 | """ 113 | return ( 114 | "inputs: \n" 115 | " {}\n" 116 | "outputs: \n" 117 | " {}\n".format(repr(self.inputs), repr(self.outputs)) 118 | ) 119 | 120 | 121 | def infer_signature(input_data: Any, output_data: InferableDataset = None) -> Signature: 122 | """Infer a Signature from input data and (optionally) output data. 123 | 124 | The signature represents inputs and outputs scheme as a sequence 125 | of (optionally) named columns with types. 126 | 127 | Parameters 128 | ---------- 129 | input_data : Any 130 | Valid input data. E.g. (a subset of) the training dataset. It should be 131 | one of the following types: 132 | - pandas.DataFrame 133 | - dictionary of { name -> numpy.ndarray} 134 | - numpy.ndarray 135 | The element types should be mappable to one of `clearbox.schema.DataType`. 136 | output_data : InferableDataset, optional 137 | Valid output data. E.g. Preprocessed data or model predictions for 138 | (a subset of) the training dataset. It should be one of the following types: 139 | - pandas.DataFrame 140 | - numpy.ndarray 141 | The element types should be mappable to one of `clearbox.schema.DataType`. 142 | By default None. 143 | 144 | Returns 145 | ------- 146 | Signature 147 | Inferred Signature 148 | """ 149 | inputs = _infer_schema(input_data) 150 | outputs = _infer_schema(output_data) if output_data is not None else None 151 | return Signature(inputs, outputs) 152 | -------------------------------------------------------------------------------- /clearbox_wrapper/sklearn/__init__.py: -------------------------------------------------------------------------------- 1 | from .sklearn import _load_clearbox, _load_pyfunc, save_sklearn_model 2 | 3 | __all__ = [_load_clearbox, _load_pyfunc, save_sklearn_model] 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/sklearn/sklearn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from typing import Any, Dict, Optional, Union 4 | 5 | from loguru import logger 6 | import yaml 7 | 8 | from clearbox_wrapper.exceptions import ClearboxWrapperException 9 | from clearbox_wrapper.model import MLMODEL_FILE_NAME, Model 10 | import clearbox_wrapper.pyfunc as pyfunc 11 | from clearbox_wrapper.signature import Signature 12 | from clearbox_wrapper.utils import _get_default_conda_env, _get_flavor_configuration 13 | from clearbox_wrapper.wrapper import add_clearbox_flavor_to_model 14 | from clearbox_wrapper.wrapper import FLAVOR_NAME as cb_flavor_name 15 | 16 | 17 | FLAVOR_NAME = "sklearn" 18 | 19 | SERIALIZATION_FORMAT_PICKLE = "pickle" 20 | SERIALIZATION_FORMAT_CLOUDPICKLE = "cloudpickle" 21 | 22 | SUPPORTED_SERIALIZATION_FORMATS = [ 23 | SERIALIZATION_FORMAT_PICKLE, 24 | SERIALIZATION_FORMAT_CLOUDPICKLE, 25 | ] 26 | 27 | 28 | def get_default_sklearn_conda_env(include_cloudpickle: bool = False) -> Dict: 29 | """Generate the default Conda environment for Scikit-Learn models. 30 | 31 | Parameters 32 | ---------- 33 | include_cloudpickle : bool, optional 34 | Whether to include cloudpickle as a environment dependency, by default False. 35 | 36 | Returns 37 | ------- 38 | Dict 39 | The default Conda environment for Scikit-Learn models as a dictionary. 40 | """ 41 | import sklearn 42 | 43 | pip_deps = ["scikit-learn=={}".format(sklearn.__version__)] 44 | 45 | if include_cloudpickle: 46 | import cloudpickle 47 | 48 | pip_deps += ["cloudpickle=={}".format(cloudpickle.__version__)] 49 | return _get_default_conda_env( 50 | additional_pip_deps=pip_deps, additional_conda_channels=None 51 | ) 52 | 53 | 54 | def save_sklearn_model( 55 | sk_model: Any, 56 | path: str, 57 | conda_env: Optional[Union[str, Dict]] = None, 58 | mlmodel: Optional[Model] = None, 59 | serialization_format: str = SERIALIZATION_FORMAT_CLOUDPICKLE, 60 | signature: Optional[Signature] = None, 61 | add_clearbox_flavor: bool = False, 62 | preprocessing_subpath: str = None, 63 | data_preparation_subpath: str = None, 64 | ): 65 | """Save a Scikit-Learn model. Produces an MLflow Model containing the following flavors: 66 | * wrapper.sklearn 67 | * wrapper.pyfunc. NOTE: This flavor is only included for scikit-learn models 68 | that define at least `predict()`, since `predict()` is required for pyfunc model 69 | inference. 70 | 71 | Parameters 72 | ---------- 73 | sk_model : Any 74 | A Scikit-Learn model to be saved. 75 | path : str 76 | Local path to save the model to. 77 | conda_env : Optional[Union[str, Dict]], optional 78 | A dictionary representation of a Conda environment or the path to a Conda environment 79 | YAML file, by default None. This decsribes the environment this model should be run in. 80 | If None, the default Conda environment will be added to the model. Example of a 81 | dictionary representation of a Conda environment: 82 | { 83 | 'name': 'conda-env', 84 | 'channels': ['defaults'], 85 | 'dependencies': [ 86 | 'python=3.7.0', 87 | 'scikit-learn=0.19.2' 88 | ] 89 | } 90 | serialization_format : str, optional 91 | The format in which to serialize the model. This should be one of the formats listed in 92 | SUPPORTED_SERIALIZATION_FORMATS. Cloudpickle format, SERIALIZATION_FORMAT_CLOUDPICKLE, 93 | provides better cross-system compatibility by identifying and packaging code 94 | dependencies with the serialized model, by default SERIALIZATION_FORMAT_CLOUDPICKLE 95 | signature : Optional[Signature], optional 96 | A model signature describes model input schema. It can be inferred from datasets with 97 | valid model type (e.g. the training dataset with target column omitted), by default None 98 | 99 | Raises 100 | ------ 101 | ClearboxWrapperException 102 | If unrecognized serialization format or model path already exists. 103 | """ 104 | import sklearn 105 | 106 | if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS: 107 | raise ClearboxWrapperException( 108 | "Unrecognized serialization format: {serialization_format}. Please specify one" 109 | " of the following supported formats: {supported_formats}.".format( 110 | serialization_format=serialization_format, 111 | supported_formats=SUPPORTED_SERIALIZATION_FORMATS, 112 | ) 113 | ) 114 | 115 | if os.path.exists(path): 116 | raise ClearboxWrapperException("Model path '{}' already exists".format(path)) 117 | 118 | os.makedirs(path) 119 | if mlmodel is None: 120 | mlmodel = Model() 121 | 122 | if signature is not None: 123 | mlmodel.signature = signature 124 | 125 | model_data_subpath = "model.pkl" 126 | 127 | _serialize_and_save_model( 128 | sk_model=sk_model, 129 | output_path=os.path.join(path, model_data_subpath), 130 | serialization_format=serialization_format, 131 | ) 132 | 133 | conda_env_subpath = "conda.yaml" 134 | if conda_env is None: 135 | conda_env = get_default_sklearn_conda_env( 136 | include_cloudpickle=serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE 137 | ) 138 | elif not isinstance(conda_env, dict): 139 | with open(conda_env, "r") as f: 140 | conda_env = yaml.safe_load(f) 141 | 142 | with open(os.path.join(path, conda_env_subpath), "w") as f: 143 | yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 144 | 145 | # `PyFuncModel` only works for sklearn models that define `predict()`. 146 | if hasattr(sk_model, "predict"): 147 | pyfunc.add_pyfunc_flavor_to_model( 148 | mlmodel, 149 | loader_module="clearbox_wrapper.sklearn", 150 | model_path=model_data_subpath, 151 | env=conda_env_subpath, 152 | ) 153 | 154 | if add_clearbox_flavor: 155 | add_clearbox_flavor_to_model( 156 | mlmodel, 157 | loader_module="clearbox_wrapper.sklearn", 158 | model_path=model_data_subpath, 159 | env=conda_env_subpath, 160 | preprocessing=preprocessing_subpath, 161 | data_preparation=data_preparation_subpath, 162 | ) 163 | 164 | mlmodel.add_flavor( 165 | FLAVOR_NAME, 166 | model_path=model_data_subpath, 167 | sklearn_version=sklearn.__version__, 168 | serialization_format=serialization_format, 169 | ) 170 | 171 | mlmodel.save(os.path.join(path, MLMODEL_FILE_NAME)) 172 | 173 | 174 | def _serialize_and_save_model( 175 | sk_model: Any, output_path: str, serialization_format: str 176 | ) -> None: 177 | """Serialize and save a Scikit-Learn model to a local file. 178 | 179 | Parameters 180 | ---------- 181 | sk_model : Any 182 | The Scikit-Learn model to serialize. 183 | output_path : str 184 | The file path to which to write the serialized model (.pkl). 185 | serialization_format : str 186 | The format in which to serialize the model. This should be one of the following: 187 | SERIALIZATION_FORMAT_PICKLE or SERIALIZATION_FORMAT_CLOUDPICKLE. 188 | 189 | Raises 190 | ------ 191 | ClearboxWrapperException 192 | Unrecognized serialization format. 193 | """ 194 | 195 | with open(output_path, "wb") as out: 196 | if serialization_format == SERIALIZATION_FORMAT_PICKLE: 197 | pickle.dump(sk_model, out) 198 | elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE: 199 | import cloudpickle 200 | 201 | cloudpickle.dump(sk_model, out) 202 | else: 203 | raise ClearboxWrapperException( 204 | "Unrecognized serialization format: {serialization_format}".format( 205 | serialization_format=serialization_format 206 | ) 207 | ) 208 | 209 | 210 | def _load_serialized_model( 211 | serialized_model_path: str, serialization_format: str 212 | ) -> Any: 213 | """Load a serialized (through pickle or cloudpickle) Scikit-Learn model. 214 | 215 | Parameters 216 | ---------- 217 | serialized_model_path : str 218 | File path to the Scikit-Learn serialized model. 219 | serialization_format : str 220 | Format in which the model was serialized: SERIALIZATION_FORMAT_PICKLE or 221 | SERIALIZATION_FORMAT_CLOUDPICKLE 222 | 223 | Returns 224 | ------- 225 | Any 226 | A Scikit-Learn model. 227 | 228 | Raises 229 | ------ 230 | ClearboxWrapperException 231 | If Unrecognized serialization format. 232 | """ 233 | # TODO: we could validate the scikit-learn version here 234 | if serialization_format not in SUPPORTED_SERIALIZATION_FORMATS: 235 | raise ClearboxWrapperException( 236 | "Unrecognized serialization format: {serialization_format}. Please specify one" 237 | " of the following supported formats: {supported_formats}.".format( 238 | serialization_format=serialization_format, 239 | supported_formats=SUPPORTED_SERIALIZATION_FORMATS, 240 | ) 241 | ) 242 | with open(serialized_model_path, "rb") as f: 243 | # Models serialized with Cloudpickle cannot necessarily be deserialized using Pickle; 244 | if serialization_format == SERIALIZATION_FORMAT_PICKLE: 245 | return pickle.load(f) 246 | elif serialization_format == SERIALIZATION_FORMAT_CLOUDPICKLE: 247 | import cloudpickle 248 | 249 | return cloudpickle.load(f) 250 | 251 | 252 | def _load_pyfunc(model_path: str) -> Any: 253 | """Load Scikit-Learn model as a PyFunc model. This function is called by pyfunc.load_pyfunc. 254 | 255 | Parameters 256 | ---------- 257 | model_path : str 258 | File path to the model with sklearn flavor. 259 | 260 | Returns 261 | ------- 262 | Any 263 | A Scikit-Learn model. 264 | """ 265 | try: 266 | sklearn_flavor_conf = _get_flavor_configuration( 267 | model_path=model_path, flavor_name=FLAVOR_NAME 268 | ) 269 | serialization_format = sklearn_flavor_conf.get( 270 | "serialization_format", SERIALIZATION_FORMAT_PICKLE 271 | ) 272 | except ClearboxWrapperException: 273 | logger.warning( 274 | "Could not find scikit-learn flavor configuration during model loading process." 275 | " Assuming 'pickle' serialization format." 276 | ) 277 | serialization_format = SERIALIZATION_FORMAT_PICKLE 278 | 279 | pyfunc_flavor_conf = _get_flavor_configuration( 280 | model_path=model_path, flavor_name=pyfunc.FLAVOR_NAME 281 | ) 282 | serialized_model_path = os.path.join(model_path, pyfunc_flavor_conf["model_path"]) 283 | 284 | return _load_serialized_model( 285 | serialized_model_path=serialized_model_path, 286 | serialization_format=serialization_format, 287 | ) 288 | 289 | 290 | def _load_clearbox(model_path: str) -> Any: 291 | """Load Scikit-Learn model as a ClearboxWrapper model. 292 | 293 | Parameters 294 | ---------- 295 | model_path : str 296 | File path to the model with sklearn flavor. 297 | 298 | Returns 299 | ------- 300 | Any 301 | A Scikit-Learn model. 302 | """ 303 | try: 304 | sklearn_flavor_conf = _get_flavor_configuration( 305 | model_path=model_path, flavor_name=FLAVOR_NAME 306 | ) 307 | serialization_format = sklearn_flavor_conf.get( 308 | "serialization_format", SERIALIZATION_FORMAT_CLOUDPICKLE 309 | ) 310 | except ClearboxWrapperException: 311 | logger.warning( 312 | "Could not find scikit-learn flavor configuration during model loading process." 313 | " Assuming 'pickle' serialization format." 314 | ) 315 | serialization_format = SERIALIZATION_FORMAT_PICKLE 316 | 317 | clearbox_flavor_conf = _get_flavor_configuration( 318 | model_path=model_path, flavor_name=cb_flavor_name 319 | ) 320 | serialized_model_path = os.path.join(model_path, clearbox_flavor_conf["model_path"]) 321 | 322 | return _load_serialized_model( 323 | serialized_model_path=serialized_model_path, 324 | serialization_format=serialization_format, 325 | ) 326 | -------------------------------------------------------------------------------- /clearbox_wrapper/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import ( 2 | _get_default_conda_env, 3 | get_major_minor_py_version, 4 | PYTHON_VERSION, 5 | ) 6 | from .file_utils import _copy_file_or_tree, TempDir 7 | from .model_utils import ( 8 | _get_flavor_configuration, 9 | get_super_classes_names, 10 | ) 11 | 12 | __all__ = [ 13 | _copy_file_or_tree, 14 | _get_default_conda_env, 15 | _get_flavor_configuration, 16 | get_major_minor_py_version, 17 | get_super_classes_names, 18 | PYTHON_VERSION, 19 | TempDir, 20 | ] 21 | -------------------------------------------------------------------------------- /clearbox_wrapper/utils/environment.py: -------------------------------------------------------------------------------- 1 | from sys import version_info 2 | from typing import Dict, List, Optional, Union 3 | 4 | import yaml 5 | 6 | CONDA_HEADER = """\ 7 | name: conda-env 8 | channels: 9 | - defaults 10 | - conda-forge 11 | """ 12 | 13 | PYTHON_VERSION = "{major}.{minor}.{micro}".format( 14 | major=version_info.major, minor=version_info.minor, micro=version_info.micro 15 | ) 16 | 17 | 18 | def get_major_minor_py_version(py_version): 19 | return ".".join(py_version.split(".")[:2]) 20 | 21 | 22 | def _get_default_conda_env( 23 | path: str = None, 24 | additional_conda_deps: Optional[List] = None, 25 | additional_pip_deps: Optional[List] = None, 26 | additional_conda_channels: Optional[List] = None, 27 | ) -> Union[Dict, None]: 28 | """Generate and, optionally, save to file the default Conda environment for models. 29 | 30 | Parameters 31 | ---------- 32 | path : str, optional 33 | File path. If not None, the Conda env will be saved to file, by default None 34 | additional_conda_deps : Optional[List], optional 35 | List of additional conda dependencies, by default None 36 | additional_pip_deps : Optional[List], optional 37 | List of additional Pypi dependencies, by default None 38 | additional_conda_channels : Optional[List], optional 39 | List of additional conda channels, by default None 40 | 41 | Returns 42 | ------- 43 | Union[Dict, None] 44 | None if path is not None, else the Conda environment generated as a dictionary. 45 | """ 46 | pip_deps = additional_pip_deps if additional_pip_deps else [] 47 | conda_deps = (additional_conda_deps if additional_conda_deps else []) + ( 48 | ["pip"] if pip_deps else [] 49 | ) 50 | 51 | env = yaml.safe_load(CONDA_HEADER) 52 | env["dependencies"] = ["python={}".format(PYTHON_VERSION)] 53 | if conda_deps is not None: 54 | env["dependencies"] += conda_deps 55 | env["dependencies"].append({"pip": pip_deps}) 56 | if additional_conda_channels is not None: 57 | env["channels"] += additional_conda_channels 58 | 59 | if path is not None: 60 | with open(path, "w") as out: 61 | yaml.safe_dump(env, stream=out, default_flow_style=False) 62 | return None 63 | else: 64 | return env 65 | -------------------------------------------------------------------------------- /clearbox_wrapper/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | 5 | 6 | def _copy_file_or_tree(src, dst, dst_dir=None): 7 | """ 8 | :return: The path to the copied artifacts, relative to `dst` 9 | """ 10 | dst_subpath = os.path.basename(os.path.abspath(src)) 11 | if dst_dir is not None: 12 | dst_subpath = os.path.join(dst_dir, dst_subpath) 13 | dst_path = os.path.join(dst, dst_subpath) 14 | if os.path.isfile(src): 15 | dst_dirpath = os.path.dirname(dst_path) 16 | if not os.path.exists(dst_dirpath): 17 | os.makedirs(dst_dirpath) 18 | shutil.copy(src=src, dst=dst_path) 19 | else: 20 | shutil.copytree(src=src, dst=dst_path) 21 | return dst_subpath 22 | 23 | 24 | class TempDir(object): 25 | def __init__(self, chdr=False, remove_on_exit=True): 26 | self._dir = None 27 | self._path = None 28 | self._chdr = chdr 29 | self._remove = remove_on_exit 30 | 31 | def __enter__(self): 32 | self._path = os.path.abspath(tempfile.mkdtemp()) 33 | assert os.path.exists(self._path) 34 | if self._chdr: 35 | self._dir = os.path.abspath(os.getcwd()) 36 | os.chdir(self._path) 37 | return self 38 | 39 | def __exit__(self, tp, val, traceback): 40 | if self._chdr and self._dir: 41 | os.chdir(self._dir) 42 | self._dir = None 43 | if self._remove and os.path.exists(self._path): 44 | shutil.rmtree(self._path) 45 | 46 | assert not self._remove or not os.path.exists(self._path) 47 | assert os.path.exists(os.getcwd()) 48 | 49 | def path(self, *path): 50 | return ( 51 | os.path.join("./", *path) if self._chdr else os.path.join(self._path, *path) 52 | ) 53 | -------------------------------------------------------------------------------- /clearbox_wrapper/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | from typing import Any, Dict, List 4 | 5 | from clearbox_wrapper.exceptions import ClearboxWrapperException 6 | from clearbox_wrapper.model import MLMODEL_FILE_NAME, Model 7 | 8 | 9 | def _get_flavor_configuration(model_path: str, flavor_name: str) -> Dict: 10 | """Get the configuration for a specified flavor of a model. 11 | 12 | Parameters 13 | ---------- 14 | model_path : str 15 | Path to the model directory. 16 | flavor_name : str 17 | Name of the flavor configuration to load. 18 | 19 | Returns 20 | ------- 21 | Dict 22 | Flavor configuration as a dictionary. 23 | 24 | Raises 25 | ------ 26 | ClearboxWrapperException 27 | If it couldn't find a MLmodel file or if the model doesn't contain 28 | the specified flavor. 29 | """ 30 | mlmodel_path = os.path.join(model_path, MLMODEL_FILE_NAME) 31 | if not os.path.exists(mlmodel_path): 32 | raise ClearboxWrapperException( 33 | 'Could not find an "{}" configuration file at "{}"'.format( 34 | MLMODEL_FILE_NAME, model_path 35 | ) 36 | ) 37 | 38 | mlmodel = Model.load(mlmodel_path) 39 | if flavor_name not in mlmodel.flavors: 40 | raise ClearboxWrapperException( 41 | 'Model does not have the "{}" flavor'.format(flavor_name) 42 | ) 43 | flavor_configuration_dict = mlmodel.flavors[flavor_name] 44 | return flavor_configuration_dict 45 | 46 | 47 | def get_super_classes_names(instance_or_class: Any) -> List[str]: 48 | """Given an instance or a class, computes and returns a list of its superclasses. 49 | 50 | Parameters 51 | ---------- 52 | instance_or_class : Any 53 | An instance of an object or a class. 54 | 55 | Returns 56 | ------- 57 | List[str] 58 | List of superclasses names strings. 59 | """ 60 | super_class_names_list = [] 61 | if not inspect.isclass(instance_or_class): 62 | instance_or_class = instance_or_class.__class__ 63 | super_classes_tuple = inspect.getmro(instance_or_class) 64 | for super_class in super_classes_tuple: 65 | super_class_name = ( 66 | str(super_class).replace("'", "").replace("", "") 67 | ) 68 | super_class_names_list.append(super_class_name) 69 | return super_class_names_list 70 | -------------------------------------------------------------------------------- /clearbox_wrapper/wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import add_clearbox_flavor_to_model, FLAVOR_NAME 2 | from .wrapper import load_model, save_model 3 | 4 | 5 | __all__ = [ 6 | add_clearbox_flavor_to_model, 7 | FLAVOR_NAME, 8 | load_model, 9 | save_model, 10 | ] 11 | -------------------------------------------------------------------------------- /clearbox_wrapper/wrapper/model.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import cloudpickle 4 | 5 | from clearbox_wrapper.utils import _get_default_conda_env 6 | 7 | 8 | def get_default_conda_env(): 9 | """ 10 | :return: The default Conda environment for MLflow Models produced by calls to 11 | :func:`save_model() ` 12 | and :func:`log_model() ` when a user-defined subclass of 13 | :class:`PythonModel` is provided. 14 | """ 15 | return _get_default_conda_env( 16 | additional_conda_deps=None, 17 | additional_pip_deps=["cloudpickle=={}".format(cloudpickle.__version__)], 18 | additional_conda_channels=None, 19 | ) 20 | 21 | 22 | class ClearboxModel(object, metaclass=ABCMeta): 23 | @abstractmethod 24 | def prepare_data(self, data): 25 | pass 26 | 27 | @abstractmethod 28 | def preprocess_data(self, data): 29 | pass 30 | 31 | @abstractmethod 32 | def predict(self, model_input): 33 | pass 34 | 35 | @abstractmethod 36 | def predict_proba(self, model_input): 37 | pass 38 | 39 | 40 | class _ModelWrapper(object): 41 | """ 42 | Wrapper class that creates a predict function such that 43 | predict(model_input: pd.DataFrame) -> model's output as pd.DataFrame (pandas DataFrame) 44 | """ 45 | 46 | def __init__(self, wrapper_model): 47 | """ 48 | :param python_model: An instance of a subclass of :class:`~PythonModel`. 49 | :param context: A :class:`~PythonModelContext` instance containing artifacts that 50 | ``python_model`` may use when performing inference. 51 | """ 52 | self.wrapper_model = wrapper_model 53 | 54 | def prepare_data(self, data): 55 | return self.wrapper_model.prepare_data(data) 56 | 57 | def preprocess_data(self, data): 58 | return self.wrapper_model.preprocess_data(data) 59 | 60 | def predict(self, model_input): 61 | return self.wrapper_model.predict(model_input) 62 | -------------------------------------------------------------------------------- /clearbox_wrapper/wrapper/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import shutil 4 | import zipfile 5 | 6 | from clearbox_wrapper.model import Model 7 | from clearbox_wrapper.utils import PYTHON_VERSION 8 | 9 | FLAVOR_NAME = "clearbox" 10 | MAIN = "loader_module" 11 | CODE = "code" 12 | DATA = "data" 13 | ENV = "env" 14 | PREPROCESSING = "preprocessing_path" 15 | DATA_PREPARATION = "data_preparation_path" 16 | PY_VERSION = "python_version" 17 | 18 | 19 | def add_clearbox_flavor_to_model( 20 | model: Model, 21 | loader_module: str, 22 | data: str = None, 23 | code=None, 24 | env: str = None, 25 | preprocessing: str = None, 26 | data_preparation: str = None, 27 | **kwargs, 28 | ) -> Model: 29 | """Add Clearbox flavor to a model configuration. Caller can use this to create a valid 30 | Clearbox model flavor out of an existing directory structure. A Clearbox flavor will be 31 | added to the flavors list into the MLModel file: 32 | flavors: 33 | clearbox: 34 | env: ... 35 | loader_module: ... 36 | model_path: ... 37 | python_version: ... 38 | 39 | Parameters 40 | ---------- 41 | model : Model 42 | Existing model. 43 | loader_module : str 44 | The module to be used to load the model (e.g. clearbox_wrapper.sklearn) 45 | data : str, optional 46 | Path to the model data, by default None. 47 | code : str, optional 48 | Path to the code dependencies, by default None. 49 | env : str, optional 50 | Path to the Conda environment, by default None. 51 | 52 | Returns 53 | ------- 54 | Model 55 | The Model with the new flavor added. 56 | """ 57 | parms = deepcopy(kwargs) 58 | 59 | parms[MAIN] = loader_module 60 | parms[PY_VERSION] = PYTHON_VERSION 61 | if code: 62 | parms[CODE] = code 63 | if data: 64 | parms[DATA] = data 65 | if env: 66 | parms[ENV] = env 67 | if preprocessing: 68 | parms[PREPROCESSING] = preprocessing 69 | if data_preparation: 70 | parms[DATA_PREPARATION] = data_preparation 71 | return model.add_flavor(FLAVOR_NAME, **parms) 72 | 73 | 74 | def zip_directory(directory_path: str) -> None: 75 | """Given a directory path, zip the directory. 76 | 77 | Parameters 78 | ---------- 79 | directory_path : str 80 | Directory path 81 | """ 82 | zip_object = zipfile.ZipFile(directory_path + ".zip", "w", zipfile.ZIP_DEFLATED) 83 | root_len = len(directory_path) + 1 84 | for base, _dirs, files in os.walk(directory_path): 85 | for file in files: 86 | fn = os.path.join(base, file) 87 | zip_object.write(fn, fn[root_len:]) 88 | shutil.rmtree(directory_path) 89 | -------------------------------------------------------------------------------- /clearbox_wrapper/wrapper/wrapper.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import sys 4 | from typing import Any, Callable, Dict, List, Optional, Union 5 | 6 | from loguru import logger 7 | import numpy as np 8 | import pandas as pd 9 | import yaml 10 | 11 | from clearbox_wrapper.data_preparation import ( 12 | DataPreparation, 13 | load_serialized_data_preparation, 14 | ) 15 | from clearbox_wrapper.exceptions import ClearboxWrapperException 16 | from clearbox_wrapper.keras import save_keras_model 17 | from clearbox_wrapper.model import MLMODEL_FILE_NAME, Model 18 | from clearbox_wrapper.preprocessing import ( 19 | load_serialized_preprocessing, 20 | Preprocessing, 21 | ) 22 | from clearbox_wrapper.pytorch import save_pytorch_model 23 | from clearbox_wrapper.signature import infer_signature 24 | from clearbox_wrapper.sklearn import save_sklearn_model 25 | from clearbox_wrapper.utils import ( 26 | _get_default_conda_env, 27 | get_major_minor_py_version, 28 | get_super_classes_names, 29 | PYTHON_VERSION, 30 | ) 31 | from clearbox_wrapper.xgboost import save_xgboost_model 32 | from .model import ClearboxModel 33 | from .utils import ( 34 | DATA, 35 | DATA_PREPARATION, 36 | FLAVOR_NAME, 37 | MAIN, 38 | PREPROCESSING, 39 | PY_VERSION, 40 | zip_directory, 41 | ) 42 | 43 | WrapperInput = Union[pd.DataFrame, pd.Series, np.ndarray, List[Any], Dict[str, Any]] 44 | WrapperOutput = Union[pd.DataFrame, pd.Series, np.ndarray, list] 45 | logger.add(sys.stdout, backtrace=False, diagnose=False) 46 | 47 | 48 | class WrapperModel(ClearboxModel): 49 | def __init__( 50 | self, 51 | model_meta: Model, 52 | model_impl: Any, 53 | preprocessing: Any = None, 54 | data_preparation: Any = None, 55 | ): 56 | if not hasattr(model_impl, "predict"): 57 | raise ClearboxWrapperException( 58 | "Model implementation is missing required predict method." 59 | ) 60 | if not model_meta: 61 | raise ClearboxWrapperException("Model is missing metadata.") 62 | if data_preparation is not None and preprocessing is None: 63 | raise ValueError( 64 | "Attribute 'preprocessing' is None but attribute " 65 | "'data_preparation' is not None. If you have a single step " 66 | "preprocessing, pass it as attribute 'preprocessing'" 67 | ) 68 | 69 | self._model_meta = model_meta 70 | self._model_impl = model_impl 71 | self._preprocessing = preprocessing 72 | self._data_preparation = data_preparation 73 | 74 | def prepare_data(self, data: WrapperInput) -> WrapperOutput: 75 | if self._data_preparation is None: 76 | raise ClearboxWrapperException("This model has no data preparation.") 77 | return self._data_preparation.prepare_data(data) 78 | 79 | def preprocess_data(self, data: WrapperInput) -> WrapperOutput: 80 | if self._preprocessing is None: 81 | raise ClearboxWrapperException("This model has no preprocessing.") 82 | return self._preprocessing.preprocess(data) 83 | 84 | def predict( 85 | self, data: WrapperInput, preprocess: bool = True, prepare_data: bool = True 86 | ) -> WrapperOutput: 87 | if prepare_data and self._data_preparation is not None: 88 | data = self._data_preparation.prepare_data(data) 89 | elif not prepare_data: 90 | logger.warning( 91 | "This model has data preparation and you're bypassing it," 92 | " this can lead to unexpected results." 93 | ) 94 | 95 | if preprocess and self._preprocessing is not None: 96 | data = self._preprocessing.preprocess(data) 97 | elif not preprocess: 98 | logger.warning( 99 | "This model has preprocessing and you're bypassing it," 100 | " this can lead to unexpected results." 101 | ) 102 | 103 | return self._model_impl.predict(data) 104 | 105 | def predict_proba( 106 | self, data: WrapperInput, preprocess: bool = True, prepare_data: bool = True 107 | ) -> WrapperOutput: 108 | if not hasattr(self._model_impl, "predict_proba"): 109 | raise ClearboxWrapperException("This model has no predict_proba method.") 110 | 111 | if prepare_data and self._data_preparation is not None: 112 | data = self._data_preparation.prepare_data(data) 113 | elif not prepare_data: 114 | logger.warning( 115 | "This model has data preparation and you're bypassing it," 116 | " this can lead to unexpected results." 117 | ) 118 | 119 | if preprocess and self._preprocessing is not None: 120 | data = self._preprocessing.preprocess(data) 121 | elif not preprocess: 122 | logger.warning( 123 | "This model has preprocessing and you're bypassing it," 124 | " this can lead to unexpected results." 125 | ) 126 | 127 | return self._model_impl.predict_proba(data) 128 | 129 | @property 130 | def metadata(self): 131 | """Model metadata.""" 132 | if self._model_meta is None: 133 | raise ClearboxWrapperException("Model is missing metadata.") 134 | return self._model_meta 135 | 136 | def __repr__(self): 137 | info = {} 138 | if self._model_meta is not None: 139 | info["flavor"] = self._model_meta.flavors[FLAVOR_NAME]["loader_module"] 140 | return yaml.safe_dump({"wrapper.loaded_model": info}, default_flow_style=False) 141 | 142 | 143 | def _check_and_get_conda_env(model: Any, additional_deps: List = None) -> Dict: 144 | import cloudpickle 145 | 146 | pip_deps = ["cloudpickle=={}".format(cloudpickle.__version__)] 147 | 148 | if additional_deps is not None: 149 | pip_deps += additional_deps 150 | 151 | model_super_classes = get_super_classes_names(model) 152 | 153 | # order of ifs matters 154 | if any("xgboost" in super_class for super_class in model_super_classes): 155 | import xgboost 156 | 157 | pip_deps.append("xgboost=={}".format(xgboost.__version__)) 158 | elif any("sklearn" in super_class for super_class in model_super_classes): 159 | pip_deps.append( 160 | "scikit-learn=={}".format(model.__getstate__()["_sklearn_version"]) 161 | ) 162 | elif any("keras" in super_class for super_class in model_super_classes): 163 | import tensorflow 164 | 165 | pip_deps.append("tensorflow=={}".format(tensorflow.__version__)) 166 | elif any("torch" in super_class for super_class in model_super_classes): 167 | import torch 168 | 169 | pip_deps.append("torch=={}".format(torch.__version__)) 170 | 171 | unique_pip_deps = [dep.split("==")[0] for dep in pip_deps] 172 | if len(unique_pip_deps) > len(set(unique_pip_deps)): 173 | raise ValueError( 174 | "Multiple occurences of the same dependency is not allowed: {}".format( 175 | pip_deps 176 | ) 177 | ) 178 | 179 | return _get_default_conda_env(additional_pip_deps=pip_deps) 180 | 181 | 182 | def _warn_potentially_incompatible_py_version_if_necessary(model_py_version=None): 183 | """ 184 | Compares the version of Python that was used to save a given model with the version 185 | of Python that is currently running. If a major or minor version difference is detected, 186 | logs an appropriate warning. 187 | """ 188 | if model_py_version is None: 189 | logger.warning( 190 | "The specified model does not have a specified Python version. It may be" 191 | " incompatible with the version of Python that is currently running: Python %s", 192 | PYTHON_VERSION, 193 | ) 194 | elif get_major_minor_py_version(model_py_version) != get_major_minor_py_version( 195 | PYTHON_VERSION 196 | ): 197 | logger.warning( 198 | "The version of Python that the model was saved in, `Python %s`, differs" 199 | " from the version of Python that is currently running, `Python %s`," 200 | " and may be incompatible", 201 | model_py_version, 202 | PYTHON_VERSION, 203 | ) 204 | 205 | 206 | def save_model( 207 | path: str, 208 | model: Any, 209 | input_data: Optional[WrapperInput] = None, 210 | preprocessing: Optional[Callable] = None, 211 | data_preparation: Optional[Callable] = None, 212 | additional_deps: Optional[List] = None, 213 | zip: bool = True, 214 | ) -> None: 215 | 216 | path_check = path + ".zip" if zip else path 217 | if os.path.exists(path_check): 218 | raise ClearboxWrapperException("Model path '{}' already exists".format(path)) 219 | 220 | mlmodel = Model() 221 | saved_preprocessing_subpath = None 222 | saved_data_preparation_subpath = None 223 | 224 | if data_preparation is not None and preprocessing is None: 225 | raise ValueError( 226 | "Attribute 'preprocessing' is None but attribute " 227 | "'data_preparation' is not None. If you have a single step " 228 | "preprocessing, pass it as attribute 'preprocessing'" 229 | ) 230 | 231 | if data_preparation and preprocessing: 232 | preparation = DataPreparation(data_preparation) 233 | data_preprocessing = Preprocessing(preprocessing) 234 | saved_data_preparation_subpath = "data_preparation.pkl" 235 | saved_preprocessing_subpath = "preprocessing.pkl" 236 | if input_data is not None: 237 | if isinstance(input_data, pd.DataFrame) and input_data.shape[0] > 50: 238 | input_data = input_data.head(50) 239 | elif isinstance(input_data, np.ndarray) and input_data.shape[0] > 50: 240 | input_data = input_data[:50, :] 241 | 242 | data_preparation_output = preparation.prepare_data(input_data) 243 | preprocessing_output = data_preprocessing.preprocess( 244 | data_preparation_output 245 | ) 246 | data_preparation_signature = infer_signature( 247 | input_data, data_preparation_output 248 | ) 249 | preprocessing_signature = infer_signature( 250 | data_preparation_output, preprocessing_output 251 | ) 252 | model_signature = infer_signature(preprocessing_output) 253 | mlmodel.preparation_signature = data_preparation_signature 254 | mlmodel.preprocessing_signature = preprocessing_signature 255 | mlmodel.model_signature = model_signature 256 | elif preprocessing: 257 | data_preprocessing = Preprocessing(preprocessing) 258 | saved_preprocessing_subpath = "preprocessing.pkl" 259 | if input_data is not None: 260 | preprocessing_output = data_preprocessing.preprocess(input_data) 261 | preprocessing_signature = infer_signature(input_data, preprocessing_output) 262 | model_signature = infer_signature(preprocessing_output) 263 | mlmodel.preprocessing_signature = preprocessing_signature 264 | mlmodel.model_signature = model_signature 265 | elif input_data is not None: 266 | model_signature = infer_signature(input_data) 267 | mlmodel.model_signature = model_signature 268 | 269 | conda_env = _check_and_get_conda_env(model, additional_deps) 270 | model_super_classes = get_super_classes_names(model) 271 | 272 | if any("sklearn" in super_class for super_class in model_super_classes): 273 | save_sklearn_model( 274 | model, 275 | path, 276 | conda_env=conda_env, 277 | mlmodel=mlmodel, 278 | add_clearbox_flavor=True, 279 | preprocessing_subpath=saved_preprocessing_subpath, 280 | data_preparation_subpath=saved_data_preparation_subpath, 281 | ) 282 | elif any("xgboost" in super_class for super_class in model_super_classes): 283 | save_xgboost_model( 284 | model, 285 | path, 286 | conda_env=conda_env, 287 | mlmodel=mlmodel, 288 | add_clearbox_flavor=True, 289 | preprocessing_subpath=saved_preprocessing_subpath, 290 | data_preparation_subpath=saved_data_preparation_subpath, 291 | ) 292 | elif any("keras" in super_class for super_class in model_super_classes): 293 | save_keras_model( 294 | model, 295 | path, 296 | conda_env=conda_env, 297 | mlmodel=mlmodel, 298 | add_clearbox_flavor=True, 299 | preprocessing_subpath=saved_preprocessing_subpath, 300 | data_preparation_subpath=saved_data_preparation_subpath, 301 | ) 302 | elif any("torch" in super_class for super_class in model_super_classes): 303 | save_pytorch_model( 304 | model, 305 | path, 306 | conda_env=conda_env, 307 | mlmodel=mlmodel, 308 | add_clearbox_flavor=True, 309 | preprocessing_subpath=saved_preprocessing_subpath, 310 | data_preparation_subpath=saved_data_preparation_subpath, 311 | ) 312 | 313 | if preprocessing: 314 | data_preprocessing.save(os.path.join(path, saved_preprocessing_subpath)) 315 | if data_preparation: 316 | preparation.save(os.path.join(path, saved_data_preparation_subpath)) 317 | if zip: 318 | zip_directory(path) 319 | 320 | 321 | def load_model(model_path: str, suppress_warnings: bool = False) -> WrapperModel: 322 | """Load a model that has python_function flavor. 323 | 324 | Parameters 325 | ---------- 326 | model_path : str 327 | Filepath of the model directory. 328 | suppress_warnings : bool, optional 329 | If Fatal, non-fatal warning messages associated with the model loading process 330 | will be emitted, by default True 331 | 332 | Returns 333 | ------- 334 | PyFuncModel 335 | A python_function model. 336 | 337 | Raises 338 | ------ 339 | ClearboxWrapperException 340 | If the model does not have the python_function flavor. 341 | """ 342 | preprocessing = None 343 | data_preparation = None 344 | 345 | mlmodel = Model.load(os.path.join(model_path, MLMODEL_FILE_NAME)) 346 | clearbox_flavor_configuration = mlmodel.flavors.get(FLAVOR_NAME) 347 | 348 | if clearbox_flavor_configuration is None: 349 | raise ClearboxWrapperException( 350 | 'Model does not have the "{flavor_name}" flavor'.format( 351 | flavor_name=FLAVOR_NAME 352 | ) 353 | ) 354 | 355 | model_python_version = clearbox_flavor_configuration.get(PY_VERSION) 356 | 357 | if not suppress_warnings: 358 | _warn_potentially_incompatible_py_version_if_necessary( 359 | model_py_version=model_python_version 360 | ) 361 | 362 | data_path = ( 363 | os.path.join(model_path, clearbox_flavor_configuration[DATA]) 364 | if (DATA in clearbox_flavor_configuration) 365 | else model_path 366 | ) 367 | 368 | model_implementation = importlib.import_module( 369 | clearbox_flavor_configuration[MAIN] 370 | )._load_clearbox(data_path) 371 | 372 | if PREPROCESSING in clearbox_flavor_configuration: 373 | preprocessing_path = os.path.join( 374 | model_path, clearbox_flavor_configuration[PREPROCESSING] 375 | ) 376 | preprocessing = load_serialized_preprocessing(preprocessing_path) 377 | 378 | if DATA_PREPARATION in clearbox_flavor_configuration: 379 | data_preparation_path = os.path.join( 380 | model_path, clearbox_flavor_configuration[DATA_PREPARATION] 381 | ) 382 | data_preparation = load_serialized_data_preparation(data_preparation_path) 383 | 384 | loaded_model = WrapperModel( 385 | model_meta=mlmodel, 386 | model_impl=model_implementation, 387 | preprocessing=preprocessing, 388 | data_preparation=data_preparation, 389 | ) 390 | 391 | return loaded_model 392 | -------------------------------------------------------------------------------- /clearbox_wrapper/xgboost/__init__.py: -------------------------------------------------------------------------------- 1 | from .xgboost import _load_clearbox, _load_pyfunc, save_xgboost_model 2 | 3 | __all__ = [_load_clearbox, _load_pyfunc, save_xgboost_model] 4 | -------------------------------------------------------------------------------- /clearbox_wrapper/xgboost/xgboost.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, Optional, Union 3 | 4 | import yaml 5 | 6 | from clearbox_wrapper.exceptions import ClearboxWrapperException 7 | from clearbox_wrapper.model import MLMODEL_FILE_NAME, Model 8 | import clearbox_wrapper.pyfunc as pyfunc 9 | from clearbox_wrapper.signature import Signature 10 | from clearbox_wrapper.utils import _get_default_conda_env, _get_flavor_configuration 11 | from clearbox_wrapper.wrapper import add_clearbox_flavor_to_model 12 | from clearbox_wrapper.wrapper import FLAVOR_NAME as cb_flavor_name 13 | 14 | 15 | FLAVOR_NAME = "xgboost" 16 | 17 | 18 | def get_default_xgboost_conda_env() -> Dict: 19 | """Generate the default Conda environment for Scikit-Learn models. 20 | 21 | Parameters 22 | ---------- 23 | include_cloudpickle : bool, optional 24 | Whether to include cloudpickle as a environment dependency, by default False. 25 | 26 | Returns 27 | ------- 28 | Dict 29 | The default Conda environment for Scikit-Learn models as a dictionary. 30 | """ 31 | import xgboost as xgb 32 | 33 | pip_deps = ["xgboost=={}".format(xgb.__version__)] 34 | return _get_default_conda_env(additional_pip_deps=pip_deps) 35 | 36 | 37 | def save_xgboost_model( 38 | xgb_model: Any, 39 | path: str, 40 | conda_env: Optional[Union[str, Dict]] = None, 41 | mlmodel: Optional[Model] = None, 42 | signature: Optional[Signature] = None, 43 | add_clearbox_flavor: bool = False, 44 | preprocessing_subpath: str = None, 45 | data_preparation_subpath: str = None, 46 | ): 47 | """Save a Scikit-Learn model. Produces an MLflow Model containing the following flavors: 48 | * wrapper.sklearn 49 | * wrapper.pyfunc. NOTE: This flavor is only included for scikit-learn models 50 | that define at least `predict()`, since `predict()` is required for pyfunc model 51 | inference. 52 | 53 | Parameters 54 | ---------- 55 | sk_model : Any 56 | A Scikit-Learn model to be saved. 57 | path : str 58 | Local path to save the model to. 59 | conda_env : Optional[Union[str, Dict]], optional 60 | A dictionary representation of a Conda environment or the path to a Conda environment 61 | YAML file, by default None. This decsribes the environment this model should be run in. 62 | If None, the default Conda environment will be added to the model. Example of a 63 | dictionary representation of a Conda environment: 64 | { 65 | 'name': 'conda-env', 66 | 'channels': ['defaults'], 67 | 'dependencies': [ 68 | 'python=3.7.0', 69 | 'scikit-learn=0.19.2' 70 | ] 71 | } 72 | serialization_format : str, optional 73 | The format in which to serialize the model. This should be one of the formats listed in 74 | SUPPORTED_SERIALIZATION_FORMATS. Cloudpickle format, SERIALIZATION_FORMAT_CLOUDPICKLE, 75 | provides better cross-system compatibility by identifying and packaging code 76 | dependencies with the serialized model, by default SERIALIZATION_FORMAT_CLOUDPICKLE 77 | signature : Optional[Signature], optional 78 | A model signature describes model input schema. It can be inferred from datasets with 79 | valid model type (e.g. the training dataset with target column omitted), by default None 80 | 81 | Raises 82 | ------ 83 | ClearboxWrapperException 84 | If unrecognized serialization format or model path already exists. 85 | """ 86 | import xgboost as xgb 87 | 88 | if os.path.exists(path): 89 | raise ClearboxWrapperException("Model path '{}' already exists".format(path)) 90 | os.makedirs(path) 91 | 92 | if mlmodel is None: 93 | mlmodel = Model() 94 | if signature is not None: 95 | mlmodel.signature = signature 96 | 97 | model_data_subpath = "model.xgb" 98 | xgb_model.save_model(os.path.join(path, model_data_subpath)) 99 | 100 | conda_env_subpath = "conda.yaml" 101 | if conda_env is None: 102 | conda_env = get_default_xgboost_conda_env() 103 | elif not isinstance(conda_env, dict): 104 | with open(conda_env, "r") as f: 105 | conda_env = yaml.safe_load(f) 106 | 107 | with open(os.path.join(path, conda_env_subpath), "w") as f: 108 | yaml.safe_dump(conda_env, stream=f, default_flow_style=False) 109 | 110 | pyfunc.add_pyfunc_flavor_to_model( 111 | mlmodel, 112 | loader_module="clearbox_wrapper.xgboost", 113 | model_path=model_data_subpath, 114 | env=conda_env_subpath, 115 | ) 116 | 117 | if add_clearbox_flavor: 118 | add_clearbox_flavor_to_model( 119 | mlmodel, 120 | loader_module="clearbox_wrapper.xgboost", 121 | model_path=model_data_subpath, 122 | env=conda_env_subpath, 123 | preprocessing=preprocessing_subpath, 124 | data_preparation=data_preparation_subpath, 125 | ) 126 | 127 | mlmodel.add_flavor( 128 | FLAVOR_NAME, 129 | model_path=model_data_subpath, 130 | sklearn_version=xgb.__version__, 131 | data=model_data_subpath, 132 | ) 133 | 134 | mlmodel.save(os.path.join(path, MLMODEL_FILE_NAME)) 135 | 136 | 137 | def _load_model(model_path): 138 | import xgboost as xgb 139 | 140 | model = xgb.Booster() 141 | model.load_model(model_path) 142 | return model 143 | 144 | 145 | def _load_pyfunc(path): 146 | """ 147 | Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``. 148 | 149 | :param path: Local filesystem path to the MLflow Model with the ``xgboost`` flavor. 150 | """ 151 | return _XGBModelWrapper(_load_model(path)) 152 | 153 | 154 | def _load_clearbox(model_path: str): 155 | """ 156 | Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``. 157 | 158 | :param path: Local filesystem path to the MLflow Model with the ``xgboost`` flavor. 159 | """ 160 | clearbox_flavor_conf = _get_flavor_configuration( 161 | model_path=model_path, flavor_name=cb_flavor_name 162 | ) 163 | serialized_model_path = os.path.join(model_path, clearbox_flavor_conf["model_path"]) 164 | return _XGBModelWrapper(_load_model(serialized_model_path)) 165 | 166 | 167 | class _XGBModelWrapper: 168 | def __init__(self, xgb_model): 169 | self.xgb_model = xgb_model 170 | 171 | def predict(self, dataframe): 172 | import xgboost as xgb 173 | 174 | return self.xgb_model.predict(xgb.DMatrix(dataframe)) 175 | 176 | def predict_proba(self, dataframe): 177 | if not hasattr(self.xgb_model, "predict_proba"): 178 | raise ClearboxWrapperException("This model has no predict_proba method.") 179 | import xgboost as xgb 180 | 181 | return self.xgb_model.predict_proba(xgb.DMatrix(dataframe)) 182 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/conftest.py -------------------------------------------------------------------------------- /docs/images/clearbox_ai_wrapper_no_preprocessing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/docs/images/clearbox_ai_wrapper_no_preprocessing.png -------------------------------------------------------------------------------- /docs/images/clearbox_ai_wrapper_preprocessing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/docs/images/clearbox_ai_wrapper_preprocessing.png -------------------------------------------------------------------------------- /docs/images/clearbox_ai_wrapper_preprocessing_data_preparation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/docs/images/clearbox_ai_wrapper_preprocessing_data_preparation.png -------------------------------------------------------------------------------- /examples/1_iris_sklearn/1_Clearbox_Wrapper_Iris_Scikit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Clearbox Wrapper Tutorial\n", 8 | "\n", 9 | "Clearbox Wrapper is a Python library to package and save a ML model.\n", 10 | "\n", 11 | "This is the simplest case possible: we'll wrap a Scikit-Learn model trained on the popular Iris Dataset. The dataset contains only ordinal values and we'll use all columns, so we do not need either **preprocessing** or **data preparation** for the X. We'll just use a simple LabelEncoder to encode the y strings to numerical values (0, 1, 2), but the LabelEncoder doesn't need to be saved together with the model.\n", 12 | "\n", 13 | "## Install and import required libraries" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": { 20 | "scrolled": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "%%capture\n", 25 | "!pip install pandas\n", 26 | "!pip install numpy\n", 27 | "!pip install scikit-learn\n", 28 | "\n", 29 | "!pip install clearbox-wrapper==0.3.10" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "import pandas as pd\n", 39 | "import numpy as np\n", 40 | "\n", 41 | "from sklearn.preprocessing import LabelEncoder\n", 42 | "from sklearn.tree import DecisionTreeClassifier\n", 43 | "import sklearn.metrics as metrics\n", 44 | "\n", 45 | "import clearbox_wrapper as cbw" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "## Datasets\n", 53 | "\n", 54 | "We have two different csv files for the training and test set." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 3, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "iris_training_csv_path = 'iris_training_set.csv'\n", 64 | "iris_test_csv_path = 'iris_test_set.csv'" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "iris_training = pd.read_csv(iris_training_csv_path)\n", 74 | "iris_test = pd.read_csv(iris_test_csv_path)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 5, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "target_column = 'species'" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 6, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "y_train = iris_training[target_column]\n", 93 | "X_train = iris_training.drop(target_column, axis=1)" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 7, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "y_test = iris_test[target_column]\n", 103 | "X_test = iris_test.drop(target_column, axis=1)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 8, 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "output_type": "stream", 113 | "name": "stdout", 114 | "text": [ 115 | "\nRangeIndex: 120 entries, 0 to 119\nData columns (total 4 columns):\n # Column Non-Null Count Dtype \n--- ------ -------------- ----- \n 0 sepal_length 120 non-null float64\n 1 sepal_width 120 non-null float64\n 2 petal_length 120 non-null float64\n 3 petal_width 120 non-null float64\ndtypes: float64(4)\nmemory usage: 3.9 KB\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "X_train.info()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 9, 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "output_type": "stream", 130 | "name": "stdout", 131 | "text": [ 132 | "\nRangeIndex: 30 entries, 0 to 29\nData columns (total 4 columns):\n # Column Non-Null Count Dtype \n--- ------ -------------- ----- \n 0 sepal_length 30 non-null float64\n 1 sepal_width 30 non-null float64\n 2 petal_length 30 non-null float64\n 3 petal_width 30 non-null float64\ndtypes: float64(4)\nmemory usage: 1.1 KB\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "X_test.info()" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "We create a simple LabelEncoder for the y series:" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 10, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "y_encoder = LabelEncoder()" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "We fit the LabelEncoder on the y of the training set and we get the encoded y for both the datasets:" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 11, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "y_train = y_encoder.fit_transform(y_train)\n", 170 | "y_test = y_encoder.transform(y_test)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "## Create and train the model\n", 178 | "\n", 179 | "We build a simple Sklearn Decision Tree classifier setting some basic parameters..." 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 12, 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "tree_clf = DecisionTreeClassifier(max_depth=4, random_state=42)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "...and fit on the training dataset:" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 13, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "output_type": "execute_result", 205 | "data": { 206 | "text/plain": [ 207 | "DecisionTreeClassifier(max_depth=4, random_state=42)" 208 | ] 209 | }, 210 | "metadata": {}, 211 | "execution_count": 13 212 | } 213 | ], 214 | "source": [ 215 | "tree_clf.fit(X_train, y_train)" 216 | ] 217 | }, 218 | { 219 | "source": [ 220 | "We show some metrics on the training set..." 221 | ], 222 | "cell_type": "markdown", 223 | "metadata": {} 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 14, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "training_predictions = tree_clf.predict(X_train)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 15, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "output_type": "stream", 241 | "name": "stdout", 242 | "text": [ 243 | "Training Metrics Report:\n precision recall f1-score support\n\n 0 1.00 1.00 1.00 43\n 1 1.00 1.00 1.00 39\n 2 1.00 1.00 1.00 38\n\n accuracy 1.00 120\n macro avg 1.00 1.00 1.00 120\nweighted avg 1.00 1.00 1.00 120\n\n" 244 | ] 245 | } 246 | ], 247 | "source": [ 248 | "print('Training Metrics Report:\\n', metrics.classification_report(y_train, training_predictions))" 249 | ] 250 | }, 251 | { 252 | "source": [ 253 | "...and on the test set:" 254 | ], 255 | "cell_type": "markdown", 256 | "metadata": {} 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 16, 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "test_predictions = tree_clf.predict(X_test)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 17, 270 | "metadata": {}, 271 | "outputs": [ 272 | { 273 | "output_type": "stream", 274 | "name": "stdout", 275 | "text": [ 276 | "Test Metrics Report:\n precision recall f1-score support\n\n 0 1.00 1.00 1.00 7\n 1 0.90 0.82 0.86 11\n 2 0.85 0.92 0.88 12\n\n accuracy 0.90 30\n macro avg 0.92 0.91 0.91 30\nweighted avg 0.90 0.90 0.90 30\n\n" 277 | ] 278 | } 279 | ], 280 | "source": [ 281 | "print('Test Metrics Report:\\n', metrics.classification_report(y_test, test_predictions))" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "## Wrap and Save the Model\n", 289 | "\n", 290 | "Finally, we use Clearbox Wrapper to wrap and save the model as a zipped folder in a specified path. The only dependency required for this model is Scikit Learn, but it is detected automatically by CBW and added to the requirements saved into the resulting folder. We pass the training dataset to `save_model` in order to generate a Model Signature (the signature represents model input as data frames with (optionally) named columns and data type)." 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 18, 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "wrapped_model_path = 'iris_wrapped_model_v0.3.10'" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": 19, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "cbw.save_model(wrapped_model_path, tree_clf, input_data=X_train)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "## Unzip and load the model\n", 316 | "\n", 317 | "The following cells are not necessary for the final users, the zip created should be uploaded to our SAAS as it is. But here we want to show how to load a saved model and compare it to the original one.\n", 318 | "\n", 319 | "**IMPORTANT**: To assure reproducibility and avoid loading errors, it is necessary to load the wrapped model with the same Python version with which the model was saved." 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 20, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "import zipfile" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 21, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "zipped_model_path = 'iris_wrapped_model_v0.3.10.zip'\n", 338 | "unzipped_model_path = 'iris_wrapped_model_v0.3.10_unzipped'" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 22, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "with zipfile.ZipFile(zipped_model_path, 'r') as zip_ref:\n", 348 | " zip_ref.extractall(unzipped_model_path)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 23, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "loaded_model = cbw.load_model(unzipped_model_path)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 24, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "original_model_predictions = tree_clf.predict_proba(X_test)" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 25, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "loaded_model_predictions = loaded_model.predict_proba(X_test)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": 26, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions)" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "## Remove all generated files and directory" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "import os\n", 401 | "import shutil" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "if os.path.exists(zipped_model_path):\n", 411 | " os.remove(zipped_model_path)" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "if os.path.exists(unzipped_model_path):\n", 421 | " shutil.rmtree(unzipped_model_path)" 422 | ] 423 | } 424 | ], 425 | "metadata": { 426 | "kernelspec": { 427 | "display_name": "Python 3.8.5 64-bit ('.venv': venv)", 428 | "language": "python", 429 | "name": "python38564bitvenvvenv859ca4a018fa42de98d84719ac267976" 430 | }, 431 | "language_info": { 432 | "codemirror_mode": { 433 | "name": "ipython", 434 | "version": 3 435 | }, 436 | "file_extension": ".py", 437 | "mimetype": "text/x-python", 438 | "name": "python", 439 | "nbconvert_exporter": "python", 440 | "pygments_lexer": "ipython3", 441 | "version": "3.8.5-final" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 4 446 | } 447 | -------------------------------------------------------------------------------- /examples/1_iris_sklearn/iris_test_set.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,species 2 | 6,2.2,5,1.5,virginica 3 | 6.4,3.1,5.5,1.8,virginica 4 | 6.9,3.1,4.9,1.5,versicolor 5 | 6.3,2.7,4.9,1.8,virginica 6 | 5.5,2.4,3.8,1.1,versicolor 7 | 6.1,2.8,4,1.3,versicolor 8 | 6,2.7,5.1,1.6,versicolor 9 | 6.7,3.1,5.6,2.4,virginica 10 | 6.5,2.8,4.6,1.5,versicolor 11 | 6.2,2.2,4.5,1.5,versicolor 12 | 5.2,3.5,1.5,0.2,setosa 13 | 5,2,3.5,1,versicolor 14 | 5.7,2.6,3.5,1,versicolor 15 | 7.7,3.8,6.7,2.2,virginica 16 | 5.4,3.4,1.5,0.4,setosa 17 | 5.7,2.8,4.1,1.3,versicolor 18 | 5.3,3.7,1.5,0.2,setosa 19 | 6.4,3.2,5.3,2.3,virginica 20 | 5.9,3,5.1,1.8,virginica 21 | 6.5,3,5.8,2.2,virginica 22 | 5,3.5,1.3,0.3,setosa 23 | 5.4,3.4,1.7,0.2,setosa 24 | 6.4,2.7,5.3,1.9,virginica 25 | 6.5,3,5.2,2,virginica 26 | 6.9,3.1,5.1,2.3,virginica 27 | 6.7,3,5,1.7,versicolor 28 | 5.6,3,4.1,1.3,versicolor 29 | 6.1,2.6,5.6,1.4,virginica 30 | 5.1,3.5,1.4,0.3,setosa 31 | 5.1,3.7,1.5,0.4,setosa 32 | -------------------------------------------------------------------------------- /examples/1_iris_sklearn/iris_training_set.csv: -------------------------------------------------------------------------------- 1 | sepal_length,sepal_width,petal_length,petal_width,species 2 | 7.9,3.8,6.4,2,virginica 3 | 4.7,3.2,1.6,0.2,setosa 4 | 5.5,2.6,4.4,1.2,versicolor 5 | 6,2.9,4.5,1.5,versicolor 6 | 5.4,3.9,1.3,0.4,setosa 7 | 6.6,3,4.4,1.4,versicolor 8 | 5.6,2.9,3.6,1.3,versicolor 9 | 5,3.4,1.6,0.4,setosa 10 | 5.5,2.5,4,1.3,versicolor 11 | 6.1,3,4.6,1.4,versicolor 12 | 6.6,2.9,4.6,1.3,versicolor 13 | 5.2,4.1,1.5,0.1,setosa 14 | 7.2,3.6,6.1,2.5,virginica 15 | 5.1,3.8,1.5,0.3,setosa 16 | 7.6,3,6.6,2.1,virginica 17 | 5.8,2.8,5.1,2.4,virginica 18 | 6.4,3.2,4.5,1.5,versicolor 19 | 5.6,2.5,3.9,1.1,versicolor 20 | 7.2,3,5.8,1.6,virginica 21 | 6.7,3.1,4.4,1.4,versicolor 22 | 6.8,3,5.5,2.1,virginica 23 | 6.4,2.9,4.3,1.3,versicolor 24 | 7.7,2.8,6.7,2,virginica 25 | 4.4,2.9,1.4,0.2,setosa 26 | 4.9,2.4,3.3,1,versicolor 27 | 5.6,3,4.5,1.5,versicolor 28 | 5.6,2.8,4.9,2,virginica 29 | 4.6,3.4,1.4,0.3,setosa 30 | 4.9,3,1.4,0.2,setosa 31 | 5.5,3.5,1.3,0.2,setosa 32 | 5,2.3,3.3,1,versicolor 33 | 5.5,2.4,3.7,1,versicolor 34 | 6.3,2.9,5.6,1.8,virginica 35 | 5,3,1.6,0.2,setosa 36 | 6.7,3.3,5.7,2.1,virginica 37 | 5.9,3.2,4.8,1.8,versicolor 38 | 7.7,2.6,6.9,2.3,virginica 39 | 5.5,4.2,1.4,0.2,setosa 40 | 5,3.2,1.2,0.2,setosa 41 | 6.3,2.3,4.4,1.3,versicolor 42 | 5.9,3,4.2,1.5,versicolor 43 | 5.7,4.4,1.5,0.4,setosa 44 | 6.1,2.9,4.7,1.4,versicolor 45 | 5.7,2.9,4.2,1.3,versicolor 46 | 5.7,2.5,5,2,virginica 47 | 6.9,3.2,5.7,2.3,virginica 48 | 6.5,3,5.5,1.8,virginica 49 | 5.7,2.8,4.5,1.3,versicolor 50 | 6.3,2.5,5,1.9,virginica 51 | 4.3,3,1.1,0.1,setosa 52 | 6.4,2.8,5.6,2.1,virginica 53 | 4.8,3,1.4,0.1,setosa 54 | 6.2,3.4,5.4,2.3,virginica 55 | 4.9,3.1,1.5,0.1,setosa 56 | 6.7,3,5.2,2.3,virginica 57 | 6.9,3.1,5.4,2.1,virginica 58 | 4.8,3.4,1.6,0.2,setosa 59 | 6.3,2.8,5.1,1.5,virginica 60 | 7.4,2.8,6.1,1.9,virginica 61 | 5.1,3.4,1.5,0.2,setosa 62 | 6.2,2.8,4.8,1.8,virginica 63 | 4.4,3,1.3,0.2,setosa 64 | 5,3.5,1.6,0.6,setosa 65 | 5,3.4,1.5,0.2,setosa 66 | 5,3.3,1.4,0.2,setosa 67 | 4.8,3.1,1.6,0.2,setosa 68 | 7.7,3,6.1,2.3,virginica 69 | 5.8,2.7,4.1,1,versicolor 70 | 4.6,3.1,1.5,0.2,setosa 71 | 6.3,3.3,4.7,1.6,versicolor 72 | 4.9,3.1,1.5,0.1,setosa 73 | 6.7,2.5,5.8,1.8,virginica 74 | 6.7,3.3,5.7,2.5,virginica 75 | 5,3.6,1.4,0.2,setosa 76 | 5.8,4,1.2,0.2,setosa 77 | 5.6,2.7,4.2,1.3,versicolor 78 | 6,2.2,4,1,versicolor 79 | 6.8,3.2,5.9,2.3,virginica 80 | 6.8,2.8,4.8,1.4,versicolor 81 | 4.8,3.4,1.9,0.2,setosa 82 | 5.1,3.5,1.4,0.2,setosa 83 | 4.6,3.6,1,0.2,setosa 84 | 4.8,3,1.4,0.3,setosa 85 | 4.9,3.1,1.5,0.1,setosa 86 | 5.8,2.7,5.1,1.9,virginica 87 | 6,3,4.8,1.8,virginica 88 | 5.5,2.3,4,1.3,versicolor 89 | 5.7,3.8,1.7,0.3,setosa 90 | 6,3.4,4.5,1.6,versicolor 91 | 5.8,2.7,5.1,1.9,virginica 92 | 6.5,3.2,5.1,2,virginica 93 | 7,3.2,4.7,1.4,versicolor 94 | 4.6,3.2,1.4,0.2,setosa 95 | 6.2,2.9,4.3,1.3,versicolor 96 | 4.5,2.3,1.3,0.3,setosa 97 | 6.7,3.1,4.7,1.5,versicolor 98 | 5.7,3,4.2,1.2,versicolor 99 | 4.4,3.2,1.3,0.2,setosa 100 | 5.1,3.8,1.6,0.2,setosa 101 | 6.3,3.4,5.6,2.4,virginica 102 | 6.4,2.8,5.6,2.2,virginica 103 | 5.4,3,4.5,1.5,versicolor 104 | 5.2,2.7,3.9,1.4,versicolor 105 | 5.4,3.7,1.5,0.2,setosa 106 | 4.9,2.5,4.5,1.7,virginica 107 | 6.3,3.3,6,2.5,virginica 108 | 7.3,2.9,6.3,1.8,virginica 109 | 7.1,3,5.9,2.1,virginica 110 | 5.4,3.9,1.7,0.4,setosa 111 | 4.7,3.2,1.3,0.2,setosa 112 | 5.1,2.5,3,1.1,versicolor 113 | 5.8,2.7,3.9,1.2,versicolor 114 | 6.1,3,4.9,1.8,virginica 115 | 5.2,3.4,1.4,0.2,setosa 116 | 6.3,2.5,4.9,1.5,versicolor 117 | 5.1,3.3,1.7,0.5,setosa 118 | 6.1,2.8,4.7,1.2,versicolor 119 | 5.8,2.6,4,1.2,versicolor 120 | 5.1,3.8,1.9,0.4,setosa 121 | 7.2,3.2,6,1.8,virginica 122 | -------------------------------------------------------------------------------- /examples/1_iris_sklearn/iris_wrapped_model_v0.3.10.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/examples/1_iris_sklearn/iris_wrapped_model_v0.3.10.zip -------------------------------------------------------------------------------- /examples/2_loans_preprocessing_xgboost/loans_xgboost_wrapped_model_v0.3.10.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/examples/2_loans_preprocessing_xgboost/loans_xgboost_wrapped_model_v0.3.10.zip -------------------------------------------------------------------------------- /examples/3_boston_preprocessing_pytorch/boston_test_set.csv: -------------------------------------------------------------------------------- 1 | CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,MEDV 2 | 0.09178,0,4.05,0,0.51,6.416,84.1,2.6463,5,296,16.6,395.5,9.04,23.6 3 | 0.05644,40,6.41,1,0.447,6.758,32.9,4.0776,4,254,17.6,396.9,3.53,32.4 4 | 0.10574,0,27.74,0,0.609,5.983,98.8,1.8681,4,711,20.1,390.11,18.07,13.6 5 | 0.09164,0,10.81,0,0.413,6.065,7.8,5.2873,4,305,19.2,390.91,5.52,22.8 6 | 5.09017,0,18.1,0,0.713,6.297,91.8,2.3682,24,666,20.2,385.09,17.27,16.1 7 | 0.10153,0,12.83,0,0.437,6.279,74.5,4.0522,5,398,18.7,373.66,11.97,20 8 | 0.31827,0,9.9,0,0.544,5.914,83.2,3.9986,4,304,18.4,390.7,18.33,17.8 9 | 0.2909,0,21.89,0,0.624,6.174,93.6,1.6119,4,437,21.2,388.08,24.16,14 10 | 4.03841,0,18.1,0,0.532,6.229,90.7,3.0993,24,666,20.2,395.33,12.87,19.6 11 | 0.22438,0,9.69,0,0.585,6.027,79.7,2.4982,6,391,19.2,396.9,14.33,16.8 12 | 0.11069,0,13.89,1,0.55,5.951,93.8,2.8893,5,276,16.4,396.9,17.92,21.5 13 | 0.17004,12.5,7.87,0,0.524,6.004,85.9,6.5921,5,311,15.2,386.71,17.1,18.9 14 | 45.7461,0,18.1,0,0.693,4.519,100,1.6582,24,666,20.2,88.27,36.98,7 15 | 0.05646,0,12.83,0,0.437,6.232,53.7,5.0141,5,398,18.7,386.4,12.34,21.2 16 | 0.28392,0,7.38,0,0.493,5.708,74.3,4.7211,5,287,19.6,391.13,11.74,18.5 17 | 4.64689,0,18.1,0,0.614,6.98,67.6,2.5329,24,666,20.2,374.68,11.66,29.8 18 | 0.09849,0,25.65,0,0.581,5.879,95.8,2.0063,2,188,19.1,379.38,17.58,18.8 19 | 14.3337,0,18.1,0,0.7,4.88,100,1.5895,24,666,20.2,372.92,30.62,10.2 20 | 0.01381,80,0.46,0,0.422,7.875,32,5.6484,4,255,14.4,394.23,2.97,50 21 | 9.32909,0,18.1,0,0.713,6.185,98.7,2.2616,24,666,20.2,396.9,18.13,14.1 22 | 0.16211,20,6.96,0,0.464,6.24,16.3,4.429,3,223,18.6,396.9,6.59,25.2 23 | 0.07978,40,6.41,0,0.447,6.482,32.1,4.1403,4,254,17.6,396.9,7.19,29.1 24 | 1.13081,0,8.14,0,0.538,5.713,94.1,4.233,4,307,21,360.17,22.6,12.7 25 | 0.06263,0,11.93,0,0.573,6.593,69.1,2.4786,1,273,21,391.99,9.67,22.4 26 | 7.02259,0,18.1,0,0.718,6.006,95.3,1.8746,24,666,20.2,319.98,15.7,14.2 27 | 8.05579,0,18.1,0,0.584,5.427,95.4,2.4298,24,666,20.2,352.58,18.14,13.8 28 | 0.08387,0,12.83,0,0.437,5.874,36.6,4.5026,5,398,18.7,396.06,9.1,20.3 29 | 9.51363,0,18.1,0,0.713,6.728,94.1,2.4961,24,666,20.2,6.68,18.71,14.9 30 | 0.17446,0,10.59,1,0.489,5.96,92.1,3.8771,4,277,18.6,393.25,17.27,21.7 31 | 0.26838,0,9.69,0,0.585,5.794,70.6,2.8927,6,391,19.2,396.9,14.1,18.3 32 | 0.13914,0,4.05,0,0.51,5.572,88.5,2.5961,5,296,16.6,396.9,14.69,23.1 33 | 0.1676,0,7.38,0,0.493,6.426,52.3,4.5404,5,287,19.6,396.9,7.2,23.8 34 | 19.6091,0,18.1,0,0.671,7.313,97.9,1.3163,24,666,20.2,396.9,13.44,15 35 | 3.67822,0,18.1,0,0.77,5.362,96.2,2.1036,24,666,20.2,380.79,10.19,20.8 36 | 4.42228,0,18.1,0,0.584,6.003,94.5,2.5403,24,666,20.2,331.29,21.32,19.1 37 | 2.14918,0,19.58,0,0.871,5.709,98.5,1.6232,5,403,14.7,261.95,15.79,19.4 38 | 0.02729,0,7.07,0,0.469,7.185,61.1,4.9671,2,242,17.8,392.83,4.03,34.7 39 | 0.03427,0,5.19,0,0.515,5.869,46.3,5.2311,5,224,20.2,396.9,9.8,19.5 40 | 0.13587,0,10.59,1,0.489,6.064,59.1,4.2392,4,277,18.6,381.32,14.66,24.4 41 | 0.19539,0,10.81,0,0.413,6.245,6.2,5.2873,4,305,19.2,377.17,7.54,23.4 42 | 0.2896,0,9.69,0,0.585,5.39,72.9,2.7986,6,391,19.2,396.9,21.14,19.7 43 | 0.04932,33,2.18,0,0.472,6.849,70.3,3.1827,7,222,18.4,396.9,7.53,28.2 44 | 0.02009,95,2.68,0,0.4161,8.034,31.9,5.118,4,224,14.7,390.55,2.88,50 45 | 0.13554,12.5,6.07,0,0.409,5.594,36.8,6.498,4,345,18.9,396.9,13.09,17.4 46 | 0.04684,0,3.41,0,0.489,6.417,66.1,3.0923,2,270,17.8,392.18,8.81,22.6 47 | 6.96215,0,18.1,0,0.7,5.713,97,1.9265,24,666,20.2,394.43,17.11,15.1 48 | 1.15172,0,8.14,0,0.538,5.701,95,3.7872,4,307,21,358.77,18.35,13.1 49 | 0.08826,0,10.81,0,0.413,6.417,6.6,5.2873,4,305,19.2,383.73,6.72,24.2 50 | 4.34879,0,18.1,0,0.58,6.167,84,3.0334,24,666,20.2,396.9,16.29,19.9 51 | 0.00632,18,2.31,0,0.538,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24 52 | 0.11747,12.5,7.87,0,0.524,6.009,82.9,6.2267,5,311,15.2,396.9,13.27,18.9 53 | 0.03705,20,3.33,0,0.4429,6.968,37.2,5.2447,5,216,14.9,392.23,4.59,35.4 54 | 1.23247,0,8.14,0,0.538,6.142,91.7,3.9769,4,307,21,396.9,18.72,15.2 55 | 0.11432,0,8.56,0,0.52,6.781,71.3,2.8561,5,384,20.9,395.58,7.67,26.5 56 | 0.5405,20,3.97,0,0.575,7.47,52.6,2.872,5,264,13,390.3,3.16,43.5 57 | 3.67367,0,18.1,0,0.583,6.312,51.9,3.9917,24,666,20.2,388.62,10.58,21.2 58 | 5.66637,0,18.1,0,0.74,6.219,100,2.0048,24,666,20.2,395.69,16.59,18.4 59 | 0.03502,80,4.95,0,0.411,6.861,27.9,5.1167,4,245,19.2,396.9,3.33,28.5 60 | 0.05059,0,4.49,0,0.449,6.389,48,4.7794,3,247,18.5,396.9,9.62,23.9 61 | 0.19133,22,5.86,0,0.431,5.605,70.2,7.9549,7,330,19.1,389.13,18.46,18.5 62 | 0.1265,25,5.13,0,0.453,6.762,43.4,7.9809,8,284,19.7,395.58,9.5,25 63 | 0.01311,90,1.22,0,0.403,7.249,21.9,8.6966,5,226,17.9,395.93,4.81,35.4 64 | 0.44178,0,6.2,0,0.504,6.552,21.4,3.3751,8,307,17.4,380.34,3.76,31.5 65 | 0.80271,0,8.14,0,0.538,5.456,36.6,3.7965,4,307,21,288.99,11.69,20.2 66 | 0.0795,60,1.69,0,0.411,6.579,35.9,10.7103,4,411,18.3,370.78,5.49,24.1 67 | 0.43571,0,10.59,1,0.489,5.344,100,3.875,4,277,18.6,396.9,23.09,20 68 | 8.71675,0,18.1,0,0.693,6.471,98.8,1.7257,24,666,20.2,391.98,17.12,13.1 69 | 0.03659,25,4.86,0,0.426,6.302,32.2,5.4007,4,281,19,396.9,6.72,24.8 70 | 0.02763,75,2.95,0,0.428,6.595,21.8,5.4011,3,252,18.3,395.63,4.32,30.8 71 | 4.66883,0,18.1,0,0.713,5.976,87.9,2.5806,24,666,20.2,10.48,19.01,12.7 72 | 0.18836,0,6.91,0,0.448,5.786,33.3,5.1004,3,233,17.9,396.9,14.15,20 73 | 5.70818,0,18.1,0,0.532,6.75,74.9,3.3317,24,666,20.2,393.07,7.74,23.7 74 | 12.8023,0,18.1,0,0.74,5.854,96.6,1.8956,24,666,20.2,240.52,23.79,10.8 75 | 0.10659,80,1.91,0,0.413,5.936,19.5,10.5857,4,334,22,376.04,5.57,20.6 76 | 0.08707,0,12.83,0,0.437,6.14,45.8,4.0905,5,398,18.7,386.96,10.27,20.8 77 | 38.3518,0,18.1,0,0.693,5.453,100,1.4896,24,666,20.2,396.9,30.59,5 78 | 0.1396,0,8.56,0,0.52,6.167,90,2.421,5,384,20.9,392.69,12.33,20.1 79 | 0.0351,95,2.68,0,0.4161,7.853,33.2,5.118,4,224,14.7,392.78,3.81,48.5 80 | 15.8744,0,18.1,0,0.671,6.545,99.1,1.5192,24,666,20.2,396.9,21.08,10.9 81 | 0.18337,0,27.74,0,0.609,5.414,98.3,1.7554,4,711,20.1,344.05,23.97,7 82 | 0.12816,12.5,6.07,0,0.409,5.885,33,6.498,4,345,18.9,396.9,8.79,20.9 83 | 7.40389,0,18.1,0,0.597,5.617,97.9,1.4547,24,666,20.2,314.64,26.4,17.2 84 | 0.03548,80,3.64,0,0.392,5.876,19.1,9.2203,1,315,16.4,395.18,9.25,20.9 85 | 11.5779,0,18.1,0,0.7,5.036,97,1.77,24,666,20.2,396.9,25.68,9.7 86 | 0.26169,0,9.9,0,0.544,6.023,90.4,2.834,4,304,18.4,396.3,11.72,19.4 87 | 0.44791,0,6.2,1,0.507,6.726,66.5,3.6519,8,307,17.4,360.2,8.05,29 88 | 4.81213,0,18.1,0,0.713,6.701,90,2.5975,24,666,20.2,255.23,16.42,16.4 89 | 0.34109,0,7.38,0,0.493,6.415,40.1,4.7211,5,287,19.6,396.9,6.12,25 90 | 0.02875,28,15.04,0,0.464,6.211,28.9,3.6659,4,270,18.2,396.33,6.21,25 91 | 0.35233,0,21.89,0,0.624,6.454,98.4,1.8498,4,437,21.2,394.08,14.59,17.1 92 | 0.07022,0,4.05,0,0.51,6.02,47.2,3.5549,5,296,16.6,393.23,10.11,23.2 93 | 25.9406,0,18.1,0,0.679,5.304,89.1,1.6475,24,666,20.2,127.36,26.64,10.4 94 | 1.19294,0,21.89,0,0.624,6.326,97.7,2.271,4,437,21.2,396.9,12.26,19.6 95 | 0.06162,0,4.39,0,0.442,5.898,52.3,8.0136,3,352,18.8,364.61,12.67,17.2 96 | 4.55587,0,18.1,0,0.718,3.561,87.9,1.6132,24,666,20.2,354.7,7.12,27.5 97 | 0.59005,0,21.89,0,0.624,6.372,97.9,2.3274,4,437,21.2,385.76,11.12,23 98 | 9.2323,0,18.1,0,0.631,6.216,100,1.1691,24,666,20.2,366.15,9.53,50 99 | 18.811,0,18.1,0,0.597,4.628,100,1.5539,24,666,20.2,28.79,34.37,17.9 100 | 14.4208,0,18.1,0,0.74,6.461,93.3,2.0026,24,666,20.2,27.49,18.05,9.6 101 | 14.0507,0,18.1,0,0.597,6.657,100,1.5275,24,666,20.2,35.05,21.22,17.2 102 | 0.05188,0,4.49,0,0.449,6.015,45.1,4.4272,3,247,18.5,395.99,12.86,22.5 103 | 0.09512,0,12.83,0,0.437,6.286,45,4.5026,5,398,18.7,383.23,8.94,21.4 104 | -------------------------------------------------------------------------------- /examples/3_boston_preprocessing_pytorch/boston_wrapped_model_v0.3.10.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/examples/3_boston_preprocessing_pytorch/boston_wrapped_model_v0.3.10.zip -------------------------------------------------------------------------------- /examples/4_adult_data_cleaning_preprocessing_keras/adult_wrapped_model_preparation_preprocessing_v0.3.10.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/examples/4_adult_data_cleaning_preprocessing_keras/adult_wrapped_model_preparation_preprocessing_v0.3.10.zip -------------------------------------------------------------------------------- /examples/5_hospital_preprocessing_pytorch/hospital_wrapped_model_v0.3.10.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/examples/5_hospital_preprocessing_pytorch/hospital_wrapped_model_v0.3.10.zip -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | 4 | [mypy-nox.*,pytest] 5 | ignore_missing_imports = True 6 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import nox 4 | 5 | # nox.options.sessions = "lint", "mypy", "tests" 6 | nox.options.sessions = "lint", "tests" 7 | locations = "clearbox_wrapper", "tests", "noxfile.py" 8 | 9 | 10 | def install_with_constraints(session, *args, **kwargs): 11 | with tempfile.NamedTemporaryFile() as requirements: 12 | session.run( 13 | "poetry", 14 | "export", 15 | "--dev", 16 | "--format=requirements.txt", 17 | "--without-hashes", 18 | f"--output={requirements.name}", 19 | external=True, 20 | ) 21 | session.install(f"--constraint={requirements.name}", *args, **kwargs) 22 | 23 | 24 | @nox.session(python=["3.8.5"]) 25 | def black(session): 26 | args = session.posargs or locations 27 | install_with_constraints(session, "black") 28 | session.run("black", *args) 29 | 30 | 31 | @nox.session(python=["3.8.5"]) 32 | def lint(session): 33 | args = session.posargs or locations 34 | install_with_constraints( 35 | session, 36 | "flake8", 37 | "flake8-black", 38 | "flake8-bugbear", 39 | "flake8-import-order", 40 | ) 41 | session.run("flake8", *args) 42 | 43 | 44 | @nox.session(python=["3.8.5"]) 45 | def mypy(session): 46 | args = session.posargs or locations 47 | install_with_constraints(session, "mypy") 48 | session.run("mypy", *args) 49 | 50 | 51 | @nox.session(python=["3.6.2", "3.7.0", "3.8.0", "3.8.5"]) 52 | def tests(session): 53 | args = session.posargs or ["--cov"] 54 | session.run("poetry", "install", external=True) 55 | session.run("pytest", *args) 56 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "clearbox-wrapper" 3 | version = "0.3.11" 4 | description = "An agnostic wrapper for the most common frameworks of ML models." 5 | authors = ["Clearbox AI "] 6 | license = "Apache-2.0" 7 | readme = "README.md" 8 | homepage = "https://clearbox.ai" 9 | repository = "https://github.com/ClearBox-AI/clearbox-wrapper" 10 | keywords = ["ML wrapper", "machine learning"] 11 | 12 | [tool.poetry.dependencies] 13 | python = "^3.6.2" 14 | loguru = "^0.5.3" 15 | pandas = "^1" 16 | numpy = "^1.16.0" 17 | cloudpickle = "^1.6.0" 18 | PyYAML = "^5.4.1" 19 | 20 | [tool.poetry.dev-dependencies] 21 | flake8 = "^3.8.4" 22 | pytest = "^6.1.1" 23 | coverage = {extras = ["toml"], version = "^5.3"} 24 | pytest-cov = "^2.10.1" 25 | scikit-learn = "^0.23.2" 26 | black = "^20.8b1" 27 | pytest-lazy-fixture = "^0.6.3" 28 | PyYAML = "^5.3.1" 29 | xgboost = "^1.2.1" 30 | tensorflow = "^2.3.1" 31 | torch = "^1.7.0" 32 | torchvision = "^0.8.1" 33 | flake8-black = "^0.2.1" 34 | flake8-bugbear = "^20.11.1" 35 | flake8-import-order = "^0.18.1" 36 | mypy = "^0.790" 37 | 38 | [tool.coverage.paths] 39 | source = ["clearbox_wrapper", "*/site-packages"] 40 | 41 | [tool.coverage.run] 42 | branch = true 43 | source = ["clearbox_wrapper"] 44 | 45 | [tool.coverage.report] 46 | show_missing = true 47 | 48 | [build-system] 49 | requires = ["poetry-core>=1.0.0"] 50 | build-backend = "poetry.core.masonry.api" 51 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Clearbox-AI/clearbox-wrapper/3629923763629d2a12cb96723695b5ab497ad756/tests/__init__.py -------------------------------------------------------------------------------- /tests/datasets/adult_test_50_rows.csv: -------------------------------------------------------------------------------- 1 | age,work_class,education,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income 2 | 25,Private,11th,Never-married,Machine-op-inspct,Own-child,Black,Male,0,0,40,United-States,<=50K. 3 | 38,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,50,United-States,<=50K. 4 | 28,Local-gov,Assoc-acdm,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,>50K. 5 | 44,Private,Some-college,Married-civ-spouse,Machine-op-inspct,Husband,Black,Male,7688,0,40,United-States,>50K. 6 | 18,,Some-college,Never-married,,Own-child,White,Female,0,0,30,United-States,<=50K. 7 | 34,Private,10th,Never-married,Other-service,Not-in-family,White,Male,0,0,30,United-States,<=50K. 8 | 29,,HS-grad,Never-married,,Unmarried,Black,Male,0,0,40,United-States,<=50K. 9 | 63,Self-emp-not-inc,Prof-school,Married-civ-spouse,Prof-specialty,Husband,White,Male,3103,0,32,United-States,>50K. 10 | 24,Private,Some-college,Never-married,Other-service,Unmarried,White,Female,0,0,40,United-States,<=50K. 11 | 55,Private,7th-8th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,10,United-States,<=50K. 12 | 65,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,6418,0,40,United-States,>50K. 13 | 36,Federal-gov,Bachelors,Married-civ-spouse,Adm-clerical,Husband,White,Male,0,0,40,United-States,<=50K. 14 | 26,Private,HS-grad,Never-married,Adm-clerical,Not-in-family,White,Female,0,0,39,United-States,<=50K. 15 | 58,,HS-grad,Married-civ-spouse,,Husband,White,Male,0,0,35,United-States,<=50K. 16 | 48,Private,HS-grad,Married-civ-spouse,Machine-op-inspct,Husband,White,Male,3103,0,48,United-States,>50K. 17 | 43,Private,Masters,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,50,United-States,>50K. 18 | 20,State-gov,Some-college,Never-married,Other-service,Own-child,White,Male,0,0,25,United-States,<=50K. 19 | 43,Private,HS-grad,Married-civ-spouse,Adm-clerical,Wife,White,Female,0,0,30,United-States,<=50K. 20 | 37,Private,HS-grad,Widowed,Machine-op-inspct,Unmarried,White,Female,0,0,20,United-States,<=50K. 21 | 40,Private,Doctorate,Married-civ-spouse,Prof-specialty,Husband,Asian-Pac-Islander,Male,0,0,45,,>50K. 22 | 34,Private,Bachelors,Married-civ-spouse,Tech-support,Husband,White,Male,0,0,47,United-States,>50K. 23 | 34,Private,Some-college,Never-married,Other-service,Own-child,Black,Female,0,0,35,United-States,<=50K. 24 | 72,,7th-8th,Divorced,,Not-in-family,White,Female,0,0,6,United-States,<=50K. 25 | 25,Private,Bachelors,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,43,Peru,<=50K. 26 | 25,Private,Bachelors,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,United-States,<=50K. 27 | 45,Self-emp-not-inc,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,7298,0,90,United-States,>50K. 28 | 22,Private,HS-grad,Never-married,Adm-clerical,Own-child,White,Male,0,0,20,United-States,<=50K. 29 | 23,Private,HS-grad,Separated,Machine-op-inspct,Unmarried,Black,Male,0,0,54,United-States,<=50K. 30 | 54,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,35,United-States,<=50K. 31 | 32,Self-emp-not-inc,Some-college,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,60,United-States,<=50K. 32 | 46,State-gov,Some-college,Married-civ-spouse,Exec-managerial,Husband,Black,Male,7688,0,38,United-States,>50K. 33 | 56,Self-emp-not-inc,11th,Widowed,Other-service,Unmarried,White,Female,0,0,50,United-States,<=50K. 34 | 24,Self-emp-not-inc,Bachelors,Never-married,Sales,Not-in-family,White,Male,0,0,50,United-States,<=50K. 35 | 23,Local-gov,Some-college,Married-civ-spouse,Protective-serv,Husband,White,Male,0,0,40,United-States,<=50K. 36 | 26,Private,HS-grad,Divorced,Exec-managerial,Unmarried,White,Female,0,0,40,United-States,<=50K. 37 | 65,,HS-grad,Married-civ-spouse,,Husband,White,Male,0,0,40,United-States,<=50K. 38 | 36,Local-gov,Bachelors,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40,United-States,>50K. 39 | 22,Private,5th-6th,Never-married,Priv-house-serv,Not-in-family,White,Male,0,0,50,Guatemala,<=50K. 40 | 17,Private,10th,Never-married,Machine-op-inspct,Not-in-family,White,Male,0,0,40,United-States,<=50K. 41 | 20,Private,HS-grad,Never-married,Craft-repair,Own-child,White,Male,0,0,40,United-States,<=50K. 42 | 65,Private,Masters,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,50,United-States,>50K. 43 | 44,Self-emp-inc,Assoc-voc,Married-civ-spouse,Sales,Husband,White,Male,0,0,45,United-States,>50K. 44 | 36,Private,HS-grad,Married-civ-spouse,Farming-fishing,Husband,White,Male,0,0,40,United-States,<=50K. 45 | 29,Private,11th,Married-civ-spouse,Other-service,Husband,White,Male,0,0,40,United-States,<=50K. 46 | 20,State-gov,Some-college,Never-married,Farming-fishing,Own-child,White,Male,0,0,32,United-States,<=50K. 47 | 28,Private,Assoc-voc,Married-civ-spouse,Prof-specialty,Wife,White,Female,0,0,36,United-States,>50K. 48 | 39,Private,7th-8th,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,Mexico,<=50K. 49 | 54,Private,Some-college,Married-civ-spouse,Transport-moving,Husband,White,Male,3908,0,50,United-States,<=50K. 50 | 52,Private,11th,Separated,Priv-house-serv,Not-in-family,Black,Female,0,0,18,United-States,<=50K. 51 | 56,Self-emp-inc,HS-grad,Widowed,Exec-managerial,Not-in-family,White,Female,0,0,50,United-States,<=50K. 52 | 18,Private,Some-college,Never-married,Other-service,Own-child,White,Male,0,0,20,United-States,<=50K. 53 | 39,Private,HS-grad,Divorced,Handlers-cleaners,Own-child,Black,Male,0,0,40,United-States,<=50K. 54 | 21,Private,Some-college,Never-married,Other-service,Own-child,White,Female,0,1721,24,United-States,<=50K. 55 | 22,Private,HS-grad,Never-married,Other-service,Not-in-family,White,Male,14084,0,60,United-States,>50K. 56 | 38,Private,9th,Married-spouse-absent,Exec-managerial,Not-in-family,White,Male,0,0,54,Mexico,<=50K. 57 | 21,Private,Some-college,Never-married,Adm-clerical,Own-child,White,Female,0,0,40,United-States,<=50K. 58 | 63,Private,HS-grad,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40,United-States,<=50K. 59 | 34,Local-gov,Bachelors,Married-civ-spouse,Exec-managerial,Husband,White,Male,3103,0,50,United-States,>50K. 60 | 42,Self-emp-inc,HS-grad,Married-civ-spouse,Exec-managerial,Husband,White,Male,5178,0,50,United-States,>50K. 61 | 33,Private,HS-grad,Married-civ-spouse,Handlers-cleaners,Husband,White,Male,0,0,40,United-States,<=50K. 62 | -------------------------------------------------------------------------------- /tests/datasets/iris_test_set_one_hot_y.csv: -------------------------------------------------------------------------------- 1 | 5.0,3.5,1.3,0.3,1,0,0 2 | 4.5,2.3,1.3,0.3,1,0,0 3 | 4.4,3.2,1.3,0.2,1,0,0 4 | 5.0,3.5,1.6,0.6,1,0,0 5 | 5.1,3.8,1.9,0.4,1,0,0 6 | 4.8,3.0,1.4,0.3,1,0,0 7 | 5.1,3.8,1.6,0.2,1,0,0 8 | 4.6,3.2,1.4,0.2,1,0,0 9 | 5.3,3.7,1.5,0.2,1,0,0 10 | 5.0,3.3,1.4,0.2,1,0,0 11 | 5.5,2.6,4.4,1.2,0,1,0 12 | 6.1,3.0,4.6,1.4,0,1,0 13 | 5.8,2.6,4.0,1.2,0,1,0 14 | 5.0,2.3,3.3,1.0,0,1,0 15 | 5.6,2.7,4.2,1.3,0,1,0 16 | 5.7,3.0,4.2,1.2,0,1,0 17 | 5.7,2.9,4.2,1.3,0,1,0 18 | 6.2,2.9,4.3,1.3,0,1,0 19 | 5.1,2.5,3.0,1.1,0,1,0 20 | 5.7,2.8,4.1,1.3,0,1,0 21 | 6.7,3.1,5.6,2.4,0,0,1 22 | 6.9,3.1,5.1,2.3,0,0,1 23 | 5.8,2.7,5.1,1.9,0,0,1 24 | 6.8,3.2,5.9,2.3,0,0,1 25 | 6.7,3.3,5.7,2.5,0,0,1 26 | 6.7,3.0,5.2,2.3,0,0,1 27 | 6.3,2.5,5.0,1.9,0,0,1 28 | 6.5,3.0,5.2,2.0,0,0,1 29 | 6.2,3.4,5.4,2.3,0,0,1 30 | 5.9,3.0,5.1,1.8,0,0,1 31 | -------------------------------------------------------------------------------- /tests/datasets/iris_training_set_one_hot_y.csv: -------------------------------------------------------------------------------- 1 | 5.1,3.5,1.4,0.2,1,0,0 2 | 4.9,3.0,1.4,0.2,1,0,0 3 | 4.7,3.2,1.3,0.2,1,0,0 4 | 4.6,3.1,1.5,0.2,1,0,0 5 | 5.0,3.6,1.4,0.2,1,0,0 6 | 5.4,3.9,1.7,0.4,1,0,0 7 | 4.6,3.4,1.4,0.3,1,0,0 8 | 5.0,3.4,1.5,0.2,1,0,0 9 | 4.4,2.9,1.4,0.2,1,0,0 10 | 4.9,3.1,1.5,0.1,1,0,0 11 | 5.4,3.7,1.5,0.2,1,0,0 12 | 4.8,3.4,1.6,0.2,1,0,0 13 | 4.8,3.0,1.4,0.1,1,0,0 14 | 4.3,3.0,1.1,0.1,1,0,0 15 | 5.8,4.0,1.2,0.2,1,0,0 16 | 5.7,4.4,1.5,0.4,1,0,0 17 | 5.4,3.9,1.3,0.4,1,0,0 18 | 5.1,3.5,1.4,0.3,1,0,0 19 | 5.7,3.8,1.7,0.3,1,0,0 20 | 5.1,3.8,1.5,0.3,1,0,0 21 | 5.4,3.4,1.7,0.2,1,0,0 22 | 5.1,3.7,1.5,0.4,1,0,0 23 | 4.6,3.6,1.0,0.2,1,0,0 24 | 5.1,3.3,1.7,0.5,1,0,0 25 | 4.8,3.4,1.9,0.2,1,0,0 26 | 5.0,3.0,1.6,0.2,1,0,0 27 | 5.0,3.4,1.6,0.4,1,0,0 28 | 5.2,3.5,1.5,0.2,1,0,0 29 | 5.2,3.4,1.4,0.2,1,0,0 30 | 4.7,3.2,1.6,0.2,1,0,0 31 | 4.8,3.1,1.6,0.2,1,0,0 32 | 5.4,3.4,1.5,0.4,1,0,0 33 | 5.2,4.1,1.5,0.1,1,0,0 34 | 5.5,4.2,1.4,0.2,1,0,0 35 | 4.9,3.1,1.5,0.1,1,0,0 36 | 5.0,3.2,1.2,0.2,1,0,0 37 | 5.5,3.5,1.3,0.2,1,0,0 38 | 4.9,3.1,1.5,0.1,1,0,0 39 | 4.4,3.0,1.3,0.2,1,0,0 40 | 5.1,3.4,1.5,0.2,1,0,0 41 | 7.0,3.2,4.7,1.4,0,1,0 42 | 6.4,3.2,4.5,1.5,0,1,0 43 | 6.9,3.1,4.9,1.5,0,1,0 44 | 5.5,2.3,4.0,1.3,0,1,0 45 | 6.5,2.8,4.6,1.5,0,1,0 46 | 5.7,2.8,4.5,1.3,0,1,0 47 | 6.3,3.3,4.7,1.6,0,1,0 48 | 4.9,2.4,3.3,1.0,0,1,0 49 | 6.6,2.9,4.6,1.3,0,1,0 50 | 5.2,2.7,3.9,1.4,0,1,0 51 | 5.0,2.0,3.5,1.0,0,1,0 52 | 5.9,3.0,4.2,1.5,0,1,0 53 | 6.0,2.2,4.0,1.0,0,1,0 54 | 6.1,2.9,4.7,1.4,0,1,0 55 | 5.6,2.9,3.6,1.3,0,1,0 56 | 6.7,3.1,4.4,1.4,0,1,0 57 | 5.6,3.0,4.5,1.5,0,1,0 58 | 5.8,2.7,4.1,1.0,0,1,0 59 | 6.2,2.2,4.5,1.5,0,1,0 60 | 5.6,2.5,3.9,1.1,0,1,0 61 | 5.9,3.2,4.8,1.8,0,1,0 62 | 6.1,2.8,4.0,1.3,0,1,0 63 | 6.3,2.5,4.9,1.5,0,1,0 64 | 6.1,2.8,4.7,1.2,0,1,0 65 | 6.4,2.9,4.3,1.3,0,1,0 66 | 6.6,3.0,4.4,1.4,0,1,0 67 | 6.8,2.8,4.8,1.4,0,1,0 68 | 6.7,3.0,5.0,1.7,0,1,0 69 | 6.0,2.9,4.5,1.5,0,1,0 70 | 5.7,2.6,3.5,1.0,0,1,0 71 | 5.5,2.4,3.8,1.1,0,1,0 72 | 5.5,2.4,3.7,1.0,0,1,0 73 | 5.8,2.7,3.9,1.2,0,1,0 74 | 6.0,2.7,5.1,1.6,0,1,0 75 | 5.4,3.0,4.5,1.5,0,1,0 76 | 6.0,3.4,4.5,1.6,0,1,0 77 | 6.7,3.1,4.7,1.5,0,1,0 78 | 6.3,2.3,4.4,1.3,0,1,0 79 | 5.6,3.0,4.1,1.3,0,1,0 80 | 5.5,2.5,4.0,1.3,0,1,0 81 | 6.3,3.3,6.0,2.5,0,0,1 82 | 5.8,2.7,5.1,1.9,0,0,1 83 | 7.1,3.0,5.9,2.1,0,0,1 84 | 6.3,2.9,5.6,1.8,0,0,1 85 | 6.5,3.0,5.8,2.2,0,0,1 86 | 7.6,3.0,6.6,2.1,0,0,1 87 | 4.9,2.5,4.5,1.7,0,0,1 88 | 7.3,2.9,6.3,1.8,0,0,1 89 | 6.7,2.5,5.8,1.8,0,0,1 90 | 7.2,3.6,6.1,2.5,0,0,1 91 | 6.5,3.2,5.1,2.0,0,0,1 92 | 6.4,2.7,5.3,1.9,0,0,1 93 | 6.8,3.0,5.5,2.1,0,0,1 94 | 5.7,2.5,5.0,2.0,0,0,1 95 | 5.8,2.8,5.1,2.4,0,0,1 96 | 6.4,3.2,5.3,2.3,0,0,1 97 | 6.5,3.0,5.5,1.8,0,0,1 98 | 7.7,3.8,6.7,2.2,0,0,1 99 | 7.7,2.6,6.9,2.3,0,0,1 100 | 6.0,2.2,5.0,1.5,0,0,1 101 | 6.9,3.2,5.7,2.3,0,0,1 102 | 5.6,2.8,4.9,2.0,0,0,1 103 | 7.7,2.8,6.7,2.0,0,0,1 104 | 6.3,2.7,4.9,1.8,0,0,1 105 | 6.7,3.3,5.7,2.1,0,0,1 106 | 7.2,3.2,6.0,1.8,0,0,1 107 | 6.2,2.8,4.8,1.8,0,0,1 108 | 6.1,3.0,4.9,1.8,0,0,1 109 | 6.4,2.8,5.6,2.1,0,0,1 110 | 7.2,3.0,5.8,1.6,0,0,1 111 | 7.4,2.8,6.1,1.9,0,0,1 112 | 7.9,3.8,6.4,2.0,0,0,1 113 | 6.4,2.8,5.6,2.2,0,0,1 114 | 6.3,2.8,5.1,1.5,0,0,1 115 | 6.1,2.6,5.6,1.4,0,0,1 116 | 7.7,3.0,6.1,2.3,0,0,1 117 | 6.3,3.4,5.6,2.4,0,0,1 118 | 6.4,3.1,5.5,1.8,0,0,1 119 | 6.0,3.0,4.8,1.8,0,0,1 120 | 6.9,3.1,5.4,2.1,0,0,1 121 | -------------------------------------------------------------------------------- /tests/keras/test_keras_boston.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from sklearn.model_selection import train_test_split 7 | import sklearn.preprocessing as sk_preprocessing 8 | from tensorflow.keras.layers import Dense 9 | from tensorflow.keras.models import Sequential 10 | 11 | import clearbox_wrapper as cbw 12 | 13 | 14 | @pytest.fixture 15 | def model_path(tmpdir): 16 | return os.path.join(str(tmpdir), "model") 17 | 18 | 19 | @pytest.fixture(scope="module") 20 | def boston_training_test(): 21 | csv_path = "tests/datasets/boston_housing.csv" 22 | target_column = "MEDV" 23 | boston_dataset = pd.read_csv(csv_path) 24 | y = boston_dataset[target_column] 25 | x = boston_dataset.drop(target_column, axis=1) 26 | x_train, x_test, y_train, y_test = train_test_split( 27 | x, y, test_size=0.2, random_state=42 28 | ) 29 | return x_train, x_test, y_train, y_test 30 | 31 | 32 | @pytest.fixture() 33 | def keras_model(): 34 | keras_clf = Sequential() 35 | keras_clf.add(Dense(8, input_dim=13, activation="relu")) 36 | keras_clf.add(Dense(4, activation="relu")) 37 | keras_clf.add(Dense(1)) 38 | 39 | keras_clf.compile(optimizer="rmsprop", loss="mse", metrics=["mae"]) 40 | return keras_clf 41 | 42 | 43 | @pytest.fixture() 44 | def sk_function_transformer(): 45 | def simple_preprocessor(data_x): 46 | return data_x ** 2 47 | 48 | transformer = sk_preprocessing.FunctionTransformer( 49 | simple_preprocessor, validate=True 50 | ) 51 | return transformer 52 | 53 | 54 | @pytest.fixture() 55 | def custom_transformer(): 56 | def simple_preprocessor(data_x): 57 | transformed_x = data_x + 1.0 58 | return transformed_x 59 | 60 | return simple_preprocessor 61 | 62 | 63 | @pytest.fixture() 64 | def add_value_to_column_transformer(): 65 | def double_dataframe(dataframe_x): 66 | x_transformed = dataframe_x + dataframe_x 67 | return x_transformed 68 | 69 | return double_dataframe 70 | 71 | 72 | def test_boston_keras_no_preprocessing(boston_training_test, keras_model, model_path): 73 | 74 | x_train, x_test, y_train, _ = boston_training_test 75 | 76 | model = keras_model 77 | model.fit(x_train, y_train, epochs=10, batch_size=32) 78 | cbw.save_model(model_path, model, zip=False) 79 | 80 | loaded_model = cbw.load_model(model_path) 81 | 82 | original_model_predictions = model.predict(x_test) 83 | loaded_model_predictions = loaded_model.predict(x_test) 84 | 85 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 86 | 87 | 88 | @pytest.mark.parametrize( 89 | "sk_transformer", 90 | [ 91 | (sk_preprocessing.StandardScaler()), 92 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 93 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 94 | (sk_preprocessing.RobustScaler()), 95 | (sk_preprocessing.MaxAbsScaler()), 96 | ], 97 | ) 98 | def test_boston_keras_preprocessing( 99 | sk_transformer, boston_training_test, keras_model, model_path 100 | ): 101 | x_train, x_test, y_train, _ = boston_training_test 102 | x_train_transformed = sk_transformer.fit_transform(x_train) 103 | 104 | model = keras_model 105 | model.fit(x_train_transformed, y_train, epochs=10, batch_size=32) 106 | cbw.save_model(model_path, model, preprocessing=sk_transformer, zip=False) 107 | 108 | loaded_model = cbw.load_model(model_path) 109 | x_test_transformed = sk_transformer.transform(x_test) 110 | original_model_predictions = model.predict(x_test_transformed) 111 | loaded_model_predictions = loaded_model.predict(x_test) 112 | 113 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 114 | 115 | 116 | def test_boston_keras_preprocessing_with_function_transformer( 117 | sk_function_transformer, boston_training_test, keras_model, model_path 118 | ): 119 | x_train, x_test, y_train, _ = boston_training_test 120 | x_train_transformed = sk_function_transformer.fit_transform(x_train) 121 | 122 | model = keras_model 123 | model.fit(x_train_transformed, y_train, epochs=10, batch_size=32) 124 | cbw.save_model(model_path, model, preprocessing=sk_function_transformer, zip=False) 125 | 126 | loaded_model = cbw.load_model(model_path) 127 | x_test_transformed = sk_function_transformer.transform(x_test) 128 | original_model_predictions = model.predict(x_test_transformed) 129 | loaded_model_predictions = loaded_model.predict(x_test) 130 | 131 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 132 | 133 | 134 | def test_boston_keras_preprocessing_with_custom_transformer( 135 | custom_transformer, boston_training_test, keras_model, model_path 136 | ): 137 | x_train, x_test, y_train, _ = boston_training_test 138 | x_train_transformed = custom_transformer(x_train) 139 | 140 | model = keras_model 141 | model.fit(x_train_transformed, y_train, epochs=10, batch_size=32) 142 | cbw.save_model(model_path, model, preprocessing=custom_transformer, zip=False) 143 | 144 | loaded_model = cbw.load_model(model_path) 145 | x_test_transformed = custom_transformer(x_test) 146 | original_model_predictions = model.predict(x_test_transformed) 147 | loaded_model_predictions = loaded_model.predict(x_test) 148 | 149 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 150 | 151 | 152 | @pytest.mark.parametrize( 153 | "preprocessor", 154 | [ 155 | (sk_preprocessing.StandardScaler()), 156 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 157 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 158 | (sk_preprocessing.RobustScaler()), 159 | (sk_preprocessing.MaxAbsScaler()), 160 | ], 161 | ) 162 | def test_boston_keras_data_preparation_and_preprocessing( 163 | preprocessor, 164 | add_value_to_column_transformer, 165 | boston_training_test, 166 | keras_model, 167 | model_path, 168 | ): 169 | x_train, x_test, y_train, _ = boston_training_test 170 | 171 | x_train_prepared = add_value_to_column_transformer(x_train) 172 | x_train_transformed = preprocessor.fit_transform(x_train_prepared) 173 | 174 | model = keras_model 175 | model.fit(x_train_transformed, y_train, epochs=10, batch_size=32) 176 | cbw.save_model( 177 | model_path, 178 | model, 179 | preprocessing=preprocessor, 180 | data_preparation=add_value_to_column_transformer, 181 | zip=False, 182 | ) 183 | 184 | loaded_model = cbw.load_model(model_path) 185 | x_test_prepared = add_value_to_column_transformer(x_test) 186 | x_test_transformed = preprocessor.transform(x_test_prepared) 187 | original_model_predictions = model.predict(x_test_transformed) 188 | loaded_model_predictions = loaded_model.predict(x_test) 189 | 190 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 191 | 192 | 193 | def test_boston_keras_zipped_path_already_exists( 194 | sk_function_transformer, 195 | add_value_to_column_transformer, 196 | boston_training_test, 197 | keras_model, 198 | model_path, 199 | ): 200 | x_train, x_test, y_train, _ = boston_training_test 201 | 202 | x_train_prepared = add_value_to_column_transformer(x_train) 203 | x_train_transformed = sk_function_transformer.fit_transform(x_train_prepared) 204 | 205 | model = keras_model 206 | model.fit(x_train_transformed, y_train, epochs=10, batch_size=32) 207 | cbw.save_model( 208 | model_path, 209 | model, 210 | preprocessing=sk_function_transformer, 211 | data_preparation=add_value_to_column_transformer, 212 | ) 213 | 214 | with pytest.raises(cbw.ClearboxWrapperException): 215 | cbw.save_model( 216 | model_path, 217 | model, 218 | preprocessing=sk_function_transformer, 219 | data_preparation=add_value_to_column_transformer, 220 | ) 221 | 222 | 223 | def test_boston_keras_path_already_exists( 224 | sk_function_transformer, 225 | add_value_to_column_transformer, 226 | boston_training_test, 227 | keras_model, 228 | model_path, 229 | ): 230 | x_train, x_test, y_train, _ = boston_training_test 231 | 232 | x_train_prepared = add_value_to_column_transformer(x_train) 233 | x_train_transformed = sk_function_transformer.fit_transform(x_train_prepared) 234 | 235 | model = keras_model 236 | model.fit(x_train_transformed, y_train, epochs=10, batch_size=32) 237 | cbw.save_model( 238 | model_path, 239 | model, 240 | preprocessing=sk_function_transformer, 241 | data_preparation=add_value_to_column_transformer, 242 | zip=False, 243 | ) 244 | 245 | with pytest.raises(cbw.ClearboxWrapperException): 246 | cbw.save_model( 247 | model_path, 248 | model, 249 | preprocessing=sk_function_transformer, 250 | data_preparation=add_value_to_column_transformer, 251 | zip=False, 252 | ) 253 | -------------------------------------------------------------------------------- /tests/pytorch/test_pytorch_boston.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from sklearn.model_selection import train_test_split 7 | import sklearn.preprocessing as sk_preprocessing 8 | import torch 9 | import torch.nn as nn 10 | 11 | import clearbox_wrapper as cbw 12 | 13 | num_epochs = 20 14 | learning_rate = 0.0001 15 | size_hidden1 = 25 16 | size_hidden2 = 12 17 | size_hidden3 = 6 18 | size_hidden4 = 1 19 | 20 | 21 | class BostonModel(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.lin1 = nn.Linear(13, size_hidden1) 25 | self.relu1 = nn.ReLU() 26 | self.lin2 = nn.Linear(size_hidden1, size_hidden2) 27 | self.relu2 = nn.ReLU() 28 | self.lin3 = nn.Linear(size_hidden2, size_hidden3) 29 | self.relu3 = nn.ReLU() 30 | self.lin4 = nn.Linear(size_hidden3, size_hidden4) 31 | 32 | def forward(self, input): 33 | return self.lin4( 34 | self.relu3(self.lin3(self.relu2(self.lin2(self.relu1(self.lin1(input)))))) 35 | ) 36 | 37 | 38 | def train(model_inp, x_train, y_train, num_epochs=num_epochs): 39 | datasets = torch.utils.data.TensorDataset(x_train, y_train) 40 | train_iter = torch.utils.data.DataLoader(datasets, batch_size=10, shuffle=True) 41 | criterion = nn.MSELoss(reduction="sum") 42 | optimizer = torch.optim.RMSprop(model_inp.parameters(), lr=learning_rate) 43 | for epoch in range(num_epochs): # loop over the dataset multiple times 44 | running_loss = 0.0 45 | for inputs, labels in train_iter: 46 | # forward pass 47 | outputs = model_inp(inputs) 48 | # defining loss 49 | loss = criterion(outputs, labels) 50 | # zero the parameter gradients 51 | optimizer.zero_grad() 52 | # computing gradients 53 | loss.backward() 54 | # accumulating running loss 55 | running_loss += loss.item() 56 | # updated weights based on computed gradients 57 | optimizer.step() 58 | if epoch % 20 == 0: 59 | print( 60 | "Epoch [%d]/[%d] running accumulative loss across all batches: %.3f" 61 | % (epoch + 1, num_epochs, running_loss) 62 | ) 63 | 64 | 65 | @pytest.fixture 66 | def model_path(tmpdir): 67 | return os.path.join(str(tmpdir), "model") 68 | 69 | 70 | @pytest.fixture(scope="module") 71 | def boston_training_test(): 72 | csv_path = "tests/datasets/boston_housing.csv" 73 | target_column = "MEDV" 74 | boston_dataset = pd.read_csv(csv_path) 75 | y = boston_dataset[target_column] 76 | x = boston_dataset.drop(target_column, axis=1) 77 | x_train, x_test, y_train, y_test = train_test_split( 78 | x, y, test_size=0.2, random_state=42 79 | ) 80 | return x_train, x_test, y_train, y_test 81 | 82 | 83 | @pytest.fixture() 84 | def sk_function_transformer(): 85 | def simple_preprocessor(data_x): 86 | return data_x ** 2 87 | 88 | transformer = sk_preprocessing.FunctionTransformer( 89 | simple_preprocessor, validate=True 90 | ) 91 | return transformer 92 | 93 | 94 | @pytest.fixture() 95 | def custom_transformer(): 96 | def simple_preprocessor(data_x): 97 | transformed_x = data_x + 1.0 98 | return transformed_x 99 | 100 | return simple_preprocessor 101 | 102 | 103 | @pytest.fixture() 104 | def add_value_to_column_transformer(): 105 | def double_dataframe(dataframe_x): 106 | x_transformed = dataframe_x + dataframe_x 107 | return x_transformed 108 | 109 | return double_dataframe 110 | 111 | 112 | def test_boston_pytorch_no_preprocessing(boston_training_test, model_path): 113 | 114 | x_train, x_test, y_train, _ = boston_training_test 115 | 116 | x_train = torch.Tensor(x_train.values) 117 | y_train = torch.Tensor(y_train.values) 118 | 119 | x_test_tensor = torch.Tensor(x_test.values) 120 | 121 | model = BostonModel() 122 | model.train() 123 | train(model, x_train, y_train) 124 | 125 | cbw.save_model(model_path, model, zip=False) 126 | loaded_model = cbw.load_model(model_path) 127 | 128 | original_model_predictions = model(x_test_tensor).detach().numpy() 129 | loaded_model_predictions = loaded_model.predict(x_test) 130 | 131 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 132 | 133 | 134 | @pytest.mark.parametrize( 135 | "sk_transformer", 136 | [ 137 | (sk_preprocessing.StandardScaler()), 138 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 139 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 140 | (sk_preprocessing.RobustScaler()), 141 | (sk_preprocessing.MaxAbsScaler()), 142 | ], 143 | ) 144 | def test_boston_pytorch_preprocessing(sk_transformer, boston_training_test, model_path): 145 | x_train, x_test, y_train, _ = boston_training_test 146 | 147 | x_transformed = sk_transformer.fit_transform(x_train) 148 | x_transformed = torch.Tensor(x_transformed) 149 | y_train = torch.Tensor(y_train.values) 150 | 151 | def preprocessing_function(x_data): 152 | x_transformed = sk_transformer.transform(x_data) 153 | return x_transformed 154 | 155 | model = BostonModel() 156 | model.train() 157 | train(model, x_transformed, y_train) 158 | 159 | cbw.save_model(model_path, model, preprocessing=preprocessing_function, zip=False) 160 | loaded_model = cbw.load_model(model_path) 161 | 162 | x_test_transformed = preprocessing_function(x_test) 163 | x_test_transformed = torch.Tensor(x_test_transformed) 164 | original_model_predictions = model(x_test_transformed).detach().numpy() 165 | loaded_model_predictions = loaded_model.predict(x_test) 166 | 167 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 168 | 169 | 170 | def test_boston_pytorch_preprocessing_with_function_transformer( 171 | sk_function_transformer, boston_training_test, model_path 172 | ): 173 | x_train, x_test, y_train, _ = boston_training_test 174 | 175 | x_transformed = sk_function_transformer.fit_transform(x_train) 176 | x_transformed = torch.Tensor(x_transformed) 177 | y_train = torch.Tensor(y_train.values) 178 | 179 | def preprocessing_function(x_data): 180 | x_transformed = sk_function_transformer.transform(x_data) 181 | return x_transformed 182 | 183 | model = BostonModel() 184 | model.train() 185 | train(model, x_transformed, y_train) 186 | 187 | cbw.save_model(model_path, model, preprocessing=preprocessing_function, zip=False) 188 | loaded_model = cbw.load_model(model_path) 189 | 190 | x_test_transformed = preprocessing_function(x_test) 191 | x_test_transformed = torch.Tensor(x_test_transformed) 192 | original_model_predictions = model(x_test_transformed).detach().numpy() 193 | loaded_model_predictions = loaded_model.predict(x_test) 194 | 195 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 196 | 197 | 198 | def test_boston_pytorch_preprocessing_with_custom_transformer( 199 | custom_transformer, boston_training_test, model_path 200 | ): 201 | x_train, x_test, y_train, _ = boston_training_test 202 | 203 | x_transformed = custom_transformer(x_train) 204 | x_transformed = torch.Tensor(x_transformed.values) 205 | y_train = torch.Tensor(y_train.values) 206 | 207 | def preprocessing_function(x_data): 208 | x_transformed = custom_transformer(x_data) 209 | return x_transformed 210 | 211 | model = BostonModel() 212 | model.train() 213 | train(model, x_transformed, y_train) 214 | 215 | cbw.save_model(model_path, model, preprocessing=preprocessing_function, zip=False) 216 | loaded_model = cbw.load_model(model_path) 217 | 218 | x_test_transformed = preprocessing_function(x_test) 219 | x_test_transformed = torch.Tensor(x_test_transformed.values) 220 | original_model_predictions = model(x_test_transformed).detach().numpy() 221 | loaded_model_predictions = loaded_model.predict(x_test) 222 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 223 | 224 | 225 | @pytest.mark.parametrize( 226 | "preprocessor", 227 | [ 228 | (sk_preprocessing.StandardScaler()), 229 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 230 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 231 | (sk_preprocessing.RobustScaler()), 232 | (sk_preprocessing.MaxAbsScaler()), 233 | ], 234 | ) 235 | def test_boston_pytorch_data_preparation_and_preprocessing( 236 | preprocessor, add_value_to_column_transformer, boston_training_test, model_path 237 | ): 238 | x_train, x_test, y_train, _ = boston_training_test 239 | 240 | x_prepared = add_value_to_column_transformer(x_train) 241 | x_transformed = preprocessor.fit_transform(x_prepared) 242 | 243 | x_transformed = torch.Tensor(x_transformed) 244 | y_train = torch.Tensor(y_train.values) 245 | 246 | def preprocessing_function(x_data): 247 | x_transformed = preprocessor.transform(x_data) 248 | return x_transformed 249 | 250 | model = BostonModel() 251 | model.train() 252 | train(model, x_transformed, y_train) 253 | 254 | cbw.save_model( 255 | model_path, 256 | model, 257 | preprocessing=preprocessing_function, 258 | data_preparation=add_value_to_column_transformer, 259 | zip=False, 260 | ) 261 | loaded_model = cbw.load_model(model_path) 262 | 263 | x_test_prepared = add_value_to_column_transformer(x_test) 264 | x_test_transformed = preprocessing_function(x_test_prepared) 265 | x_test_transformed = torch.Tensor(x_test_transformed) 266 | original_model_predictions = model(x_test_transformed).detach().numpy() 267 | loaded_model_predictions = loaded_model.predict(x_test) 268 | 269 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 270 | -------------------------------------------------------------------------------- /tests/sklearn/test_sklearn_boston.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | import sklearn.datasets as datasets 5 | import sklearn.ensemble as ensemble 6 | import sklearn.linear_model as linear_model 7 | import sklearn.neighbors as neighbors 8 | import sklearn.preprocessing as sk_preprocessing 9 | import sklearn.svm as svm 10 | import sklearn.tree as tree 11 | 12 | import clearbox_wrapper as cbw 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def boston_data(): 17 | boston = datasets.load_boston() 18 | x = boston.data 19 | y = boston.target 20 | return x, y 21 | 22 | 23 | @pytest.fixture() 24 | def sk_function_transformer(): 25 | def simple_preprocessor(numpy_x): 26 | return numpy_x ** 2 27 | 28 | transformer = sk_preprocessing.FunctionTransformer( 29 | simple_preprocessor, validate=True 30 | ) 31 | return transformer 32 | 33 | 34 | @pytest.fixture() 35 | def custom_transformer(): 36 | def simple_preprocessor(numpy_x): 37 | transformed_x = numpy_x + 1.0 38 | return transformed_x 39 | 40 | return simple_preprocessor 41 | 42 | 43 | @pytest.fixture() 44 | def drop_column_transformer(): 45 | def drop_column(numpy_x): 46 | transformed_x = np.delete(numpy_x, 0, axis=1) 47 | return transformed_x 48 | 49 | return drop_column 50 | 51 | 52 | def _check_schema(pdf, input_schema): 53 | if isinstance(pdf, (list, np.ndarray, dict)): 54 | try: 55 | pdf = pd.DataFrame(pdf) 56 | except Exception as e: 57 | message = ( 58 | "This model contains a model signature, which suggests a DataFrame input." 59 | "There was an error casting the input data to a DataFrame: {0}".format( 60 | str(e) 61 | ) 62 | ) 63 | raise cbw.ClearboxWrapperException(message) 64 | if not isinstance(pdf, pd.DataFrame): 65 | message = ( 66 | "Expected input to be DataFrame or list. Found: %s" % type(pdf).__name__ 67 | ) 68 | raise cbw.ClearboxWrapperException(message) 69 | 70 | if input_schema.has_column_names(): 71 | # make sure there are no missing columns 72 | col_names = input_schema.column_names() 73 | expected_names = set(col_names) 74 | actual_names = set(pdf.columns) 75 | missing_cols = expected_names - actual_names 76 | extra_cols = actual_names - expected_names 77 | # Preserve order from the original columns, since missing/extra columns are likely to 78 | # be in same order. 79 | missing_cols = [c for c in col_names if c in missing_cols] 80 | extra_cols = [c for c in pdf.columns if c in extra_cols] 81 | if missing_cols: 82 | print( 83 | "Model input is missing columns {0}." 84 | " Note that there were extra columns: {1}".format( 85 | missing_cols, extra_cols 86 | ) 87 | ) 88 | return False 89 | else: 90 | if len(pdf.columns) != len(input_schema.columns): 91 | print( 92 | "The model signature declares " 93 | "{0} input columns but the provided input has " 94 | "{1} columns. Note: the columns were not named in the signature so we can " 95 | "only verify their count.".format( 96 | len(input_schema.columns), len(pdf.columns) 97 | ) 98 | ) 99 | return False 100 | col_names = pdf.columns[: len(input_schema.columns)] 101 | return True 102 | 103 | 104 | @pytest.mark.parametrize( 105 | "sklearn_model", 106 | [ 107 | (linear_model.LinearRegression()), 108 | (svm.SVR()), 109 | (neighbors.KNeighborsRegressor()), 110 | (tree.DecisionTreeRegressor()), 111 | (ensemble.RandomForestRegressor()), 112 | ], 113 | ) 114 | def test_boston_sklearn_no_preprocessing(sklearn_model, boston_data, tmpdir): 115 | x, y = boston_data 116 | fitted_model = sklearn_model.fit(x, y) 117 | tmp_model_path = str(tmpdir + "/saved_model") 118 | cbw.save_model(tmp_model_path, fitted_model, zip=False) 119 | loaded_model = cbw.load_model(tmp_model_path) 120 | original_model_predictions = fitted_model.predict(x[:5]) 121 | wrapped_model_predictions = loaded_model.predict(x[:5]) 122 | np.testing.assert_array_equal(original_model_predictions, wrapped_model_predictions) 123 | 124 | 125 | @pytest.mark.parametrize( 126 | "sklearn_model, preprocessor", 127 | [ 128 | ( 129 | linear_model.LinearRegression(), 130 | sk_preprocessing.StandardScaler(), 131 | ), 132 | ( 133 | svm.SVR(), 134 | sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50), 135 | ), 136 | ( 137 | neighbors.KNeighborsRegressor(), 138 | sk_preprocessing.MaxAbsScaler(), 139 | ), 140 | (tree.DecisionTreeRegressor(), sk_preprocessing.RobustScaler()), 141 | (ensemble.RandomForestRegressor(), sk_preprocessing.MaxAbsScaler()), 142 | ], 143 | ) 144 | def test_boston_sklearn_preprocessing(sklearn_model, preprocessor, boston_data, tmpdir): 145 | x, y = boston_data 146 | x_transformed = preprocessor.fit_transform(x) 147 | fitted_model = sklearn_model.fit(x_transformed, y) 148 | tmp_model_path = str(tmpdir + "/saved_model") 149 | cbw.save_model(tmp_model_path, fitted_model, preprocessing=preprocessor, zip=False) 150 | loaded_model = cbw.load_model(tmp_model_path) 151 | original_model_predictions = fitted_model.predict(x_transformed[:5]) 152 | loaded_model_predictions = loaded_model.predict(x[:5]) 153 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 154 | 155 | 156 | @pytest.mark.parametrize( 157 | "sklearn_model, preprocessor", 158 | [ 159 | ( 160 | linear_model.LinearRegression(), 161 | sk_preprocessing.StandardScaler(), 162 | ), 163 | ( 164 | svm.SVR(), 165 | sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=20), 166 | ), 167 | ( 168 | neighbors.KNeighborsRegressor(), 169 | sk_preprocessing.RobustScaler(), 170 | ), 171 | (tree.DecisionTreeRegressor(), sk_preprocessing.RobustScaler()), 172 | (ensemble.RandomForestRegressor(), sk_preprocessing.MaxAbsScaler()), 173 | ], 174 | ) 175 | def test_boston_sklearn_data_preparation_and_preprocessing( 176 | sklearn_model, preprocessor, boston_data, drop_column_transformer, tmpdir 177 | ): 178 | x, y = boston_data 179 | data_preparation = drop_column_transformer 180 | x_transformed = data_preparation(x) 181 | x_transformed = preprocessor.fit_transform(x_transformed) 182 | fitted_model = sklearn_model.fit(x_transformed, y) 183 | tmp_model_path = str(tmpdir + "/saved_model") 184 | cbw.save_model( 185 | tmp_model_path, 186 | fitted_model, 187 | preprocessing=preprocessor, 188 | data_preparation=data_preparation, 189 | zip=False, 190 | ) 191 | loaded_model = cbw.load_model(tmp_model_path) 192 | original_model_predictions = fitted_model.predict(x_transformed[:5]) 193 | loaded_model_predictions = loaded_model.predict(x[:5]) 194 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 195 | 196 | 197 | def tests_boston_sklearn_zipped_path_already_exists(boston_data, tmpdir): 198 | x, y = boston_data 199 | sklearn_model = tree.DecisionTreeRegressor() 200 | fitted_model = sklearn_model.fit(x, y) 201 | tmp_model_path = str(tmpdir + "/saved_model") 202 | cbw.save_model(tmp_model_path, fitted_model) 203 | with pytest.raises(cbw.ClearboxWrapperException): 204 | cbw.save_model(tmp_model_path, fitted_model) 205 | 206 | 207 | def tests_boston_sklearn_path_already_exists(boston_data, tmpdir): 208 | x, y = boston_data 209 | sklearn_model = tree.DecisionTreeRegressor() 210 | fitted_model = sklearn_model.fit(x, y) 211 | tmp_model_path = str(tmpdir + "/saved_model") 212 | cbw.save_model(tmp_model_path, fitted_model, zip=False) 213 | with pytest.raises(cbw.ClearboxWrapperException): 214 | cbw.save_model(tmp_model_path, fitted_model, zip=False) 215 | 216 | 217 | @pytest.mark.parametrize( 218 | "sklearn_model", 219 | [ 220 | (neighbors.KNeighborsRegressor()), 221 | (tree.DecisionTreeRegressor()), 222 | ], 223 | ) 224 | def test_boston_sklearn_no_preprocessing_check_model_signature( 225 | sklearn_model, boston_data, tmpdir 226 | ): 227 | x, y = boston_data 228 | fitted_model = sklearn_model.fit(x, y) 229 | tmp_model_path = str(tmpdir + "/saved_model") 230 | cbw.save_model(tmp_model_path, fitted_model, input_data=x, zip=False) 231 | loaded_model = cbw.load_model(tmp_model_path) 232 | original_model_predictions = fitted_model.predict(x[:5]) 233 | loaded_model_predictions = loaded_model.predict(x[:5]) 234 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 235 | 236 | mlmodel = cbw.Model.load(tmp_model_path) 237 | assert _check_schema(x, mlmodel.get_model_input_schema()) 238 | 239 | 240 | @pytest.mark.parametrize( 241 | "sklearn_model, preprocessor", 242 | [ 243 | (tree.DecisionTreeRegressor(), sk_preprocessing.RobustScaler()), 244 | (ensemble.RandomForestRegressor(), sk_preprocessing.MaxAbsScaler()), 245 | ], 246 | ) 247 | def test_boston_sklearn_preprocessing_check_model_and_preprocessing_signature( 248 | sklearn_model, preprocessor, boston_data, tmpdir 249 | ): 250 | x, y = boston_data 251 | x_transformed = preprocessor.fit_transform(x) 252 | fitted_model = sklearn_model.fit(x_transformed, y) 253 | tmp_model_path = str(tmpdir + "/saved_model") 254 | cbw.save_model( 255 | tmp_model_path, 256 | fitted_model, 257 | preprocessing=preprocessor, 258 | input_data=x, 259 | zip=False, 260 | ) 261 | loaded_model = cbw.load_model(tmp_model_path) 262 | original_model_predictions = fitted_model.predict(x_transformed[:5]) 263 | loaded_model_predictions = loaded_model.predict(x[:5]) 264 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 265 | 266 | mlmodel = cbw.Model.load(tmp_model_path) 267 | preprocessing_input_schema = mlmodel.get_preprocessing_input_schema() 268 | preprocessing_output_schema = mlmodel.get_preprocessing_output_schema() 269 | model_input_schema = mlmodel.get_model_input_schema() 270 | 271 | assert _check_schema(x, preprocessing_input_schema) 272 | assert _check_schema(x_transformed, preprocessing_output_schema) 273 | assert _check_schema(x_transformed, model_input_schema) 274 | assert preprocessing_output_schema == model_input_schema 275 | 276 | 277 | @pytest.mark.parametrize( 278 | "sklearn_model, preprocessor", 279 | [ 280 | (tree.DecisionTreeRegressor(), sk_preprocessing.RobustScaler()), 281 | (ensemble.RandomForestRegressor(), sk_preprocessing.MaxAbsScaler()), 282 | ], 283 | ) 284 | def test_boston_sklearn_check_model_preprocessing_and_data_preparation_signature( 285 | sklearn_model, preprocessor, boston_data, drop_column_transformer, tmpdir 286 | ): 287 | x, y = boston_data 288 | data_preparation = drop_column_transformer 289 | x_prepared = data_preparation(x) 290 | x_transformed = preprocessor.fit_transform(x_prepared) 291 | fitted_model = sklearn_model.fit(x_transformed, y) 292 | tmp_model_path = str(tmpdir + "/saved_model") 293 | cbw.save_model( 294 | tmp_model_path, 295 | fitted_model, 296 | preprocessing=preprocessor, 297 | data_preparation=data_preparation, 298 | input_data=x, 299 | zip=False, 300 | ) 301 | loaded_model = cbw.load_model(tmp_model_path) 302 | original_model_predictions = fitted_model.predict(x_transformed[:5]) 303 | loaded_model_predictions = loaded_model.predict(x[:5]) 304 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 305 | 306 | mlmodel = cbw.Model.load(tmp_model_path) 307 | data_preparation_input_schema = mlmodel.get_data_preparation_input_schema() 308 | data_preparation_output_schema = mlmodel.get_data_preparation_output_schema() 309 | preprocessing_input_schema = mlmodel.get_preprocessing_input_schema() 310 | preprocessing_output_schema = mlmodel.get_preprocessing_output_schema() 311 | model_input_schema = mlmodel.get_model_input_schema() 312 | 313 | assert _check_schema(x, data_preparation_input_schema) 314 | assert _check_schema(x_prepared, data_preparation_output_schema) 315 | assert _check_schema(x_prepared, preprocessing_input_schema) 316 | assert _check_schema(x_transformed, preprocessing_output_schema) 317 | assert _check_schema(x_transformed, model_input_schema) 318 | assert not _check_schema(x, model_input_schema) 319 | assert data_preparation_output_schema == preprocessing_input_schema 320 | assert preprocessing_output_schema == model_input_schema 321 | -------------------------------------------------------------------------------- /tests/xgboost/test_xgboost_boston.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from sklearn.model_selection import train_test_split 7 | import sklearn.preprocessing as sk_preprocessing 8 | import xgboost as xgb 9 | 10 | import clearbox_wrapper as cbw 11 | 12 | 13 | @pytest.fixture 14 | def model_path(tmpdir): 15 | return os.path.join(str(tmpdir), "model") 16 | 17 | 18 | @pytest.fixture(scope="module") 19 | def boston_training_test(): 20 | csv_path = "tests/datasets/boston_housing.csv" 21 | target_column = "MEDV" 22 | boston_dataset = pd.read_csv(csv_path) 23 | y = boston_dataset[target_column] 24 | x = boston_dataset.drop(target_column, axis=1) 25 | x_train, x_test, y_train, y_test = train_test_split( 26 | x, y, test_size=0.2, random_state=42 27 | ) 28 | return x_train, x_test, y_train, y_test 29 | 30 | 31 | @pytest.fixture() 32 | def sk_function_transformer(): 33 | def simple_preprocessor(data_x): 34 | return data_x ** 2 35 | 36 | transformer = sk_preprocessing.FunctionTransformer( 37 | simple_preprocessor, validate=True 38 | ) 39 | return transformer 40 | 41 | 42 | @pytest.fixture() 43 | def custom_transformer(): 44 | def simple_preprocessor(data_x): 45 | transformed_x = data_x + 1.0 46 | return transformed_x 47 | 48 | return simple_preprocessor 49 | 50 | 51 | @pytest.fixture() 52 | def add_value_to_column_transformer(): 53 | def drop_column(dataframe_x): 54 | x_transformed = dataframe_x + dataframe_x 55 | return x_transformed 56 | 57 | return drop_column 58 | 59 | 60 | @pytest.fixture() 61 | def drop_column_transformer(): 62 | def drop_column(dataframe_x): 63 | x_transformed = dataframe_x.drop("INDUS", axis=1) 64 | return x_transformed 65 | 66 | return drop_column 67 | 68 | 69 | def _check_schema(pdf, input_schema): 70 | if isinstance(pdf, (list, np.ndarray, dict)): 71 | try: 72 | pdf = pd.DataFrame(pdf) 73 | except Exception as e: 74 | message = ( 75 | "This model contains a model signature, which suggests a DataFrame input." 76 | "There was an error casting the input data to a DataFrame: {0}".format( 77 | str(e) 78 | ) 79 | ) 80 | raise cbw.ClearboxWrapperException(message) 81 | if not isinstance(pdf, pd.DataFrame): 82 | message = ( 83 | "Expected input to be DataFrame or list. Found: %s" % type(pdf).__name__ 84 | ) 85 | raise cbw.ClearboxWrapperException(message) 86 | 87 | if input_schema.has_column_names(): 88 | # make sure there are no missing columns 89 | col_names = input_schema.column_names() 90 | expected_names = set(col_names) 91 | actual_names = set(pdf.columns) 92 | missing_cols = expected_names - actual_names 93 | extra_cols = actual_names - expected_names 94 | # Preserve order from the original columns, since missing/extra columns are likely to 95 | # be in same order. 96 | missing_cols = [c for c in col_names if c in missing_cols] 97 | extra_cols = [c for c in pdf.columns if c in extra_cols] 98 | if missing_cols: 99 | print( 100 | "Model input is missing columns {0}." 101 | " Note that there were extra columns: {1}".format( 102 | missing_cols, extra_cols 103 | ) 104 | ) 105 | return False 106 | else: 107 | if len(pdf.columns) != len(input_schema.columns): 108 | print( 109 | "The model signature declares " 110 | "{0} input columns but the provided input has " 111 | "{1} columns. Note: the columns were not named in the signature so we can " 112 | "only verify their count.".format( 113 | len(input_schema.columns), len(pdf.columns) 114 | ) 115 | ) 116 | return False 117 | col_names = pdf.columns[: len(input_schema.columns)] 118 | return True 119 | 120 | 121 | def test_boston_xgboost_no_preprocessing(boston_training_test, model_path): 122 | x_train, x_test, y_train, _ = boston_training_test 123 | model = xgb.XGBRegressor() 124 | fitted_model = model.fit(x_train, y_train) 125 | cbw.save_model(model_path, fitted_model, zip=False) 126 | 127 | loaded_model = cbw.load_model(model_path) 128 | original_model_predictions = fitted_model.predict(x_test) 129 | loaded_model_predictions = loaded_model.predict(x_test) 130 | 131 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 132 | 133 | 134 | @pytest.mark.parametrize( 135 | "sk_transformer", 136 | [ 137 | (sk_preprocessing.StandardScaler()), 138 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 139 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 140 | (sk_preprocessing.RobustScaler()), 141 | (sk_preprocessing.MaxAbsScaler()), 142 | ], 143 | ) 144 | def test_boston_xgboost_preprocessing(sk_transformer, boston_training_test, model_path): 145 | x_train, x_test, y_train, _ = boston_training_test 146 | x_train_transformed = sk_transformer.fit_transform(x_train) 147 | 148 | model = xgb.XGBRegressor() 149 | fitted_model = model.fit(x_train_transformed, y_train) 150 | cbw.save_model(model_path, fitted_model, preprocessing=sk_transformer, zip=False) 151 | 152 | loaded_model = cbw.load_model(model_path) 153 | x_test_transformed = sk_transformer.transform(x_test) 154 | original_model_predictions = fitted_model.predict(x_test_transformed) 155 | loaded_model_predictions = loaded_model.predict(x_test) 156 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 157 | 158 | 159 | def test_boston_xgboost_preprocessing_with_function_transformer( 160 | sk_function_transformer, boston_training_test, model_path 161 | ): 162 | x_train, x_test, y_train, _ = boston_training_test 163 | x_train_transformed = sk_function_transformer.fit_transform(x_train) 164 | 165 | model = xgb.XGBRegressor() 166 | fitted_model = model.fit(x_train_transformed, y_train) 167 | cbw.save_model( 168 | model_path, fitted_model, preprocessing=sk_function_transformer, zip=False 169 | ) 170 | 171 | loaded_model = cbw.load_model(model_path) 172 | x_test_transformed = sk_function_transformer.transform(x_test) 173 | original_model_predictions = fitted_model.predict(x_test_transformed) 174 | loaded_model_predictions = loaded_model.predict(x_test) 175 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 176 | 177 | 178 | def test_boston_xgboost_preprocessing_with_custom_transformer( 179 | custom_transformer, boston_training_test, model_path 180 | ): 181 | x_train, x_test, y_train, _ = boston_training_test 182 | x_train_transformed = custom_transformer(x_train) 183 | 184 | model = xgb.XGBRegressor() 185 | fitted_model = model.fit(x_train_transformed, y_train) 186 | cbw.save_model( 187 | model_path, fitted_model, preprocessing=custom_transformer, zip=False 188 | ) 189 | 190 | loaded_model = cbw.load_model(model_path) 191 | x_test_transformed = custom_transformer(x_test) 192 | original_model_predictions = fitted_model.predict(x_test_transformed) 193 | loaded_model_predictions = loaded_model.predict(x_test) 194 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 195 | 196 | 197 | @pytest.mark.parametrize( 198 | "preprocessor", 199 | [ 200 | (sk_preprocessing.StandardScaler()), 201 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 202 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 203 | (sk_preprocessing.RobustScaler()), 204 | (sk_preprocessing.MaxAbsScaler()), 205 | ], 206 | ) 207 | def test_boston_xgboost_data_preparation_and_preprocessing( 208 | preprocessor, add_value_to_column_transformer, boston_training_test, model_path 209 | ): 210 | x_train, x_test, y_train, _ = boston_training_test 211 | x_train_prepared = add_value_to_column_transformer(x_train) 212 | x_train_transformed = preprocessor.fit_transform(x_train_prepared) 213 | 214 | model = xgb.XGBRegressor() 215 | fitted_model = model.fit(x_train_transformed, y_train) 216 | cbw.save_model( 217 | model_path, 218 | fitted_model, 219 | preprocessing=preprocessor, 220 | data_preparation=add_value_to_column_transformer, 221 | zip=False, 222 | ) 223 | 224 | loaded_model = cbw.load_model(model_path) 225 | x_test_prepared = add_value_to_column_transformer(x_test) 226 | x_test_transformed = preprocessor.transform(x_test_prepared) 227 | 228 | original_model_predictions = fitted_model.predict(x_test_transformed) 229 | loaded_model_predictions = loaded_model.predict(x_test) 230 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 231 | 232 | 233 | def tests_boston_xgb_zipped_path_already_exists(boston_training_test, model_path): 234 | x_train, x_test, y_train, _ = boston_training_test 235 | model = xgb.XGBRegressor() 236 | fitted_model = model.fit(x_train, y_train) 237 | cbw.save_model(model_path, fitted_model) 238 | with pytest.raises(cbw.ClearboxWrapperException): 239 | cbw.save_model(model_path, fitted_model) 240 | 241 | 242 | def tests_boston_xgb_path_already_exists(boston_training_test, model_path): 243 | x_train, x_test, y_train, _ = boston_training_test 244 | model = xgb.XGBRegressor() 245 | fitted_model = model.fit(x_train, y_train) 246 | cbw.save_model(model_path, fitted_model, zip=False) 247 | with pytest.raises(cbw.ClearboxWrapperException): 248 | cbw.save_model(model_path, fitted_model, zip=False) 249 | 250 | 251 | def test_boston_xgb_no_preprocessing_check_model_signature( 252 | boston_training_test, model_path 253 | ): 254 | x_train, x_test, y_train, _ = boston_training_test 255 | model = xgb.XGBRegressor() 256 | fitted_model = model.fit(x_train, y_train) 257 | cbw.save_model(model_path, fitted_model, input_data=x_train, zip=False) 258 | loaded_model = cbw.load_model(model_path) 259 | original_model_predictions = fitted_model.predict(x_train[:5]) 260 | loaded_model_predictions = loaded_model.predict(x_train[:5]) 261 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 262 | 263 | mlmodel = cbw.Model.load(model_path) 264 | assert _check_schema(x_train, mlmodel.get_model_input_schema()) 265 | 266 | 267 | @pytest.mark.parametrize( 268 | "preprocessor", 269 | [ 270 | (sk_preprocessing.StandardScaler()), 271 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 272 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 273 | ], 274 | ) 275 | def test_boston_xgb_preprocessing_check_model_and_preprocessing_signature( 276 | preprocessor, boston_training_test, model_path 277 | ): 278 | x_train, x_test, y_train, _ = boston_training_test 279 | x_train_transformed = preprocessor.fit_transform(x_train) 280 | 281 | model = xgb.XGBRegressor() 282 | fitted_model = model.fit(x_train_transformed, y_train) 283 | cbw.save_model( 284 | model_path, 285 | fitted_model, 286 | preprocessing=preprocessor, 287 | input_data=x_train, 288 | zip=False, 289 | ) 290 | 291 | loaded_model = cbw.load_model(model_path) 292 | original_model_predictions = fitted_model.predict(x_train_transformed[:5]) 293 | loaded_model_predictions = loaded_model.predict(x_train[:5]) 294 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 295 | 296 | mlmodel = cbw.Model.load(model_path) 297 | preprocessing_input_schema = mlmodel.get_preprocessing_input_schema() 298 | preprocessing_output_schema = mlmodel.get_preprocessing_output_schema() 299 | model_input_schema = mlmodel.get_model_input_schema() 300 | 301 | assert _check_schema(x_train, preprocessing_input_schema) 302 | assert _check_schema(x_train_transformed, preprocessing_output_schema) 303 | assert _check_schema(x_train_transformed, model_input_schema) 304 | assert preprocessing_output_schema == model_input_schema 305 | 306 | 307 | @pytest.mark.parametrize( 308 | "preprocessor", 309 | [ 310 | (sk_preprocessing.StandardScaler()), 311 | (sk_preprocessing.QuantileTransformer(random_state=0, n_quantiles=50)), 312 | (sk_preprocessing.KBinsDiscretizer(n_bins=2, encode="ordinal")), 313 | ], 314 | ) 315 | def test_boston_xgb_check_model_preprocessing_and_data_preparation_signature( 316 | preprocessor, boston_training_test, drop_column_transformer, model_path 317 | ): 318 | x_train, x_test, y_train, _ = boston_training_test 319 | print(x_train.columns) 320 | x_train_prepared = drop_column_transformer(x_train) 321 | x_train_transformed = preprocessor.fit_transform(x_train_prepared) 322 | 323 | model = xgb.XGBRegressor() 324 | fitted_model = model.fit(x_train_transformed, y_train) 325 | cbw.save_model( 326 | model_path, 327 | fitted_model, 328 | preprocessing=preprocessor, 329 | data_preparation=drop_column_transformer, 330 | input_data=x_train, 331 | zip=False, 332 | ) 333 | 334 | loaded_model = cbw.load_model(model_path) 335 | x_test_prepared = drop_column_transformer(x_test) 336 | x_test_transformed = preprocessor.transform(x_test_prepared) 337 | 338 | original_model_predictions = fitted_model.predict(x_test_transformed) 339 | loaded_model_predictions = loaded_model.predict(x_test) 340 | np.testing.assert_array_equal(original_model_predictions, loaded_model_predictions) 341 | 342 | mlmodel = cbw.Model.load(model_path) 343 | data_preparation_input_schema = mlmodel.get_data_preparation_input_schema() 344 | data_preparation_output_schema = mlmodel.get_data_preparation_output_schema() 345 | preprocessing_input_schema = mlmodel.get_preprocessing_input_schema() 346 | preprocessing_output_schema = mlmodel.get_preprocessing_output_schema() 347 | model_input_schema = mlmodel.get_model_input_schema() 348 | 349 | assert _check_schema(x_train, data_preparation_input_schema) 350 | assert _check_schema(x_train_prepared, data_preparation_output_schema) 351 | assert _check_schema(x_train_prepared, preprocessing_input_schema) 352 | assert _check_schema(x_train_transformed, preprocessing_output_schema) 353 | assert _check_schema(x_train_transformed, model_input_schema) 354 | assert not _check_schema(x_train, model_input_schema) 355 | assert data_preparation_output_schema == preprocessing_input_schema 356 | assert preprocessing_output_schema == model_input_schema 357 | --------------------------------------------------------------------------------