├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .travis.yml ├── .zenodo.json ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── ceteris_paribus ├── __init__.py ├── datasets │ ├── __init__.py │ ├── insurance.csv │ └── titanic_train.csv ├── explainer.py ├── gower.py ├── plots │ ├── __init__.py │ ├── ceterisParibusD3.js │ ├── plot_template.html │ └── plots.py ├── profiles.py ├── select_data.py └── utils.py ├── codecov.yml ├── docs ├── Makefile ├── ceteris_paribus.plots.rst ├── ceteris_paribus.rst ├── conf.py ├── index.rst ├── make.bat ├── misc ├── modules.rst └── readme_include.rst ├── examples ├── __init__.py ├── categorical_variables.py ├── cheatsheet.py ├── classification_example.py ├── keras_example.py ├── multiple_models.py └── regression_example.py ├── jupyter-notebooks ├── CeterisParibusCheatsheet.ipynb ├── __init__.py └── titanic.ipynb ├── misc ├── multiclass_models.png ├── titanic_interactions_average.png ├── titanic_many_models.png ├── titanic_many_variables.png └── titanic_single_response.png ├── paper ├── img │ ├── figure1.png │ ├── figure2.png │ ├── figure3.png │ └── figure4.png ├── paper.bib └── paper.md ├── publish.sh ├── requirements-dev.txt ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── test_api.py ├── test_explain.py ├── test_plots.py ├── test_profiles.py └── test_select.py /.gitattributes: -------------------------------------------------------------------------------- 1 | ceteris_paribus/plots/ceterisParibusD3.js linguist-vendored 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Attach minimal code reproducing the bug 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Desktop (please complete the following information):** 23 | - Version [e.g. 0.0.1] 24 | - Browser [e.g. chrome, safari] (if applicable, e.g. plotting) 25 | - Jupyter/Console (if relevant) 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | .coverage 4 | *.pyc 5 | 6 | ceteris_paribus/plots/obs*.js 7 | ceteris_paribus/plots/params*.js 8 | ceteris_paribus/plots/profile*.js 9 | ceteris_paribus/plots/plots*.html 10 | 11 | /docs/_build/ 12 | 13 | **/.ipynb_checkpoints 14 | /build/ 15 | /pyCeterisParibus.egg-info/ 16 | /dist/ 17 | 18 | # data files produced for plots 19 | **/_plot_files -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | matrix: 3 | include: 4 | - os: linux 5 | dist: trusty 6 | python: '3.6' 7 | - os: linux 8 | dist: trusty 9 | python: '3.5' 10 | - os: linux 11 | dist: xenial 12 | python: '3.7' 13 | - os: linux 14 | dist: xenial 15 | python: '3.6' 16 | - os: linux 17 | dist: xenial 18 | python: '3.5' 19 | 20 | # command to install dependencies 21 | install: 22 | - pip install -r requirements.txt 23 | - pip install -r requirements-dev.txt 24 | script: 25 | - python -m pytest --cov=./ 26 | after_success: 27 | - codecov 28 | -------------------------------------------------------------------------------- /.zenodo.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Python library for Ceteris Paribus Plots (What-if plots)", 3 | "license": "other-open", 4 | "title": "pyCeterisParibus: explaining Machine Learning models with Ceteris Paribus Profiles in Python", 5 | "upload_type": "software", 6 | "creators": [ 7 | { 8 | "affiliation": "Faculty of Mathematics and Information Science, Warsaw University of Technology, Faculty of Mathematics, Informatics, and Mechanics, University of Warsaw", 9 | "name": "Michał Kuźba", 10 | "orcid": "0000-0002-9181-0126" 11 | }, 12 | { 13 | "affiliation": "Faculty of Mathematics and Information Science, Warsaw University of Technology", 14 | "name": "Ewa Baranowska", 15 | "orcid": "0000-0002-2278-4219" 16 | }, 17 | { 18 | "affiliation": "Faculty of Mathematics and Information Science, Warsaw University of Technology, Faculty of Mathematics, Informatics, and Mechanics, University of Warsaw", 19 | "name": "Przemysław Biecek", 20 | "orcid": "0000-0001-8423-1823" 21 | } 22 | ], 23 | "access_right": "open" 24 | } 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contribution 2 | Great to see you here. Thank you for considering contribution to pyCeterisParibus. Let's make it better together! 3 | 4 | ### Reporting bugs 5 | Have you found a bug? 6 | * Make sure you have the **latest version** of the package. 7 | * **Check whether it is not already in** [Issues](https://github.com/ModelOriented/pyCeterisParibus/issues). 8 | * Add an issue following **Bug report template**. 9 | 10 | ### Fixing bugs 11 | Do you know how to fix the bug? You are more than welcome to do this! 12 | * Clone this repo. 13 | * Fix the bug. 14 | * Make sure all the **tests** passed. 15 | * Open the pull request, attach the issue number in description if possible. 16 | 17 | ### Suggest a new feature 18 | * Use **Feature request template**. 19 | 20 | ### Launching tests 21 | Tests are launched automatically using travis. 22 | 23 | You can also run them with ```python -m pytest --cov=./``` 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include ceteris_paribus/datasets *.csv 2 | include ceteris_paribus/plots/ceterisParibusD3.js 3 | include ceteris_paribus/plots/plot_template.html 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![travis](https://travis-ci.org/ModelOriented/pyCeterisParibus.svg?branch=master)](https://travis-ci.org/ModelOriented/pyCeterisParibus) 3 | [![codecov](https://codecov.io/gh/ModelOriented/pyCeterisParibus/branch/master/graph/badge.svg)](https://codecov.io/gh/ModelOriented/pyCeterisParibus) 4 | [![Documentation Status](https://readthedocs.org/projects/pyceterisparibus/badge/?version=latest)](https://pyceterisparibus.readthedocs.io/en/latest/?badge=latest) 5 | [![Downloads](https://pepy.tech/badge/pyceterisparibus)](https://pepy.tech/project/pyceterisparibus) 6 | [![PyPI version](https://badge.fury.io/py/pyCeterisParibus.svg)](https://badge.fury.io/py/pyCeterisParibus) 7 | 8 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.2667756.svg)](https://doi.org/10.5281/zenodo.2667756) 9 | [![status](http://joss.theoj.org/papers/aad9a21c61c01adebe11bc5bc1ceca92/status.svg)](http://joss.theoj.org/papers/aad9a21c61c01adebe11bc5bc1ceca92) 10 | 11 | # pyCeterisParibus 12 | 13 | **Please note that the Ceteris Paribus method is moved to the dalex Python package which is actively maintained. If you will experience any problem with pyCeterisParibus please consider the dalex implementation at https://dalex.drwhy.ai/python/api/.** 14 | 15 | pyCeterisParibus is a Python library based on an *R* package [CeterisParibus](https://github.com/pbiecek/ceterisParibus). 16 | It implements Ceteris Paribus Plots. 17 | They allow understanding how the model response would change if a selected variable is changed. 18 | It’s a perfect tool for What-If scenarios. Ceteris Paribus is a Latin phrase meaning all else unchanged. 19 | These plots present the change in model response as the values of one feature change with all others being fixed. 20 | Ceteris Paribus method is model-agnostic - it works for any Machine Learning model. 21 | The idea is an extension of PDP (Partial Dependency Plots) and ICE (Individual Conditional Expectations) plots. 22 | It allows explaining single observations for multiple variables at the same time. 23 | The plot engine is developed [here](https://github.com/ModelOriented/ceterisParibusD3). 24 | 25 | ## Why is it so useful? 26 | There might be several motivations behind utilizing this idea. 27 | Imagine a person gets a low credit score. 28 | The client wants to understand how to increase the score and the scoring institution (e.g. a bank) should be able to answer such questions. 29 | Moreover, this method is useful for researchers and developers to analyze, debug, explain and improve Machine Learning models, assisting the entire process of the model design. 30 | 31 | ## Setup 32 | Tested on Python 3.5+ 33 | 34 | PyCeterisParibus is on [PyPI](https://pypi.org/project/pyCeterisParibus/). Simply run: 35 | 36 | ```bash 37 | pip install pyCeterisParibus 38 | ``` 39 | or install the newest version from GitHub by executing: 40 | ```bash 41 | pip install git+https://github.com/ModelOriented/pyCeterisParibus 42 | ``` 43 | or download the sources, enter the main directory and perform: 44 | ```bash 45 | https://github.com/ModelOriented/pyCeterisParibus.git 46 | cd pyCeterisParibus 47 | python setup.py install # (alternatively use pip install .) 48 | ``` 49 | 50 | ## Docs 51 | A detailed description of all methods and their parameters might be found in [documentation](https://pyceterisparibus.readthedocs.io/en/latest/ceteris_paribus.html). 52 | 53 | To build the documentation locally: 54 | ```bash 55 | pip install -r requirements-dev.txt 56 | cd docs 57 | make html 58 | ``` 59 | and open `_build/html/index.html` 60 | 61 | ## Examples 62 | Below we present use cases on two well-known datasets - Titanic and Iris. More examples e.g. for regression problems might be found [here](examples) and in jupyter notebooks [here](jupyter-notebooks). 63 | 64 | Note, that in order to run the examples you need to install extra requirements from `requirements-dev.txt`. 65 | 66 | ## Use case - Titanic survival 67 | We demonstrate Ceteris Paribus Plots using the well-known Titanic dataset. In this problem, we examine the chance of survival for Titanic passengers. 68 | We start with preprocessing the data and creating an XGBoost model. 69 | ```python 70 | import pandas as pd 71 | df = pd.read_csv('titanic_train.csv') 72 | 73 | y = df['Survived'] 74 | x = df.drop(['Survived', 'PassengerId', 'Name', 'Cabin', 'Ticket'], 75 | inplace=False, axis=1) 76 | 77 | valid = x['Age'].isnull() | x['Embarked'].isnull() 78 | x = x[-valid] 79 | y = y[-valid] 80 | 81 | from sklearn.model_selection import train_test_split 82 | X_train, X_test, y_train, y_test = train_test_split(x, y, 83 | test_size=0.2, random_state=42) 84 | ``` 85 | ```python 86 | from sklearn.pipeline import Pipeline 87 | from sklearn.preprocessing import StandardScaler, OneHotEncoder 88 | from sklearn.compose import ColumnTransformer 89 | 90 | # We create the preprocessing pipelines for both numeric and categorical data. 91 | numeric_features = ['Pclass', 'Age', 'SibSp', 'Parch', 'Fare'] 92 | numeric_transformer = Pipeline(steps=[ 93 | ('scaler', StandardScaler())]) 94 | 95 | categorical_features = ['Embarked', 'Sex'] 96 | categorical_transformer = Pipeline(steps=[ 97 | ('onehot', OneHotEncoder(handle_unknown='ignore'))]) 98 | 99 | preprocessor = ColumnTransformer( 100 | transformers=[ 101 | ('num', numeric_transformer, numeric_features), 102 | ('cat', categorical_transformer, categorical_features)]) 103 | ``` 104 | 105 | ```python 106 | from xgboost import XGBClassifier 107 | xgb_clf = Pipeline(steps=[('preprocessor', preprocessor), 108 | ('classifier', XGBClassifier())]) 109 | xgb_clf.fit(X_train, y_train) 110 | ``` 111 | 112 | Here the pyCeterisParibus starts. Since this library works in a model agnostic fashion, first we need to create a wrapper around the model with uniform predict interface. 113 | ```python 114 | from ceteris_paribus.explainer import explain 115 | explainer_xgb = explain(xgb_clf, data=x, y=y, label='XGBoost', 116 | predict_function=lambda X: xgb_clf.predict_proba(X)[::, 1]) 117 | ``` 118 | 119 | 120 | ### Single variable profile 121 | Let's look at Mr Ernest James Crease, the 19-year-old man, travelling on the 3. class from Southampton with an 8 pounds ticket in his pocket. He died on Titanic. Most likely, this would not have been the case had Ernest been a few years younger. 122 | Figure 1 presents the chance of survival for a person like Ernest at different ages. We can see things were tough for people like him unless they were a child. 123 | 124 | ```python 125 | ernest = X_test.iloc[10] 126 | label_ernest = y_test.iloc[10] 127 | from ceteris_paribus.profiles import individual_variable_profile 128 | cp_xgb = individual_variable_profile(explainer_xgb, ernest, label_ernest) 129 | ``` 130 | 131 | Having calculated the profile we can plot it. Note, that `plot_notebook` might be used instead of `plot` when used in Jupyter notebooks. 132 | 133 | ```python 134 | from ceteris_paribus.plots.plots import plot 135 | plot(cp_xgb, selected_variables=["Age"]) 136 | ``` 137 | 138 | ![Chance of survival depending on age](misc/titanic_single_response.png) 139 | 140 | ### Many models 141 | The above picture explains the prediction of XGBoost model. What if we compare various models? 142 | 143 | ```python 144 | from sklearn.ensemble import RandomForestClassifier 145 | from sklearn.linear_model import LogisticRegression 146 | rf_clf = Pipeline(steps=[('preprocessor', preprocessor), 147 | ('classifier', RandomForestClassifier())]) 148 | linear_clf = Pipeline(steps=[('preprocessor', preprocessor), 149 | ('classifier', LogisticRegression())]) 150 | 151 | rf_clf.fit(X_train, y_train) 152 | linear_clf.fit(X_train, y_train) 153 | 154 | explainer_rf = explain(rf_clf, data=x, y=y, label='RandomForest', 155 | predict_function=lambda X: rf_clf.predict_proba(X)[::, 1]) 156 | explainer_linear = explain(linear_clf, data=x, y=y, label='LogisticRegression', 157 | predict_function=lambda X: linear_clf.predict_proba(X)[::, 1]) 158 | 159 | plot(cp_xgb, cp_rf, cp_linear, selected_variables=["Age"]) 160 | ``` 161 | 162 | ![The probability of survival estimated with various models.](misc/titanic_many_models.png) 163 | 164 | Clearly, XGBoost offers a better fit than Logistic Regression. 165 | Also, it predicts a higher chance of survival at child's age than the Random Forest model does. 166 | 167 | ### Profiles for many variables 168 | This time we have a look at Miss. Elizabeth Mussey Eustis. She is 54 years old, travels at 1. class with her sister Marta, as they return to the US from their tour of southern Europe. They both survived the disaster. 169 | 170 | ```python 171 | elizabeth = X_test.iloc[1] 172 | label_elizabeth = y_test.iloc[1] 173 | cp_xgb_2 = individual_variable_profile(explainer_xgb, elizabeth, label_elizabeth) 174 | ``` 175 | 176 | ```python 177 | plot(cp_xgb_2, selected_variables=["Pclass", "Sex", "Age", "Embarked"]) 178 | ``` 179 | 180 | ![Profiles for many variables.](misc/titanic_many_variables.png) 181 | 182 | Would she have returned home if she had travelled at 3. class or if she had been a man? As we can observe this is less likely. On the other hand, for a first class, female passenger chances of survival were high regardless of age. Note, this was different in the case of Ernest. Place of embarkment (Cherbourg) has no influence, which is expected behaviour. 183 | 184 | ### Feature interactions and average response 185 | Now, what if we look at passengers most similar to Miss. Eustis (middle-aged, upper class)? 186 | 187 | ```python 188 | from ceteris_paribus.select_data import select_neighbours 189 | neighbours = select_neighbours(X_train, elizabeth, 190 | selected_variables=['Pclass', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], 191 | n=15) 192 | cp_xgb_ns = individual_variable_profile(explainer_xgb, neighbours) 193 | ``` 194 | 195 | ```python 196 | plot(cp_xgb_ns, color="Sex", selected_variables=["Pclass", "Age"], 197 | aggregate_profiles='mean', size_pdps=6, alpha_pdps=1, size=2) 198 | ``` 199 | 200 | ![Interaction with gender. Apart from charts with Ceteris Paribus Profiles (top of the visualisation), we can plot a table with observations used to calculate these profiles (bottom of the visualisation).](misc/titanic_interactions_average.png) 201 | 202 | There are two distinct clusters of passengers determined with their gender, therefore a *PDP* average plot (on grey) does not show the whole picture. Children of both genders were likely to survive, but then we see a large gap. Also, being female increased the chance of survival mostly for second and first class passengers. 203 | 204 | Plot function comes with extensive customization options. List of all parameters might be found in the documentation. Additionally, one can interact with the plot by hovering over a point of interest to see more details. Similarly, there is an interactive table with options for highlighting relevant elements as well as filtering and sorting rows. 205 | 206 | 207 | 208 | ### Multiclass models - Iris dataset 209 | Prepare dataset and model 210 | ```python 211 | iris = load_iris() 212 | 213 | def random_forest_classifier(): 214 | rf_model = RandomForestClassifier(n_estimators=100, random_state=42) 215 | rf_model.fit(iris['data'], iris['target']) 216 | return rf_model, iris['data'], iris['target'], iris['feature_names'] 217 | ``` 218 | 219 | Wrap model into explainers 220 | ```python 221 | rf_model, iris_x, iris_y, iris_var_names = random_forest_classifier() 222 | 223 | explainer_rf1 = explain(rf_model, iris_var_names, iris_x, iris_y, 224 | predict_function= lambda X: rf_model.predict_proba(X)[::, 0], label=iris.target_names[0]) 225 | explainer_rf2 = explain(rf_model, iris_var_names, iris_x, iris_y, 226 | predict_function= lambda X: rf_model.predict_proba(X)[::, 1], label=iris.target_names[1]) 227 | explainer_rf3 = explain(rf_model, iris_var_names, iris_x, iris_y, 228 | predict_function= lambda X: rf_model.predict_proba(X)[::, 2], label=iris.target_names[2]) 229 | ``` 230 | 231 | Calculate profiles and plot 232 | ```python 233 | cp_rf1 = individual_variable_profile(explainer_rf1, iris_x[0], iris_y[0]) 234 | cp_rf2 = individual_variable_profile(explainer_rf2, iris_x[0], iris_y[0]) 235 | cp_rf3 = individual_variable_profile(explainer_rf3, iris_x[0], iris_y[0]) 236 | 237 | plot(cp_rf1, cp_rf2, cp_rf3, selected_variables=['petal length (cm)', 'petal width (cm)', 'sepal length (cm)']) 238 | ``` 239 | ![Multiclass models](misc/multiclass_models.png) 240 | 241 | ## Contributing 242 | You're more than welcomed to contribute to this package. See the [guideline](CONTRIBUTING.md). 243 | 244 | ## Acknowledgments 245 | Work on this package was financially supported by the ‘NCN Opus grant 2016/21/B/ST6/0217’. 246 | -------------------------------------------------------------------------------- /ceteris_paribus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/ceteris_paribus/__init__.py -------------------------------------------------------------------------------- /ceteris_paribus/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | DATASETS_DIR = os.path.dirname(__file__) 4 | -------------------------------------------------------------------------------- /ceteris_paribus/explainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from collections import namedtuple 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | Explainer = namedtuple("Explainer", "model var_names data y predict_fun label") 9 | 10 | 11 | def explain(model, variable_names=None, data=None, y=None, predict_function=None, label=None): 12 | """ 13 | This function creates a unified representation of a model, which can be further processed by various explainers 14 | 15 | :param model: a model to be explained 16 | :param variable_names: names of variables, if not supplied then derived from data 17 | :param data: data that was used for fitting 18 | :param y: labels for the data 19 | :param predict_function: function that takes the data and returns predictions 20 | :param label: label of the model, if not supplied the function will try to infer it from the model object, otherwise unset 21 | :return: Explainer object 22 | """ 23 | if not predict_function: 24 | if hasattr(model, 'predict'): 25 | if isinstance(data, pd.core.frame.DataFrame): 26 | predict_function = model.predict 27 | else: 28 | predict_function = lambda df: model.predict(df.values) 29 | else: 30 | raise ValueError('Unable to find predict function') 31 | if not label: 32 | logging.warning("Model is unlabeled... \n You can add label using method set_label") 33 | label_items = re.split('\(', str(model)) 34 | if label_items and len(label_items) > 1: 35 | label = label_items[0] 36 | else: 37 | label = 'unlabeled_model' 38 | if variable_names is None: 39 | if isinstance(data, pd.core.frame.DataFrame): 40 | variable_names = list(data) 41 | else: 42 | raise ValueError("Unable to impute the variable names. Those must be supplied directly!") 43 | 44 | if data is not None: 45 | if not isinstance(data, pd.core.frame.DataFrame): 46 | data = np.array(data) 47 | if data.ndim == 1: 48 | # make 1D array 2D 49 | data = data.reshape((1, -1)) 50 | if len(variable_names) != data.shape[1]: 51 | raise ValueError("Incorrect number of variables given.") 52 | 53 | data = pd.DataFrame(data, columns=variable_names) 54 | 55 | if y is not None: 56 | if isinstance(y, pd.core.frame.DataFrame): 57 | y = pd.Series(y[0]) 58 | else: 59 | y = pd.Series(y) 60 | 61 | explainer = Explainer(model, list(variable_names), data, y, predict_function, label) 62 | return explainer 63 | 64 | -------------------------------------------------------------------------------- /ceteris_paribus/gower.py: -------------------------------------------------------------------------------- 1 | """ This is the module for calculating gower's distance/dissimilarity """ 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | # Normalize the array 8 | def _normalize_mixed_data_columns(arr): 9 | """ 10 | Returns the numpy array representation of the data. 11 | Loses information about the types 12 | """ 13 | return np.array(arr, dtype=object) 14 | 15 | 16 | def _calc_range_mixed_data_columns(data, observation, dtypes): 17 | """ Return range for each numeric column, 0 for categorical variables """ 18 | _, cols = data.shape 19 | 20 | result = np.zeros(cols) 21 | for col in range(cols): 22 | if np.issubdtype(dtypes[col], np.number): 23 | result[col] = max(max(data[:, col]), observation[col]) - min(min(data[:, col]), observation[col]) 24 | return result 25 | 26 | 27 | def _gower_dist(xi, xj, ranges, dtypes): 28 | """ 29 | Return gower's distance between xi and xj 30 | 31 | :param ranges: ranges of values for each column 32 | :param dtypes: types of each column 33 | """ 34 | dtypes = np.array(dtypes) 35 | 36 | sum_sij = 0.0 37 | sum_wij = 0.0 38 | 39 | cols = len(ranges) 40 | for col in range(cols): 41 | if np.issubdtype(dtypes[col], np.number): 42 | if pd.isnull(xi[col]) or pd.isnull(xj[col]) or np.isclose(0, ranges[col]): 43 | wij = 0 44 | sij = 0 45 | else: 46 | wij = 1 47 | sij = abs(xi[col] - xj[col]) / ranges[col] 48 | else: 49 | sij = xi[col] != xj[col] 50 | wij = 0 if pd.isnull(xi[col]) and pd.isnull(xj[col]) else 1 51 | 52 | sum_sij += wij * sij 53 | sum_wij += wij 54 | 55 | return sum_sij / sum_wij 56 | 57 | 58 | def gower_distances(data, observation): 59 | """ 60 | Return an array of distances between all observations and a chosen one 61 | Based on: 62 | https://sourceforge.net/projects/gower-distance-4python 63 | https://beta.vu.nl/nl/Images/stageverslag-hoven_tcm235-777817.pdf 64 | 65 | :type data: DataFrame 66 | :type observation: pandas Series 67 | """ 68 | dtypes = data.dtypes 69 | data = _normalize_mixed_data_columns(data) 70 | observation = _normalize_mixed_data_columns(observation) 71 | ranges = _calc_range_mixed_data_columns(data, observation, dtypes) 72 | return np.array([_gower_dist(row, observation, ranges, dtypes) for row in data]) 73 | -------------------------------------------------------------------------------- /ceteris_paribus/plots/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PLOTS_DIR = os.path.dirname(__file__) 4 | -------------------------------------------------------------------------------- /ceteris_paribus/plots/plot_template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | ceterisParibus D3 template 8 | 9 | 10 | 11 | 12 | 15 | 16 | 17 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 |
34 | 35 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /ceteris_paribus/plots/plots.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | import webbrowser 6 | from shutil import copyfile 7 | 8 | from flask import Flask, render_template 9 | 10 | from ceteris_paribus.plots import PLOTS_DIR 11 | from ceteris_paribus.utils import save_observations, save_profiles 12 | 13 | app = Flask(__name__, template_folder=PLOTS_DIR) 14 | 15 | MAX_PLOTS_PER_SESSION = 10000 16 | 17 | # generates ids for subsequent plots 18 | _PLOT_NUMBER = iter(range(MAX_PLOTS_PER_SESSION)) 19 | 20 | # directory with all files produced in plot generation process 21 | _DATA_PATH = '_plot_files' 22 | os.makedirs(_DATA_PATH, exist_ok=True) 23 | _D3_engine_filename = 'ceterisParibusD3.js' 24 | copyfile(os.path.join(PLOTS_DIR, _D3_engine_filename), os.path.join(_DATA_PATH, _D3_engine_filename)) 25 | 26 | 27 | def _calculate_plot_variables(cp_profile, selected_variables): 28 | """ 29 | Helper function to calculate valid subset of variables to be plotted 30 | """ 31 | if not selected_variables: 32 | return cp_profile.selected_variables 33 | if not set(selected_variables).issubset(set(cp_profile.selected_variables)): 34 | logging.warning("Selected variables are not subset of all variables. Parameter is ignored.") 35 | return cp_profile.selected_variables 36 | else: 37 | return list(selected_variables) 38 | 39 | 40 | def _params_update(params, **kwargs): 41 | for key, val in kwargs.items(): 42 | if val: 43 | params[key] = val 44 | return params 45 | 46 | 47 | def _detect_plot_destination(destination): 48 | """ 49 | Detect plot destination (browser or embedded inside a notebook) based on the user choice 50 | """ 51 | if destination == "notebook": 52 | try: 53 | from IPython.display import IFrame 54 | return "notebook" 55 | except ImportError: 56 | logging.warning("Notebook environment not detected. Plots will be placed in a new tab") 57 | # when browser is explicitly chosen or as a default 58 | return "browser" 59 | 60 | 61 | def plot_notebook(cp_profile, *args, **kwargs): 62 | """ 63 | Wrapper for the ``plot`` function with option to embed in the notebook 64 | """ 65 | plot(cp_profile, *args, destination="notebook", **kwargs) 66 | 67 | 68 | def plot(cp_profile, *args, destination="browser", 69 | show_profiles=True, show_observations=True, show_residuals=False, show_rugs=False, 70 | aggregate_profiles=None, selected_variables=None, 71 | color=None, size=2, alpha=0.4, 72 | color_pdps=None, size_pdps=None, alpha_pdps=None, 73 | color_points=None, size_points=None, alpha_points=None, 74 | color_residuals=None, size_residuals=None, alpha_residuals=None, 75 | height=500, width=600, 76 | plot_title='', yaxis_title='y', 77 | print_observations=True, 78 | **kwargs): 79 | """ 80 | Plot ceteris paribus profile 81 | 82 | :param cp_profile: ceteris paribus profile 83 | :param args: next (optional) ceteris paribus profiles to be plotted along 84 | :param destination: available *browser* - open plot in a new tab, *notebook* - embed a plot in jupyter notebook if possible 85 | :param show_profiles: whether to show profiles 86 | :param show_observations: whether to show individual observations 87 | :param show_residuals: whether to plot residuals 88 | :param show_rugs: whether to plot rugs 89 | :param aggregate_profiles: if specified additional aggregated profile will be plotted, available values: `mean`, `median` 90 | :param selected_variables: variables selected for the plots 91 | :param color: color for profiles - either a color or a variable that should be used for coloring 92 | :param size: size of lines to be plotted 93 | :param alpha: opacity of lines (between 0 and 1) 94 | :param color_pdps: color of pdps - aggregated profiles 95 | :param size_pdps: size of pdps - aggregated profiles 96 | :param alpha_pdps: opacity of pdps - aggregated profiles 97 | :param color_points: color points to be plotted 98 | :param size_points: size of points to be plotted 99 | :param alpha_points: opacity of points 100 | :param color_residuals: color of plotted residuals 101 | :param size_residuals: size of plotted residuals 102 | :param alpha_residuals: opacity of plotted residuals 103 | :param height: height of the window containing plots 104 | :param width: width of the window containing plots 105 | :param plot_title: Title of the plot displayed above 106 | :param yaxis_title: Label for the y axis 107 | :param print_observations: whether to print the table with observations values 108 | :param kwargs: other options passed to the plot 109 | """ 110 | 111 | params = dict() 112 | params.update(kwargs) 113 | params["variables"] = _calculate_plot_variables(cp_profile, selected_variables) 114 | params['color'] = "_label_" if args else color 115 | params['show_profiles'] = show_profiles 116 | params['show_observations'] = show_observations 117 | params['show_rugs'] = show_rugs 118 | params['show_residuals'] = show_residuals and (cp_profile.new_observation_true is not None) 119 | params['add_table'] = print_observations 120 | params['height'] = height 121 | params['width'] = width 122 | params['plot_title'] = plot_title 123 | params['size_ices'] = size 124 | params['alpha_ices'] = alpha 125 | params = _params_update(params, 126 | color_pdps=color_pdps, size_pdps=size_pdps, alpha_pdps=alpha_pdps, 127 | size_points=size_points, alpha_points=alpha_points, color_points=color_points, 128 | size_residuals=size_residuals, alpha_residuals=alpha_residuals, 129 | color_residuals=color_residuals, 130 | yaxis_title=yaxis_title) 131 | 132 | if aggregate_profiles in {'mean', 'median', None}: 133 | params['aggregate_profiles'] = aggregate_profiles 134 | else: 135 | logging.warning("Incorrect function for profile aggregation: {}. Parameter ignored." 136 | "Available values are: 'mean' and 'median'".format(aggregate_profiles)) 137 | params['aggregate_profiles'] = None 138 | 139 | all_profiles = [cp_profile] + list(args) 140 | 141 | plot_id = str(next(_PLOT_NUMBER)) 142 | plot_path, params_path, obs_path, profile_path = _get_data_paths(plot_id) 143 | 144 | with open(params_path, 'w') as f: 145 | f.write("params = " + json.dumps(params, indent=2) + ";") 146 | 147 | save_observations(all_profiles, obs_path) 148 | save_profiles(all_profiles, profile_path) 149 | 150 | with app.app_context(): 151 | data = render_template("plot_template.html", i=plot_id, params=params) 152 | 153 | with open(plot_path, 'w') as f: 154 | f.write(data) 155 | 156 | destination = _detect_plot_destination(destination) 157 | if destination == "notebook": 158 | from IPython.display import IFrame, display 159 | display(IFrame(plot_path, width=int(width * 1.1), height=int(height * 1.1))) 160 | else: 161 | # open plot in a browser 162 | if sys.platform == "darwin": # check if on OSX 163 | plot_path = "file://" + os.path.abspath(plot_path) 164 | webbrowser.open(plot_path) 165 | 166 | 167 | def _get_data_paths(plot_id): 168 | plot_path = os.path.join(_DATA_PATH, "plots{}.html".format(plot_id)) 169 | params_path = os.path.join(_DATA_PATH, "params{}.js".format(plot_id)) 170 | obs_path = os.path.join(_DATA_PATH, 'obs{}.js'.format(plot_id)) 171 | profile_path = os.path.join(_DATA_PATH, "profile{}.js".format(plot_id)) 172 | return plot_path, params_path, obs_path, profile_path 173 | -------------------------------------------------------------------------------- /ceteris_paribus/profiles.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ceteris_paribus.utils import transform_into_Series 8 | 9 | 10 | def individual_variable_profile(explainer, new_observation, y=None, variables=None, grid_points=101, 11 | variable_splits=None): 12 | """ 13 | Calculate ceteris paribus profile 14 | 15 | :param explainer: a model to be explained 16 | :param new_observation: a new observation for which the profiles are calculated 17 | :param y: y true labels for `new_observation`. If specified then will be added to ceteris paribus plots 18 | :param variables: collection of variables selected for calculating profiles 19 | :param grid_points: number of points for profile 20 | :param variable_splits: dictionary of splits for variables, in most cases created with `_calculate_variable_splits()`. If None then it will be calculated based on validation data avaliable in the `explainer`. 21 | :return: instance of CeterisParibus class 22 | """ 23 | variables = _get_variables(variables, explainer) 24 | if not isinstance(new_observation, pd.core.frame.DataFrame): 25 | new_observation = np.array(new_observation) 26 | if new_observation.ndim == 1: 27 | # make 1D array 2D 28 | new_observation = new_observation.reshape((1, -1)) 29 | new_observation = pd.DataFrame(new_observation, columns=explainer.var_names) 30 | else: 31 | try: 32 | new_observation.columns = explainer.var_names 33 | except ValueError as e: 34 | raise ValueError("Mismatched number of variables {} instead of {}".format(len(new_observation.columns), 35 | len(explainer.var_names))) 36 | 37 | if y is not None: 38 | y = transform_into_Series(y) 39 | 40 | cp_profile = CeterisParibus(explainer, new_observation, y, variables, grid_points, variable_splits) 41 | return cp_profile 42 | 43 | 44 | def _get_variables(variables, explainer): 45 | """ 46 | Get valid variables for the profile 47 | 48 | :param variables: collection of variables 49 | :param explainer: Explainer object 50 | :return: collection of variables 51 | """ 52 | if variables: 53 | if not set(variables).issubset(explainer.var_names): 54 | raise ValueError('Invalid variable names') 55 | else: 56 | variables = explainer.var_names 57 | return variables 58 | 59 | 60 | def _valid_variable_splits(variable_splits, variables): 61 | """ 62 | Validate variable splits 63 | """ 64 | if set(variable_splits.keys()) == set(variables): 65 | return True 66 | else: 67 | logging.warning("Variable splits are incorrect - wrong set of variables supplied. Parameter is ignored") 68 | return False 69 | 70 | 71 | class CeterisParibus: 72 | 73 | def __init__(self, explainer, new_observation, y, selected_variables, grid_points, variable_splits): 74 | """ 75 | Creates Ceteris Paribus object 76 | 77 | :param explainer: explainer wrapping the model 78 | :param new_observation: DataFrame with observations for which the profiles will be calculated 79 | :param y: pandas Series with labels for the observations 80 | :param selected_variables: variables for which the profiles are calculated 81 | :param grid_points: number of points in a single variable split if calculated automatically 82 | :param variable_splits: mapping of variables into points the profile will be calculated, if None then calculate with the function `_calculate_variable_splits` 83 | """ 84 | self._data = explainer.data 85 | self._predict_function = explainer.predict_fun 86 | self._grid_points = grid_points 87 | self._label = explainer.label 88 | self.all_variable_names = explainer.var_names 89 | self.new_observation = new_observation 90 | self.selected_variables = list(selected_variables) 91 | variable_splits = self._get_variable_splits(variable_splits) 92 | self.profile = self._calculate_profile(variable_splits) 93 | self.new_observation_values = self.new_observation[self.selected_variables] 94 | self.new_observation_predictions = self._predict_function(self.new_observation) 95 | self.new_observation_true = y 96 | 97 | def _get_variable_splits(self, variable_splits): 98 | """ 99 | Helper function for calculating variable splits 100 | """ 101 | if variable_splits is None or not _valid_variable_splits(variable_splits, self.selected_variables): 102 | variables_dict = self._data.to_dict(orient='series') 103 | chosen_variables_dict = dict((var, variables_dict[var]) for var in self.selected_variables) 104 | variable_splits = self._calculate_variable_splits(chosen_variables_dict) 105 | return variable_splits 106 | 107 | def _calculate_profile(self, variable_splits): 108 | """ 109 | Calculate DataFrame profile 110 | """ 111 | profiles_list = [self._single_variable_df(var_name, var_split) 112 | for var_name, var_split in variable_splits.items()] 113 | profile = pd.concat(profiles_list, ignore_index=True) 114 | return profile 115 | 116 | def _calculate_single_split(self, X_var): 117 | """ 118 | Calculate the split for a single variable 119 | 120 | :param X_var: variable data - pandas Series 121 | :return: selected subset of values for the variable 122 | """ 123 | if np.issubdtype(X_var.dtype, np.floating): 124 | # grid points might be larger than the number of unique values 125 | quantiles = np.linspace(0, 1, self._grid_points) 126 | return np.quantile(X_var, quantiles) 127 | else: 128 | return np.unique(X_var) 129 | 130 | def _calculate_variable_splits(self, chosen_variables_dict): 131 | """ 132 | Calculate splits for the given variables 133 | 134 | :param chosen_variables_dict: mapping of variables into the values 135 | :return: mapping of variables into selected subsets of values 136 | """ 137 | return dict( 138 | (var, self._calculate_single_split(X_var)) 139 | for (var, X_var) in chosen_variables_dict.items() 140 | ) 141 | 142 | def _single_variable_df(self, var_name, var_split): 143 | """ 144 | Calculate profiles for a given variable 145 | 146 | :param var_name: variable name 147 | :param var_split: split values for the variable 148 | :return: DataFrame with profiles for a given variable 149 | """ 150 | return pd.concat([self._single_observation_df(observation, var_name, var_split, profile_id) 151 | for profile_id, observation in self.new_observation.iterrows()], ignore_index=True) 152 | 153 | def _single_observation_df(self, observation, var_name, var_split, profile_id): 154 | """ 155 | Calculates the single profile 156 | 157 | :param observation: observation for which the profile is calculated 158 | :param var_name: variable name 159 | :param var_split: split values for the variable 160 | :param profile_id: profile id 161 | :return: DataFrame with the calculated profile values 162 | """ 163 | # grid_points and self._grid_point might differ for categorical variables 164 | grid_points = len(var_split) 165 | X = np.tile(observation, (grid_points, 1)) 166 | X_dict = OrderedDict(zip(self.all_variable_names, X.T)) 167 | df = pd.DataFrame.from_dict(X_dict) 168 | df[var_name] = var_split 169 | df['_yhat_'] = self._predict_function(df) 170 | df['_vname_'] = np.repeat(var_name, grid_points) 171 | df['_label_'] = self._label 172 | df['_ids_'] = profile_id 173 | return df 174 | 175 | def split_by(self, column): 176 | """ 177 | Split cp profile data frame by values of a given column 178 | 179 | :return: sorted mapping of values to dataframes 180 | """ 181 | return OrderedDict(sorted(list(self.profile.groupby(column, sort=False)))) 182 | 183 | def set_label(self, label): 184 | self._label = label 185 | 186 | def print_profile(self): 187 | print('Selected variables: {}'.format(self.selected_variables)) 188 | print('Training data size: {}'.format(self._data.shape[0])) 189 | print(self.profile) 190 | 191 | -------------------------------------------------------------------------------- /ceteris_paribus/select_data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from ceteris_paribus.gower import gower_distances 7 | from ceteris_paribus.utils import transform_into_Series 8 | 9 | 10 | def select_sample(data, y=None, n=15, seed=42): 11 | """ 12 | Select sample from dataset. 13 | 14 | :param data: array or dataframe with observations 15 | :param y: labels for observations 16 | :param n: size of the sample 17 | :param seed: seed for random number generator 18 | :return: selected observations and corresponding labels if provided 19 | """ 20 | np.random.seed(seed) 21 | if n > data.shape[0]: 22 | logging.warning("Given n ({}) is larger than data size ({})".format(n, data.shape[0])) 23 | n = data.shape[0] 24 | indices = np.random.choice(data.shape[0], n, replace=False) 25 | 26 | if isinstance(data, pd.core.frame.DataFrame): 27 | sampled_x = data.iloc[indices] 28 | sampled_x.reset_index(drop=True, inplace=True) 29 | else: 30 | sampled_x = data[indices, :] 31 | 32 | if y is not None: 33 | y = transform_into_Series(y) 34 | return sampled_x, y[indices].reset_index(drop=True) 35 | else: 36 | return sampled_x 37 | 38 | 39 | def _select_columns(data, observation, variable_names=None, selected_variables=None): 40 | """ 41 | Select data with specified columns 42 | 43 | :param data: DataFrame with observations 44 | :param observation: pandas Series with reference observation for neighbours selection 45 | :param variable_names: names of all variables 46 | :param selected_variables: names of selected variables 47 | :return: DataFrame with observations and pandas Series with referenced observation, with selected columns 48 | """ 49 | if selected_variables is None: 50 | return data, observation 51 | try: 52 | if variable_names is None: 53 | if isinstance(data, pd.core.frame.DataFrame): 54 | variable_names = data.columns 55 | else: 56 | raise ValueError("Impossible to detect variable names") 57 | indices = [list(variable_names).index(var) for var in selected_variables] 58 | except ValueError: 59 | logging.warning("Selected variables: {} is not a subset of variables: {}".format( 60 | selected_variables, variable_names)) 61 | return data, observation 62 | 63 | subset_data = data.iloc[:, indices] 64 | return subset_data, observation[indices] 65 | 66 | 67 | def select_neighbours(data, observation, y=None, variable_names=None, selected_variables=None, dist_fun='gower', n=20): 68 | """ 69 | Select observations from dataset, that are similar to a given observation 70 | 71 | :param data: array or DataFrame with observations 72 | :param observation: reference observation for neighbours selection 73 | :param y: labels for observations 74 | :param variable_names: names of variables 75 | :param selected_variables: selected variables - require supplying variable names along with data 76 | :param dist_fun: 'gower' or distance function, as pairwise distances in sklearn, gower works with missing data 77 | :param n: size of the sample 78 | :return: DataFrame with selected observations and pandas Series with corresponding labels if provided 79 | """ 80 | if n > data.shape[0]: 81 | logging.warning("Given n ({}) is larger than data size ({})".format(n, data.shape[0])) 82 | n = data.shape[0] 83 | 84 | if not isinstance(data, pd.core.frame.DataFrame): 85 | data = pd.DataFrame(data) 86 | 87 | observation = transform_into_Series(observation) 88 | 89 | # columns are selected for the purpose of distance calculation 90 | selected_data, observation = _select_columns(data, observation, variable_names, selected_variables) 91 | 92 | if dist_fun == 'gower': 93 | distances = gower_distances(selected_data, observation) 94 | else: 95 | if not callable(dist_fun): 96 | raise ValueError('Distance has to be "gower" or a custom function') 97 | distances = dist_fun([observation], selected_data)[0] 98 | 99 | indices = np.argpartition(distances, n - 1)[:n] 100 | 101 | # selected points have all variables 102 | selected_points = data.iloc[indices] 103 | selected_points.reset_index(drop=True, inplace=True) 104 | 105 | if y is not None: 106 | y = transform_into_Series(y) 107 | return selected_points, y.iloc[indices].reset_index(drop=True) 108 | else: 109 | return selected_points 110 | -------------------------------------------------------------------------------- /ceteris_paribus/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def save_profiles(profiles, filename): 8 | data = dump_profiles(profiles) 9 | with open(filename, 'w') as f: 10 | f.write("profile = {};".format(json.dumps(data, indent=2, default=default))) 11 | 12 | 13 | def dump_profiles(profiles): 14 | """ 15 | Dump profiles into json format accepted by the plotting library 16 | 17 | :return: list of dicts representing points in the profiles 18 | """ 19 | data = [] 20 | for cp_profile in profiles: 21 | for i, row in cp_profile.profile.iterrows(): 22 | data.append(dict(zip(cp_profile.profile.columns, row))) 23 | return data 24 | 25 | 26 | def default(o): 27 | """ 28 | Workaround for dumping arrays with np.int64 type into json 29 | From: https://stackoverflow.com/a/50577730/7828646 30 | 31 | """ 32 | if isinstance(o, np.int64): 33 | return int(o) 34 | return float(o) 35 | 36 | 37 | def save_observations(profiles, filename): 38 | data = dump_observations(profiles) 39 | with open(filename, 'w') as f: 40 | f.write("observation = {};".format(json.dumps(data, indent=2, default=default))) 41 | 42 | 43 | def dump_observations(profiles): 44 | """ 45 | Dump observations data into json format accepted by the plotting library 46 | 47 | :return: list of dicts representing observations in the profiles 48 | """ 49 | data = [] 50 | for profile in profiles: 51 | for i, yhat in enumerate(profile.new_observation_predictions): 52 | d = dict(zip(profile.all_variable_names, profile.new_observation.iloc[i])) 53 | d['_yhat_'] = yhat 54 | d['_label_'] = profile._label 55 | d['_ids_'] = i 56 | d['_y_'] = profile.new_observation_true[i] if profile.new_observation_true is not None else None 57 | data.append(d) 58 | return data 59 | 60 | 61 | def transform_into_Series(y): 62 | if isinstance(y, pd.core.frame.DataFrame): 63 | y = pd.Series(y.iloc[:, 0]) 64 | else: 65 | y = pd.Series(y) 66 | return y 67 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | ignore: 2 | - "examples/" 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/ceteris_paribus.plots.rst: -------------------------------------------------------------------------------- 1 | ceteris\_paribus.plots package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | ceteris\_paribus.plots.plots module 8 | ----------------------------------- 9 | 10 | .. automodule:: ceteris_paribus.plots.plots 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: ceteris_paribus.plots 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/ceteris_paribus.rst: -------------------------------------------------------------------------------- 1 | ceteris\_paribus package 2 | ======================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | ceteris_paribus.plots 10 | 11 | Submodules 12 | ---------- 13 | 14 | ceteris\_paribus.explainer module 15 | --------------------------------- 16 | 17 | .. automodule:: ceteris_paribus.explainer 18 | :members: 19 | :undoc-members: 20 | 21 | ceteris\_paribus.gower module 22 | ----------------------------- 23 | Gower Distance is a distance measure, that might be used to calculate the similarity between two observations with both categorical and numerical values. It also permits missing values in categorical variables. Therefore this measure might be applied in any dataset. Here we use it as a default function for finding the closest observations to the given one. 24 | 25 | The original paper describing the idea might be found `here `_. 26 | 27 | .. automodule:: ceteris_paribus.gower 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | ceteris\_paribus.profiles module 33 | -------------------------------- 34 | 35 | .. automodule:: ceteris_paribus.profiles 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | 40 | ceteris\_paribus.select\_data module 41 | ------------------------------------ 42 | 43 | .. automodule:: ceteris_paribus.select_data 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | 48 | 49 | Module contents 50 | --------------- 51 | 52 | .. automodule:: ceteris_paribus 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath('../')) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = 'pyCeterisParibus' 23 | copyright = '2019, Michał Kuźba' 24 | author = 'Michał Kuźba' 25 | 26 | # The short X.Y version 27 | version = '' 28 | # The full version, including alpha/beta/rc tags 29 | release = '' 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # If your documentation needs a minimal Sphinx version, state it here. 34 | # 35 | # needs_sphinx = '1.0' 36 | 37 | # Add any Sphinx extension module names here, as strings. They can be 38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 39 | # ones. 40 | extensions = [ 41 | 'sphinx.ext.autodoc', 42 | 'm2r' 43 | ] 44 | 45 | # Add any paths that contain templates here, relative to this directory. 46 | templates_path = ['_templates'] 47 | 48 | # The suffix(es) of source filenames. 49 | # You can specify multiple suffix as a list of string: 50 | # 51 | source_suffix = ['.rst', '.md'] 52 | # source_suffix = '.rst' 53 | 54 | # The master toctree document. 55 | master_doc = 'index' 56 | 57 | # The language for content autogenerated by Sphinx. Refer to documentation 58 | # for a list of supported languages. 59 | # 60 | # This is also used if you do content translation via gettext catalogs. 61 | # Usually you set "language" from the command line for these cases. 62 | language = None 63 | 64 | # List of patterns, relative to source directory, that match files and 65 | # directories to ignore when looking for source files. 66 | # This pattern also affects html_static_path and html_extra_path. 67 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 68 | 69 | # The name of the Pygments (syntax highlighting) style to use. 70 | pygments_style = None 71 | 72 | # -- Options for HTML output ------------------------------------------------- 73 | 74 | # The theme to use for HTML and HTML Help pages. See the documentation for 75 | # a list of builtin themes. 76 | # 77 | html_theme = 'default' 78 | # html_theme_path = sphinx_bootstrap_theme.get_html_theme_path() 79 | 80 | # Theme options are theme-specific and customize the look and feel of a theme 81 | # further. For a list of options available for each theme, see the 82 | # documentation. 83 | # 84 | # html_theme_options = {} 85 | 86 | # Add any paths that contain custom static files (such as style sheets) here, 87 | # relative to this directory. They are copied after the builtin static files, 88 | # so a file named "default.css" will overwrite the builtin "default.css". 89 | html_static_path = ['_static'] 90 | 91 | # Custom sidebar templates, must be a dictionary that maps document names 92 | # to template names. 93 | # 94 | # The default sidebars (for documents that don't match any pattern) are 95 | # defined by theme itself. Builtin themes are using these templates by 96 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 97 | # 'searchbox.html']``. 98 | # 99 | # html_sidebars = {} 100 | 101 | 102 | # -- Options for HTMLHelp output --------------------------------------------- 103 | 104 | # Output file base name for HTML help builder. 105 | htmlhelp_basename = 'pyCeterisParibusdoc' 106 | 107 | # -- Options for LaTeX output ------------------------------------------------ 108 | 109 | latex_elements = { 110 | # The paper size ('letterpaper' or 'a4paper'). 111 | # 112 | # 'papersize': 'letterpaper', 113 | 114 | # The font size ('10pt', '11pt' or '12pt'). 115 | # 116 | # 'pointsize': '10pt', 117 | 118 | # Additional stuff for the LaTeX preamble. 119 | # 120 | # 'preamble': '', 121 | 122 | # Latex figure (float) alignment 123 | # 124 | # 'figure_align': 'htbp', 125 | } 126 | 127 | # Grouping the document tree into LaTeX files. List of tuples 128 | # (source start file, target name, title, 129 | # author, documentclass [howto, manual, or own class]). 130 | latex_documents = [ 131 | (master_doc, 'pyCeterisParibus.tex', 'pyCeterisParibus Documentation', 132 | 'Michał Kuźba', 'manual'), 133 | ] 134 | 135 | # -- Options for manual page output ------------------------------------------ 136 | 137 | # One entry per manual page. List of tuples 138 | # (source start file, name, description, authors, manual section). 139 | man_pages = [ 140 | (master_doc, 'pyceterisparibus', 'pyCeterisParibus Documentation', 141 | [author], 1) 142 | ] 143 | 144 | # -- Options for Texinfo output ---------------------------------------------- 145 | 146 | # Grouping the document tree into Texinfo files. List of tuples 147 | # (source start file, target name, title, author, 148 | # dir menu entry, description, category) 149 | texinfo_documents = [ 150 | (master_doc, 'pyCeterisParibus', 'pyCeterisParibus Documentation', 151 | author, 'pyCeterisParibus', 'One line description of project.', 152 | 'Miscellaneous'), 153 | ] 154 | 155 | # -- Options for Epub output ------------------------------------------------- 156 | 157 | # Bibliographic Dublin Core info. 158 | epub_title = project 159 | 160 | # The unique identifier of the text. This can be a ISBN number 161 | # or the project homepage. 162 | # 163 | # epub_identifier = '' 164 | 165 | # A unique identification for the text. 166 | # 167 | # epub_uid = '' 168 | 169 | # A list of files that should not be packed into the epub file. 170 | epub_exclude_files = ['search.html'] 171 | 172 | # -- Extension configuration ------------------------------------------------- 173 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. pyCeterisParibus documentation master file, created by 2 | sphinx-quickstart on Tue Jan 29 19:38:46 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | Welcome to pyCeterisParibus's documentation! 8 | ============================================ 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: Contents: 13 | 14 | readme_include 15 | modules 16 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/misc: -------------------------------------------------------------------------------- 1 | ../misc/ -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | ceteris_paribus 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | ceteris_paribus 8 | -------------------------------------------------------------------------------- /docs/readme_include.rst: -------------------------------------------------------------------------------- 1 | .. mdinclude:: ../README.md 2 | .. mdinclude:: ../misc 3 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/examples/__init__.py -------------------------------------------------------------------------------- /examples/categorical_variables.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | from ceteris_paribus.datasets import DATASETS_DIR 6 | from ceteris_paribus.plots.plots import plot 7 | 8 | df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv')) 9 | 10 | x = df.drop(['charges'], inplace=False, axis=1) 11 | 12 | y = df['charges'] 13 | 14 | var_names = list(x) 15 | 16 | from sklearn.pipeline import Pipeline 17 | from sklearn.preprocessing import StandardScaler, OneHotEncoder 18 | from sklearn.compose import ColumnTransformer 19 | from sklearn.ensemble import RandomForestRegressor 20 | 21 | # We create the preprocessing pipelines for both numeric and categorical data. 22 | numeric_features = ['age', 'bmi', 'children'] 23 | numeric_transformer = Pipeline(steps=[ 24 | ('scaler', StandardScaler())]) 25 | 26 | categorical_features = ['sex', 'smoker', 'region'] 27 | categorical_transformer = Pipeline(steps=[ 28 | ('onehot', OneHotEncoder(handle_unknown='ignore'))]) 29 | 30 | preprocessor = ColumnTransformer( 31 | transformers=[ 32 | ('num', numeric_transformer, numeric_features), 33 | ('cat', categorical_transformer, categorical_features)]) 34 | 35 | # Append classifier to preprocessing pipeline. 36 | # Now we have a full prediction pipeline. 37 | clf = Pipeline(steps=[('preprocessor', preprocessor), 38 | ('classifier', RandomForestRegressor())]) 39 | 40 | clf.fit(x, y) 41 | 42 | from ceteris_paribus.explainer import explain 43 | 44 | explainer_cat = explain(clf, var_names, x, y, label="categorical_model") 45 | 46 | from ceteris_paribus.profiles import individual_variable_profile 47 | 48 | cp_cat = individual_variable_profile(explainer_cat, x.iloc[:10], y.iloc[:10]) 49 | 50 | cp_cat.print_profile() 51 | plot(cp_cat) 52 | 53 | plot(cp_cat, color="smoker") 54 | -------------------------------------------------------------------------------- /examples/cheatsheet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from sklearn import ensemble, svm 4 | from sklearn.datasets import load_iris 5 | from sklearn.linear_model import LinearRegression 6 | 7 | from ceteris_paribus.explainer import explain 8 | from ceteris_paribus.plots.plots import plot 9 | from ceteris_paribus.profiles import individual_variable_profile 10 | from ceteris_paribus.select_data import select_neighbours 11 | from ceteris_paribus.datasets import DATASETS_DIR 12 | 13 | df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv')) 14 | 15 | df = df[['age', 'bmi', 'children', 'charges']] 16 | 17 | x = df.drop(['charges'], inplace=False, axis=1) 18 | y = df['charges'] 19 | var_names = list(x.columns) 20 | x = x.values 21 | y = y.values 22 | 23 | iris = load_iris() 24 | 25 | 26 | def random_forest_classifier(): 27 | rf_model = ensemble.RandomForestClassifier(n_estimators=100, random_state=42) 28 | 29 | rf_model.fit(iris['data'], iris['target']) 30 | 31 | return rf_model, iris['data'], iris['target'], iris['feature_names'] 32 | 33 | 34 | def linear_regression_model(): 35 | linear_model = LinearRegression() 36 | linear_model.fit(x, y) 37 | # model, data, labels, variable_names 38 | return linear_model, x, y, var_names 39 | 40 | 41 | def gradient_boosting_model(): 42 | gb_model = ensemble.GradientBoostingRegressor(n_estimators=1000, random_state=42) 43 | gb_model.fit(x, y) 44 | return gb_model, x, y, var_names 45 | 46 | 47 | def supported_vector_machines_model(): 48 | svm_model = svm.SVR(C=0.01, gamma='scale', kernel='poly') 49 | svm_model.fit(x, y) 50 | return svm_model, x, y, var_names 51 | 52 | 53 | if __name__ == "__main__": 54 | 55 | (linear_model, data, labels, variable_names) = linear_regression_model() 56 | (gb_model, _, _, _) = gradient_boosting_model() 57 | (svm_model, _, _, _) = supported_vector_machines_model() 58 | 59 | explainer_linear = explain(linear_model, variable_names, data, y) 60 | explainer_gb = explain(gb_model, variable_names, data, y) 61 | explainer_svm = explain(svm_model, variable_names, data, y) 62 | 63 | # single profile 64 | cp_1 = individual_variable_profile(explainer_gb, x[0], y[0]) 65 | plot(cp_1, destination="notebook", selected_variables=["bmi"], print_observations=False) 66 | 67 | # local fit 68 | neighbours_x, neighbours_y = select_neighbours(x, x[10], y=y, n=10) 69 | cp_2 = individual_variable_profile(explainer_gb, 70 | neighbours_x, neighbours_y) 71 | plot(cp_2, show_residuals=True, selected_variables=["age"], print_observations=False, color_residuals='red', 72 | plot_title='') 73 | 74 | # aggregate profiles 75 | plot(cp_2, aggregate_profiles="mean", selected_variables=["age"], color_pdps='black', size_pdps=6, 76 | alpha_pdps=0.7, print_observations=False, 77 | plot_title='') 78 | 79 | # many variables 80 | plot(cp_1, selected_variables=["bmi", "age", "children"], print_observations=False, plot_title='', width=950) 81 | 82 | # many models 83 | cp_svm = individual_variable_profile(explainer_svm, x[0], y[0]) 84 | cp_linear = individual_variable_profile(explainer_linear, x[0], y[0]) 85 | plot(cp_1, cp_svm, cp_linear, print_observations=False, plot_title='', width=1050) 86 | 87 | # color by feature 88 | plot(cp_2, color="bmi", print_observations=False, plot_title='', width=1050, selected_variables=["age"], size=3) 89 | 90 | # classification multiplot 91 | rf_model, iris_x, iris_y, iris_var_names = random_forest_classifier() 92 | 93 | explainer_rf1 = explain(rf_model, iris_var_names, iris_x, iris_y, 94 | predict_function= lambda X: rf_model.predict_proba(X)[::, 0], label=iris.target_names[0]) 95 | explainer_rf2 = explain(rf_model, iris_var_names, iris_x, iris_y, 96 | predict_function= lambda X: rf_model.predict_proba(X)[::, 1], label=iris.target_names[1]) 97 | explainer_rf3 = explain(rf_model, iris_var_names, iris_x, iris_y, 98 | predict_function= lambda X: rf_model.predict_proba(X)[::, 2], label=iris.target_names[2]) 99 | 100 | 101 | cp_rf1 = individual_variable_profile(explainer_rf1, iris_x[0], iris_y[0]) 102 | cp_rf2 = individual_variable_profile(explainer_rf2, iris_x[0], iris_y[0]) 103 | cp_rf3 = individual_variable_profile(explainer_rf3, iris_x[0], iris_y[0]) 104 | 105 | plot(cp_rf1, cp_rf2, cp_rf3, selected_variables=['petal length (cm)', 'petal width (cm)', 'sepal length (cm)'], 106 | plot_title='', print_observations=False, width=1050) -------------------------------------------------------------------------------- /examples/classification_example.py: -------------------------------------------------------------------------------- 1 | from sklearn.datasets import load_iris 2 | from sklearn.ensemble import RandomForestClassifier 3 | from sklearn.model_selection import train_test_split 4 | 5 | from ceteris_paribus.explainer import explain 6 | from ceteris_paribus.plots.plots import plot 7 | from ceteris_paribus.profiles import individual_variable_profile 8 | 9 | iris = load_iris() 10 | 11 | X = iris['data'] 12 | y = iris['target'] 13 | 14 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) 15 | 16 | print(iris['feature_names']) 17 | 18 | def random_forest_classifier(): 19 | rf_model = RandomForestClassifier(n_estimators=100, random_state=42) 20 | 21 | rf_model.fit(X_train, y_train) 22 | 23 | return rf_model, X_train, y_train, iris['feature_names'] 24 | 25 | 26 | if __name__ == "__main__": 27 | (model, data, labels, variable_names) = random_forest_classifier() 28 | predict_function = lambda X: model.predict_proba(X)[::, 0] 29 | explainer_rf = explain(model, variable_names, data, labels, predict_function=predict_function) 30 | cp_profile = individual_variable_profile(explainer_rf, X[1], y=y[1]) 31 | plot(cp_profile) 32 | -------------------------------------------------------------------------------- /examples/keras_example.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Dense, Activation 2 | from keras.models import Sequential 3 | from keras.wrappers.scikit_learn import KerasRegressor 4 | from sklearn.datasets import load_boston 5 | from sklearn.model_selection import train_test_split 6 | from sklearn.pipeline import Pipeline 7 | from sklearn.preprocessing import StandardScaler 8 | 9 | from ceteris_paribus.explainer import explain 10 | from ceteris_paribus.plots.plots import plot 11 | from ceteris_paribus.profiles import individual_variable_profile 12 | 13 | boston = load_boston() 14 | x = boston.data 15 | y = boston.target 16 | x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) 17 | 18 | 19 | def network_architecture(): 20 | model = Sequential() 21 | model.add(Dense(640, input_dim=x.shape[1])) 22 | model.add(Activation('tanh')) 23 | model.add(Dense(320)) 24 | model.add(Activation('tanh')) 25 | model.add(Dense(1)) 26 | model.compile(loss='mean_squared_error', optimizer='adam') 27 | return model 28 | 29 | 30 | def keras_model(): 31 | estimators = [ 32 | ('scaler', StandardScaler()), 33 | ('mlp', KerasRegressor(build_fn=network_architecture, epochs=200)) 34 | ] 35 | model = Pipeline(estimators) 36 | model.fit(x_train, y_train) 37 | return model, x_train, y_train, boston.feature_names 38 | 39 | 40 | if __name__ == "__main__": 41 | model, x_train, y_train, var_names = keras_model() 42 | explainer_keras = explain(model, var_names, x_train, y_train, label='KerasMLP') 43 | cp = individual_variable_profile(explainer_keras, x_train[:10], y=y_train[:10], 44 | variables=["CRIM", "ZN", "AGE", "INDUS", "B"]) 45 | plot(cp, show_residuals=True, selected_variables=["CRIM", "ZN", "AGE", "B"], show_observations=True, 46 | show_rugs=True) 47 | -------------------------------------------------------------------------------- /examples/multiple_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from sklearn import ensemble, svm 5 | from sklearn.linear_model import LinearRegression 6 | 7 | from ceteris_paribus.datasets import DATASETS_DIR 8 | from ceteris_paribus.explainer import explain 9 | from ceteris_paribus.plots.plots import plot 10 | from ceteris_paribus.profiles import individual_variable_profile 11 | from ceteris_paribus.select_data import select_sample 12 | 13 | df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv')) 14 | 15 | df = df[['age', 'bmi', 'children', 'charges']] 16 | 17 | x = df.drop(['charges'], inplace=False, axis=1) 18 | y = df['charges'] 19 | var_names = list(x.columns) 20 | x = x.values 21 | y = y.values 22 | 23 | 24 | def linear_regression_model(): 25 | # Create linear regression object 26 | linear_model = LinearRegression() 27 | 28 | # Train the model using the training set 29 | linear_model.fit(x, y) 30 | 31 | # model, data, labels, variable_names 32 | return linear_model, x, y, var_names 33 | 34 | 35 | def gradient_boosting_model(): 36 | gb_model = ensemble.GradientBoostingRegressor(n_estimators=1000, random_state=42) 37 | gb_model.fit(x, y) 38 | return gb_model, x, y, var_names 39 | 40 | 41 | def supported_vector_machines_model(): 42 | svm_model = svm.SVR(C=0.01, gamma='scale') 43 | svm_model.fit(x, y) 44 | return svm_model, x, y, var_names 45 | 46 | 47 | if __name__ == "__main__": 48 | (linear_model, data, labels, variable_names) = linear_regression_model() 49 | (gb_model, _, _, _) = gradient_boosting_model() 50 | (svm_model, _, _, _) = supported_vector_machines_model() 51 | 52 | explainer_linear = explain(linear_model, variable_names, data, y) 53 | explainer_gb = explain(gb_model, variable_names, data, y) 54 | explainer_svm = explain(svm_model, variable_names, data, y) 55 | 56 | cp_profile = individual_variable_profile(explainer_linear, x[0], y[0]) 57 | plot(cp_profile, show_residuals=True) 58 | 59 | sample_x, sample_y = select_sample(x, y, n=10) 60 | cp2 = individual_variable_profile(explainer_gb, sample_x, y=sample_y) 61 | 62 | cp3 = individual_variable_profile(explainer_gb, x[0], y[0]) 63 | plot(cp3, show_residuals=True) 64 | 65 | plot(cp_profile, cp3, show_residuals=True) 66 | -------------------------------------------------------------------------------- /examples/regression_example.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets, ensemble 2 | from sklearn.model_selection import train_test_split 3 | 4 | from ceteris_paribus.explainer import explain 5 | from ceteris_paribus.plots.plots import plot 6 | from ceteris_paribus.profiles import individual_variable_profile 7 | from ceteris_paribus.select_data import select_sample, select_neighbours 8 | 9 | boston = datasets.load_boston() 10 | 11 | X = boston['data'] 12 | y = boston['target'] 13 | 14 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42) 15 | 16 | 17 | def random_forest_regression(): 18 | # Create linear regression object 19 | rf_model = ensemble.RandomForestRegressor(n_estimators=100, random_state=42) 20 | 21 | # Train the model using the training set 22 | rf_model.fit(X_train, y_train) 23 | 24 | # model, data, labels, variable_names 25 | return rf_model, X_train, y_train, list(boston['feature_names']) 26 | 27 | 28 | if __name__ == "__main__": 29 | (model, data, labels, variable_names) = random_forest_regression() 30 | explainer_rf = explain(model, variable_names, data, labels) 31 | 32 | cp_profile = individual_variable_profile(explainer_rf, X_train[0], y=y_train[0], variables=['TAX', 'CRIM']) 33 | plot(cp_profile) 34 | 35 | sample = select_sample(X_train, n=3) 36 | cp2 = individual_variable_profile(explainer_rf, sample, variables=['TAX', 'CRIM']) 37 | plot(cp2) 38 | 39 | neighbours = select_neighbours(X_train, X_train[0], variable_names=variable_names, 40 | selected_variables=variable_names, n=15) 41 | cp3 = individual_variable_profile(explainer_rf, neighbours, variables=['LSTAT', 'RM'], 42 | variable_splits={'LSTAT': [10, 20, 30], 'RM': [4, 5, 6, 7]}) 43 | plot(cp3) 44 | -------------------------------------------------------------------------------- /jupyter-notebooks/CeterisParibusCheatsheet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Ceteris Paribus cheatsheet" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Prepare dataset" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import pandas as pd\n", 24 | "from ceteris_paribus.datasets import DATASETS_DIR\n", 25 | "import os\n", 26 | "\n", 27 | "df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv'))\n", 28 | "\n", 29 | "df = df[['age', 'bmi', 'children', 'charges']]\n", 30 | "\n", 31 | "x = df.drop(['charges'], inplace=False, axis=1)\n", 32 | "y = df['charges']\n", 33 | "var_names = list(x.columns)\n", 34 | "x = x.values\n", 35 | "y = y.values" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "### Prepare regression models" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "from sklearn.linear_model import LinearRegression\n", 52 | "from sklearn import ensemble, svm\n", 53 | "\n", 54 | "def linear_regression_model():\n", 55 | " linear_model = LinearRegression()\n", 56 | " linear_model.fit(x, y)\n", 57 | " # model, data, labels, variable_names\n", 58 | " return linear_model, x, y, var_names\n", 59 | "\n", 60 | "\n", 61 | "def gradient_boosting_model():\n", 62 | " gb_model = ensemble.GradientBoostingRegressor(n_estimators=1000, random_state=42)\n", 63 | " gb_model.fit(x, y)\n", 64 | " return gb_model, x, y, var_names\n", 65 | "\n", 66 | "\n", 67 | "def supported_vector_machines_model():\n", 68 | " svm_model = svm.SVR(C=0.01, gamma='scale', kernel='poly')\n", 69 | " svm_model.fit(x, y)\n", 70 | " return svm_model, x, y, var_names" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "### Calculate single profile variables" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 3, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "name": "stderr", 87 | "output_type": "stream", 88 | "text": [ 89 | "WARNING:root:Model is unlabeled... \n", 90 | " You can add label using method set_label\n" 91 | ] 92 | }, 93 | { 94 | "data": { 95 | "text/html": [ 96 | "\n", 97 | " \n", 104 | " " 105 | ], 106 | "text/plain": [ 107 | "" 108 | ] 109 | }, 110 | "metadata": {}, 111 | "output_type": "display_data" 112 | } 113 | ], 114 | "source": [ 115 | "from ceteris_paribus.explainer import explain\n", 116 | "from ceteris_paribus.plots.plots import plot_notebook\n", 117 | "from ceteris_paribus.profiles import individual_variable_profile\n", 118 | "\n", 119 | "(gb_model, data, labels, variable_names) = gradient_boosting_model()\n", 120 | "\n", 121 | "explainer_gb = explain(gb_model, variable_names, data, y)\n", 122 | "\n", 123 | "cp_1 = individual_variable_profile(explainer_gb, x[0], y[0])\n", 124 | "plot_notebook(cp_1, selected_variables=[\"bmi\"], print_observations=False)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "### Local fit plots" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 4, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "(svm_model, _, _, _) = supported_vector_machines_model()\n", 141 | "(linear_model, data, labels, variable_names) = linear_regression_model()\n", 142 | "\n", 143 | "explainer_linear = explain(linear_model, variable_names, data, y, label='linear_model')\n", 144 | "explainer_svm = explain(svm_model, variable_names, data, y, label='svm_model')" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 5, 150 | "metadata": {}, 151 | "outputs": [ 152 | { 153 | "data": { 154 | "text/html": [ 155 | "\n", 156 | " \n", 163 | " " 164 | ], 165 | "text/plain": [ 166 | "" 167 | ] 168 | }, 169 | "metadata": {}, 170 | "output_type": "display_data" 171 | } 172 | ], 173 | "source": [ 174 | "from ceteris_paribus.select_data import select_neighbours\n", 175 | "\n", 176 | "neighbours_x, neighbours_y = select_neighbours(x, x[10], y=y, n=10)\n", 177 | "cp_2 = individual_variable_profile(explainer_gb,\n", 178 | " neighbours_x, neighbours_y)\n", 179 | "plot_notebook(cp_2, show_residuals=True, selected_variables=[\"age\"], print_observations=False, color_residuals='red', \n", 180 | " plot_title='')" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "### Aggregate profiles" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 6, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "data": { 197 | "text/html": [ 198 | "\n", 199 | " \n", 206 | " " 207 | ], 208 | "text/plain": [ 209 | "" 210 | ] 211 | }, 212 | "metadata": {}, 213 | "output_type": "display_data" 214 | } 215 | ], 216 | "source": [ 217 | "plot_notebook(cp_2, aggregate_profiles=\"mean\", selected_variables=[\"age\"], color_pdps='black', size_pdps=6,\n", 218 | " alpha_pdps=0.7, print_observations=False,\n", 219 | " plot_title='')" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "### Many variables" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 7, 232 | "metadata": {}, 233 | "outputs": [ 234 | { 235 | "data": { 236 | "text/html": [ 237 | "\n", 238 | " \n", 245 | " " 246 | ], 247 | "text/plain": [ 248 | "" 249 | ] 250 | }, 251 | "metadata": {}, 252 | "output_type": "display_data" 253 | } 254 | ], 255 | "source": [ 256 | "plot_notebook(cp_1, selected_variables=[\"bmi\", \"age\", \"children\"], print_observations=False, plot_title='', width=900)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "### Many models" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 8, 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "data": { 273 | "text/html": [ 274 | "\n", 275 | " \n", 282 | " " 283 | ], 284 | "text/plain": [ 285 | "" 286 | ] 287 | }, 288 | "metadata": {}, 289 | "output_type": "display_data" 290 | } 291 | ], 292 | "source": [ 293 | "cp_svm = individual_variable_profile(explainer_svm, x[0], y[0])\n", 294 | "cp_linear = individual_variable_profile(explainer_linear, x[0], y[0])\n", 295 | "plot_notebook(cp_1, cp_svm, cp_linear, print_observations=False, plot_title='', width=850, size=3, alpha=0.7)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "### Color by feature" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 9, 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "data": { 312 | "text/html": [ 313 | "\n", 314 | " \n", 321 | " " 322 | ], 323 | "text/plain": [ 324 | "" 325 | ] 326 | }, 327 | "metadata": {}, 328 | "output_type": "display_data" 329 | } 330 | ], 331 | "source": [ 332 | "plot_notebook(cp_2, color=\"age\", plot_title='', width=900, size=3)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "markdown", 337 | "metadata": {}, 338 | "source": [ 339 | "### Prepare classification example" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 10, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "from sklearn.datasets import load_iris\n", 349 | "from sklearn.ensemble import RandomForestClassifier\n", 350 | "iris = load_iris()\n", 351 | "\n", 352 | "def random_forest_classifier():\n", 353 | " rf_model = ensemble.RandomForestClassifier(n_estimators=100, random_state=42)\n", 354 | "\n", 355 | " rf_model.fit(iris['data'], iris['target'])\n", 356 | "\n", 357 | " return rf_model, iris['data'], iris['target'], iris['feature_names']" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 11, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "rf_model, iris_x, iris_y, iris_var_names = random_forest_classifier()\n", 367 | "\n", 368 | "explainer_rf1 = explain(rf_model, iris_var_names, iris_x, iris_y,\n", 369 | " predict_function= lambda X: rf_model.predict_proba(X)[::, 0], label=iris.target_names[0])\n", 370 | "explainer_rf2 = explain(rf_model, iris_var_names, iris_x, iris_y,\n", 371 | " predict_function= lambda X: rf_model.predict_proba(X)[::, 1], label=iris.target_names[1])\n", 372 | "explainer_rf3 = explain(rf_model, iris_var_names, iris_x, iris_y,\n", 373 | " predict_function= lambda X: rf_model.predict_proba(X)[::, 2], label=iris.target_names[2])\n", 374 | "\n", 375 | "\n", 376 | "cp_rf1 = individual_variable_profile(explainer_rf1, iris_x[0], iris_y[0])\n", 377 | "cp_rf2 = individual_variable_profile(explainer_rf2, iris_x[0], iris_y[0])\n", 378 | "cp_rf3 = individual_variable_profile(explainer_rf3, iris_x[0], iris_y[0])" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "### Multiclass profiles" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 14, 391 | "metadata": {}, 392 | "outputs": [ 393 | { 394 | "data": { 395 | "text/html": [ 396 | "\n", 397 | " \n", 404 | " " 405 | ], 406 | "text/plain": [ 407 | "" 408 | ] 409 | }, 410 | "metadata": {}, 411 | "output_type": "display_data" 412 | } 413 | ], 414 | "source": [ 415 | "plot_notebook(cp_rf1, cp_rf2, cp_rf3, selected_variables=['petal length (cm)', 'petal width (cm)', 'sepal length (cm)'],\n", 416 | " plot_title='', print_observations=False, width=800, height=300, size=4, alpha=0.9)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [] 425 | } 426 | ], 427 | "metadata": { 428 | "kernelspec": { 429 | "display_name": "ceteris", 430 | "language": "python", 431 | "name": "ceteris" 432 | }, 433 | "language_info": { 434 | "codemirror_mode": { 435 | "name": "ipython", 436 | "version": 3 437 | }, 438 | "file_extension": ".py", 439 | "mimetype": "text/x-python", 440 | "name": "python", 441 | "nbconvert_exporter": "python", 442 | "pygments_lexer": "ipython3", 443 | "version": "3.5.2" 444 | } 445 | }, 446 | "nbformat": 4, 447 | "nbformat_minor": 2 448 | } 449 | -------------------------------------------------------------------------------- /jupyter-notebooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/jupyter-notebooks/__init__.py -------------------------------------------------------------------------------- /jupyter-notebooks/titanic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Titanic survival" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "### Read data" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import pandas as pd\n", 24 | "from ceteris_paribus.datasets import DATASETS_DIR\n", 25 | "import os\n", 26 | "df = pd.read_csv(os.path.join(DATASETS_DIR, 'titanic_train.csv'))" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "data": { 36 | "text/html": [ 37 | "
\n", 38 | "\n", 51 | "\n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | "
PassengerIdSurvivedPclassNameSexAgeSibSpParchTicketFareCabinEmbarked
0103Braund, Mr. Owen Harrismale22.010A/5 211717.2500NaNS
1211Cumings, Mrs. John Bradley (Florence Briggs Th...female38.010PC 1759971.2833C85C
2313Heikkinen, Miss. Lainafemale26.000STON/O2. 31012827.9250NaNS
3411Futrelle, Mrs. Jacques Heath (Lily May Peel)female35.01011380353.1000C123S
4503Allen, Mr. William Henrymale35.0003734508.0500NaNS
\n", 147 | "
" 148 | ], 149 | "text/plain": [ 150 | " PassengerId Survived Pclass \\\n", 151 | "0 1 0 3 \n", 152 | "1 2 1 1 \n", 153 | "2 3 1 3 \n", 154 | "3 4 1 1 \n", 155 | "4 5 0 3 \n", 156 | "\n", 157 | " Name Sex Age SibSp \\\n", 158 | "0 Braund, Mr. Owen Harris male 22.0 1 \n", 159 | "1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 \n", 160 | "2 Heikkinen, Miss. Laina female 26.0 0 \n", 161 | "3 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 \n", 162 | "4 Allen, Mr. William Henry male 35.0 0 \n", 163 | "\n", 164 | " Parch Ticket Fare Cabin Embarked \n", 165 | "0 0 A/5 21171 7.2500 NaN S \n", 166 | "1 0 PC 17599 71.2833 C85 C \n", 167 | "2 0 STON/O2. 3101282 7.9250 NaN S \n", 168 | "3 0 113803 53.1000 C123 S \n", 169 | "4 0 373450 8.0500 NaN S " 170 | ] 171 | }, 172 | "execution_count": 2, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "df.head()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 3, 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "y = df['Survived']\n", 188 | "x = df.drop(['Survived', 'PassengerId', 'Name', 'Cabin', 'Ticket'], inplace=False, axis=1)" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 4, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "data": { 198 | "text/html": [ 199 | "
\n", 200 | "\n", 213 | "\n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | "
PclassSexAgeSibSpParchFareEmbarked
03male22.0107.2500S
11female38.01071.2833C
23female26.0007.9250S
31female35.01053.1000S
43male35.0008.0500S
\n", 279 | "
" 280 | ], 281 | "text/plain": [ 282 | " Pclass Sex Age SibSp Parch Fare Embarked\n", 283 | "0 3 male 22.0 1 0 7.2500 S\n", 284 | "1 1 female 38.0 1 0 71.2833 C\n", 285 | "2 3 female 26.0 0 0 7.9250 S\n", 286 | "3 1 female 35.0 1 0 53.1000 S\n", 287 | "4 3 male 35.0 0 0 8.0500 S" 288 | ] 289 | }, 290 | "execution_count": 4, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "x.head()" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 5, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "valid = x['Age'].isnull() | x['Embarked'].isnull()\n", 306 | "x = x[-valid]\n", 307 | "y = y[-valid]" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 6, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "x['Pclass'] = x['Pclass'].astype('float64')\n", 317 | "x['SibSp'] = x['SibSp'].astype('float64')\n", 318 | "x['Parch'] = x['Parch'].astype('float64')" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 7, 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "data": { 328 | "text/html": [ 329 | "
\n", 330 | "\n", 343 | "\n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | "
PclassSexAgeSibSpParchFareEmbarked
03.0male22.01.00.07.2500S
11.0female38.01.00.071.2833C
23.0female26.00.00.07.9250S
31.0female35.01.00.053.1000S
43.0male35.00.00.08.0500S
\n", 409 | "
" 410 | ], 411 | "text/plain": [ 412 | " Pclass Sex Age SibSp Parch Fare Embarked\n", 413 | "0 3.0 male 22.0 1.0 0.0 7.2500 S\n", 414 | "1 1.0 female 38.0 1.0 0.0 71.2833 C\n", 415 | "2 3.0 female 26.0 0.0 0.0 7.9250 S\n", 416 | "3 1.0 female 35.0 1.0 0.0 53.1000 S\n", 417 | "4 3.0 male 35.0 0.0 0.0 8.0500 S" 418 | ] 419 | }, 420 | "execution_count": 7, 421 | "metadata": {}, 422 | "output_type": "execute_result" 423 | } 424 | ], 425 | "source": [ 426 | "x.head()" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": 8, 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "from sklearn.model_selection import train_test_split\n", 436 | "X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "metadata": {}, 442 | "source": [ 443 | "### Building the models" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 53, 449 | "metadata": {}, 450 | "outputs": [], 451 | "source": [ 452 | "from sklearn.pipeline import Pipeline\n", 453 | "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", 454 | "from sklearn.compose import ColumnTransformer\n", 455 | "\n", 456 | "# We create the preprocessing pipelines for both numeric and categorical data.\n", 457 | "numeric_features = ['Pclass', 'Age', 'SibSp', 'Parch', 'Fare']\n", 458 | "numeric_transformer = Pipeline(steps=[\n", 459 | " ('scaler', StandardScaler())])\n", 460 | "\n", 461 | "categorical_features = ['Embarked', 'Sex']\n", 462 | "categorical_transformer = Pipeline(steps=[\n", 463 | " ('onehot', OneHotEncoder(handle_unknown='ignore'))])\n", 464 | "\n", 465 | "preprocessor = ColumnTransformer(\n", 466 | " transformers=[\n", 467 | " ('num', numeric_transformer, numeric_features),\n", 468 | " ('cat', categorical_transformer, categorical_features)])" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": 54, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "from xgboost import XGBClassifier\n", 478 | "from sklearn.ensemble import RandomForestClassifier\n", 479 | "from sklearn.linear_model import LogisticRegression" 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 55, 485 | "metadata": {}, 486 | "outputs": [], 487 | "source": [ 488 | "# Append classifier to preprocessing pipeline.\n", 489 | "# Now we have a full prediction pipeline.\n", 490 | "xgb_clf = Pipeline(steps=[('preprocessor', preprocessor),\n", 491 | "('classifier', XGBClassifier())])" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 107, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "rf_clf = Pipeline(steps=[('preprocessor', preprocessor),\n", 501 | "('classifier', RandomForestClassifier(n_estimators=100, min_samples_leaf=2))])" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": 108, 507 | "metadata": {}, 508 | "outputs": [], 509 | "source": [ 510 | "linear_clf = Pipeline(steps=[('preprocessor', preprocessor),\n", 511 | "('classifier', LogisticRegression())])" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": {}, 517 | "source": [ 518 | "### Train the models" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 109, 524 | "metadata": {}, 525 | "outputs": [ 526 | { 527 | "data": { 528 | "text/plain": [ 529 | "Pipeline(memory=None,\n", 530 | " steps=[('preprocessor', ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,\n", 531 | " transformer_weights=None,\n", 532 | " transformers=[('num', Pipeline(memory=None,\n", 533 | " steps=[('scaler', StandardScaler(copy=True, with_mean=True, with_std=True))]), ['Pclass', 'Age', 'SibSp', 'Parch...\n", 534 | " reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,\n", 535 | " silent=True, subsample=1))])" 536 | ] 537 | }, 538 | "execution_count": 109, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | } 542 | ], 543 | "source": [ 544 | "xgb_clf.fit(X_train, y_train)" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": 110, 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "data": { 554 | "text/plain": [ 555 | "Pipeline(memory=None,\n", 556 | " steps=[('preprocessor', ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,\n", 557 | " transformer_weights=None,\n", 558 | " transformers=[('num', Pipeline(memory=None,\n", 559 | " steps=[('scaler', StandardScaler(copy=True, with_mean=True, with_std=True))]), ['Pclass', 'Age', 'SibSp', 'Parch...obs=None,\n", 560 | " oob_score=False, random_state=None, verbose=0,\n", 561 | " warm_start=False))])" 562 | ] 563 | }, 564 | "execution_count": 110, 565 | "metadata": {}, 566 | "output_type": "execute_result" 567 | } 568 | ], 569 | "source": [ 570 | "rf_clf.fit(X_train, y_train)" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": 111, 576 | "metadata": {}, 577 | "outputs": [ 578 | { 579 | "name": "stderr", 580 | "output_type": "stream", 581 | "text": [ 582 | "/home/michal/Desktop/dalex/ceteris_env/lib/python3.5/site-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n", 583 | " FutureWarning)\n" 584 | ] 585 | }, 586 | { 587 | "data": { 588 | "text/plain": [ 589 | "Pipeline(memory=None,\n", 590 | " steps=[('preprocessor', ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,\n", 591 | " transformer_weights=None,\n", 592 | " transformers=[('num', Pipeline(memory=None,\n", 593 | " steps=[('scaler', StandardScaler(copy=True, with_mean=True, with_std=True))]), ['Pclass', 'Age', 'SibSp', 'Parch...penalty='l2', random_state=None, solver='warn',\n", 594 | " tol=0.0001, verbose=0, warm_start=False))])" 595 | ] 596 | }, 597 | "execution_count": 111, 598 | "metadata": {}, 599 | "output_type": "execute_result" 600 | } 601 | ], 602 | "source": [ 603 | "linear_clf.fit(X_train, y_train)" 604 | ] 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "metadata": {}, 609 | "source": [ 610 | "### Evaluate the models" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 112, 616 | "metadata": {}, 617 | "outputs": [ 618 | { 619 | "data": { 620 | "text/plain": [ 621 | "array([0, 1])" 622 | ] 623 | }, 624 | "execution_count": 112, 625 | "metadata": {}, 626 | "output_type": "execute_result" 627 | } 628 | ], 629 | "source": [ 630 | "xgb_clf.classes_" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 113, 636 | "metadata": {}, 637 | "outputs": [ 638 | { 639 | "name": "stdout", 640 | "output_type": "stream", 641 | "text": [ 642 | "XGB 0.8041958041958042\n", 643 | "Random Forest 0.7832167832167832\n", 644 | "Linear 0.7972027972027972\n" 645 | ] 646 | } 647 | ], 648 | "source": [ 649 | "from sklearn.metrics import accuracy_score\n", 650 | "print(\"XGB {}\".format(accuracy_score(y_test, xgb_clf.predict(X_test))))\n", 651 | "print(\"Random Forest {}\".format(accuracy_score(y_test, rf_clf.predict(X_test))))\n", 652 | "print(\"Linear {}\".format(accuracy_score(y_test, linear_clf.predict(X_test))))" 653 | ] 654 | }, 655 | { 656 | "cell_type": "markdown", 657 | "metadata": {}, 658 | "source": [ 659 | "### Explain the models" 660 | ] 661 | }, 662 | { 663 | "cell_type": "code", 664 | "execution_count": 114, 665 | "metadata": {}, 666 | "outputs": [], 667 | "source": [ 668 | "from ceteris_paribus.explainer import explain\n", 669 | "\n", 670 | "explainer_xgb = explain(xgb_clf, data=x, y=y, label='XGBoost', predict_function=lambda X: xgb_clf.predict_proba(X)[::, 1])\n", 671 | "explainer_rf = explain(rf_clf, data=x, y=y, label='RandomForest', predict_function=lambda X: rf_clf.predict_proba(X)[::, 1])\n", 672 | "explainer_linear = explain(linear_clf, data=x, y=y, label='LogisticRegression', predict_function=lambda X: linear_clf.predict_proba(X)[::, 1])" 673 | ] 674 | }, 675 | { 676 | "cell_type": "markdown", 677 | "metadata": {}, 678 | "source": [ 679 | "##### Ernest James Crease" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 115, 685 | "metadata": { 686 | "scrolled": true 687 | }, 688 | "outputs": [ 689 | { 690 | "name": "stdout", 691 | "output_type": "stream", 692 | "text": [ 693 | "Referenced observation \n", 694 | "Pclass 3\n", 695 | "Sex male\n", 696 | "Age 19\n", 697 | "SibSp 0\n", 698 | "Parch 0\n", 699 | "Fare 8.1583\n", 700 | "Embarked S\n", 701 | "Name: 67, dtype: object\n" 702 | ] 703 | } 704 | ], 705 | "source": [ 706 | "import warnings\n", 707 | "import sklearn\n", 708 | "warnings.filterwarnings(\"ignore\", category=sklearn.exceptions.DataConversionWarning)\n", 709 | "ernest = X_test.iloc[10]\n", 710 | "label_ernest = y_test.iloc[10]\n", 711 | "print(\"Referenced observation \\n{}\".format(ernest))\n", 712 | "from ceteris_paribus.profiles import individual_variable_profile\n", 713 | "cp_xgb = individual_variable_profile(explainer_xgb, ernest, label_ernest)\n", 714 | "cp_rf = individual_variable_profile(explainer_rf, ernest, label_ernest)\n", 715 | "cp_linear = individual_variable_profile(explainer_linear, ernest, label_ernest)" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": 116, 721 | "metadata": {}, 722 | "outputs": [], 723 | "source": [ 724 | "from ceteris_paribus.plots.plots import plot_notebook, plot" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "execution_count": 117, 730 | "metadata": {}, 731 | "outputs": [ 732 | { 733 | "data": { 734 | "text/html": [ 735 | "\n", 736 | " \n", 743 | " " 744 | ], 745 | "text/plain": [ 746 | "" 747 | ] 748 | }, 749 | "metadata": {}, 750 | "output_type": "display_data" 751 | } 752 | ], 753 | "source": [ 754 | "plot_notebook(cp_xgb, selected_variables=[\"Age\"], width=700, height=800, show_rugs=True, size=4)" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": 120, 760 | "metadata": {}, 761 | "outputs": [ 762 | { 763 | "data": { 764 | "text/html": [ 765 | "\n", 766 | " \n", 773 | " " 774 | ], 775 | "text/plain": [ 776 | "" 777 | ] 778 | }, 779 | "metadata": {}, 780 | "output_type": "display_data" 781 | } 782 | ], 783 | "source": [ 784 | "plot_notebook(cp_xgb, cp_rf, cp_linear, selected_variables=[\"Age\"], width=650, height=800, size=3)" 785 | ] 786 | }, 787 | { 788 | "cell_type": "markdown", 789 | "metadata": {}, 790 | "source": [ 791 | "##### Miss. Elizabeth Mussey Eustis" 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": 30, 797 | "metadata": { 798 | "scrolled": true 799 | }, 800 | "outputs": [ 801 | { 802 | "name": "stdout", 803 | "output_type": "stream", 804 | "text": [ 805 | "Pclass 1\n", 806 | "Sex female\n", 807 | "Age 54\n", 808 | "SibSp 1\n", 809 | "Parch 0\n", 810 | "Fare 78.2667\n", 811 | "Embarked C\n", 812 | "Name: 496, dtype: object\n" 813 | ] 814 | } 815 | ], 816 | "source": [ 817 | "elizabeth = X_test.iloc[1]\n", 818 | "print(elizabeth)\n", 819 | "label_elizabeth = y_test.iloc[1]\n", 820 | "cp_xgb_2 = individual_variable_profile(explainer_xgb, elizabeth, label_elizabeth)" 821 | ] 822 | }, 823 | { 824 | "cell_type": "code", 825 | "execution_count": 31, 826 | "metadata": {}, 827 | "outputs": [ 828 | { 829 | "data": { 830 | "text/html": [ 831 | "\n", 832 | " \n", 839 | " " 840 | ], 841 | "text/plain": [ 842 | "" 843 | ] 844 | }, 845 | "metadata": {}, 846 | "output_type": "display_data" 847 | } 848 | ], 849 | "source": [ 850 | "plot_notebook(cp_xgb_2, selected_variables=[\"Pclass\", \"Sex\", \"Age\", \"Embarked\"], width=900, height=1000, size=4)" 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "execution_count": 32, 856 | "metadata": {}, 857 | "outputs": [], 858 | "source": [ 859 | "from ceteris_paribus.select_data import select_neighbours" 860 | ] 861 | }, 862 | { 863 | "cell_type": "code", 864 | "execution_count": 33, 865 | "metadata": {}, 866 | "outputs": [], 867 | "source": [ 868 | "neighbours = select_neighbours(X_train, elizabeth, selected_variables=['Pclass', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], n=15)" 869 | ] 870 | }, 871 | { 872 | "cell_type": "code", 873 | "execution_count": 34, 874 | "metadata": { 875 | "scrolled": true 876 | }, 877 | "outputs": [], 878 | "source": [ 879 | "cp_xgb_ns = individual_variable_profile(explainer_xgb, neighbours)" 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": 35, 885 | "metadata": {}, 886 | "outputs": [ 887 | { 888 | "data": { 889 | "text/html": [ 890 | "\n", 891 | " \n", 898 | " " 899 | ], 900 | "text/plain": [ 901 | "" 902 | ] 903 | }, 904 | "metadata": {}, 905 | "output_type": "display_data" 906 | } 907 | ], 908 | "source": [ 909 | "plot_notebook(cp_xgb_ns, color=\"Sex\", selected_variables=[\"Pclass\", \"Age\"],\n", 910 | " height=600, width=1000, \n", 911 | " aggregate_profiles='mean', size_pdps=6, alpha_pdps=1, size=2)" 912 | ] 913 | } 914 | ], 915 | "metadata": { 916 | "kernelspec": { 917 | "display_name": "ceteris", 918 | "language": "python", 919 | "name": "ceteris" 920 | }, 921 | "language_info": { 922 | "codemirror_mode": { 923 | "name": "ipython", 924 | "version": 3 925 | }, 926 | "file_extension": ".py", 927 | "mimetype": "text/x-python", 928 | "name": "python", 929 | "nbconvert_exporter": "python", 930 | "pygments_lexer": "ipython3", 931 | "version": "3.5.2" 932 | } 933 | }, 934 | "nbformat": 4, 935 | "nbformat_minor": 2 936 | } 937 | -------------------------------------------------------------------------------- /misc/multiclass_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/misc/multiclass_models.png -------------------------------------------------------------------------------- /misc/titanic_interactions_average.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/misc/titanic_interactions_average.png -------------------------------------------------------------------------------- /misc/titanic_many_models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/misc/titanic_many_models.png -------------------------------------------------------------------------------- /misc/titanic_many_variables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/misc/titanic_many_variables.png -------------------------------------------------------------------------------- /misc/titanic_single_response.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/misc/titanic_single_response.png -------------------------------------------------------------------------------- /paper/img/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/paper/img/figure1.png -------------------------------------------------------------------------------- /paper/img/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/paper/img/figure2.png -------------------------------------------------------------------------------- /paper/img/figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/paper/img/figure3.png -------------------------------------------------------------------------------- /paper/img/figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/paper/img/figure4.png -------------------------------------------------------------------------------- /paper/paper.bib: -------------------------------------------------------------------------------- 1 | 2 | 3 | @book{ONeil, 4 | author = {O'Neil, Cathy}, 5 | title = {Weapons of Math Destruction: How Big Data Increases Inequality and Threatens Democracy}, 6 | year = {2016}, 7 | isbn = {0553418815, 9780553418811}, 8 | publisher = {Crown Publishing Group}, 9 | address = {New York, NY, USA}, 10 | } 11 | 12 | @misc{RightToExpl3, 13 | title={Machine learning and the right to explanation in GDPR}, 14 | author={Javier Ruiz}, 15 | year={2018}, 16 | howpublished = {\url{https://www.openrightsgroup.org/blog/2018/machine-learning-and-the-right-to-explanation-in-gdpr}} 17 | } 18 | 19 | @article{DALEX, 20 | title={DALEX: Explainers for Complex Predictive Models in R}, 21 | author={Biecek, Przemysław}, 22 | journal={Journal of Machine Learning Research}, 23 | volume={19}, 24 | pages={1--5}, 25 | year={2018} 26 | } 27 | 28 | @misc{CeterisParibus, 29 | author = {Biecek, Przemysław}, 30 | title = {Ceteris Paribus Plots (What-If plots) for explanations of a single observation}, 31 | year = {2019}, 32 | publisher = {GitHub}, 33 | journal = {GitHub repository}, 34 | howpublished = {\url{https://github.com/pbiecek/ceterisParibus}}, 35 | commit = {5eb44b93a52ee3f9c84dc455e8fcc6a81de4c6e7} 36 | } 37 | 38 | @incollection{SHAP, 39 | title = {A Unified Approach to Interpreting Model Predictions}, 40 | author = {Lundberg, Scott M and Lee, Su-In}, 41 | booktitle = {Advances in Neural Information Processing Systems 30}, 42 | editor = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett}, 43 | pages = {4765--4774}, 44 | year = {2017}, 45 | publisher = {Curran Associates, Inc.}, 46 | url = {http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf} 47 | } 48 | 49 | 50 | 51 | @inbook{LIME, 52 | title={Why Should I Trust You?: Explaining the Predictions of Any Classifier}, 53 | ISBN={978-1-4503-4232-2}, 54 | url={http://dl.acm.org/citation.cfm?doid=2939672.2939778}, 55 | DOI={10.1145/2939672.2939778}, 56 | publisher={ACM Press}, 57 | author={Ribeiro, Marco Tulio and Singh, Sameer and Guestrin, Carlos}, 58 | year={2016}, 59 | pages={1135–1144} 60 | } 61 | 62 | @Article{GoldsteinICE, 63 | title = {Peeking Inside the Black Box: Visualizing Statistical Learning With Plots of Individual Conditional Expectation}, 64 | author = {Alex Goldstein and Adam Kapelner and Justin Bleich and Emil Pitkin}, 65 | journal = {Journal of Computational and Graphical Statistics}, 66 | volume = {24}, 67 | number = {1}, 68 | pages = {44--65}, 69 | doi = {10.1080/10618600.2014.907095}, 70 | year = {2015}, 71 | } 72 | 73 | @article{friedman2001, 74 | author = "Friedman, Jerome H.", 75 | doi = "10.1214/aos/1013203451", 76 | fjournal = "The Annals of Statistics", 77 | journal = "Ann. Statist.", 78 | month = "10", 79 | number = "5", 80 | pages = "1189--1232", 81 | publisher = "The Institute of Mathematical Statistics", 82 | title = "Greedy function approximation: A gradient boosting machine.", 83 | url = "https://doi.org/10.1214/aos/1013203451", 84 | volume = "29", 85 | year = "2001" 86 | } 87 | 88 | @misc{pramit_choudhary_2018_1198885, 89 | author = {Pramit Choudhary and 90 | Aaron Kramer and 91 | datascience.com team, contributors}, 92 | title = {{Skater: Model Interpretation Library}}, 93 | month = mar, 94 | year = 2018, 95 | doi = {10.5281/zenodo.1198885}, 96 | url = {https://doi.org/10.5281/zenodo.1198885} 97 | } 98 | 99 | @article{gower, 100 | author = {Gower, John}, 101 | year = {1971}, 102 | month = {12}, 103 | pages = {857-871}, 104 | title = {A General Coefficient of Similarity and Some of Its Properties}, 105 | volume = {27}, 106 | journal = {Biometrics}, 107 | doi = {10.2307/2528823} 108 | } 109 | -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'pyCeterisParibus: explaining Machine Learning models with Ceteris Paribus Profiles in Python' 3 | tags: 4 | - Python 5 | - xAI 6 | - Machine Learning 7 | authors: 8 | - name: Michał Kuźba 9 | orcid: 0000-0002-9181-0126 10 | affiliation: "1, 2" 11 | - name: Ewa Baranowska 12 | orcid: 0000-0002-2278-4219 13 | affiliation: 1 14 | - name: Przemysław Biecek 15 | orcid: 0000-0001-8423-1823 16 | affiliation: "2, 1" 17 | affiliations: 18 | - name: Faculty of Mathematics and Information Science, Warsaw University of Technology 19 | index: 1 20 | - name: Faculty of Mathematics, Informatics, and Mechanics, University of Warsaw 21 | index: 2 22 | date: 18 March 2019 23 | bibliography: paper.bib 24 | --- 25 | 26 | # Introduction 27 | 28 | Machine learning is needed and used everywhere. It has fundamentally changed all data-driven disciplines, like health-care, biology, finance, legal, military, security, transportation, and many others. The increasing availability of large annotated data sources combined with recent developments in Machine Learning revolutionizes many disciplines. However, predictive models become more and more complex. It is not uncommon to have ensembles of predictive models with thousands or millions of parameters. Such models act as blackboxes. It is almost impossible for a human to understand reasons for model decisions. 29 | 30 | This lack of interpretability often leads to harmful situations. Models are not working properly or are hard to debug, results are biased in a systematic way, data drift leads to the deterioration in models performance, the model is wrong but no-one can explain what caused a wrong prediction. Many examples for these problems are listed in the *Weapons of Math Destruction*, a bestseller with an expressive subtitle *How Big Data Increases Inequality and Threatens Democracy* [@ONeil]. Reactions for some of these problems are new legal regulations, like the *General Data Protection Regulation and the Right to Explanation* [@RightToExpl3]. 31 | 32 | New tools are being created to support model interpretability. The most known methods are Local Interpretable Model-agnostic Explanations (LIME) [@LIME], SHapley Additive exPlanations (SHAP) [@SHAP] and Descriptive mAchine Learning EXplanations (DALEX) [@DALEX]. General purpose libraries for interpretable Machine Learning in Python are skater [@pramit_choudhary_2018_1198885] and [ELI5](https://eli5.readthedocs.io/en/latest/). 33 | 34 | An interesting alternative to these tools is the methodology of *Ceteris Paribus Profiles* and their averages called *Partial Dependency Plots*. They enable to understand how the model response would change if a selected variable is changed. It's a perfect tool for What-If scenarios. _Ceteris Paribus_ is a Latin phrase meaning _all else unchanged_. These plots present the change in model response as the values of one feature change with all others being fixed. Ceteris Paribus method is model-agnostic - it works for any Machine Learning model. 35 | The idea is an extension of *PDP* (Partial Dependency Plots) [@friedman2001] and *ICE* (Individual Conditional Expectations) plots [@GoldsteinICE]. It allows explaining single observations for multiple variables at the same time. 36 | 37 | In this paper, we introduce a `pyCeterisParibus` library for Python that supports a wide range of tools built on Ceteris Paribus Profiles. There might be several motivations behind utilizing this idea. Imagine a person gets a low credit score. The client wants to understand how to increase the score and the scoring institution (e.g., a bank) should be able to answer such questions. Moreover, this method is useful for researchers and developers to analyze, debug, explain and improve Machine Learning models, assisting the entire process of the model design. The more detailed demonstration is available in the *Examples* section. 38 | 39 | 40 | # pyCeterisParibus library 41 | `pyCeterisParibus` is a Python [library](https://github.com/ModelOriented/pyCeterisParibus) based on an *R* package *CeterisParibus* [@CeterisParibus]. It is an open source software released on the Apache license. 42 | 43 | The workflow consists of three steps: 44 | 45 | * wrapping models into a unified representation 46 | * calculating profiles for given observations, variables and models 47 | * plotting the profiles 48 | 49 | Plots are drawn using a separate *D3.js* [library](https://github.com/ModelOriented/ceterisParibusD3) which produces interactive Ceteris Paribus plots for both *Python* (this package) and [*R*](https://github.com/flaminka/ceterisParibusD3) implementations. 50 | This package allows explaining multiple observations and observing local behaviours of the model. For this purpose, methods for sampling and selecting neighbouring observations are implemented along with the *Gower's distance [@gower]* function. 51 | A more detailed description might be found in the package [documentation](https://pyceterisparibus.readthedocs.io/). 52 | 53 | # Examples 54 | We demonstrate Ceteris Paribus Plots using the well-known Titanic dataset. In this problem, we examine the chance of survival for Titanic passengers. 55 | We start with preprocessing the data and creating an XGBoost model. 56 | ```python 57 | import pandas as pd 58 | df = pd.read_csv('titanic_train.csv') 59 | 60 | y = df['Survived'] 61 | x = df.drop(['Survived', 'PassengerId', 'Name', 'Cabin', 'Ticket'], 62 | inplace=False, axis=1) 63 | 64 | valid = x['Age'].isnull() | x['Embarked'].isnull() 65 | x = x[-valid] 66 | y = y[-valid] 67 | 68 | from sklearn.model_selection import train_test_split 69 | X_train, X_test, y_train, y_test = train_test_split(x, y, 70 | test_size=0.2, random_state=42) 71 | ``` 72 | ```python 73 | from sklearn.pipeline import Pipeline 74 | from sklearn.preprocessing import StandardScaler, OneHotEncoder 75 | from sklearn.compose import ColumnTransformer 76 | 77 | # We create the preprocessing pipelines for both numeric and categorical data. 78 | numeric_features = ['Pclass', 'Age', 'SibSp', 'Parch', 'Fare'] 79 | numeric_transformer = Pipeline(steps=[ 80 | ('scaler', StandardScaler())]) 81 | 82 | categorical_features = ['Embarked', 'Sex'] 83 | categorical_transformer = Pipeline(steps=[ 84 | ('onehot', OneHotEncoder(handle_unknown='ignore'))]) 85 | 86 | preprocessor = ColumnTransformer( 87 | transformers=[ 88 | ('num', numeric_transformer, numeric_features), 89 | ('cat', categorical_transformer, categorical_features)]) 90 | ``` 91 | 92 | ```python 93 | from xgboost import XGBClassifier 94 | xgb_clf = Pipeline(steps=[('preprocessor', preprocessor), 95 | ('classifier', XGBClassifier())]) 96 | xgb_clf.fit(X_train, y_train) 97 | ``` 98 | 99 | Here the pyCeterisParibus starts. Since this library works in a model agnostic fashion, first we need to create a wrapper around the model with uniform predict interface. 100 | ```python 101 | from ceteris_paribus.explainer import explain 102 | explainer_xgb = explain(xgb_clf, data=x, y=y, label='XGBoost', 103 | predict_function=lambda X: xgb_clf.predict_proba(X)[::, 1]) 104 | ``` 105 | 106 | 107 | ### Single variable profile 108 | Let's look at Mr Ernest James Crease, the 19-year-old man, travelling on the 3rd class from Southampton with an 8 pounds ticket in his pocket. He died on the Titanic. Most likely, this would not have been the case had Ernest been a few years younger. 109 | Figure 1 presents the chance of survival for a person like Ernest at different ages. We can see things were tough for people like him unless they were a child. 110 | 111 | ```python 112 | ernest = X_test.iloc[10] 113 | label_ernest = y_test.iloc[10] 114 | from ceteris_paribus.profiles import individual_variable_profile 115 | cp_xgb = individual_variable_profile(explainer_xgb, ernest, label_ernest) 116 | ``` 117 | 118 | Having calculated the profile we can plot it. Note, that `plot_notebook` might be used instead of `plot` when used in Jupyter notebooks. 119 | 120 | ```python 121 | from ceteris_paribus.plots.plots import plot 122 | plot(cp_xgb, selected_variables=["Age"]) 123 | ``` 124 | 125 | ![Chance of survival depending on age](./img/figure1.png) 126 | 127 | ### Many models 128 | The above picture explains the prediction of XGBoost model. What if we compare various models? 129 | 130 | ```python 131 | from sklearn.ensemble import RandomForestClassifier 132 | from sklearn.linear_model import LogisticRegression 133 | rf_clf = Pipeline(steps=[('preprocessor', preprocessor), 134 | ('classifier', RandomForestClassifier())]) 135 | linear_clf = Pipeline(steps=[('preprocessor', preprocessor), 136 | ('classifier', LogisticRegression())]) 137 | 138 | rf_clf.fit(X_train, y_train) 139 | linear_clf.fit(X_train, y_train) 140 | 141 | explainer_rf = explain(rf_clf, data=x, y=y, label='RandomForest', 142 | predict_function=lambda X: rf_clf.predict_proba(X)[::, 1]) 143 | explainer_linear = explain(linear_clf, data=x, y=y, label='LogisticRegression', 144 | predict_function=lambda X: linear_clf.predict_proba(X)[::, 1]) 145 | 146 | plot(cp_xgb, cp_rf, cp_linear, selected_variables=["Age"]) 147 | ``` 148 | 149 | ![The probability of survival estimated with various models.](./img/figure2.png) 150 | 151 | Clearly, XGBoost offers a better fit than Logistic Regression (Figure 2). 152 | Also, it predicts a higher chance of survival at child's age than the Random Forest model does. 153 | 154 | ### Profiles for many variables 155 | This time we have a look at Miss. Elizabeth Mussey Eustis. She is 54 years old, travels at 1st class with her sister Marta, as they return to the US from their tour of southern Europe. They both survived the disaster. 156 | 157 | ```python 158 | elizabeth = X_test.iloc[1] 159 | label_elizabeth = y_test.iloc[1] 160 | cp_xgb_2 = individual_variable_profile(explainer_xgb, elizabeth, label_elizabeth) 161 | ``` 162 | 163 | ```python 164 | plot(cp_xgb_2, selected_variables=["Pclass", "Sex", "Age", "Embarked"]) 165 | ``` 166 | 167 | ![Profiles for many variables.](./img/figure3.png) 168 | 169 | Would she have returned home if she had travelled at 3rd class or if she had been a man? As we can observe (Figure 3) this is less likely. On the other hand, for a first class, female passenger chances of survival were high regardless of age. Note, this was different in the case of Ernest. Place of embarkment (Cherbourg) has no influence, which is expected behaviour. 170 | 171 | ### Feature interactions and average response 172 | Now, what if we look at passengers most similar to Miss. Eustis (middle-aged, upper class)? 173 | 174 | ```python 175 | from ceteris_paribus.select_data import select_neighbours 176 | neighbours = select_neighbours(X_train, elizabeth, 177 | selected_variables=['Pclass', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'], 178 | n=15) 179 | cp_xgb_ns = individual_variable_profile(explainer_xgb, neighbours) 180 | ``` 181 | 182 | ```python 183 | plot(cp_xgb_ns, color="Sex", selected_variables=["Pclass", "Age"], 184 | aggregate_profiles='mean', size_pdps=6, alpha_pdps=1, size=2) 185 | ``` 186 | 187 | ![Interaction with gender. Apart from charts with Ceteris Paribus Profiles (top of the visualisation), we can plot a table with observations used to calculate these profiles (bottom of the visualisation).](./img/figure4.png) 188 | 189 | There are two distinct clusters of passengers determined with their gender (Figure 4), therefore a *PDP* average plot (on grey) does not show the whole picture. Children of both genders were likely to survive, but then we see a large gap. Also, being female increased the chance of survival mostly for second and first class passengers. 190 | 191 | Plot function comes with extensive customization options. List of all parameters might be found in the documentation. Additionally, one can interact with the plot by hovering over a point of interest to see more details. Similarly, there is an interactive table with options for highlighting relevant elements as well as filtering and sorting rows. 192 | 193 | 194 | # Acknowledgements 195 | 196 | Michał Kuźba was financially supported by NCN Opus grant 2016/21/B/ST6/0217 , Ewa Baranowska was financially supported by NCN Opus grant 2017/27/B/ST6/01307. 197 | 198 | # References 199 | -------------------------------------------------------------------------------- /publish.sh: -------------------------------------------------------------------------------- 1 | # build and publish to 2 | # https://pypi.org/project/pyCeterisParibus/ 3 | rm -r dist/ 4 | python setup.py sdist bdist_wheel 5 | python -m twine upload dist/* 6 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | codecov>=2.0.15 2 | coverage>=4.5.2 3 | Keras>=2.2.4 4 | m2r==0.2.1 5 | pytest>=4.0.1 6 | pytest-cov>=2.6.0 7 | scikit-learn~=0.20.0 8 | sklearn==0.0 9 | Sphinx>=1.8.3 10 | tensorflow>=1.12.0 11 | xgboost>=0.82 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask>=1.0.2 2 | numpy>=1.15.4 3 | pandas>=0.23.4 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | PACKAGE_NAME = 'pyCeterisParibus' 4 | REQUIREMENTS_PATH = 'requirements.txt' 5 | README_PATH = 'README.md' 6 | VERSION = '0.5.2' 7 | 8 | 9 | def get_requirements(): 10 | with open(REQUIREMENTS_PATH, 'r') as f: 11 | requirements = f.read().splitlines() 12 | return requirements 13 | 14 | 15 | def get_readme(): 16 | with open(README_PATH, 'r') as f: 17 | return f.read() 18 | 19 | 20 | setup(name=PACKAGE_NAME, 21 | version=VERSION, 22 | description='Ceteris Paribus python package', 23 | long_description=get_readme(), 24 | long_description_content_type='text/markdown', 25 | url='https://github.com/ModelOriented/pyCeterisParibus', 26 | author='Michał Kuźba', 27 | author_email='michal.kuzba@students.mimuw.edu.pl', 28 | packages=find_packages(exclude=['examples', 'tests']), 29 | package_data={'ceteris_paribus': ['plots/ceterisParibusD3.js', 'plots/plot_template.html', 'datasets/*.csv']}, 30 | install_requires=get_requirements(), 31 | classifiers=[ 32 | 'Intended Audience :: Science/Research', 33 | 'Intended Audience :: Developers', 34 | 'License :: OSI Approved :: Apache Software License', 35 | 'Operating System :: OS Independent', 36 | 'Programming Language :: Python :: 3', 37 | 'Programming Language :: Python :: 3.5', 38 | 'Programming Language :: Python :: 3.6', 39 | 'Programming Language :: Python :: 3.7', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'Topic :: Scientific/Engineering :: Visualization' 42 | ]) 43 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for using the main function `individual_variable_profile` 3 | """ 4 | import os 5 | import unittest 6 | 7 | import pandas as pd 8 | from keras.layers import Dense, Activation 9 | from keras.models import Sequential 10 | from keras.wrappers.scikit_learn import KerasRegressor 11 | from sklearn import datasets, ensemble 12 | from sklearn.compose import ColumnTransformer 13 | from sklearn.datasets import load_iris 14 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 15 | from sklearn.model_selection import train_test_split 16 | from sklearn.pipeline import Pipeline 17 | from sklearn.preprocessing import StandardScaler, OneHotEncoder 18 | 19 | from ceteris_paribus.datasets import DATASETS_DIR 20 | from ceteris_paribus.explainer import explain 21 | from ceteris_paribus.profiles import CeterisParibus 22 | from ceteris_paribus.profiles import individual_variable_profile 23 | from ceteris_paribus.select_data import select_sample, select_neighbours 24 | 25 | 26 | def random_forest_classifier(X_train, y_train, var_names): 27 | rf_model = RandomForestClassifier(n_estimators=100, random_state=42) 28 | rf_model.fit(X_train, y_train) 29 | return rf_model, X_train, y_train, var_names 30 | 31 | 32 | class TestClassification(unittest.TestCase): 33 | 34 | def setUp(self): 35 | self.iris = load_iris() 36 | 37 | self.X = self.iris['data'] 38 | self.y = self.iris['target'] 39 | X_train, X_test, y_train, y_test = train_test_split(self.X, self.y, test_size=0.33, random_state=42) 40 | (model, data, labels, variable_names) = random_forest_classifier(X_train, y_train, 41 | list(self.iris['feature_names'])) 42 | predict_function = lambda X: model.predict_proba(X)[::, 0] 43 | self.explainer_rf = explain(model, variable_names, data, labels, predict_function=predict_function, 44 | label="rf_model") 45 | 46 | def test_iris_classification_1(self): 47 | cp_profile = individual_variable_profile(self.explainer_rf, self.X[1], y=self.y[1]) 48 | self.assertTrue(isinstance(cp_profile, CeterisParibus)) 49 | self.assertIsNotNone(cp_profile.new_observation_true) 50 | self.assertEqual(len(cp_profile._predict_function(self.X[:3])), 3) 51 | 52 | def test_iris_classification_2(self): 53 | grid_points = 3 54 | feature = self.iris['feature_names'][0] 55 | cp_profile = individual_variable_profile(self.explainer_rf, self.X[:10], variables=[feature], 56 | grid_points=grid_points) 57 | self.assertIn(feature, cp_profile.profile.columns) 58 | self.assertEqual(len(cp_profile.profile), 10 * grid_points) 59 | self.assertIsNone(cp_profile.new_observation_true) 60 | self.assertEqual(cp_profile.selected_variables, [feature]) 61 | 62 | def test_iris_classification_3(self): 63 | feature = self.iris['feature_names'][0] 64 | grid_points = 5 # should be ignored 65 | num_points = 3 66 | cp_profile = individual_variable_profile(self.explainer_rf, self.X[:num_points], y=self.y[:num_points], 67 | variables=[feature], grid_points=grid_points, 68 | variable_splits={feature: [10, 31]}) 69 | self.assertEqual(len(cp_profile.profile), num_points * 2) 70 | 71 | def test_iris_classification_4(self): 72 | X_data = pd.DataFrame(self.X[:20]) 73 | cp_profile = individual_variable_profile(self.explainer_rf, X_data, y=self.y[:20]) 74 | self.assertEqual(len(cp_profile.profile), len(self.iris['feature_names']) * 20 * cp_profile._grid_points) 75 | 76 | def test_iris_classification_5(self): 77 | feature = self.iris['feature_names'][0] 78 | X_data = pd.DataFrame(self.X[:20]) 79 | y_data = pd.DataFrame(self.y) 80 | cp_profile = individual_variable_profile(self.explainer_rf, X_data, y_data, variables=[feature]) 81 | self.assertEqual(len(cp_profile.profile), 20 * cp_profile._grid_points) 82 | self.assertLessEqual(max(cp_profile.profile['_yhat_']), 1) 83 | self.assertGreaterEqual(min(cp_profile.profile['_yhat_']), 0) 84 | 85 | def test_iris_classification_6(self): 86 | X_data = pd.DataFrame(self.X[5]).T 87 | cp_profile = individual_variable_profile(self.explainer_rf, X_data, self.y[5]) 88 | self.assertEqual(len(cp_profile.profile), len(self.iris['feature_names']) * cp_profile._grid_points) 89 | 90 | def test_iris_classification_7(self): 91 | X_data = pd.DataFrame(self.X[5]) 92 | with self.assertRaises(ValueError): 93 | cp_profile = individual_variable_profile(self.explainer_rf, X_data, self.y[5]) 94 | 95 | 96 | def random_forest_regression(X_train, y_train, var_names): 97 | # Create linear regression object 98 | rf_model = ensemble.RandomForestRegressor(n_estimators=100, random_state=42) 99 | 100 | # Train the model using the training set 101 | rf_model.fit(X_train, y_train) 102 | 103 | # model, data, labels, variable_names 104 | return rf_model, X_train, y_train, var_names 105 | 106 | 107 | class TestRegression(unittest.TestCase): 108 | 109 | def setUp(self): 110 | boston = datasets.load_boston() 111 | 112 | X = boston['data'] 113 | y = boston['target'] 114 | 115 | self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.33, random_state=42) 116 | 117 | (model, data, labels, self.variable_names) = random_forest_regression(self.X_train, self.y_train, 118 | list(boston['feature_names'])) 119 | self.explainer_rf = explain(model, self.variable_names, data, labels, label="rf_model") 120 | 121 | def test_regression_1(self): 122 | cp_profile = individual_variable_profile(self.explainer_rf, self.X_train[0], y=self.y_train[0], 123 | variables=['TAX', 'CRIM']) 124 | self.assertIsNotNone(cp_profile.new_observation_true) 125 | self.assertEqual(len(cp_profile.selected_variables), 2) 126 | self.assertEqual(len(cp_profile.profile), cp_profile._grid_points * 2) 127 | self.assertIn("TAX", cp_profile.profile.columns) 128 | 129 | def test_regression_2(self): 130 | n = 3 131 | sample = select_sample(self.X_train, n=n) 132 | cp2 = individual_variable_profile(self.explainer_rf, sample, variables=['TAX', 'CRIM']) 133 | self.assertEqual(len(cp2.profile), cp2._grid_points * 2 * n) 134 | 135 | def test_regression_3(self): 136 | variable_names = self.variable_names 137 | neighbours = select_neighbours(self.X_train, self.X_train[0], variable_names=variable_names, 138 | selected_variables=variable_names, n=15) 139 | cp3 = individual_variable_profile(self.explainer_rf, neighbours, variables=['LSTAT', 'RM'], 140 | variable_splits={'LSTAT': [10, 20, 30], 'RM': [4, 5, 6, 7]}) 141 | self.assertEqual(cp3.selected_variables, ['LSTAT', 'RM']) 142 | # num of different values in splits 143 | self.assertEqual(len(cp3.profile), 15 * 7) 144 | 145 | 146 | class TestKeras(unittest.TestCase): 147 | 148 | def setUp(self): 149 | boston = datasets.load_boston() 150 | x = boston.data 151 | y = boston.target 152 | x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) 153 | 154 | def network_architecture(): 155 | model = Sequential() 156 | model.add(Dense(640, input_dim=x.shape[1])) 157 | model.add(Activation('tanh')) 158 | model.add(Dense(320)) 159 | model.add(Activation('tanh')) 160 | model.add(Dense(1)) 161 | model.compile(loss='mean_squared_error', optimizer='adam') 162 | return model 163 | 164 | def keras_model(): 165 | estimators = [ 166 | ('scaler', StandardScaler()), 167 | ('mlp', KerasRegressor(build_fn=network_architecture, epochs=20)) 168 | ] 169 | model = Pipeline(estimators) 170 | model.fit(x_train, y_train) 171 | return model, x_train, y_train, boston.feature_names 172 | 173 | model, self.x_train, self.y_train, self.var_names = keras_model() 174 | self.explainer_keras = explain(model, self.var_names, self.x_train, self.y_train, label='KerasMLP') 175 | 176 | def test_keras_1(self): 177 | cp = individual_variable_profile(self.explainer_keras, self.x_train[:10], y=self.y_train[:10], 178 | variables=["CRIM", "ZN", "AGE", "INDUS", "B"]) 179 | self.assertEqual(len(cp.new_observation_true), 10) 180 | self.assertEqual(len(cp.profile), cp._grid_points * 5 * 10) 181 | 182 | def test_keras_2(self): 183 | cp = individual_variable_profile(self.explainer_keras, pd.DataFrame(self.x_train[:10]), 184 | y=list(self.y_train[:10]), 185 | variables=["CRIM", "ZN", "AGE", "INDUS", "B"]) 186 | self.assertEqual(len(cp.new_observation_true), 10) 187 | self.assertEqual(len(cp.profile), cp._grid_points * 5 * 10) 188 | 189 | def test_keras_3(self): 190 | cp = individual_variable_profile(self.explainer_keras, self.x_train[5], y=self.y_train[5], 191 | variables=["CRIM", "ZN", "AGE", "INDUS", "B"]) 192 | self.assertEqual(len(cp.new_observation_true), 1) 193 | self.assertEqual(len(cp.profile), cp._grid_points * 5) 194 | 195 | 196 | class TestCategorical(unittest.TestCase): 197 | def setUp(self): 198 | df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv')) 199 | 200 | self.x = df.drop(['charges'], inplace=False, axis=1) 201 | 202 | self.y = df['charges'] 203 | 204 | var_names = list(self.x) 205 | 206 | # We create the preprocessing pipelines for both numeric and categorical data. 207 | numeric_features = ['age', 'bmi', 'children'] 208 | numeric_transformer = Pipeline(steps=[ 209 | ('scaler', StandardScaler())]) 210 | 211 | categorical_features = ['sex', 'smoker', 'region'] 212 | categorical_transformer = Pipeline(steps=[ 213 | ('onehot', OneHotEncoder(handle_unknown='ignore'))]) 214 | 215 | preprocessor = ColumnTransformer( 216 | transformers=[ 217 | ('num', numeric_transformer, numeric_features), 218 | ('cat', categorical_transformer, categorical_features)]) 219 | 220 | # Append classifier to preprocessing pipeline. 221 | # Now we have a full prediction pipeline. 222 | clf = Pipeline(steps=[('preprocessor', preprocessor), 223 | ('classifier', RandomForestRegressor())]) 224 | 225 | clf.fit(self.x, self.y) 226 | 227 | self.explainer_cat = explain(clf, var_names, self.x, self.y, label="categorical_model") 228 | 229 | def test_categorical_1(self): 230 | cp = individual_variable_profile(self.explainer_cat, self.x.iloc[:10], self.y.iloc[:10]) 231 | self.assertEqual(len(cp.new_observation_true), 10) 232 | self.assertIn('female', list(cp.profile['sex'])) 233 | self.assertIn('sex', list(cp.profile['_vname_'])) 234 | -------------------------------------------------------------------------------- /tests/test_explain.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn import datasets, ensemble 7 | 8 | from ceteris_paribus.explainer import explain 9 | 10 | 11 | class TestExplain(unittest.TestCase): 12 | 13 | @classmethod 14 | def setUpClass(cls): 15 | boston = datasets.load_boston() 16 | 17 | cls.X = boston['data'] 18 | cls.y = boston['target'] 19 | 20 | cls.rf_model = ensemble.RandomForestRegressor(n_estimators=100, random_state=42) 21 | 22 | # Train the model using the training set 23 | cls.rf_model.fit(cls.X, cls.y) 24 | 25 | cls.var_names = list(boston['feature_names']) 26 | 27 | cls.df = pd.DataFrame.from_dict({'a': [1, 2], 'b': [3, 6]}) 28 | 29 | def test_explainer_1(self): 30 | model = MagicMock() 31 | delattr(model, 'predict') 32 | with self.assertRaises(ValueError) as c: 33 | explain(model, self.var_names) 34 | 35 | def test_explainer_2(self): 36 | model = MagicMock(predict=id) 37 | explainer = explain(model, data=pd.DataFrame()) 38 | self.assertEqual(explainer.predict_fun, id) 39 | 40 | def test_explainer_3(self): 41 | explainer = explain(self.rf_model, [], predict_function=sum) 42 | self.assertEqual(explainer.predict_fun, sum) 43 | 44 | def test_explainer_4(self): 45 | label = "xyz" 46 | explainer = explain(self.rf_model, [], label=label) 47 | self.assertEqual(explainer.label, label) 48 | 49 | def test_explainer_5(self): 50 | # raises warning 51 | explainer = explain(self.rf_model, []) 52 | self.assertEqual(explainer.label, "RandomForestRegressor") 53 | 54 | def test_explainer_6(self): 55 | model = MagicMock() 56 | model.__str__.return_value = 'xyz' 57 | # raises warning 58 | explainer = explain(model, []) 59 | self.assertEqual(explainer.label, "unlabeled_model") 60 | 61 | def test_explainer_7(self): 62 | # no labels given 63 | with self.assertRaises(ValueError) as c: 64 | explainer = explain(self.rf_model) 65 | 66 | def test_explainer_8(self): 67 | # labels imputed from the dataframe 68 | explainer = explain(self.rf_model, data=self.df) 69 | self.assertEqual(explainer.var_names, ['a', 'b']) 70 | 71 | def test_explainer_9(self): 72 | explainer = explain(self.rf_model, variable_names=["a", "b", "c"], y=[1, 2, 3]) 73 | np.testing.assert_array_equal(explainer.y, pd.Series([1, 2, 3])) 74 | 75 | def test_explainer_10(self): 76 | explainer = explain(self.rf_model, variable_names=["a", "b"], y=np.array([1, 4])) 77 | np.testing.assert_array_equal(explainer.y, pd.Series([1, 4])) 78 | 79 | def test_explainer_11(self): 80 | explainer = explain(self.rf_model, variable_names=["a", "b"], y=pd.DataFrame(np.array([1, 4]))) 81 | np.testing.assert_array_equal(explainer.y, pd.Series([1, 4])) 82 | 83 | def test_explainer_12(self): 84 | # data from dataframe 85 | explainer = explain(self.rf_model, data=self.df) 86 | np.testing.assert_array_equal(explainer.data, self.df) 87 | 88 | def test_explainer_13(self): 89 | # data from numpy array 90 | explainer = explain(self.rf_model, variable_names=["a", "b"], data=self.df.values) 91 | np.testing.assert_array_equal(explainer.data, self.df) 92 | 93 | def test_explainer_14(self): 94 | # data for one observation - 1D array 95 | explainer = explain(self.rf_model, variable_names=["a", "b"], data=np.array(["cc", "dd"])) 96 | np.testing.assert_array_equal(explainer.data, pd.DataFrame.from_dict({"a": ["cc"], "b": ["dd"]})) 97 | 98 | def test_explainer_15(self): 99 | # wrong number of variables 100 | with self.assertRaises(ValueError): 101 | explainer = explain(self.rf_model, variable_names=["a", "b", "c"], data=self.df.values) 102 | 103 | def test_explainer_16(self): 104 | # predict function for array 105 | explainer = explain(self.rf_model, variable_names=self.var_names, data=self.X[:10], y=self.y[:10]) 106 | self.assertEqual(len(explainer.predict_fun(pd.DataFrame(self.X[:10]))), 10) 107 | 108 | def test_explainer_17(self): 109 | # predict function for dataframe 110 | boston_df = pd.DataFrame(self.X[:10]) 111 | explainer = explain(self.rf_model, data=boston_df) 112 | self.assertEqual(len(explainer.predict_fun(boston_df)), 10) 113 | -------------------------------------------------------------------------------- /tests/test_plots.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock 3 | 4 | from ceteris_paribus.plots.plots import _calculate_plot_variables, _params_update, _get_data_paths 5 | 6 | 7 | class TestPlots(unittest.TestCase): 8 | 9 | def test_calculate_plot_variables(self): 10 | profile = MagicMock(selected_variables=["a", "b", "c"]) 11 | self.assertEqual(_calculate_plot_variables(profile, ["b", "c"]), ["b", "c"]) 12 | 13 | def test_calculate_plot_variables_2(self): 14 | profile = MagicMock(selected_variables=["a", "b", "c"]) 15 | self.assertEqual(_calculate_plot_variables(profile, None), ["a", "b", "c"]) 16 | 17 | def test_calculate_plot_variables_3(self): 18 | profile = MagicMock(selected_variables=["a", "b", "c"]) 19 | # prints warning 20 | self.assertEqual(_calculate_plot_variables(profile, ["a", "d"]), ["a", "b", "c"]) 21 | 22 | def test_params_update(self): 23 | self.assertEqual(_params_update({}, a=1, b=None, c=13), {'a': 1, 'c': 13}) 24 | 25 | def test_data_paths(self): 26 | plot_path, params_path, obs_path, profile_path = _get_data_paths(5) 27 | self.assertTrue(plot_path.endswith('.html')) 28 | self.assertTrue(params_path.endswith('.js')) 29 | self.assertTrue(obs_path.endswith('.js')) 30 | self.assertTrue(profile_path.endswith('.js')) 31 | -------------------------------------------------------------------------------- /tests/test_profiles.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from collections import OrderedDict 4 | from unittest.mock import MagicMock 5 | 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from ceteris_paribus.profiles import _get_variables, CeterisParibus, _valid_variable_splits 10 | from ceteris_paribus.utils import dump_profiles, dump_observations, transform_into_Series, save_profiles, \ 11 | save_observations 12 | 13 | 14 | class TestProfiles(unittest.TestCase): 15 | 16 | def setUp(self): 17 | self.cp = object.__new__(CeterisParibus) 18 | 19 | def test_get_variables(self): 20 | explainer = MagicMock(var_names=["c", "a", "b"]) 21 | variables = ["b", "a"] 22 | self.assertEqual(_get_variables(variables, explainer), variables) 23 | 24 | def test_get_variables_2(self): 25 | explainer = MagicMock(var_names=["c", "a", "b"]) 26 | variables = ["a", "a"] 27 | self.assertEqual(_get_variables(variables, explainer), variables) 28 | 29 | def test_get_variables_3(self): 30 | explainer = MagicMock(var_names=["a", "b"]) 31 | variables = ["c", "a"] 32 | with self.assertRaises(ValueError): 33 | _get_variables(variables, explainer) 34 | 35 | def test_get_variables_4(self): 36 | variables = ["c", "a", "b"] 37 | explainer = MagicMock(var_names=variables) 38 | self.assertEqual(_get_variables(None, explainer), variables) 39 | 40 | def test_calculate_single_split(self): 41 | X_var = np.array([1, 2, 1, 2, 1]) 42 | np.testing.assert_array_equal(self.cp._calculate_single_split(X_var), np.array([1, 2])) 43 | 44 | def test_calculate_single_split_2(self): 45 | np.random.seed(42) 46 | X_var = np.random.random(100) 47 | self.cp._grid_points = 5 48 | splits = self.cp._calculate_single_split(X_var) 49 | self.assertEqual(splits.shape, (self.cp._grid_points,)) 50 | 51 | def test_calculate_single_split_3(self): 52 | X_var = np.array([1, 1.2, 3, 1.7]) 53 | self.cp._grid_points = 2 54 | splits = self.cp._calculate_single_split(X_var) 55 | np.testing.assert_array_equal(splits, np.array([1, 3])) 56 | 57 | def test_calculate_single_split_4(self): 58 | np.random.seed(42) 59 | X_var = np.random.random(1000) 60 | self.cp._grid_points = 200 61 | splits = self.cp._calculate_single_split(X_var) 62 | self.assertEqual(len(splits), len(set(splits))) 63 | np.testing.assert_array_equal(sorted(splits), splits) 64 | 65 | def test_calculate_variable_splits(self): 66 | self.cp._grid_points = 4 67 | chosen_variables_dict = { 68 | 'a': np.array([1, 2, 1, 2]), 69 | 'b': np.array([1., 2., 3., 4., 5.5, 7.2]) 70 | } 71 | splits_dict = self.cp._calculate_variable_splits(chosen_variables_dict) 72 | self.assertEqual(len(splits_dict), 2) 73 | self.assertEqual(len(splits_dict['b']), self.cp._grid_points) 74 | np.testing.assert_array_equal(splits_dict['a'], [1, 2]) 75 | 76 | def test_single_observation_df(self): 77 | self.cp._label = "xyz" 78 | self.cp.all_variable_names = ["a", "b", "c"] 79 | self.cp._predict_function = lambda df: df.sum(axis=1) 80 | splits = np.array([1, 3, 15]) 81 | observation_df = self.cp._single_observation_df(np.array([1, 2, 10]), "a", splits, profile_id=42) 82 | self.assertEqual(set(observation_df.columns), {"a", "b", "c", "_yhat_", "_vname_", "_label_", "_ids_"}) 83 | np.testing.assert_array_equal(observation_df["_yhat_"], [13, 15, 27]) 84 | np.testing.assert_array_equal(observation_df["a"], splits) 85 | np.testing.assert_array_equal(observation_df["b"], [2] * 3) 86 | 87 | def test_single_variable_df(self): 88 | self.cp._label = "xyz" 89 | self.cp.all_variable_names = ["a", "b", "c"] 90 | self.cp._predict_function = lambda df: df.sum(axis=1) 91 | splits = np.array([1, 3, 15]) 92 | self.cp.new_observation = pd.DataFrame(np.array([[1, 2, 10]])) 93 | variable_df = self.cp._single_variable_df("a", splits) 94 | self.assertEqual(set(variable_df.columns), {"a", "b", "c", "_yhat_", "_vname_", "_label_", "_ids_"}) 95 | np.testing.assert_array_equal(variable_df["_yhat_"], [13, 15, 27]) 96 | np.testing.assert_array_equal(variable_df["a"], splits) 97 | np.testing.assert_array_equal(variable_df["b"], [2] * 3) 98 | np.testing.assert_array_equal(variable_df["_ids_"], [0] * 3) 99 | 100 | def test_set_label(self): 101 | label = "xyzabc" 102 | self.cp.set_label(label) 103 | self.assertEqual(self.cp._label, label) 104 | 105 | def test_valid_variable_splits_1(self): 106 | var_splits = {"a": [1, 2], "b": [4]} 107 | self.assertTrue(_valid_variable_splits(var_splits, ["a", "b"])) 108 | # expect warning here 109 | self.assertFalse(_valid_variable_splits(var_splits, ["c", "a", "b"])) 110 | 111 | def test_get_variable_splits_1(self): 112 | self.cp.selected_variables = ["a", "c"] 113 | var_splits = {"a": [1, 2], "c": [3]} 114 | self.assertEqual(var_splits, self.cp._get_variable_splits(var_splits)) 115 | 116 | def test_get_variable_splits_2(self): 117 | self.cp.selected_variables = ["b", "c"] 118 | self.cp._data = pd.DataFrame.from_dict({ 119 | "a": [1, 2, 4, 2], 120 | "b": ["a", "x", "a", "c"], 121 | "c": [1.21, 1.45, 1.72, 1.9132] 122 | }) 123 | self.cp._grid_points = 4 124 | var_splits = self.cp._get_variable_splits(None) 125 | self.assertEqual(len(var_splits["b"]), 3) 126 | self.assertEqual(len(var_splits["c"]), 4) 127 | 128 | def test_calculate_profile(self): 129 | self.cp._single_variable_df = lambda var_name, var_split: pd.DataFrame({"c": [5], "d": [7]}) 130 | var_splits = {"a": [2, 5, 2], "b": [3, 6]} 131 | self.assertEqual(self.cp._calculate_profile(var_splits).shape, (2, 2)) 132 | 133 | 134 | class TestProfilesUtils(unittest.TestCase): 135 | 136 | def setUp(self): 137 | self.cp1 = object.__new__(CeterisParibus) 138 | 139 | def test_dump_profiles_1(self): 140 | self.cp1.profile = pd.DataFrame() 141 | # one empty profile 142 | observations = dump_profiles([self.cp1]) 143 | self.assertEqual(observations, []) 144 | 145 | def test_dump_profiles_2(self): 146 | records = [{ 147 | "age": 19.0, 148 | "_vname_": "children", 149 | "_yhat_": 19531.877043934743, 150 | "_label_": "GradientBoostingRegressor", 151 | "children": 4.0, 152 | "_ids_": 0, 153 | "bmi": 27.9 154 | }, 155 | { 156 | "age": 19.0, 157 | "_vname_": "children", 158 | "_yhat_": 19531.877043934743, 159 | "_label_": "GradientBoostingRegressor", 160 | "children": 4.0, 161 | "_ids_": 0, 162 | "bmi": 27.9 163 | }] 164 | 165 | self.cp1.profile = pd.DataFrame.from_records(records) 166 | self.assertEqual(records, dump_profiles([self.cp1])) 167 | 168 | def test_dump_profiles_3(self): 169 | records = [{ 170 | "age": 19.0, 171 | "_vname_": "children", 172 | "_yhat_": 19531.877043934743, 173 | "_label_": "GradientBoostingRegressor", 174 | "children": 4.0, 175 | "_ids_": 0, 176 | "bmi": 27.9 177 | }, 178 | { 179 | "age": 19.0, 180 | "_vname_": "children", 181 | "_yhat_": 19531.877043934743, 182 | "_label_": "GradientBoostingRegressor", 183 | "children": 4.0, 184 | "_ids_": 0, 185 | "bmi": 27.9 186 | }] 187 | self.cp1.profile = pd.DataFrame.from_records(records) 188 | self.assertEqual(records + records + records, dump_profiles([self.cp1, self.cp1, self.cp1])) 189 | 190 | def test_dump_observations(self): 191 | self.cp1.new_observation = pd.DataFrame.from_dict({ 192 | "a": [1.2, 3.4, 2.6], 193 | "c": [4, 5, 12] 194 | }) 195 | self.cp1.all_variable_names = list(self.cp1.new_observation.columns) 196 | self.cp1.selected_variables = ["a", "c"] 197 | self.cp1.new_observation_predictions = [12, 3, 6] 198 | self.cp1._label = "some_label" 199 | # true values not given 200 | self.cp1.new_observation_true = None 201 | observations = dump_observations([self.cp1]) 202 | self.assertEqual(len(observations), 3) 203 | self.assertEqual(observations[2]["_label_"], "some_label") 204 | self.assertEqual(observations[1]["_y_"], None) 205 | 206 | def test_dump_observations_2(self): 207 | self.cp1.new_observation = pd.DataFrame.from_dict({ 208 | "a": [1.2, 3.4, 2.6], 209 | "c": [4, 5, 12] 210 | }) 211 | self.cp1.all_variable_names = list(self.cp1.new_observation.columns) 212 | self.cp1.selected_variables = ["a"] 213 | self.cp1.new_observation_predictions = [12, 3, 6] 214 | self.cp1._label = "some_label" 215 | self.cp1.new_observation_true = [13, 4, 5] 216 | observations = dump_observations([self.cp1, self.cp1]) 217 | self.assertEqual(len(observations), 6) 218 | self.assertEqual(observations[1]["_y_"], 4) 219 | 220 | def test_transform_into_Series_1(self): 221 | a = [1, 4, 2] 222 | b = transform_into_Series(a) 223 | np.testing.assert_array_equal(a, b) 224 | 225 | def test_transform_into_Series_2(self): 226 | a = np.array([4, 1, 6]) 227 | b = transform_into_Series(a) 228 | np.testing.assert_array_equal(a, b) 229 | 230 | def test_transform_into_Series_3(self): 231 | a = pd.DataFrame(OrderedDict(zip(['a', 'b'], [[1, 2, 3], [4, 2, 1]]))) 232 | b = transform_into_Series(a) 233 | np.testing.assert_array_equal(b, [1, 2, 3]) 234 | 235 | def test_save_observations(self): 236 | self.cp1.new_observation = pd.DataFrame.from_dict({ 237 | "a": [1.2, 3.4, 2.6], 238 | "c": [4, 5, 12] 239 | }) 240 | self.cp1.all_variable_names = list(self.cp1.new_observation.columns) 241 | self.cp1.selected_variables = ["a", "c"] 242 | self.cp1.new_observation_predictions = np.array([12, 3, 6], dtype=np.int64) 243 | self.cp1._label = "some_label" 244 | # true values not given 245 | self.cp1.new_observation_true = None 246 | filename = '_tmp_file_' 247 | save_observations([self.cp1], filename) 248 | with open(filename, 'r') as f: 249 | self.assertTrue(f.read().startswith('observation =')) 250 | os.remove(filename) 251 | 252 | def test_save_profiles(self): 253 | records = [{ 254 | "age": 19.0, 255 | "_vname_": "children", 256 | "_yhat_": 19531.877043934743, 257 | "_label_": "GradientBoostingRegressor", 258 | "children": 4.0, 259 | "_ids_": 0, 260 | "bmi": 27.9 261 | }, 262 | { 263 | "age": 19.0, 264 | "_vname_": "children", 265 | "_yhat_": 19531.877043934743, 266 | "_label_": "GradientBoostingRegressor", 267 | "children": 4.0, 268 | "_ids_": 0, 269 | "bmi": 27.9 270 | }] 271 | 272 | self.cp1.profile = pd.DataFrame.from_records(records) 273 | self.cp1.profile['_yhat_'] = self.cp1.profile['_yhat_'].astype(np.float128) 274 | filename = '_tmp_file2_' 275 | save_profiles([self.cp1], filename) 276 | with open(filename, 'r') as f: 277 | self.assertTrue(f.read().startswith('profile =')) 278 | os.remove(filename) 279 | -------------------------------------------------------------------------------- /tests/test_select.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.metrics.pairwise import euclidean_distances 7 | 8 | from ceteris_paribus.gower import _normalize_mixed_data_columns, gower_distances, _calc_range_mixed_data_columns, \ 9 | _gower_dist 10 | from ceteris_paribus.select_data import select_sample, select_neighbours, _select_columns 11 | 12 | 13 | class TestSelect(unittest.TestCase): 14 | 15 | def setUp(self): 16 | self.x = np.array([[1, 0, 1], [2, 0, 2], [10, 2, 10], [1, 0, 2]]) 17 | self.y = np.array([11, 12.1, 13.2, 10.5]) 18 | 19 | def test_select_sample(self): 20 | (_, m) = self.x.shape 21 | size = 2 22 | sample = select_sample(self.x, n=size) 23 | self.assertEqual(sample.shape, (size, m)) 24 | 25 | def test_select_sample_2(self): 26 | sample = select_sample(self.x, n=1) 27 | self.assertIn(sample[0], self.x) 28 | 29 | def test_select_sample_3(self): 30 | sample_x, sample_y = select_sample(self.x, self.y, n=1) 31 | pos = list(self.y).index(sample_y[0]) 32 | self.assertSequenceEqual(list(sample_x[0]), list(self.x[pos])) 33 | 34 | def test_select_sample_4(self): 35 | sample_x = select_sample(self.x, n=300) 36 | self.assertEqual(len(sample_x), len(self.x)) 37 | 38 | def test_select_sample_5(self): 39 | sample_x = select_sample(self.x, n=300) 40 | sample_x_2 = select_sample(pd.DataFrame(self.x), n=300) 41 | np.testing.assert_array_equal(sample_x, sample_x_2) 42 | 43 | def test_select_sample_6(self): 44 | sample_x, sample_y = select_sample(pd.DataFrame(self.x), pd.DataFrame(self.y), n=1) 45 | pos = list(self.y).index(sample_y[0]) 46 | self.assertSequenceEqual(list(sample_x.iloc[0]), list(self.x[pos])) 47 | 48 | def test_select_neighbours(self): 49 | neighbours = select_neighbours(self.x, self.x[0], dist_fun=euclidean_distances, n=1) 50 | neighbours2 = select_neighbours(self.x, self.x[0], dist_fun='gower', n=1) 51 | self.assertSequenceEqual(list(neighbours.iloc[0]), list(self.x[0])) 52 | self.assertSequenceEqual(list(neighbours2.iloc[0]), list(self.x[0])) 53 | 54 | def test_select_neighbours_2(self): 55 | (_, m) = self.x.shape 56 | size = 3 57 | neighbours = select_neighbours(self.x, np.array([4, 3, 2]), dist_fun=euclidean_distances, n=size) 58 | self.assertEqual(neighbours.shape, (size, m)) 59 | neighbours2 = select_neighbours(self.x, np.array([4, 3, 2]), dist_fun='gower', n=size) 60 | self.assertEqual(neighbours2.shape, (size, m)) 61 | 62 | def test_select_neighbours_3(self): 63 | sample_x, sample_y = select_neighbours(self.x, np.array([4, 3, 2]), y=self.y, n=3) 64 | pos = list(self.y).index(sample_y[1]) 65 | self.assertSequenceEqual(list(sample_x.iloc[1]), list(self.x[pos])) 66 | 67 | def test_select_neighbours_4(self): 68 | # it logs warning 69 | sample_x = select_neighbours(self.x, np.array([4, 3, 2]), n=300) 70 | self.assertEqual(len(sample_x), len(self.x)) 71 | 72 | def test_select_neighbours_5(self): 73 | # wrong distance function given 74 | with self.assertRaises(ValueError) as c: 75 | select_neighbours(self.x, np.array([4, 3, 2]), n=1, dist_fun='euclidean') 76 | 77 | def test_select_neighbours_6(self): 78 | sample_x = select_neighbours(pd.DataFrame(self.x), np.array([4, 3, 2]), n=300) 79 | self.assertEqual(len(sample_x), len(self.x)) 80 | 81 | def test_select_neighbours_7(self): 82 | sample_x = select_neighbours(pd.DataFrame(self.x, columns=['a', 'b', 'c']), [4, 1, 5], n=2, 83 | selected_variables=['a', 'b']) 84 | self.assertEqual(sample_x.shape, (2, 3)) 85 | 86 | def test_select_neighbours_8(self): 87 | sample_x = select_neighbours(pd.DataFrame(self.x, columns=['a', 'b', 'c']), [4, 1, 5], n=10, 88 | selected_variables=['a', 'd']) 89 | sample_x2 = select_neighbours(pd.DataFrame(self.x), [4, 1, 5], n=10) 90 | np.testing.assert_array_equal(sample_x, sample_x2) 91 | 92 | def test_select_neighbours_9(self): 93 | sample_x = select_neighbours(pd.DataFrame(self.x, columns=['a', 'b', 'c']), [4, 1, 5], n=10, 94 | variable_names=['a', 'b', 'c'], 95 | selected_variables=['a', 'd']) 96 | sample_x2 = select_neighbours(pd.DataFrame(self.x), [4, 1, 5], n=10) 97 | np.testing.assert_array_equal(sample_x, sample_x2) 98 | 99 | def test_select_neighbours_10(self): 100 | df = pd.DataFrame({'a': list(range(100)), 'b': 11, 'c': np.arange(0, 200, 2) / 7}) 101 | y = pd.Series(range(100)) 102 | sample_x, sample_y = select_neighbours(df, [3, 11, 7.4], y, n=5) 103 | self.assertEqual(sample_x.shape, (5, 3)) 104 | self.assertEqual(len(sample_y), 5) 105 | np.testing.assert_array_equal(sample_x['a'], sample_y) 106 | 107 | 108 | @staticmethod 109 | def select_columns_helper(true, result): 110 | np.testing.assert_array_equal(true[0], result[0]) 111 | np.testing.assert_array_equal(true[1], result[1]) 112 | 113 | def test_select_columns_1(self): 114 | observation = self.x[0] 115 | self.select_columns_helper((pd.DataFrame(self.x), pd.Series(observation)), _select_columns(self.x, observation)) 116 | 117 | def test_select_columns_2(self): 118 | observation = self.x[0] 119 | variables = ['var1', 'var2', 'var3'] 120 | self.select_columns_helper((pd.DataFrame(self.x), pd.Series(observation)), 121 | _select_columns(pd.DataFrame(self.x), pd.Series(observation), variables, variables)) 122 | 123 | def test_select_columns_3(self): 124 | observation = self.x[0] 125 | variables = ['var1', 'var2', 'var3'] 126 | selected_variables = ['var3', 'var2'] 127 | subset = _select_columns(pd.DataFrame(self.x), pd.Series(observation), variable_names=variables, 128 | selected_variables=selected_variables) 129 | self.select_columns_helper(subset, (self.x[:, [2, 1]], observation[[2, 1]])) 130 | 131 | def test_select_columns_4(self): 132 | # warning expected 133 | observation = self.x[0] 134 | variables = ['var1', 'var2', 'var3'] 135 | # selection of invalid variable 136 | selected_variables = ['var4', 'var2'] 137 | subset = _select_columns(pd.DataFrame(self.x), pd.Series(observation), variable_names=variables, 138 | selected_variables=selected_variables) 139 | self.select_columns_helper(subset, (self.x, observation)) 140 | 141 | 142 | class TestGower(unittest.TestCase): 143 | 144 | def setUp(self): 145 | X_items = [ 146 | ('age', [21, 21, 19, 30, 21, 21, 19, 30]), 147 | ('gender', ['M', 'M', 'N', 'M', 'F', 'F', 'F', 'F']), 148 | ('civil_status', ['MARRIED', 'SINGLE', 'SINGLE', 'SINGLE', 'MARRIED', 'SINGLE', 'WIDOW', 'DIVORCED']), 149 | ('salary', [3000.0, 1200.0, 32000.0, 1800.0, 2900.0, 1100.0, 10000.0, 1500.0]), 150 | ('children', [True, False, True, True, True, False, False, True]), 151 | ('available_credit', [2200, 100, 22000, 1100, 2000, 100, 6000, 2200]) 152 | ] 153 | self.X = pd.DataFrame.from_dict(OrderedDict(X_items)) 154 | self.arr = _normalize_mixed_data_columns(self.X) 155 | self.observation = [22, 'F', 'DIVORCED', 2000, False, 1000] 156 | self.observation_missing = [22, np.nan, np.nan, 2000, False, 1000] 157 | self.observation = _normalize_mixed_data_columns(self.observation) 158 | self.first = _normalize_mixed_data_columns(self.X.iloc[0]) 159 | self.ranges = _calc_range_mixed_data_columns(self.arr, self.observation, self.X.dtypes) 160 | 161 | def test_normalize_1(self): 162 | self.assertEqual(self.X.shape, self.arr.shape) 163 | 164 | def test_ranges_1(self): 165 | np.testing.assert_array_almost_equal(self.ranges, np.array([11, 0, 0, 30900, 0, 21900]), decimal=2) 166 | 167 | def test_gower_distances_1(self): 168 | distances = gower_distances(self.X, self.X.iloc[0]) 169 | np.testing.assert_array_almost_equal(distances, 170 | np.array([0, 0.3590, 0.6707, 0.3178, 0.1687, 0.5262, 0.5969, 0.4777]), 171 | decimal=3) 172 | 173 | def test_gower_dist_1(self): 174 | # test dist(a, a) == 0 175 | distance = _gower_dist(self.observation, self.observation, self.ranges, self.X.dtypes) 176 | distance2 = _gower_dist(self.observation_missing, self.observation_missing, self.ranges, self.X.dtypes) 177 | self.assertEqual(distance, 0.0) 178 | self.assertEqual(distance2, 0.0) 179 | 180 | def test_gower_dist_2(self): 181 | # test symmetry 182 | distance1 = _gower_dist(self.observation, self.first, self.ranges, self.X.dtypes) 183 | distance2 = _gower_dist(self.first, self.observation, self.ranges, self.X.dtypes) 184 | self.assertAlmostEqual(distance1, distance2, delta=0.0001) 185 | 186 | def test_gower_dist_3(self): 187 | distance = _gower_dist(self.observation, self.first, self.ranges, self.X.dtypes) 188 | self.assertAlmostEqual(distance, 0.52967, delta=0.0001) 189 | 190 | def test_gower_dist_4(self): 191 | # test with missing values 192 | X_with_nans = self.X.append({'age': 21, 'children': True}, ignore_index=True) 193 | dtypes = X_with_nans.dtypes 194 | X_with_nans = _normalize_mixed_data_columns(X_with_nans) 195 | ranges = _calc_range_mixed_data_columns(X_with_nans, self.observation, dtypes) 196 | distance = _gower_dist(X_with_nans[-1], self.observation, ranges, dtypes) 197 | self.assertAlmostEqual(distance, 0.7727, delta=0.0001) 198 | --------------------------------------------------------------------------------