├── .gitattributes
├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── .gitignore
├── .travis.yml
├── .zenodo.json
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── ceteris_paribus
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── insurance.csv
│ └── titanic_train.csv
├── explainer.py
├── gower.py
├── plots
│ ├── __init__.py
│ ├── ceterisParibusD3.js
│ ├── plot_template.html
│ └── plots.py
├── profiles.py
├── select_data.py
└── utils.py
├── codecov.yml
├── docs
├── Makefile
├── ceteris_paribus.plots.rst
├── ceteris_paribus.rst
├── conf.py
├── index.rst
├── make.bat
├── misc
├── modules.rst
└── readme_include.rst
├── examples
├── __init__.py
├── categorical_variables.py
├── cheatsheet.py
├── classification_example.py
├── keras_example.py
├── multiple_models.py
└── regression_example.py
├── jupyter-notebooks
├── CeterisParibusCheatsheet.ipynb
├── __init__.py
└── titanic.ipynb
├── misc
├── multiclass_models.png
├── titanic_interactions_average.png
├── titanic_many_models.png
├── titanic_many_variables.png
└── titanic_single_response.png
├── paper
├── img
│ ├── figure1.png
│ ├── figure2.png
│ ├── figure3.png
│ └── figure4.png
├── paper.bib
└── paper.md
├── publish.sh
├── requirements-dev.txt
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
├── test_api.py
├── test_explain.py
├── test_plots.py
├── test_profiles.py
└── test_select.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | ceteris_paribus/plots/ceterisParibusD3.js linguist-vendored
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Attach minimal code reproducing the bug
15 |
16 | **Expected behavior**
17 | A clear and concise description of what you expected to happen.
18 |
19 | **Screenshots**
20 | If applicable, add screenshots to help explain your problem.
21 |
22 | **Desktop (please complete the following information):**
23 | - Version [e.g. 0.0.1]
24 | - Browser [e.g. chrome, safari] (if applicable, e.g. plotting)
25 | - Jupyter/Console (if relevant)
26 |
27 | **Additional context**
28 | Add any other context about the problem here.
29 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 | .coverage
4 | *.pyc
5 |
6 | ceteris_paribus/plots/obs*.js
7 | ceteris_paribus/plots/params*.js
8 | ceteris_paribus/plots/profile*.js
9 | ceteris_paribus/plots/plots*.html
10 |
11 | /docs/_build/
12 |
13 | **/.ipynb_checkpoints
14 | /build/
15 | /pyCeterisParibus.egg-info/
16 | /dist/
17 |
18 | # data files produced for plots
19 | **/_plot_files
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | matrix:
3 | include:
4 | - os: linux
5 | dist: trusty
6 | python: '3.6'
7 | - os: linux
8 | dist: trusty
9 | python: '3.5'
10 | - os: linux
11 | dist: xenial
12 | python: '3.7'
13 | - os: linux
14 | dist: xenial
15 | python: '3.6'
16 | - os: linux
17 | dist: xenial
18 | python: '3.5'
19 |
20 | # command to install dependencies
21 | install:
22 | - pip install -r requirements.txt
23 | - pip install -r requirements-dev.txt
24 | script:
25 | - python -m pytest --cov=./
26 | after_success:
27 | - codecov
28 |
--------------------------------------------------------------------------------
/.zenodo.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": "Python library for Ceteris Paribus Plots (What-if plots)",
3 | "license": "other-open",
4 | "title": "pyCeterisParibus: explaining Machine Learning models with Ceteris Paribus Profiles in Python",
5 | "upload_type": "software",
6 | "creators": [
7 | {
8 | "affiliation": "Faculty of Mathematics and Information Science, Warsaw University of Technology, Faculty of Mathematics, Informatics, and Mechanics, University of Warsaw",
9 | "name": "Michał Kuźba",
10 | "orcid": "0000-0002-9181-0126"
11 | },
12 | {
13 | "affiliation": "Faculty of Mathematics and Information Science, Warsaw University of Technology",
14 | "name": "Ewa Baranowska",
15 | "orcid": "0000-0002-2278-4219"
16 | },
17 | {
18 | "affiliation": "Faculty of Mathematics and Information Science, Warsaw University of Technology, Faculty of Mathematics, Informatics, and Mechanics, University of Warsaw",
19 | "name": "Przemysław Biecek",
20 | "orcid": "0000-0001-8423-1823"
21 | }
22 | ],
23 | "access_right": "open"
24 | }
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## Contribution
2 | Great to see you here. Thank you for considering contribution to pyCeterisParibus. Let's make it better together!
3 |
4 | ### Reporting bugs
5 | Have you found a bug?
6 | * Make sure you have the **latest version** of the package.
7 | * **Check whether it is not already in** [Issues](https://github.com/ModelOriented/pyCeterisParibus/issues).
8 | * Add an issue following **Bug report template**.
9 |
10 | ### Fixing bugs
11 | Do you know how to fix the bug? You are more than welcome to do this!
12 | * Clone this repo.
13 | * Fix the bug.
14 | * Make sure all the **tests** passed.
15 | * Open the pull request, attach the issue number in description if possible.
16 |
17 | ### Suggest a new feature
18 | * Use **Feature request template**.
19 |
20 | ### Launching tests
21 | Tests are launched automatically using travis.
22 |
23 | You can also run them with ```python -m pytest --cov=./```
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | recursive-include ceteris_paribus/datasets *.csv
2 | include ceteris_paribus/plots/ceterisParibusD3.js
3 | include ceteris_paribus/plots/plot_template.html
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://travis-ci.org/ModelOriented/pyCeterisParibus)
3 | [](https://codecov.io/gh/ModelOriented/pyCeterisParibus)
4 | [](https://pyceterisparibus.readthedocs.io/en/latest/?badge=latest)
5 | [](https://pepy.tech/project/pyceterisparibus)
6 | [](https://badge.fury.io/py/pyCeterisParibus)
7 |
8 | [](https://doi.org/10.5281/zenodo.2667756)
9 | [](http://joss.theoj.org/papers/aad9a21c61c01adebe11bc5bc1ceca92)
10 |
11 | # pyCeterisParibus
12 |
13 | **Please note that the Ceteris Paribus method is moved to the dalex Python package which is actively maintained. If you will experience any problem with pyCeterisParibus please consider the dalex implementation at https://dalex.drwhy.ai/python/api/.**
14 |
15 | pyCeterisParibus is a Python library based on an *R* package [CeterisParibus](https://github.com/pbiecek/ceterisParibus).
16 | It implements Ceteris Paribus Plots.
17 | They allow understanding how the model response would change if a selected variable is changed.
18 | It’s a perfect tool for What-If scenarios. Ceteris Paribus is a Latin phrase meaning all else unchanged.
19 | These plots present the change in model response as the values of one feature change with all others being fixed.
20 | Ceteris Paribus method is model-agnostic - it works for any Machine Learning model.
21 | The idea is an extension of PDP (Partial Dependency Plots) and ICE (Individual Conditional Expectations) plots.
22 | It allows explaining single observations for multiple variables at the same time.
23 | The plot engine is developed [here](https://github.com/ModelOriented/ceterisParibusD3).
24 |
25 | ## Why is it so useful?
26 | There might be several motivations behind utilizing this idea.
27 | Imagine a person gets a low credit score.
28 | The client wants to understand how to increase the score and the scoring institution (e.g. a bank) should be able to answer such questions.
29 | Moreover, this method is useful for researchers and developers to analyze, debug, explain and improve Machine Learning models, assisting the entire process of the model design.
30 |
31 | ## Setup
32 | Tested on Python 3.5+
33 |
34 | PyCeterisParibus is on [PyPI](https://pypi.org/project/pyCeterisParibus/). Simply run:
35 |
36 | ```bash
37 | pip install pyCeterisParibus
38 | ```
39 | or install the newest version from GitHub by executing:
40 | ```bash
41 | pip install git+https://github.com/ModelOriented/pyCeterisParibus
42 | ```
43 | or download the sources, enter the main directory and perform:
44 | ```bash
45 | https://github.com/ModelOriented/pyCeterisParibus.git
46 | cd pyCeterisParibus
47 | python setup.py install # (alternatively use pip install .)
48 | ```
49 |
50 | ## Docs
51 | A detailed description of all methods and their parameters might be found in [documentation](https://pyceterisparibus.readthedocs.io/en/latest/ceteris_paribus.html).
52 |
53 | To build the documentation locally:
54 | ```bash
55 | pip install -r requirements-dev.txt
56 | cd docs
57 | make html
58 | ```
59 | and open `_build/html/index.html`
60 |
61 | ## Examples
62 | Below we present use cases on two well-known datasets - Titanic and Iris. More examples e.g. for regression problems might be found [here](examples) and in jupyter notebooks [here](jupyter-notebooks).
63 |
64 | Note, that in order to run the examples you need to install extra requirements from `requirements-dev.txt`.
65 |
66 | ## Use case - Titanic survival
67 | We demonstrate Ceteris Paribus Plots using the well-known Titanic dataset. In this problem, we examine the chance of survival for Titanic passengers.
68 | We start with preprocessing the data and creating an XGBoost model.
69 | ```python
70 | import pandas as pd
71 | df = pd.read_csv('titanic_train.csv')
72 |
73 | y = df['Survived']
74 | x = df.drop(['Survived', 'PassengerId', 'Name', 'Cabin', 'Ticket'],
75 | inplace=False, axis=1)
76 |
77 | valid = x['Age'].isnull() | x['Embarked'].isnull()
78 | x = x[-valid]
79 | y = y[-valid]
80 |
81 | from sklearn.model_selection import train_test_split
82 | X_train, X_test, y_train, y_test = train_test_split(x, y,
83 | test_size=0.2, random_state=42)
84 | ```
85 | ```python
86 | from sklearn.pipeline import Pipeline
87 | from sklearn.preprocessing import StandardScaler, OneHotEncoder
88 | from sklearn.compose import ColumnTransformer
89 |
90 | # We create the preprocessing pipelines for both numeric and categorical data.
91 | numeric_features = ['Pclass', 'Age', 'SibSp', 'Parch', 'Fare']
92 | numeric_transformer = Pipeline(steps=[
93 | ('scaler', StandardScaler())])
94 |
95 | categorical_features = ['Embarked', 'Sex']
96 | categorical_transformer = Pipeline(steps=[
97 | ('onehot', OneHotEncoder(handle_unknown='ignore'))])
98 |
99 | preprocessor = ColumnTransformer(
100 | transformers=[
101 | ('num', numeric_transformer, numeric_features),
102 | ('cat', categorical_transformer, categorical_features)])
103 | ```
104 |
105 | ```python
106 | from xgboost import XGBClassifier
107 | xgb_clf = Pipeline(steps=[('preprocessor', preprocessor),
108 | ('classifier', XGBClassifier())])
109 | xgb_clf.fit(X_train, y_train)
110 | ```
111 |
112 | Here the pyCeterisParibus starts. Since this library works in a model agnostic fashion, first we need to create a wrapper around the model with uniform predict interface.
113 | ```python
114 | from ceteris_paribus.explainer import explain
115 | explainer_xgb = explain(xgb_clf, data=x, y=y, label='XGBoost',
116 | predict_function=lambda X: xgb_clf.predict_proba(X)[::, 1])
117 | ```
118 |
119 |
120 | ### Single variable profile
121 | Let's look at Mr Ernest James Crease, the 19-year-old man, travelling on the 3. class from Southampton with an 8 pounds ticket in his pocket. He died on Titanic. Most likely, this would not have been the case had Ernest been a few years younger.
122 | Figure 1 presents the chance of survival for a person like Ernest at different ages. We can see things were tough for people like him unless they were a child.
123 |
124 | ```python
125 | ernest = X_test.iloc[10]
126 | label_ernest = y_test.iloc[10]
127 | from ceteris_paribus.profiles import individual_variable_profile
128 | cp_xgb = individual_variable_profile(explainer_xgb, ernest, label_ernest)
129 | ```
130 |
131 | Having calculated the profile we can plot it. Note, that `plot_notebook` might be used instead of `plot` when used in Jupyter notebooks.
132 |
133 | ```python
134 | from ceteris_paribus.plots.plots import plot
135 | plot(cp_xgb, selected_variables=["Age"])
136 | ```
137 |
138 | 
139 |
140 | ### Many models
141 | The above picture explains the prediction of XGBoost model. What if we compare various models?
142 |
143 | ```python
144 | from sklearn.ensemble import RandomForestClassifier
145 | from sklearn.linear_model import LogisticRegression
146 | rf_clf = Pipeline(steps=[('preprocessor', preprocessor),
147 | ('classifier', RandomForestClassifier())])
148 | linear_clf = Pipeline(steps=[('preprocessor', preprocessor),
149 | ('classifier', LogisticRegression())])
150 |
151 | rf_clf.fit(X_train, y_train)
152 | linear_clf.fit(X_train, y_train)
153 |
154 | explainer_rf = explain(rf_clf, data=x, y=y, label='RandomForest',
155 | predict_function=lambda X: rf_clf.predict_proba(X)[::, 1])
156 | explainer_linear = explain(linear_clf, data=x, y=y, label='LogisticRegression',
157 | predict_function=lambda X: linear_clf.predict_proba(X)[::, 1])
158 |
159 | plot(cp_xgb, cp_rf, cp_linear, selected_variables=["Age"])
160 | ```
161 |
162 | 
163 |
164 | Clearly, XGBoost offers a better fit than Logistic Regression.
165 | Also, it predicts a higher chance of survival at child's age than the Random Forest model does.
166 |
167 | ### Profiles for many variables
168 | This time we have a look at Miss. Elizabeth Mussey Eustis. She is 54 years old, travels at 1. class with her sister Marta, as they return to the US from their tour of southern Europe. They both survived the disaster.
169 |
170 | ```python
171 | elizabeth = X_test.iloc[1]
172 | label_elizabeth = y_test.iloc[1]
173 | cp_xgb_2 = individual_variable_profile(explainer_xgb, elizabeth, label_elizabeth)
174 | ```
175 |
176 | ```python
177 | plot(cp_xgb_2, selected_variables=["Pclass", "Sex", "Age", "Embarked"])
178 | ```
179 |
180 | 
181 |
182 | Would she have returned home if she had travelled at 3. class or if she had been a man? As we can observe this is less likely. On the other hand, for a first class, female passenger chances of survival were high regardless of age. Note, this was different in the case of Ernest. Place of embarkment (Cherbourg) has no influence, which is expected behaviour.
183 |
184 | ### Feature interactions and average response
185 | Now, what if we look at passengers most similar to Miss. Eustis (middle-aged, upper class)?
186 |
187 | ```python
188 | from ceteris_paribus.select_data import select_neighbours
189 | neighbours = select_neighbours(X_train, elizabeth,
190 | selected_variables=['Pclass', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked'],
191 | n=15)
192 | cp_xgb_ns = individual_variable_profile(explainer_xgb, neighbours)
193 | ```
194 |
195 | ```python
196 | plot(cp_xgb_ns, color="Sex", selected_variables=["Pclass", "Age"],
197 | aggregate_profiles='mean', size_pdps=6, alpha_pdps=1, size=2)
198 | ```
199 |
200 | 
201 |
202 | There are two distinct clusters of passengers determined with their gender, therefore a *PDP* average plot (on grey) does not show the whole picture. Children of both genders were likely to survive, but then we see a large gap. Also, being female increased the chance of survival mostly for second and first class passengers.
203 |
204 | Plot function comes with extensive customization options. List of all parameters might be found in the documentation. Additionally, one can interact with the plot by hovering over a point of interest to see more details. Similarly, there is an interactive table with options for highlighting relevant elements as well as filtering and sorting rows.
205 |
206 |
207 |
208 | ### Multiclass models - Iris dataset
209 | Prepare dataset and model
210 | ```python
211 | iris = load_iris()
212 |
213 | def random_forest_classifier():
214 | rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
215 | rf_model.fit(iris['data'], iris['target'])
216 | return rf_model, iris['data'], iris['target'], iris['feature_names']
217 | ```
218 |
219 | Wrap model into explainers
220 | ```python
221 | rf_model, iris_x, iris_y, iris_var_names = random_forest_classifier()
222 |
223 | explainer_rf1 = explain(rf_model, iris_var_names, iris_x, iris_y,
224 | predict_function= lambda X: rf_model.predict_proba(X)[::, 0], label=iris.target_names[0])
225 | explainer_rf2 = explain(rf_model, iris_var_names, iris_x, iris_y,
226 | predict_function= lambda X: rf_model.predict_proba(X)[::, 1], label=iris.target_names[1])
227 | explainer_rf3 = explain(rf_model, iris_var_names, iris_x, iris_y,
228 | predict_function= lambda X: rf_model.predict_proba(X)[::, 2], label=iris.target_names[2])
229 | ```
230 |
231 | Calculate profiles and plot
232 | ```python
233 | cp_rf1 = individual_variable_profile(explainer_rf1, iris_x[0], iris_y[0])
234 | cp_rf2 = individual_variable_profile(explainer_rf2, iris_x[0], iris_y[0])
235 | cp_rf3 = individual_variable_profile(explainer_rf3, iris_x[0], iris_y[0])
236 |
237 | plot(cp_rf1, cp_rf2, cp_rf3, selected_variables=['petal length (cm)', 'petal width (cm)', 'sepal length (cm)'])
238 | ```
239 | 
240 |
241 | ## Contributing
242 | You're more than welcomed to contribute to this package. See the [guideline](CONTRIBUTING.md).
243 |
244 | ## Acknowledgments
245 | Work on this package was financially supported by the ‘NCN Opus grant 2016/21/B/ST6/0217’.
246 |
--------------------------------------------------------------------------------
/ceteris_paribus/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/ceteris_paribus/__init__.py
--------------------------------------------------------------------------------
/ceteris_paribus/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | DATASETS_DIR = os.path.dirname(__file__)
4 |
--------------------------------------------------------------------------------
/ceteris_paribus/explainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from collections import namedtuple
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | Explainer = namedtuple("Explainer", "model var_names data y predict_fun label")
9 |
10 |
11 | def explain(model, variable_names=None, data=None, y=None, predict_function=None, label=None):
12 | """
13 | This function creates a unified representation of a model, which can be further processed by various explainers
14 |
15 | :param model: a model to be explained
16 | :param variable_names: names of variables, if not supplied then derived from data
17 | :param data: data that was used for fitting
18 | :param y: labels for the data
19 | :param predict_function: function that takes the data and returns predictions
20 | :param label: label of the model, if not supplied the function will try to infer it from the model object, otherwise unset
21 | :return: Explainer object
22 | """
23 | if not predict_function:
24 | if hasattr(model, 'predict'):
25 | if isinstance(data, pd.core.frame.DataFrame):
26 | predict_function = model.predict
27 | else:
28 | predict_function = lambda df: model.predict(df.values)
29 | else:
30 | raise ValueError('Unable to find predict function')
31 | if not label:
32 | logging.warning("Model is unlabeled... \n You can add label using method set_label")
33 | label_items = re.split('\(', str(model))
34 | if label_items and len(label_items) > 1:
35 | label = label_items[0]
36 | else:
37 | label = 'unlabeled_model'
38 | if variable_names is None:
39 | if isinstance(data, pd.core.frame.DataFrame):
40 | variable_names = list(data)
41 | else:
42 | raise ValueError("Unable to impute the variable names. Those must be supplied directly!")
43 |
44 | if data is not None:
45 | if not isinstance(data, pd.core.frame.DataFrame):
46 | data = np.array(data)
47 | if data.ndim == 1:
48 | # make 1D array 2D
49 | data = data.reshape((1, -1))
50 | if len(variable_names) != data.shape[1]:
51 | raise ValueError("Incorrect number of variables given.")
52 |
53 | data = pd.DataFrame(data, columns=variable_names)
54 |
55 | if y is not None:
56 | if isinstance(y, pd.core.frame.DataFrame):
57 | y = pd.Series(y[0])
58 | else:
59 | y = pd.Series(y)
60 |
61 | explainer = Explainer(model, list(variable_names), data, y, predict_function, label)
62 | return explainer
63 |
64 |
--------------------------------------------------------------------------------
/ceteris_paribus/gower.py:
--------------------------------------------------------------------------------
1 | """ This is the module for calculating gower's distance/dissimilarity """
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 |
7 | # Normalize the array
8 | def _normalize_mixed_data_columns(arr):
9 | """
10 | Returns the numpy array representation of the data.
11 | Loses information about the types
12 | """
13 | return np.array(arr, dtype=object)
14 |
15 |
16 | def _calc_range_mixed_data_columns(data, observation, dtypes):
17 | """ Return range for each numeric column, 0 for categorical variables """
18 | _, cols = data.shape
19 |
20 | result = np.zeros(cols)
21 | for col in range(cols):
22 | if np.issubdtype(dtypes[col], np.number):
23 | result[col] = max(max(data[:, col]), observation[col]) - min(min(data[:, col]), observation[col])
24 | return result
25 |
26 |
27 | def _gower_dist(xi, xj, ranges, dtypes):
28 | """
29 | Return gower's distance between xi and xj
30 |
31 | :param ranges: ranges of values for each column
32 | :param dtypes: types of each column
33 | """
34 | dtypes = np.array(dtypes)
35 |
36 | sum_sij = 0.0
37 | sum_wij = 0.0
38 |
39 | cols = len(ranges)
40 | for col in range(cols):
41 | if np.issubdtype(dtypes[col], np.number):
42 | if pd.isnull(xi[col]) or pd.isnull(xj[col]) or np.isclose(0, ranges[col]):
43 | wij = 0
44 | sij = 0
45 | else:
46 | wij = 1
47 | sij = abs(xi[col] - xj[col]) / ranges[col]
48 | else:
49 | sij = xi[col] != xj[col]
50 | wij = 0 if pd.isnull(xi[col]) and pd.isnull(xj[col]) else 1
51 |
52 | sum_sij += wij * sij
53 | sum_wij += wij
54 |
55 | return sum_sij / sum_wij
56 |
57 |
58 | def gower_distances(data, observation):
59 | """
60 | Return an array of distances between all observations and a chosen one
61 | Based on:
62 | https://sourceforge.net/projects/gower-distance-4python
63 | https://beta.vu.nl/nl/Images/stageverslag-hoven_tcm235-777817.pdf
64 |
65 | :type data: DataFrame
66 | :type observation: pandas Series
67 | """
68 | dtypes = data.dtypes
69 | data = _normalize_mixed_data_columns(data)
70 | observation = _normalize_mixed_data_columns(observation)
71 | ranges = _calc_range_mixed_data_columns(data, observation, dtypes)
72 | return np.array([_gower_dist(row, observation, ranges, dtypes) for row in data])
73 |
--------------------------------------------------------------------------------
/ceteris_paribus/plots/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | PLOTS_DIR = os.path.dirname(__file__)
4 |
--------------------------------------------------------------------------------
/ceteris_paribus/plots/plot_template.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | ceterisParibus D3 template
8 |
9 |
10 |
11 |
12 |
15 |
16 |
17 |
18 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/ceteris_paribus/plots/plots.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import sys
5 | import webbrowser
6 | from shutil import copyfile
7 |
8 | from flask import Flask, render_template
9 |
10 | from ceteris_paribus.plots import PLOTS_DIR
11 | from ceteris_paribus.utils import save_observations, save_profiles
12 |
13 | app = Flask(__name__, template_folder=PLOTS_DIR)
14 |
15 | MAX_PLOTS_PER_SESSION = 10000
16 |
17 | # generates ids for subsequent plots
18 | _PLOT_NUMBER = iter(range(MAX_PLOTS_PER_SESSION))
19 |
20 | # directory with all files produced in plot generation process
21 | _DATA_PATH = '_plot_files'
22 | os.makedirs(_DATA_PATH, exist_ok=True)
23 | _D3_engine_filename = 'ceterisParibusD3.js'
24 | copyfile(os.path.join(PLOTS_DIR, _D3_engine_filename), os.path.join(_DATA_PATH, _D3_engine_filename))
25 |
26 |
27 | def _calculate_plot_variables(cp_profile, selected_variables):
28 | """
29 | Helper function to calculate valid subset of variables to be plotted
30 | """
31 | if not selected_variables:
32 | return cp_profile.selected_variables
33 | if not set(selected_variables).issubset(set(cp_profile.selected_variables)):
34 | logging.warning("Selected variables are not subset of all variables. Parameter is ignored.")
35 | return cp_profile.selected_variables
36 | else:
37 | return list(selected_variables)
38 |
39 |
40 | def _params_update(params, **kwargs):
41 | for key, val in kwargs.items():
42 | if val:
43 | params[key] = val
44 | return params
45 |
46 |
47 | def _detect_plot_destination(destination):
48 | """
49 | Detect plot destination (browser or embedded inside a notebook) based on the user choice
50 | """
51 | if destination == "notebook":
52 | try:
53 | from IPython.display import IFrame
54 | return "notebook"
55 | except ImportError:
56 | logging.warning("Notebook environment not detected. Plots will be placed in a new tab")
57 | # when browser is explicitly chosen or as a default
58 | return "browser"
59 |
60 |
61 | def plot_notebook(cp_profile, *args, **kwargs):
62 | """
63 | Wrapper for the ``plot`` function with option to embed in the notebook
64 | """
65 | plot(cp_profile, *args, destination="notebook", **kwargs)
66 |
67 |
68 | def plot(cp_profile, *args, destination="browser",
69 | show_profiles=True, show_observations=True, show_residuals=False, show_rugs=False,
70 | aggregate_profiles=None, selected_variables=None,
71 | color=None, size=2, alpha=0.4,
72 | color_pdps=None, size_pdps=None, alpha_pdps=None,
73 | color_points=None, size_points=None, alpha_points=None,
74 | color_residuals=None, size_residuals=None, alpha_residuals=None,
75 | height=500, width=600,
76 | plot_title='', yaxis_title='y',
77 | print_observations=True,
78 | **kwargs):
79 | """
80 | Plot ceteris paribus profile
81 |
82 | :param cp_profile: ceteris paribus profile
83 | :param args: next (optional) ceteris paribus profiles to be plotted along
84 | :param destination: available *browser* - open plot in a new tab, *notebook* - embed a plot in jupyter notebook if possible
85 | :param show_profiles: whether to show profiles
86 | :param show_observations: whether to show individual observations
87 | :param show_residuals: whether to plot residuals
88 | :param show_rugs: whether to plot rugs
89 | :param aggregate_profiles: if specified additional aggregated profile will be plotted, available values: `mean`, `median`
90 | :param selected_variables: variables selected for the plots
91 | :param color: color for profiles - either a color or a variable that should be used for coloring
92 | :param size: size of lines to be plotted
93 | :param alpha: opacity of lines (between 0 and 1)
94 | :param color_pdps: color of pdps - aggregated profiles
95 | :param size_pdps: size of pdps - aggregated profiles
96 | :param alpha_pdps: opacity of pdps - aggregated profiles
97 | :param color_points: color points to be plotted
98 | :param size_points: size of points to be plotted
99 | :param alpha_points: opacity of points
100 | :param color_residuals: color of plotted residuals
101 | :param size_residuals: size of plotted residuals
102 | :param alpha_residuals: opacity of plotted residuals
103 | :param height: height of the window containing plots
104 | :param width: width of the window containing plots
105 | :param plot_title: Title of the plot displayed above
106 | :param yaxis_title: Label for the y axis
107 | :param print_observations: whether to print the table with observations values
108 | :param kwargs: other options passed to the plot
109 | """
110 |
111 | params = dict()
112 | params.update(kwargs)
113 | params["variables"] = _calculate_plot_variables(cp_profile, selected_variables)
114 | params['color'] = "_label_" if args else color
115 | params['show_profiles'] = show_profiles
116 | params['show_observations'] = show_observations
117 | params['show_rugs'] = show_rugs
118 | params['show_residuals'] = show_residuals and (cp_profile.new_observation_true is not None)
119 | params['add_table'] = print_observations
120 | params['height'] = height
121 | params['width'] = width
122 | params['plot_title'] = plot_title
123 | params['size_ices'] = size
124 | params['alpha_ices'] = alpha
125 | params = _params_update(params,
126 | color_pdps=color_pdps, size_pdps=size_pdps, alpha_pdps=alpha_pdps,
127 | size_points=size_points, alpha_points=alpha_points, color_points=color_points,
128 | size_residuals=size_residuals, alpha_residuals=alpha_residuals,
129 | color_residuals=color_residuals,
130 | yaxis_title=yaxis_title)
131 |
132 | if aggregate_profiles in {'mean', 'median', None}:
133 | params['aggregate_profiles'] = aggregate_profiles
134 | else:
135 | logging.warning("Incorrect function for profile aggregation: {}. Parameter ignored."
136 | "Available values are: 'mean' and 'median'".format(aggregate_profiles))
137 | params['aggregate_profiles'] = None
138 |
139 | all_profiles = [cp_profile] + list(args)
140 |
141 | plot_id = str(next(_PLOT_NUMBER))
142 | plot_path, params_path, obs_path, profile_path = _get_data_paths(plot_id)
143 |
144 | with open(params_path, 'w') as f:
145 | f.write("params = " + json.dumps(params, indent=2) + ";")
146 |
147 | save_observations(all_profiles, obs_path)
148 | save_profiles(all_profiles, profile_path)
149 |
150 | with app.app_context():
151 | data = render_template("plot_template.html", i=plot_id, params=params)
152 |
153 | with open(plot_path, 'w') as f:
154 | f.write(data)
155 |
156 | destination = _detect_plot_destination(destination)
157 | if destination == "notebook":
158 | from IPython.display import IFrame, display
159 | display(IFrame(plot_path, width=int(width * 1.1), height=int(height * 1.1)))
160 | else:
161 | # open plot in a browser
162 | if sys.platform == "darwin": # check if on OSX
163 | plot_path = "file://" + os.path.abspath(plot_path)
164 | webbrowser.open(plot_path)
165 |
166 |
167 | def _get_data_paths(plot_id):
168 | plot_path = os.path.join(_DATA_PATH, "plots{}.html".format(plot_id))
169 | params_path = os.path.join(_DATA_PATH, "params{}.js".format(plot_id))
170 | obs_path = os.path.join(_DATA_PATH, 'obs{}.js'.format(plot_id))
171 | profile_path = os.path.join(_DATA_PATH, "profile{}.js".format(plot_id))
172 | return plot_path, params_path, obs_path, profile_path
173 |
--------------------------------------------------------------------------------
/ceteris_paribus/profiles.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 | import pandas as pd
6 |
7 | from ceteris_paribus.utils import transform_into_Series
8 |
9 |
10 | def individual_variable_profile(explainer, new_observation, y=None, variables=None, grid_points=101,
11 | variable_splits=None):
12 | """
13 | Calculate ceteris paribus profile
14 |
15 | :param explainer: a model to be explained
16 | :param new_observation: a new observation for which the profiles are calculated
17 | :param y: y true labels for `new_observation`. If specified then will be added to ceteris paribus plots
18 | :param variables: collection of variables selected for calculating profiles
19 | :param grid_points: number of points for profile
20 | :param variable_splits: dictionary of splits for variables, in most cases created with `_calculate_variable_splits()`. If None then it will be calculated based on validation data avaliable in the `explainer`.
21 | :return: instance of CeterisParibus class
22 | """
23 | variables = _get_variables(variables, explainer)
24 | if not isinstance(new_observation, pd.core.frame.DataFrame):
25 | new_observation = np.array(new_observation)
26 | if new_observation.ndim == 1:
27 | # make 1D array 2D
28 | new_observation = new_observation.reshape((1, -1))
29 | new_observation = pd.DataFrame(new_observation, columns=explainer.var_names)
30 | else:
31 | try:
32 | new_observation.columns = explainer.var_names
33 | except ValueError as e:
34 | raise ValueError("Mismatched number of variables {} instead of {}".format(len(new_observation.columns),
35 | len(explainer.var_names)))
36 |
37 | if y is not None:
38 | y = transform_into_Series(y)
39 |
40 | cp_profile = CeterisParibus(explainer, new_observation, y, variables, grid_points, variable_splits)
41 | return cp_profile
42 |
43 |
44 | def _get_variables(variables, explainer):
45 | """
46 | Get valid variables for the profile
47 |
48 | :param variables: collection of variables
49 | :param explainer: Explainer object
50 | :return: collection of variables
51 | """
52 | if variables:
53 | if not set(variables).issubset(explainer.var_names):
54 | raise ValueError('Invalid variable names')
55 | else:
56 | variables = explainer.var_names
57 | return variables
58 |
59 |
60 | def _valid_variable_splits(variable_splits, variables):
61 | """
62 | Validate variable splits
63 | """
64 | if set(variable_splits.keys()) == set(variables):
65 | return True
66 | else:
67 | logging.warning("Variable splits are incorrect - wrong set of variables supplied. Parameter is ignored")
68 | return False
69 |
70 |
71 | class CeterisParibus:
72 |
73 | def __init__(self, explainer, new_observation, y, selected_variables, grid_points, variable_splits):
74 | """
75 | Creates Ceteris Paribus object
76 |
77 | :param explainer: explainer wrapping the model
78 | :param new_observation: DataFrame with observations for which the profiles will be calculated
79 | :param y: pandas Series with labels for the observations
80 | :param selected_variables: variables for which the profiles are calculated
81 | :param grid_points: number of points in a single variable split if calculated automatically
82 | :param variable_splits: mapping of variables into points the profile will be calculated, if None then calculate with the function `_calculate_variable_splits`
83 | """
84 | self._data = explainer.data
85 | self._predict_function = explainer.predict_fun
86 | self._grid_points = grid_points
87 | self._label = explainer.label
88 | self.all_variable_names = explainer.var_names
89 | self.new_observation = new_observation
90 | self.selected_variables = list(selected_variables)
91 | variable_splits = self._get_variable_splits(variable_splits)
92 | self.profile = self._calculate_profile(variable_splits)
93 | self.new_observation_values = self.new_observation[self.selected_variables]
94 | self.new_observation_predictions = self._predict_function(self.new_observation)
95 | self.new_observation_true = y
96 |
97 | def _get_variable_splits(self, variable_splits):
98 | """
99 | Helper function for calculating variable splits
100 | """
101 | if variable_splits is None or not _valid_variable_splits(variable_splits, self.selected_variables):
102 | variables_dict = self._data.to_dict(orient='series')
103 | chosen_variables_dict = dict((var, variables_dict[var]) for var in self.selected_variables)
104 | variable_splits = self._calculate_variable_splits(chosen_variables_dict)
105 | return variable_splits
106 |
107 | def _calculate_profile(self, variable_splits):
108 | """
109 | Calculate DataFrame profile
110 | """
111 | profiles_list = [self._single_variable_df(var_name, var_split)
112 | for var_name, var_split in variable_splits.items()]
113 | profile = pd.concat(profiles_list, ignore_index=True)
114 | return profile
115 |
116 | def _calculate_single_split(self, X_var):
117 | """
118 | Calculate the split for a single variable
119 |
120 | :param X_var: variable data - pandas Series
121 | :return: selected subset of values for the variable
122 | """
123 | if np.issubdtype(X_var.dtype, np.floating):
124 | # grid points might be larger than the number of unique values
125 | quantiles = np.linspace(0, 1, self._grid_points)
126 | return np.quantile(X_var, quantiles)
127 | else:
128 | return np.unique(X_var)
129 |
130 | def _calculate_variable_splits(self, chosen_variables_dict):
131 | """
132 | Calculate splits for the given variables
133 |
134 | :param chosen_variables_dict: mapping of variables into the values
135 | :return: mapping of variables into selected subsets of values
136 | """
137 | return dict(
138 | (var, self._calculate_single_split(X_var))
139 | for (var, X_var) in chosen_variables_dict.items()
140 | )
141 |
142 | def _single_variable_df(self, var_name, var_split):
143 | """
144 | Calculate profiles for a given variable
145 |
146 | :param var_name: variable name
147 | :param var_split: split values for the variable
148 | :return: DataFrame with profiles for a given variable
149 | """
150 | return pd.concat([self._single_observation_df(observation, var_name, var_split, profile_id)
151 | for profile_id, observation in self.new_observation.iterrows()], ignore_index=True)
152 |
153 | def _single_observation_df(self, observation, var_name, var_split, profile_id):
154 | """
155 | Calculates the single profile
156 |
157 | :param observation: observation for which the profile is calculated
158 | :param var_name: variable name
159 | :param var_split: split values for the variable
160 | :param profile_id: profile id
161 | :return: DataFrame with the calculated profile values
162 | """
163 | # grid_points and self._grid_point might differ for categorical variables
164 | grid_points = len(var_split)
165 | X = np.tile(observation, (grid_points, 1))
166 | X_dict = OrderedDict(zip(self.all_variable_names, X.T))
167 | df = pd.DataFrame.from_dict(X_dict)
168 | df[var_name] = var_split
169 | df['_yhat_'] = self._predict_function(df)
170 | df['_vname_'] = np.repeat(var_name, grid_points)
171 | df['_label_'] = self._label
172 | df['_ids_'] = profile_id
173 | return df
174 |
175 | def split_by(self, column):
176 | """
177 | Split cp profile data frame by values of a given column
178 |
179 | :return: sorted mapping of values to dataframes
180 | """
181 | return OrderedDict(sorted(list(self.profile.groupby(column, sort=False))))
182 |
183 | def set_label(self, label):
184 | self._label = label
185 |
186 | def print_profile(self):
187 | print('Selected variables: {}'.format(self.selected_variables))
188 | print('Training data size: {}'.format(self._data.shape[0]))
189 | print(self.profile)
190 |
191 |
--------------------------------------------------------------------------------
/ceteris_paribus/select_data.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from ceteris_paribus.gower import gower_distances
7 | from ceteris_paribus.utils import transform_into_Series
8 |
9 |
10 | def select_sample(data, y=None, n=15, seed=42):
11 | """
12 | Select sample from dataset.
13 |
14 | :param data: array or dataframe with observations
15 | :param y: labels for observations
16 | :param n: size of the sample
17 | :param seed: seed for random number generator
18 | :return: selected observations and corresponding labels if provided
19 | """
20 | np.random.seed(seed)
21 | if n > data.shape[0]:
22 | logging.warning("Given n ({}) is larger than data size ({})".format(n, data.shape[0]))
23 | n = data.shape[0]
24 | indices = np.random.choice(data.shape[0], n, replace=False)
25 |
26 | if isinstance(data, pd.core.frame.DataFrame):
27 | sampled_x = data.iloc[indices]
28 | sampled_x.reset_index(drop=True, inplace=True)
29 | else:
30 | sampled_x = data[indices, :]
31 |
32 | if y is not None:
33 | y = transform_into_Series(y)
34 | return sampled_x, y[indices].reset_index(drop=True)
35 | else:
36 | return sampled_x
37 |
38 |
39 | def _select_columns(data, observation, variable_names=None, selected_variables=None):
40 | """
41 | Select data with specified columns
42 |
43 | :param data: DataFrame with observations
44 | :param observation: pandas Series with reference observation for neighbours selection
45 | :param variable_names: names of all variables
46 | :param selected_variables: names of selected variables
47 | :return: DataFrame with observations and pandas Series with referenced observation, with selected columns
48 | """
49 | if selected_variables is None:
50 | return data, observation
51 | try:
52 | if variable_names is None:
53 | if isinstance(data, pd.core.frame.DataFrame):
54 | variable_names = data.columns
55 | else:
56 | raise ValueError("Impossible to detect variable names")
57 | indices = [list(variable_names).index(var) for var in selected_variables]
58 | except ValueError:
59 | logging.warning("Selected variables: {} is not a subset of variables: {}".format(
60 | selected_variables, variable_names))
61 | return data, observation
62 |
63 | subset_data = data.iloc[:, indices]
64 | return subset_data, observation[indices]
65 |
66 |
67 | def select_neighbours(data, observation, y=None, variable_names=None, selected_variables=None, dist_fun='gower', n=20):
68 | """
69 | Select observations from dataset, that are similar to a given observation
70 |
71 | :param data: array or DataFrame with observations
72 | :param observation: reference observation for neighbours selection
73 | :param y: labels for observations
74 | :param variable_names: names of variables
75 | :param selected_variables: selected variables - require supplying variable names along with data
76 | :param dist_fun: 'gower' or distance function, as pairwise distances in sklearn, gower works with missing data
77 | :param n: size of the sample
78 | :return: DataFrame with selected observations and pandas Series with corresponding labels if provided
79 | """
80 | if n > data.shape[0]:
81 | logging.warning("Given n ({}) is larger than data size ({})".format(n, data.shape[0]))
82 | n = data.shape[0]
83 |
84 | if not isinstance(data, pd.core.frame.DataFrame):
85 | data = pd.DataFrame(data)
86 |
87 | observation = transform_into_Series(observation)
88 |
89 | # columns are selected for the purpose of distance calculation
90 | selected_data, observation = _select_columns(data, observation, variable_names, selected_variables)
91 |
92 | if dist_fun == 'gower':
93 | distances = gower_distances(selected_data, observation)
94 | else:
95 | if not callable(dist_fun):
96 | raise ValueError('Distance has to be "gower" or a custom function')
97 | distances = dist_fun([observation], selected_data)[0]
98 |
99 | indices = np.argpartition(distances, n - 1)[:n]
100 |
101 | # selected points have all variables
102 | selected_points = data.iloc[indices]
103 | selected_points.reset_index(drop=True, inplace=True)
104 |
105 | if y is not None:
106 | y = transform_into_Series(y)
107 | return selected_points, y.iloc[indices].reset_index(drop=True)
108 | else:
109 | return selected_points
110 |
--------------------------------------------------------------------------------
/ceteris_paribus/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 |
7 | def save_profiles(profiles, filename):
8 | data = dump_profiles(profiles)
9 | with open(filename, 'w') as f:
10 | f.write("profile = {};".format(json.dumps(data, indent=2, default=default)))
11 |
12 |
13 | def dump_profiles(profiles):
14 | """
15 | Dump profiles into json format accepted by the plotting library
16 |
17 | :return: list of dicts representing points in the profiles
18 | """
19 | data = []
20 | for cp_profile in profiles:
21 | for i, row in cp_profile.profile.iterrows():
22 | data.append(dict(zip(cp_profile.profile.columns, row)))
23 | return data
24 |
25 |
26 | def default(o):
27 | """
28 | Workaround for dumping arrays with np.int64 type into json
29 | From: https://stackoverflow.com/a/50577730/7828646
30 |
31 | """
32 | if isinstance(o, np.int64):
33 | return int(o)
34 | return float(o)
35 |
36 |
37 | def save_observations(profiles, filename):
38 | data = dump_observations(profiles)
39 | with open(filename, 'w') as f:
40 | f.write("observation = {};".format(json.dumps(data, indent=2, default=default)))
41 |
42 |
43 | def dump_observations(profiles):
44 | """
45 | Dump observations data into json format accepted by the plotting library
46 |
47 | :return: list of dicts representing observations in the profiles
48 | """
49 | data = []
50 | for profile in profiles:
51 | for i, yhat in enumerate(profile.new_observation_predictions):
52 | d = dict(zip(profile.all_variable_names, profile.new_observation.iloc[i]))
53 | d['_yhat_'] = yhat
54 | d['_label_'] = profile._label
55 | d['_ids_'] = i
56 | d['_y_'] = profile.new_observation_true[i] if profile.new_observation_true is not None else None
57 | data.append(d)
58 | return data
59 |
60 |
61 | def transform_into_Series(y):
62 | if isinstance(y, pd.core.frame.DataFrame):
63 | y = pd.Series(y.iloc[:, 0])
64 | else:
65 | y = pd.Series(y)
66 | return y
67 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | ignore:
2 | - "examples/"
3 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SOURCEDIR = .
8 | BUILDDIR = _build
9 |
10 | # Put it first so that "make" without argument is like "make help".
11 | help:
12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13 |
14 | .PHONY: help Makefile
15 |
16 | # Catch-all target: route all unknown targets to Sphinx using the new
17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18 | %: Makefile
19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/ceteris_paribus.plots.rst:
--------------------------------------------------------------------------------
1 | ceteris\_paribus.plots package
2 | ==============================
3 |
4 | Submodules
5 | ----------
6 |
7 | ceteris\_paribus.plots.plots module
8 | -----------------------------------
9 |
10 | .. automodule:: ceteris_paribus.plots.plots
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 |
16 | Module contents
17 | ---------------
18 |
19 | .. automodule:: ceteris_paribus.plots
20 | :members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
--------------------------------------------------------------------------------
/docs/ceteris_paribus.rst:
--------------------------------------------------------------------------------
1 | ceteris\_paribus package
2 | ========================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 |
9 | ceteris_paribus.plots
10 |
11 | Submodules
12 | ----------
13 |
14 | ceteris\_paribus.explainer module
15 | ---------------------------------
16 |
17 | .. automodule:: ceteris_paribus.explainer
18 | :members:
19 | :undoc-members:
20 |
21 | ceteris\_paribus.gower module
22 | -----------------------------
23 | Gower Distance is a distance measure, that might be used to calculate the similarity between two observations with both categorical and numerical values. It also permits missing values in categorical variables. Therefore this measure might be applied in any dataset. Here we use it as a default function for finding the closest observations to the given one.
24 |
25 | The original paper describing the idea might be found `here `_.
26 |
27 | .. automodule:: ceteris_paribus.gower
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 |
32 | ceteris\_paribus.profiles module
33 | --------------------------------
34 |
35 | .. automodule:: ceteris_paribus.profiles
36 | :members:
37 | :undoc-members:
38 | :show-inheritance:
39 |
40 | ceteris\_paribus.select\_data module
41 | ------------------------------------
42 |
43 | .. automodule:: ceteris_paribus.select_data
44 | :members:
45 | :undoc-members:
46 | :show-inheritance:
47 |
48 |
49 | Module contents
50 | ---------------
51 |
52 | .. automodule:: ceteris_paribus
53 | :members:
54 | :undoc-members:
55 | :show-inheritance:
56 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # This file does only contain a selection of the most common options. For a
6 | # full list see the documentation:
7 | # http://www.sphinx-doc.org/en/master/config
8 |
9 | # -- Path setup --------------------------------------------------------------
10 |
11 | # If extensions (or modules to document with autodoc) are in another directory,
12 | # add these directories to sys.path here. If the directory is relative to the
13 | # documentation root, use os.path.abspath to make it absolute, like shown here.
14 | #
15 | import os
16 | import sys
17 |
18 | sys.path.insert(0, os.path.abspath('../'))
19 |
20 | # -- Project information -----------------------------------------------------
21 |
22 | project = 'pyCeterisParibus'
23 | copyright = '2019, Michał Kuźba'
24 | author = 'Michał Kuźba'
25 |
26 | # The short X.Y version
27 | version = ''
28 | # The full version, including alpha/beta/rc tags
29 | release = ''
30 |
31 | # -- General configuration ---------------------------------------------------
32 |
33 | # If your documentation needs a minimal Sphinx version, state it here.
34 | #
35 | # needs_sphinx = '1.0'
36 |
37 | # Add any Sphinx extension module names here, as strings. They can be
38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
39 | # ones.
40 | extensions = [
41 | 'sphinx.ext.autodoc',
42 | 'm2r'
43 | ]
44 |
45 | # Add any paths that contain templates here, relative to this directory.
46 | templates_path = ['_templates']
47 |
48 | # The suffix(es) of source filenames.
49 | # You can specify multiple suffix as a list of string:
50 | #
51 | source_suffix = ['.rst', '.md']
52 | # source_suffix = '.rst'
53 |
54 | # The master toctree document.
55 | master_doc = 'index'
56 |
57 | # The language for content autogenerated by Sphinx. Refer to documentation
58 | # for a list of supported languages.
59 | #
60 | # This is also used if you do content translation via gettext catalogs.
61 | # Usually you set "language" from the command line for these cases.
62 | language = None
63 |
64 | # List of patterns, relative to source directory, that match files and
65 | # directories to ignore when looking for source files.
66 | # This pattern also affects html_static_path and html_extra_path.
67 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
68 |
69 | # The name of the Pygments (syntax highlighting) style to use.
70 | pygments_style = None
71 |
72 | # -- Options for HTML output -------------------------------------------------
73 |
74 | # The theme to use for HTML and HTML Help pages. See the documentation for
75 | # a list of builtin themes.
76 | #
77 | html_theme = 'default'
78 | # html_theme_path = sphinx_bootstrap_theme.get_html_theme_path()
79 |
80 | # Theme options are theme-specific and customize the look and feel of a theme
81 | # further. For a list of options available for each theme, see the
82 | # documentation.
83 | #
84 | # html_theme_options = {}
85 |
86 | # Add any paths that contain custom static files (such as style sheets) here,
87 | # relative to this directory. They are copied after the builtin static files,
88 | # so a file named "default.css" will overwrite the builtin "default.css".
89 | html_static_path = ['_static']
90 |
91 | # Custom sidebar templates, must be a dictionary that maps document names
92 | # to template names.
93 | #
94 | # The default sidebars (for documents that don't match any pattern) are
95 | # defined by theme itself. Builtin themes are using these templates by
96 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
97 | # 'searchbox.html']``.
98 | #
99 | # html_sidebars = {}
100 |
101 |
102 | # -- Options for HTMLHelp output ---------------------------------------------
103 |
104 | # Output file base name for HTML help builder.
105 | htmlhelp_basename = 'pyCeterisParibusdoc'
106 |
107 | # -- Options for LaTeX output ------------------------------------------------
108 |
109 | latex_elements = {
110 | # The paper size ('letterpaper' or 'a4paper').
111 | #
112 | # 'papersize': 'letterpaper',
113 |
114 | # The font size ('10pt', '11pt' or '12pt').
115 | #
116 | # 'pointsize': '10pt',
117 |
118 | # Additional stuff for the LaTeX preamble.
119 | #
120 | # 'preamble': '',
121 |
122 | # Latex figure (float) alignment
123 | #
124 | # 'figure_align': 'htbp',
125 | }
126 |
127 | # Grouping the document tree into LaTeX files. List of tuples
128 | # (source start file, target name, title,
129 | # author, documentclass [howto, manual, or own class]).
130 | latex_documents = [
131 | (master_doc, 'pyCeterisParibus.tex', 'pyCeterisParibus Documentation',
132 | 'Michał Kuźba', 'manual'),
133 | ]
134 |
135 | # -- Options for manual page output ------------------------------------------
136 |
137 | # One entry per manual page. List of tuples
138 | # (source start file, name, description, authors, manual section).
139 | man_pages = [
140 | (master_doc, 'pyceterisparibus', 'pyCeterisParibus Documentation',
141 | [author], 1)
142 | ]
143 |
144 | # -- Options for Texinfo output ----------------------------------------------
145 |
146 | # Grouping the document tree into Texinfo files. List of tuples
147 | # (source start file, target name, title, author,
148 | # dir menu entry, description, category)
149 | texinfo_documents = [
150 | (master_doc, 'pyCeterisParibus', 'pyCeterisParibus Documentation',
151 | author, 'pyCeterisParibus', 'One line description of project.',
152 | 'Miscellaneous'),
153 | ]
154 |
155 | # -- Options for Epub output -------------------------------------------------
156 |
157 | # Bibliographic Dublin Core info.
158 | epub_title = project
159 |
160 | # The unique identifier of the text. This can be a ISBN number
161 | # or the project homepage.
162 | #
163 | # epub_identifier = ''
164 |
165 | # A unique identification for the text.
166 | #
167 | # epub_uid = ''
168 |
169 | # A list of files that should not be packed into the epub file.
170 | epub_exclude_files = ['search.html']
171 |
172 | # -- Extension configuration -------------------------------------------------
173 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. pyCeterisParibus documentation master file, created by
2 | sphinx-quickstart on Tue Jan 29 19:38:46 2019.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 |
7 | Welcome to pyCeterisParibus's documentation!
8 | ============================================
9 |
10 | .. toctree::
11 | :maxdepth: 2
12 | :caption: Contents:
13 |
14 | readme_include
15 | modules
16 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/misc:
--------------------------------------------------------------------------------
1 | ../misc/
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | ceteris_paribus
2 | ===============
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | ceteris_paribus
8 |
--------------------------------------------------------------------------------
/docs/readme_include.rst:
--------------------------------------------------------------------------------
1 | .. mdinclude:: ../README.md
2 | .. mdinclude:: ../misc
3 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/examples/__init__.py
--------------------------------------------------------------------------------
/examples/categorical_variables.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 |
5 | from ceteris_paribus.datasets import DATASETS_DIR
6 | from ceteris_paribus.plots.plots import plot
7 |
8 | df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv'))
9 |
10 | x = df.drop(['charges'], inplace=False, axis=1)
11 |
12 | y = df['charges']
13 |
14 | var_names = list(x)
15 |
16 | from sklearn.pipeline import Pipeline
17 | from sklearn.preprocessing import StandardScaler, OneHotEncoder
18 | from sklearn.compose import ColumnTransformer
19 | from sklearn.ensemble import RandomForestRegressor
20 |
21 | # We create the preprocessing pipelines for both numeric and categorical data.
22 | numeric_features = ['age', 'bmi', 'children']
23 | numeric_transformer = Pipeline(steps=[
24 | ('scaler', StandardScaler())])
25 |
26 | categorical_features = ['sex', 'smoker', 'region']
27 | categorical_transformer = Pipeline(steps=[
28 | ('onehot', OneHotEncoder(handle_unknown='ignore'))])
29 |
30 | preprocessor = ColumnTransformer(
31 | transformers=[
32 | ('num', numeric_transformer, numeric_features),
33 | ('cat', categorical_transformer, categorical_features)])
34 |
35 | # Append classifier to preprocessing pipeline.
36 | # Now we have a full prediction pipeline.
37 | clf = Pipeline(steps=[('preprocessor', preprocessor),
38 | ('classifier', RandomForestRegressor())])
39 |
40 | clf.fit(x, y)
41 |
42 | from ceteris_paribus.explainer import explain
43 |
44 | explainer_cat = explain(clf, var_names, x, y, label="categorical_model")
45 |
46 | from ceteris_paribus.profiles import individual_variable_profile
47 |
48 | cp_cat = individual_variable_profile(explainer_cat, x.iloc[:10], y.iloc[:10])
49 |
50 | cp_cat.print_profile()
51 | plot(cp_cat)
52 |
53 | plot(cp_cat, color="smoker")
54 |
--------------------------------------------------------------------------------
/examples/cheatsheet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | from sklearn import ensemble, svm
4 | from sklearn.datasets import load_iris
5 | from sklearn.linear_model import LinearRegression
6 |
7 | from ceteris_paribus.explainer import explain
8 | from ceteris_paribus.plots.plots import plot
9 | from ceteris_paribus.profiles import individual_variable_profile
10 | from ceteris_paribus.select_data import select_neighbours
11 | from ceteris_paribus.datasets import DATASETS_DIR
12 |
13 | df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv'))
14 |
15 | df = df[['age', 'bmi', 'children', 'charges']]
16 |
17 | x = df.drop(['charges'], inplace=False, axis=1)
18 | y = df['charges']
19 | var_names = list(x.columns)
20 | x = x.values
21 | y = y.values
22 |
23 | iris = load_iris()
24 |
25 |
26 | def random_forest_classifier():
27 | rf_model = ensemble.RandomForestClassifier(n_estimators=100, random_state=42)
28 |
29 | rf_model.fit(iris['data'], iris['target'])
30 |
31 | return rf_model, iris['data'], iris['target'], iris['feature_names']
32 |
33 |
34 | def linear_regression_model():
35 | linear_model = LinearRegression()
36 | linear_model.fit(x, y)
37 | # model, data, labels, variable_names
38 | return linear_model, x, y, var_names
39 |
40 |
41 | def gradient_boosting_model():
42 | gb_model = ensemble.GradientBoostingRegressor(n_estimators=1000, random_state=42)
43 | gb_model.fit(x, y)
44 | return gb_model, x, y, var_names
45 |
46 |
47 | def supported_vector_machines_model():
48 | svm_model = svm.SVR(C=0.01, gamma='scale', kernel='poly')
49 | svm_model.fit(x, y)
50 | return svm_model, x, y, var_names
51 |
52 |
53 | if __name__ == "__main__":
54 |
55 | (linear_model, data, labels, variable_names) = linear_regression_model()
56 | (gb_model, _, _, _) = gradient_boosting_model()
57 | (svm_model, _, _, _) = supported_vector_machines_model()
58 |
59 | explainer_linear = explain(linear_model, variable_names, data, y)
60 | explainer_gb = explain(gb_model, variable_names, data, y)
61 | explainer_svm = explain(svm_model, variable_names, data, y)
62 |
63 | # single profile
64 | cp_1 = individual_variable_profile(explainer_gb, x[0], y[0])
65 | plot(cp_1, destination="notebook", selected_variables=["bmi"], print_observations=False)
66 |
67 | # local fit
68 | neighbours_x, neighbours_y = select_neighbours(x, x[10], y=y, n=10)
69 | cp_2 = individual_variable_profile(explainer_gb,
70 | neighbours_x, neighbours_y)
71 | plot(cp_2, show_residuals=True, selected_variables=["age"], print_observations=False, color_residuals='red',
72 | plot_title='')
73 |
74 | # aggregate profiles
75 | plot(cp_2, aggregate_profiles="mean", selected_variables=["age"], color_pdps='black', size_pdps=6,
76 | alpha_pdps=0.7, print_observations=False,
77 | plot_title='')
78 |
79 | # many variables
80 | plot(cp_1, selected_variables=["bmi", "age", "children"], print_observations=False, plot_title='', width=950)
81 |
82 | # many models
83 | cp_svm = individual_variable_profile(explainer_svm, x[0], y[0])
84 | cp_linear = individual_variable_profile(explainer_linear, x[0], y[0])
85 | plot(cp_1, cp_svm, cp_linear, print_observations=False, plot_title='', width=1050)
86 |
87 | # color by feature
88 | plot(cp_2, color="bmi", print_observations=False, plot_title='', width=1050, selected_variables=["age"], size=3)
89 |
90 | # classification multiplot
91 | rf_model, iris_x, iris_y, iris_var_names = random_forest_classifier()
92 |
93 | explainer_rf1 = explain(rf_model, iris_var_names, iris_x, iris_y,
94 | predict_function= lambda X: rf_model.predict_proba(X)[::, 0], label=iris.target_names[0])
95 | explainer_rf2 = explain(rf_model, iris_var_names, iris_x, iris_y,
96 | predict_function= lambda X: rf_model.predict_proba(X)[::, 1], label=iris.target_names[1])
97 | explainer_rf3 = explain(rf_model, iris_var_names, iris_x, iris_y,
98 | predict_function= lambda X: rf_model.predict_proba(X)[::, 2], label=iris.target_names[2])
99 |
100 |
101 | cp_rf1 = individual_variable_profile(explainer_rf1, iris_x[0], iris_y[0])
102 | cp_rf2 = individual_variable_profile(explainer_rf2, iris_x[0], iris_y[0])
103 | cp_rf3 = individual_variable_profile(explainer_rf3, iris_x[0], iris_y[0])
104 |
105 | plot(cp_rf1, cp_rf2, cp_rf3, selected_variables=['petal length (cm)', 'petal width (cm)', 'sepal length (cm)'],
106 | plot_title='', print_observations=False, width=1050)
--------------------------------------------------------------------------------
/examples/classification_example.py:
--------------------------------------------------------------------------------
1 | from sklearn.datasets import load_iris
2 | from sklearn.ensemble import RandomForestClassifier
3 | from sklearn.model_selection import train_test_split
4 |
5 | from ceteris_paribus.explainer import explain
6 | from ceteris_paribus.plots.plots import plot
7 | from ceteris_paribus.profiles import individual_variable_profile
8 |
9 | iris = load_iris()
10 |
11 | X = iris['data']
12 | y = iris['target']
13 |
14 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
15 |
16 | print(iris['feature_names'])
17 |
18 | def random_forest_classifier():
19 | rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
20 |
21 | rf_model.fit(X_train, y_train)
22 |
23 | return rf_model, X_train, y_train, iris['feature_names']
24 |
25 |
26 | if __name__ == "__main__":
27 | (model, data, labels, variable_names) = random_forest_classifier()
28 | predict_function = lambda X: model.predict_proba(X)[::, 0]
29 | explainer_rf = explain(model, variable_names, data, labels, predict_function=predict_function)
30 | cp_profile = individual_variable_profile(explainer_rf, X[1], y=y[1])
31 | plot(cp_profile)
32 |
--------------------------------------------------------------------------------
/examples/keras_example.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Dense, Activation
2 | from keras.models import Sequential
3 | from keras.wrappers.scikit_learn import KerasRegressor
4 | from sklearn.datasets import load_boston
5 | from sklearn.model_selection import train_test_split
6 | from sklearn.pipeline import Pipeline
7 | from sklearn.preprocessing import StandardScaler
8 |
9 | from ceteris_paribus.explainer import explain
10 | from ceteris_paribus.plots.plots import plot
11 | from ceteris_paribus.profiles import individual_variable_profile
12 |
13 | boston = load_boston()
14 | x = boston.data
15 | y = boston.target
16 | x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42)
17 |
18 |
19 | def network_architecture():
20 | model = Sequential()
21 | model.add(Dense(640, input_dim=x.shape[1]))
22 | model.add(Activation('tanh'))
23 | model.add(Dense(320))
24 | model.add(Activation('tanh'))
25 | model.add(Dense(1))
26 | model.compile(loss='mean_squared_error', optimizer='adam')
27 | return model
28 |
29 |
30 | def keras_model():
31 | estimators = [
32 | ('scaler', StandardScaler()),
33 | ('mlp', KerasRegressor(build_fn=network_architecture, epochs=200))
34 | ]
35 | model = Pipeline(estimators)
36 | model.fit(x_train, y_train)
37 | return model, x_train, y_train, boston.feature_names
38 |
39 |
40 | if __name__ == "__main__":
41 | model, x_train, y_train, var_names = keras_model()
42 | explainer_keras = explain(model, var_names, x_train, y_train, label='KerasMLP')
43 | cp = individual_variable_profile(explainer_keras, x_train[:10], y=y_train[:10],
44 | variables=["CRIM", "ZN", "AGE", "INDUS", "B"])
45 | plot(cp, show_residuals=True, selected_variables=["CRIM", "ZN", "AGE", "B"], show_observations=True,
46 | show_rugs=True)
47 |
--------------------------------------------------------------------------------
/examples/multiple_models.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pandas as pd
4 | from sklearn import ensemble, svm
5 | from sklearn.linear_model import LinearRegression
6 |
7 | from ceteris_paribus.datasets import DATASETS_DIR
8 | from ceteris_paribus.explainer import explain
9 | from ceteris_paribus.plots.plots import plot
10 | from ceteris_paribus.profiles import individual_variable_profile
11 | from ceteris_paribus.select_data import select_sample
12 |
13 | df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv'))
14 |
15 | df = df[['age', 'bmi', 'children', 'charges']]
16 |
17 | x = df.drop(['charges'], inplace=False, axis=1)
18 | y = df['charges']
19 | var_names = list(x.columns)
20 | x = x.values
21 | y = y.values
22 |
23 |
24 | def linear_regression_model():
25 | # Create linear regression object
26 | linear_model = LinearRegression()
27 |
28 | # Train the model using the training set
29 | linear_model.fit(x, y)
30 |
31 | # model, data, labels, variable_names
32 | return linear_model, x, y, var_names
33 |
34 |
35 | def gradient_boosting_model():
36 | gb_model = ensemble.GradientBoostingRegressor(n_estimators=1000, random_state=42)
37 | gb_model.fit(x, y)
38 | return gb_model, x, y, var_names
39 |
40 |
41 | def supported_vector_machines_model():
42 | svm_model = svm.SVR(C=0.01, gamma='scale')
43 | svm_model.fit(x, y)
44 | return svm_model, x, y, var_names
45 |
46 |
47 | if __name__ == "__main__":
48 | (linear_model, data, labels, variable_names) = linear_regression_model()
49 | (gb_model, _, _, _) = gradient_boosting_model()
50 | (svm_model, _, _, _) = supported_vector_machines_model()
51 |
52 | explainer_linear = explain(linear_model, variable_names, data, y)
53 | explainer_gb = explain(gb_model, variable_names, data, y)
54 | explainer_svm = explain(svm_model, variable_names, data, y)
55 |
56 | cp_profile = individual_variable_profile(explainer_linear, x[0], y[0])
57 | plot(cp_profile, show_residuals=True)
58 |
59 | sample_x, sample_y = select_sample(x, y, n=10)
60 | cp2 = individual_variable_profile(explainer_gb, sample_x, y=sample_y)
61 |
62 | cp3 = individual_variable_profile(explainer_gb, x[0], y[0])
63 | plot(cp3, show_residuals=True)
64 |
65 | plot(cp_profile, cp3, show_residuals=True)
66 |
--------------------------------------------------------------------------------
/examples/regression_example.py:
--------------------------------------------------------------------------------
1 | from sklearn import datasets, ensemble
2 | from sklearn.model_selection import train_test_split
3 |
4 | from ceteris_paribus.explainer import explain
5 | from ceteris_paribus.plots.plots import plot
6 | from ceteris_paribus.profiles import individual_variable_profile
7 | from ceteris_paribus.select_data import select_sample, select_neighbours
8 |
9 | boston = datasets.load_boston()
10 |
11 | X = boston['data']
12 | y = boston['target']
13 |
14 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
15 |
16 |
17 | def random_forest_regression():
18 | # Create linear regression object
19 | rf_model = ensemble.RandomForestRegressor(n_estimators=100, random_state=42)
20 |
21 | # Train the model using the training set
22 | rf_model.fit(X_train, y_train)
23 |
24 | # model, data, labels, variable_names
25 | return rf_model, X_train, y_train, list(boston['feature_names'])
26 |
27 |
28 | if __name__ == "__main__":
29 | (model, data, labels, variable_names) = random_forest_regression()
30 | explainer_rf = explain(model, variable_names, data, labels)
31 |
32 | cp_profile = individual_variable_profile(explainer_rf, X_train[0], y=y_train[0], variables=['TAX', 'CRIM'])
33 | plot(cp_profile)
34 |
35 | sample = select_sample(X_train, n=3)
36 | cp2 = individual_variable_profile(explainer_rf, sample, variables=['TAX', 'CRIM'])
37 | plot(cp2)
38 |
39 | neighbours = select_neighbours(X_train, X_train[0], variable_names=variable_names,
40 | selected_variables=variable_names, n=15)
41 | cp3 = individual_variable_profile(explainer_rf, neighbours, variables=['LSTAT', 'RM'],
42 | variable_splits={'LSTAT': [10, 20, 30], 'RM': [4, 5, 6, 7]})
43 | plot(cp3)
44 |
--------------------------------------------------------------------------------
/jupyter-notebooks/CeterisParibusCheatsheet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Ceteris Paribus cheatsheet"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "### Prepare dataset"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import pandas as pd\n",
24 | "from ceteris_paribus.datasets import DATASETS_DIR\n",
25 | "import os\n",
26 | "\n",
27 | "df = pd.read_csv(os.path.join(DATASETS_DIR, 'insurance.csv'))\n",
28 | "\n",
29 | "df = df[['age', 'bmi', 'children', 'charges']]\n",
30 | "\n",
31 | "x = df.drop(['charges'], inplace=False, axis=1)\n",
32 | "y = df['charges']\n",
33 | "var_names = list(x.columns)\n",
34 | "x = x.values\n",
35 | "y = y.values"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "### Prepare regression models"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 2,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "from sklearn.linear_model import LinearRegression\n",
52 | "from sklearn import ensemble, svm\n",
53 | "\n",
54 | "def linear_regression_model():\n",
55 | " linear_model = LinearRegression()\n",
56 | " linear_model.fit(x, y)\n",
57 | " # model, data, labels, variable_names\n",
58 | " return linear_model, x, y, var_names\n",
59 | "\n",
60 | "\n",
61 | "def gradient_boosting_model():\n",
62 | " gb_model = ensemble.GradientBoostingRegressor(n_estimators=1000, random_state=42)\n",
63 | " gb_model.fit(x, y)\n",
64 | " return gb_model, x, y, var_names\n",
65 | "\n",
66 | "\n",
67 | "def supported_vector_machines_model():\n",
68 | " svm_model = svm.SVR(C=0.01, gamma='scale', kernel='poly')\n",
69 | " svm_model.fit(x, y)\n",
70 | " return svm_model, x, y, var_names"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "### Calculate single profile variables"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 3,
83 | "metadata": {},
84 | "outputs": [
85 | {
86 | "name": "stderr",
87 | "output_type": "stream",
88 | "text": [
89 | "WARNING:root:Model is unlabeled... \n",
90 | " You can add label using method set_label\n"
91 | ]
92 | },
93 | {
94 | "data": {
95 | "text/html": [
96 | "\n",
97 | " \n",
104 | " "
105 | ],
106 | "text/plain": [
107 | ""
108 | ]
109 | },
110 | "metadata": {},
111 | "output_type": "display_data"
112 | }
113 | ],
114 | "source": [
115 | "from ceteris_paribus.explainer import explain\n",
116 | "from ceteris_paribus.plots.plots import plot_notebook\n",
117 | "from ceteris_paribus.profiles import individual_variable_profile\n",
118 | "\n",
119 | "(gb_model, data, labels, variable_names) = gradient_boosting_model()\n",
120 | "\n",
121 | "explainer_gb = explain(gb_model, variable_names, data, y)\n",
122 | "\n",
123 | "cp_1 = individual_variable_profile(explainer_gb, x[0], y[0])\n",
124 | "plot_notebook(cp_1, selected_variables=[\"bmi\"], print_observations=False)"
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "### Local fit plots"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 4,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "(svm_model, _, _, _) = supported_vector_machines_model()\n",
141 | "(linear_model, data, labels, variable_names) = linear_regression_model()\n",
142 | "\n",
143 | "explainer_linear = explain(linear_model, variable_names, data, y, label='linear_model')\n",
144 | "explainer_svm = explain(svm_model, variable_names, data, y, label='svm_model')"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 5,
150 | "metadata": {},
151 | "outputs": [
152 | {
153 | "data": {
154 | "text/html": [
155 | "\n",
156 | " \n",
163 | " "
164 | ],
165 | "text/plain": [
166 | ""
167 | ]
168 | },
169 | "metadata": {},
170 | "output_type": "display_data"
171 | }
172 | ],
173 | "source": [
174 | "from ceteris_paribus.select_data import select_neighbours\n",
175 | "\n",
176 | "neighbours_x, neighbours_y = select_neighbours(x, x[10], y=y, n=10)\n",
177 | "cp_2 = individual_variable_profile(explainer_gb,\n",
178 | " neighbours_x, neighbours_y)\n",
179 | "plot_notebook(cp_2, show_residuals=True, selected_variables=[\"age\"], print_observations=False, color_residuals='red', \n",
180 | " plot_title='')"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "### Aggregate profiles"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 6,
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "data": {
197 | "text/html": [
198 | "\n",
199 | " \n",
206 | " "
207 | ],
208 | "text/plain": [
209 | ""
210 | ]
211 | },
212 | "metadata": {},
213 | "output_type": "display_data"
214 | }
215 | ],
216 | "source": [
217 | "plot_notebook(cp_2, aggregate_profiles=\"mean\", selected_variables=[\"age\"], color_pdps='black', size_pdps=6,\n",
218 | " alpha_pdps=0.7, print_observations=False,\n",
219 | " plot_title='')"
220 | ]
221 | },
222 | {
223 | "cell_type": "markdown",
224 | "metadata": {},
225 | "source": [
226 | "### Many variables"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 7,
232 | "metadata": {},
233 | "outputs": [
234 | {
235 | "data": {
236 | "text/html": [
237 | "\n",
238 | " \n",
245 | " "
246 | ],
247 | "text/plain": [
248 | ""
249 | ]
250 | },
251 | "metadata": {},
252 | "output_type": "display_data"
253 | }
254 | ],
255 | "source": [
256 | "plot_notebook(cp_1, selected_variables=[\"bmi\", \"age\", \"children\"], print_observations=False, plot_title='', width=900)"
257 | ]
258 | },
259 | {
260 | "cell_type": "markdown",
261 | "metadata": {},
262 | "source": [
263 | "### Many models"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 8,
269 | "metadata": {},
270 | "outputs": [
271 | {
272 | "data": {
273 | "text/html": [
274 | "\n",
275 | " \n",
282 | " "
283 | ],
284 | "text/plain": [
285 | ""
286 | ]
287 | },
288 | "metadata": {},
289 | "output_type": "display_data"
290 | }
291 | ],
292 | "source": [
293 | "cp_svm = individual_variable_profile(explainer_svm, x[0], y[0])\n",
294 | "cp_linear = individual_variable_profile(explainer_linear, x[0], y[0])\n",
295 | "plot_notebook(cp_1, cp_svm, cp_linear, print_observations=False, plot_title='', width=850, size=3, alpha=0.7)"
296 | ]
297 | },
298 | {
299 | "cell_type": "markdown",
300 | "metadata": {},
301 | "source": [
302 | "### Color by feature"
303 | ]
304 | },
305 | {
306 | "cell_type": "code",
307 | "execution_count": 9,
308 | "metadata": {},
309 | "outputs": [
310 | {
311 | "data": {
312 | "text/html": [
313 | "\n",
314 | " \n",
321 | " "
322 | ],
323 | "text/plain": [
324 | ""
325 | ]
326 | },
327 | "metadata": {},
328 | "output_type": "display_data"
329 | }
330 | ],
331 | "source": [
332 | "plot_notebook(cp_2, color=\"age\", plot_title='', width=900, size=3)"
333 | ]
334 | },
335 | {
336 | "cell_type": "markdown",
337 | "metadata": {},
338 | "source": [
339 | "### Prepare classification example"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 10,
345 | "metadata": {},
346 | "outputs": [],
347 | "source": [
348 | "from sklearn.datasets import load_iris\n",
349 | "from sklearn.ensemble import RandomForestClassifier\n",
350 | "iris = load_iris()\n",
351 | "\n",
352 | "def random_forest_classifier():\n",
353 | " rf_model = ensemble.RandomForestClassifier(n_estimators=100, random_state=42)\n",
354 | "\n",
355 | " rf_model.fit(iris['data'], iris['target'])\n",
356 | "\n",
357 | " return rf_model, iris['data'], iris['target'], iris['feature_names']"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 11,
363 | "metadata": {},
364 | "outputs": [],
365 | "source": [
366 | "rf_model, iris_x, iris_y, iris_var_names = random_forest_classifier()\n",
367 | "\n",
368 | "explainer_rf1 = explain(rf_model, iris_var_names, iris_x, iris_y,\n",
369 | " predict_function= lambda X: rf_model.predict_proba(X)[::, 0], label=iris.target_names[0])\n",
370 | "explainer_rf2 = explain(rf_model, iris_var_names, iris_x, iris_y,\n",
371 | " predict_function= lambda X: rf_model.predict_proba(X)[::, 1], label=iris.target_names[1])\n",
372 | "explainer_rf3 = explain(rf_model, iris_var_names, iris_x, iris_y,\n",
373 | " predict_function= lambda X: rf_model.predict_proba(X)[::, 2], label=iris.target_names[2])\n",
374 | "\n",
375 | "\n",
376 | "cp_rf1 = individual_variable_profile(explainer_rf1, iris_x[0], iris_y[0])\n",
377 | "cp_rf2 = individual_variable_profile(explainer_rf2, iris_x[0], iris_y[0])\n",
378 | "cp_rf3 = individual_variable_profile(explainer_rf3, iris_x[0], iris_y[0])"
379 | ]
380 | },
381 | {
382 | "cell_type": "markdown",
383 | "metadata": {},
384 | "source": [
385 | "### Multiclass profiles"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": 14,
391 | "metadata": {},
392 | "outputs": [
393 | {
394 | "data": {
395 | "text/html": [
396 | "\n",
397 | " \n",
404 | " "
405 | ],
406 | "text/plain": [
407 | ""
408 | ]
409 | },
410 | "metadata": {},
411 | "output_type": "display_data"
412 | }
413 | ],
414 | "source": [
415 | "plot_notebook(cp_rf1, cp_rf2, cp_rf3, selected_variables=['petal length (cm)', 'petal width (cm)', 'sepal length (cm)'],\n",
416 | " plot_title='', print_observations=False, width=800, height=300, size=4, alpha=0.9)"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": null,
422 | "metadata": {},
423 | "outputs": [],
424 | "source": []
425 | }
426 | ],
427 | "metadata": {
428 | "kernelspec": {
429 | "display_name": "ceteris",
430 | "language": "python",
431 | "name": "ceteris"
432 | },
433 | "language_info": {
434 | "codemirror_mode": {
435 | "name": "ipython",
436 | "version": 3
437 | },
438 | "file_extension": ".py",
439 | "mimetype": "text/x-python",
440 | "name": "python",
441 | "nbconvert_exporter": "python",
442 | "pygments_lexer": "ipython3",
443 | "version": "3.5.2"
444 | }
445 | },
446 | "nbformat": 4,
447 | "nbformat_minor": 2
448 | }
449 |
--------------------------------------------------------------------------------
/jupyter-notebooks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModelOriented/pyCeterisParibus/e38a18ff3a0faf485f3128be2505a7faeeb27234/jupyter-notebooks/__init__.py
--------------------------------------------------------------------------------
/jupyter-notebooks/titanic.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Titanic survival"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "### Read data"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 1,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import pandas as pd\n",
24 | "from ceteris_paribus.datasets import DATASETS_DIR\n",
25 | "import os\n",
26 | "df = pd.read_csv(os.path.join(DATASETS_DIR, 'titanic_train.csv'))"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 2,
32 | "metadata": {},
33 | "outputs": [
34 | {
35 | "data": {
36 | "text/html": [
37 | "