├── .gitignore
├── .travis.yml
├── AUTHORS
├── LICENSE
├── README.md
├── examples
└── regression_iris.ipynb
├── requirements.txt
├── resources
├── fo_ale_quant.png
└── so_ale_quant.png
├── setup.cfg
├── setup.py
├── src
└── alepython
│ ├── __init__.py
│ └── ale.py
└── tests
├── __init__.py
├── test_ale_calculation.py
├── test_figure_creation.py
├── test_parameters.py
└── utils
├── __init__.py
└── test_models.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # celery beat schedule file
87 | celerybeat-schedule
88 |
89 | # SageMath parsed files
90 | *.sage.py
91 |
92 | # Environments
93 | .env
94 | .venv
95 | env/
96 | venv/
97 | ENV/
98 | env.bak/
99 | venv.bak/
100 |
101 | # Spyder project settings
102 | .spyderproject
103 | .spyproject
104 |
105 | # Rope project settings
106 | .ropeproject
107 |
108 | # mkdocs documentation
109 | /site
110 |
111 | # mypy
112 | .mypy_cache/
113 | .dmypy.json
114 | dmypy.json
115 |
116 | # Pyre type checker
117 | .pyre/
118 |
119 | # Package version.
120 | _version.py
121 |
122 | # Pre-commit configuration file.
123 | .pre-commit-config.yaml
124 |
125 | # Git attributes.
126 | .gitattributes
127 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - "3.5"
4 | - "3.6"
5 | - "3.7"
6 | - "3.8"
7 | before_install:
8 | - python --version
9 | - pip install -U pip
10 | install: pip install -e ".[test]"
11 | script: pytest
12 |
--------------------------------------------------------------------------------
/AUTHORS:
--------------------------------------------------------------------------------
1 | Maxime Jumelle - maxime@blent.ai - (gh: @MaximeJumelle)
2 | Sanjif Rajaratnam (gh: @sanjifr3)
3 | Alexander Kuhn-Regnier (gh: @akuhnregnier)
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2018- Maxime Jumelle
190 | Copyright 2020- ALEPython developers (see AUTHORS file)
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/MaximeJumelle/ALEPython)
2 | [](https://github.com/ambv/black)
3 | [](https://opensource.org/licenses/Apache-2.0)
4 |
5 | Python Accumulated Local Effects package.
6 |
7 | # Why ALE?
8 |
9 | Explaining model predictions is very common when you have to deploy a Machine Learning algorithm on a large scale.
10 | There are many methods that help us understand our model; one these uses Partial Dependency Plots (PDP), which have been widely used for years.
11 |
12 | However, they suffer from a stringent assumption: **features have to be uncorrelated**.
13 | In real world scenarios, features are often correlated, whether because some are directly computed from others, or because observed phenomena produce correlated distributions.
14 |
15 | Accumulated Local Effects (or ALE) plots first proposed by [_Apley and Zhu_ (2016)](#1) alleviate this issue reasonably by using actual conditional marginal distributions instead of considering each marginal distribution of features.
16 | This is more reliable when handling (even strongly) correlated variables.
17 |
18 | This package aims to provide useful and quick access to ALE plots, so that you can easily explain your model through predictions.
19 |
20 | For further details about model interpretability and ALE plots, see eg. [_Molnar_ (2020)](#2).
21 |
22 | # Install
23 |
24 | ALEPython is supported on Python >= 3.5.
25 | You can either install package via `pip`:
26 |
27 | ```sh
28 | pip install alepython
29 | ```
30 | directly from source (including requirements):
31 | ```sh
32 | pip install git+https://github.com/MaximeJumelle/ALEPython.git@dev#egg=alepython
33 | ```
34 | or after cloning (or forking) for development purposes, including test dependencies:
35 | ```sh
36 | git clone https://github.com/MaximeJumelle/ALEPython.git
37 | pip install -e "ALEPython[test]"
38 | ```
39 |
40 | # Usage
41 |
42 | ```python
43 | from alepython import ale_plot
44 | # Plots ALE of feature 'cont' with Monte-Carlo replicas (default : 50).
45 | ale_plot(model, X_train, 'cont', monte_carlo=True)
46 | ```
47 |
48 | # Highlights
49 |
50 | - First-order ALE plots of continuous features
51 | - Second-order ALE plots of continuous features
52 |
53 | # Gallery
54 |
55 | ## First-order ALE plots of continuous features
56 |
57 |
58 |
59 | ---
60 |
61 | ## Second-order ALE plots of continuous features
62 |
63 |
64 |
65 | # Work In Progress
66 |
67 | - First-order ALE plots of categorical features
68 | - Enhanced visualization of first-order plots
69 | - Second-order ALE plots of categorical features
70 | - Documentation and API reference
71 | - Jupyter Notebook examples
72 | - Upload to PyPi
73 | - Upload to conda-forge
74 | - Use of matplotlib styles or kwargs to allow overriding plotting appearance
75 |
76 | If you are interested in the project, I would be happy to collaborate with you since there are still quite a lot of improvements needed.
77 |
78 | ## References
79 |
80 |
81 | Apley, Daniel W., and Jingyu Zhu. 2016. Visualizing the Effects of Predictor Variables in Black Box Supervised Learning Models. .
82 |
83 |
84 | Molnar, Christoph. 2020. Interpretable Machine Learning. .
85 |
--------------------------------------------------------------------------------
/examples/regression_iris.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import pandas as pd\n",
11 | "from loguru import logger\n",
12 | "from sklearn import datasets\n",
13 | "from sklearn.ensemble import RandomForestRegressor"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "metadata": {},
19 | "source": [
20 | "## Loading the Dataset\n",
21 | "\n",
22 | "Define our response `y` and predictor, `X`, variables."
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "iris = datasets.load_iris()\n",
32 | "dataset = pd.DataFrame(data=iris[\"data\"], columns=iris[\"feature_names\"])\n",
33 | "X = dataset.iloc[:, 1:]\n",
34 | "y = dataset.iloc[:, 0]\n",
35 | "dataset.head()"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## Fit a Random Forest Regressor"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "model = RandomForestRegressor(n_estimators=20, bootstrap=True)\n",
52 | "model.fit(X, y)"
53 | ]
54 | },
55 | {
56 | "cell_type": "markdown",
57 | "metadata": {},
58 | "source": [
59 | "## Create our ALE Plots\n",
60 | "\n",
61 | "### 1D Main Effect ALE Plot"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "import matplotlib as mpl\n",
71 | "\n",
72 | "from alepython import ale_plot\n",
73 | "\n",
74 | "mpl.rc(\"figure\", figsize=(9, 6))\n",
75 | "ale_plot(\n",
76 | " model,\n",
77 | " X,\n",
78 | " X.columns[:1],\n",
79 | " bins=20,\n",
80 | " monte_carlo=True,\n",
81 | " monte_carlo_rep=100,\n",
82 | " monte_carlo_ratio=0.6,\n",
83 | ")"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "### 2D Second-Order ALE Plot"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "metadata": {},
97 | "outputs": [],
98 | "source": [
99 | "mpl.rc(\"figure\", figsize=(9, 6))\n",
100 | "ale_plot(model, X, X.columns[:2], bins=10)"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "mpl.rc(\"figure\", figsize=(9, 6))\n",
110 | "ale_plot(model, X, X.columns[1:], bins=10)"
111 | ]
112 | }
113 | ],
114 | "metadata": {
115 | "kernelspec": {
116 | "display_name": "Python 3",
117 | "language": "python",
118 | "name": "python3"
119 | }
120 | },
121 | "nbformat": 4,
122 | "nbformat_minor": 4
123 | }
124 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | loguru>=0.4.1
2 | matplotlib>=2.2.3
3 | numpy>=1.15.4
4 | pandas>=0.22.0
5 | scipy>=1.1.0
6 | seaborn>=0.9.0
7 |
--------------------------------------------------------------------------------
/resources/fo_ale_quant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blent-ai/ALEPython/286350ab674980a32270db2a0b5ccca1380312a7/resources/fo_ale_quant.png
--------------------------------------------------------------------------------
/resources/so_ale_quant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blent-ai/ALEPython/286350ab674980a32270db2a0b5ccca1380312a7/resources/so_ale_quant.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | license_files = LICENSE
3 |
4 | [tool:pytest]
5 | addopts =
6 | --doctest-modules
7 | --cov=alepython
8 | -r a
9 | -v
10 | markers =
11 | slow: mark a test as being very slow.
12 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from setuptools import find_packages, setup
4 |
5 | with open("README.md") as f:
6 | long_description = f.read()
7 |
8 | with open("requirements.txt", "r") as f:
9 | required = f.read().splitlines()
10 |
11 | setup(
12 | name="alepython",
13 | description="Python Accumulated Local Effects (ALE) package.",
14 | author="Maxime Jumelle",
15 | author_email="maxime@blent.ai",
16 | license="Apache 2",
17 | long_description=long_description,
18 | long_description_content_type="text/markdown",
19 | url="https://github.com/MaximeJumelle/alepython/",
20 | install_requires=required,
21 | extras_require={"test": ["pytest>=5.4", "pytest-cov>=2.8"]},
22 | setup_requires=["setuptools-scm"],
23 | python_requires=">=3.5",
24 | use_scm_version=dict(write_to="src/alepython/_version.py"),
25 | keywords="alepython",
26 | package_dir={"": "src"},
27 | packages=find_packages(where="src"),
28 | include_package_data=True,
29 | classifiers=[
30 | "Programming Language :: Python :: 3",
31 | "License :: OSI Approched :: Apache 2",
32 | "Operating System :: OS Independent",
33 | ],
34 | )
35 |
--------------------------------------------------------------------------------
/src/alepython/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from ._version import version as __version__
3 | from .ale import *
4 |
5 | del _version
6 |
--------------------------------------------------------------------------------
/src/alepython/ale.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """ALE plotting for continuous or categorical features."""
3 | from collections.abc import Iterable
4 | from functools import reduce
5 | from itertools import product
6 | from operator import add
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | import pandas as pd
11 | import scipy
12 | import seaborn as sns
13 | from loguru import logger
14 | from matplotlib.patches import Rectangle
15 | from scipy.spatial import cKDTree
16 |
17 | logger.disable("alepython")
18 |
19 |
20 | __all__ = ("ale_plot",)
21 |
22 |
23 | def _parse_features(features):
24 | """Standardise representation of column labels.
25 |
26 | Args:
27 | features : object
28 | One or more column labels.
29 |
30 | Returns:
31 | features : array-like
32 | An array of input features.
33 |
34 | Examples
35 | --------
36 | >>> _parse_features(1)
37 | array([1])
38 | >>> _parse_features(('a', 'b'))
39 | array(['a', 'b'], dtype='>> _check_two_ints(1)
74 | (1, 1)
75 | >>> _check_two_ints((1, 2))
76 | (1, 2)
77 | >>> _check_two_ints((1,))
78 | (1, 1)
79 |
80 | """
81 | if isinstance(values, (int, np.integer)):
82 | values = (values, values)
83 | elif len(values) == 1:
84 | values = (values[0], values[0])
85 | elif len(values) != 2:
86 | raise ValueError(
87 | "'{}' values were given. Expected at most 2.".format(len(values))
88 | )
89 |
90 | if not all(isinstance(n_bin, (int, np.integer)) for n_bin in values):
91 | raise ValueError(
92 | "All values must be an integer. Got types '{}' instead.".format(
93 | {type(n_bin) for n_bin in values}
94 | )
95 | )
96 | return values
97 |
98 |
99 | def _get_centres(x):
100 | """Return bin centres from bin edges.
101 |
102 | Parameters
103 | ----------
104 | x : array-like
105 | The first axis of `x` will be averaged.
106 |
107 | Returns
108 | -------
109 | centres : array-like
110 | The centres of `x`, the shape of which is (N - 1, ...) for
111 | `x` with shape (N, ...).
112 |
113 | Examples
114 | --------
115 | >>> import numpy as np
116 | >>> x = np.array([0, 1, 2, 3])
117 | >>> _get_centres(x)
118 | array([0.5, 1.5, 2.5])
119 |
120 | """
121 | return (x[1:] + x[:-1]) / 2
122 |
123 |
124 | def _ax_title(ax, title, subtitle=""):
125 | """Add title to axis.
126 |
127 | Parameters
128 | ----------
129 | ax : matplotlib.axes.Axes
130 | Axes object to add title to.
131 | title : str
132 | Axis title.
133 | subtitle : str, optional
134 | Sub-title for figure. Will appear one line below `title`.
135 |
136 | """
137 | ax.set_title("\n".join((title, subtitle)))
138 |
139 |
140 | def _ax_labels(ax, xlabel=None, ylabel=None):
141 | """Add labels to axis.
142 |
143 | Parameters
144 | ----------
145 | ax : matplotlib.axes.Axes
146 | Axes object to add labels to.
147 | xlabel : str, optional
148 | X axis label.
149 | ylabel : str, optional
150 | Y axis label.
151 |
152 | """
153 | if xlabel is not None:
154 | ax.set_xlabel(xlabel)
155 | if ylabel is not None:
156 | ax.set_ylabel(ylabel)
157 |
158 |
159 | def _ax_quantiles(ax, quantiles, twin="x"):
160 | """Plot quantiles of a feature onto axis.
161 |
162 | Parameters
163 | ----------
164 | ax : matplotlib.axes.Axes
165 | Axis to modify.
166 | quantiles : array-like
167 | Quantiles to plot.
168 | twin : {'x', 'y'}, optional
169 | Select the axis for which to plot quantiles.
170 |
171 | Raises
172 | ------
173 | ValueError
174 | If `twin` is not one of 'x' or 'y'.
175 |
176 | """
177 | if twin not in ("x", "y"):
178 | raise ValueError("'twin' should be one of 'x' or 'y'.")
179 |
180 | logger.debug("Quantiles: {}.", quantiles)
181 |
182 | # Duplicate the 'opposite' axis so we can define a distinct set of ticks for the
183 | # desired axis (`twin`).
184 | ax_mod = ax.twiny() if twin == "x" else ax.twinx()
185 |
186 | # Set the new axis' ticks for the desired axis.
187 | getattr(ax_mod, "set_{twin}ticks".format(twin=twin))(quantiles)
188 | # Set the corresponding tick labels.
189 |
190 | # Calculate tick label percentage values for each quantile (bin edge).
191 | percentages = (
192 | 100 * np.arange(len(quantiles), dtype=np.float64) / (len(quantiles) - 1)
193 | )
194 |
195 | # If there is a fractional part, add a decimal place to show (part of) it.
196 | fractional = (~np.isclose(percentages % 1, 0)).astype("int8")
197 |
198 | getattr(ax_mod, "set_{twin}ticklabels".format(twin=twin))(
199 | [
200 | "{0:0.{1}f}%".format(percent, format_fraction)
201 | for percent, format_fraction in zip(percentages, fractional)
202 | ],
203 | color="#545454",
204 | fontsize=7,
205 | )
206 | getattr(ax_mod, "set_{twin}lim".format(twin=twin))(
207 | getattr(ax, "get_{twin}lim".format(twin=twin))()
208 | )
209 |
210 |
211 | def _first_order_quant_plot(ax, quantiles, ale, **kwargs):
212 | """First order ALE plot.
213 |
214 | Parameters
215 | ----------
216 | ax : matplotlib.axes.Axes
217 | Axis to plot onto.
218 | quantiles : array-like
219 | ALE quantiles.
220 | ale : array-like
221 | ALE to plot.
222 | **kwargs : plot properties, optional
223 | Additional keyword parameters are passed to `ax.plot`.
224 |
225 | """
226 | ax.plot(_get_centres(quantiles), ale, **kwargs)
227 |
228 |
229 | def _second_order_quant_plot(
230 | fig, ax, quantiles_list, ale, mark_empty=True, n_interp=50, **kwargs
231 | ):
232 | """Second order ALE plot.
233 |
234 | Parameters
235 | ----------
236 | ax : matplotlib.axes.Axes
237 | Axis to plot onto.
238 | quantiles_list : array-like
239 | ALE quantiles for the first (`quantiles_list[0]`) and second
240 | (`quantiles_list[1]`) features.
241 | ale : masked array
242 | ALE to plot. Where `ale.mask` is True, this denotes bins where no samples were
243 | available. See `mark_empty`.
244 | mark_empty : bool, optional
245 | If True, plot rectangles over bins that did not contain any samples.
246 | n_interp : [2-iterable of] int, optional
247 | The number of interpolated samples generated from `ale` prior to contour
248 | plotting. Two integers may be given to specify different interpolation steps
249 | for the two features.
250 | **kwargs : contourf properties, optional
251 | Additional keyword parameters are passed to `ax.contourf`.
252 |
253 | Raises
254 | ------
255 | ValueError
256 | If `n_interp` values were not integers.
257 | ValueError
258 | If more than 2 values were given for `n_interp`.
259 |
260 | """
261 | centres_list = [_get_centres(quantiles) for quantiles in quantiles_list]
262 | n_x, n_y = _check_two_ints(n_interp)
263 | x = np.linspace(centres_list[0][0], centres_list[0][-1], n_x)
264 | y = np.linspace(centres_list[1][0], centres_list[1][-1], n_y)
265 |
266 | X, Y = np.meshgrid(x, y, indexing="xy")
267 | ale_interp = scipy.interpolate.interp2d(centres_list[0], centres_list[1], ale.T)
268 | CF = ax.contourf(X, Y, ale_interp(x, y), cmap="bwr", levels=30, alpha=0.7, **kwargs)
269 |
270 | if mark_empty and np.any(ale.mask):
271 | # Do not autoscale, so that boxes at the edges (contourf only plots the bin
272 | # centres, not their edges) don't enlarge the plot.
273 | plt.autoscale(False)
274 | # Add rectangles to indicate cells without samples.
275 | for i, j in zip(*np.where(ale.mask)):
276 | ax.add_patch(
277 | Rectangle(
278 | [quantiles_list[0][i], quantiles_list[1][j]],
279 | quantiles_list[0][i + 1] - quantiles_list[0][i],
280 | quantiles_list[1][j + 1] - quantiles_list[1][j],
281 | linewidth=1,
282 | edgecolor="k",
283 | facecolor="none",
284 | alpha=0.4,
285 | )
286 | )
287 | fig.colorbar(CF)
288 |
289 |
290 | def _get_quantiles(train_set, feature, bins):
291 | """Get quantiles from a feature in a dataset.
292 |
293 | Parameters
294 | ----------
295 | train_set : pandas.core.frame.DataFrame
296 | Dataset containing feature `feature`.
297 | feature : column label
298 | Feature for which to calculate quantiles.
299 | bins : int
300 | The number of quantiles is calculated as `bins + 1`.
301 |
302 | Returns
303 | -------
304 | quantiles : array-like
305 | Quantiles.
306 | bins : int
307 | Number of bins, `len(quantiles) - 1`. This may be lower than the original
308 | `bins` if identical quantiles were present.
309 |
310 | Raises
311 | ------
312 | ValueError
313 | If `bins` is not an integer.
314 |
315 | Notes
316 | -----
317 | When using this definition of quantiles in combination with a half open interval
318 | (lower quantile, upper quantile], care has to taken that the smallest observation
319 | is included in the first bin. This is handled transparently by `np.digitize`.
320 |
321 | """
322 | if not isinstance(bins, (int, np.integer)):
323 | raise ValueError(
324 | "Expected integer 'bins', but got type '{}'.".format(type(bins))
325 | )
326 | quantiles = np.unique(
327 | np.quantile(
328 | train_set[feature], np.linspace(0, 1, bins + 1), interpolation="lower"
329 | )
330 | )
331 | bins = len(quantiles) - 1
332 | return quantiles, bins
333 |
334 |
335 | def _first_order_ale_quant(predictor, train_set, feature, bins):
336 | """Estimate the first-order ALE function for single continuous feature data.
337 |
338 | Parameters
339 | ----------
340 | predictor : callable
341 | Prediction function.
342 | train_set : pandas.core.frame.DataFrame
343 | Training set on which the model was trained.
344 | feature : column label
345 | Feature name. A single column label.
346 | bins : int
347 | This defines the number of bins to compute. The effective number of bins may
348 | be less than this as only unique quantile values of train_set[feature] are
349 | used.
350 |
351 | Returns
352 | -------
353 | ale : array-like
354 | The first order ALE.
355 | quantiles : array-like
356 | The quantiles used.
357 |
358 | """
359 | quantiles, _ = _get_quantiles(train_set, feature, bins)
360 | logger.debug("Quantiles: {}.", quantiles)
361 |
362 | # Define the bins the feature samples fall into. Shift and clip to ensure we are
363 | # getting the index of the left bin edge and the smallest sample retains its index
364 | # of 0.
365 | indices = np.clip(
366 | np.digitize(train_set[feature], quantiles, right=True) - 1, 0, None
367 | )
368 |
369 | # Assign the feature quantile values to two copied training datasets, one for each
370 | # bin edge. Then compute the difference between the corresponding predictions
371 | predictions = []
372 | for offset in range(2):
373 | mod_train_set = train_set.copy()
374 | mod_train_set[feature] = quantiles[indices + offset]
375 | predictions.append(predictor(mod_train_set))
376 | # The individual effects.
377 | effects = predictions[1] - predictions[0]
378 |
379 | # Average these differences within each bin.
380 | index_groupby = pd.DataFrame({"index": indices, "effects": effects}).groupby(
381 | "index"
382 | )
383 |
384 | mean_effects = index_groupby.mean().to_numpy().flatten()
385 |
386 | ale = np.array([0, *np.cumsum(mean_effects)])
387 |
388 | # The uncentred mean main effects at the bin centres.
389 | ale = _get_centres(ale)
390 |
391 | # Centre the effects by subtracting the mean (the mean of the individual
392 | # `effects`, which is equivalently calculated using `mean_effects` and the number
393 | # of samples in each bin).
394 | ale -= np.sum(ale * index_groupby.size() / train_set.shape[0])
395 | return ale, quantiles
396 |
397 |
398 | def _second_order_ale_quant(predictor, train_set, features, bins):
399 | """Estimate the second-order ALE function for two continuous feature data.
400 |
401 | Parameters
402 | ----------
403 | predictor : callable
404 | Prediction function.
405 | train_set : pandas.core.frame.DataFrame
406 | Training set on which the model was trained.
407 | features : 2-iterable of column label
408 | The two desired features, as two column labels.
409 | bins : [2-iterable of] int
410 | This defines the number of bins to compute. The effective number of bins may
411 | be less than this as only unique quantile values of train_set[feature] are
412 | used. If one integer is given, this is used for both features.
413 |
414 | Returns
415 | -------
416 | ale : (M, N) masked array
417 | The second order ALE. Elements are masked where no data was available.
418 | quantiles : 2-tuple of array-like
419 | The quantiles used: first the quantiles for `features[0]` with shape (M + 1,),
420 | then for `features[1]` with shape (N + 1,).
421 |
422 | Raises
423 | ------
424 | ValueError
425 | If `features` does not contain 2 features.
426 | ValueError
427 | If more than 2 bins are given.
428 | ValueError
429 | If bins are not integers.
430 |
431 | """
432 | features = _parse_features(features)
433 | if len(features) != 2:
434 | raise ValueError(
435 | "'features' contained '{n_feat}' features. Expected 2.".format(
436 | n_feat=len(features)
437 | )
438 | )
439 |
440 | quantiles_list, bins_list = tuple(
441 | zip(
442 | *(
443 | _get_quantiles(train_set, feature, n_bin)
444 | for feature, n_bin in zip(features, _check_two_ints(bins))
445 | )
446 | )
447 | )
448 | logger.debug("Quantiles: {}.", quantiles_list)
449 |
450 | # Define the bins the feature samples fall into. Shift and clip to ensure we are
451 | # getting the index of the left bin edge and the smallest sample retains its index
452 | # of 0.
453 | indices_list = [
454 | np.clip(np.digitize(train_set[feature], quantiles, right=True) - 1, 0, None)
455 | for feature, quantiles in zip(features, quantiles_list)
456 | ]
457 |
458 | # Invoke the predictor at the corners of the bins. Then compute the second order
459 | # difference between the predictions at the bin corners.
460 | predictions = {}
461 | for shifts in product(*(range(2),) * 2):
462 | mod_train_set = train_set.copy()
463 | for i in range(2):
464 | mod_train_set[features[i]] = quantiles_list[i][indices_list[i] + shifts[i]]
465 | predictions[shifts] = predictor(mod_train_set)
466 | # The individual effects.
467 | effects = (predictions[(1, 1)] - predictions[(1, 0)]) - (
468 | predictions[(0, 1)] - predictions[(0, 0)]
469 | )
470 |
471 | # Group the effects by their indices along both axes.
472 | index_groupby = pd.DataFrame(
473 | {"index_0": indices_list[0], "index_1": indices_list[1], "effects": effects}
474 | ).groupby(["index_0", "index_1"])
475 |
476 | # Compute mean effects.
477 | mean_effects = index_groupby.mean()
478 | # Get the indices of the mean values.
479 | group_indices = mean_effects.index
480 | valid_grid_indices = tuple(zip(*group_indices))
481 | # Extract only the data.
482 | mean_effects = mean_effects.to_numpy().flatten()
483 |
484 | # Get the number of samples in each bin.
485 | n_samples = index_groupby.size().to_numpy()
486 |
487 | # Create a 2D array of the number of samples in each bin.
488 | samples_grid = np.zeros(bins_list)
489 | samples_grid[valid_grid_indices] = n_samples
490 |
491 | ale = np.ma.MaskedArray(
492 | np.zeros((len(quantiles_list[0]), len(quantiles_list[1]))),
493 | mask=np.ones((len(quantiles_list[0]), len(quantiles_list[1]))),
494 | )
495 | # Mark the first row/column as valid, since these are meant to contain 0s.
496 | ale.mask[0, :] = False
497 | ale.mask[:, 0] = False
498 |
499 | # Place the mean effects into the final array.
500 | # Since `ale` contains `len(quantiles)` rows/columns the first of which are
501 | # guaranteed to be valid (and filled with 0s), ignore the first row and column.
502 | ale[1:, 1:][valid_grid_indices] = mean_effects
503 |
504 | # Record where elements were missing.
505 | missing_bin_mask = ale.mask.copy()[1:, 1:]
506 |
507 | if np.any(missing_bin_mask):
508 | # Replace missing entries with their nearest neighbours.
509 |
510 | # Calculate the dense location matrices (for both features) of all bin centres.
511 | centres_list = np.meshgrid(
512 | *(_get_centres(quantiles) for quantiles in quantiles_list), indexing="ij"
513 | )
514 |
515 | # Select only those bin centres which are valid (had observation).
516 | valid_indices_list = np.where(~missing_bin_mask)
517 | tree = cKDTree(
518 | np.hstack(
519 | tuple(
520 | centres[valid_indices_list][:, np.newaxis]
521 | for centres in centres_list
522 | )
523 | )
524 | )
525 |
526 | row_indices = np.hstack(
527 | [inds.reshape(-1, 1) for inds in np.where(missing_bin_mask)]
528 | )
529 | # Select both columns for each of the rows above.
530 | column_indices = np.hstack(
531 | (
532 | np.zeros((row_indices.shape[0], 1), dtype=np.int8),
533 | np.ones((row_indices.shape[0], 1), dtype=np.int8),
534 | )
535 | )
536 |
537 | # Determine the indices of the points which are nearest to the empty bins.
538 | nearest_points = tree.query(tree.data[row_indices, column_indices])[1]
539 |
540 | nearest_indices = tuple(
541 | valid_indices[nearest_points] for valid_indices in valid_indices_list
542 | )
543 |
544 | # Replace the invalid bin values with the nearest valid ones.
545 | ale[1:, 1:][missing_bin_mask] = ale[1:, 1:][nearest_indices]
546 |
547 | # Compute the cumulative sums.
548 | ale = np.cumsum(np.cumsum(ale, axis=0), axis=1)
549 |
550 | # Subtract first order effects along both axes separately.
551 | for i in range(2):
552 | # Depending on `i`, reverse the arguments to operate on the opposite axis.
553 | flip = slice(None, None, 1 - 2 * i)
554 |
555 | # Undo the cumulative sum along the axis.
556 | first_order = ale[(slice(1, None), ...)[flip]] - ale[(slice(-1), ...)[flip]]
557 | # Average the diffs across the other axis.
558 | first_order = (
559 | first_order[(..., slice(1, None))[flip]]
560 | + first_order[(..., slice(-1))[flip]]
561 | ) / 2
562 | # Weight by the number of samples in each bin.
563 | first_order *= samples_grid
564 | # Take the sum along the axis.
565 | first_order = np.sum(first_order, axis=1 - i)
566 | # Normalise by the number of samples in the bins along the axis.
567 | first_order /= np.sum(samples_grid, axis=1 - i)
568 | # The final result is the cumulative sum (with an additional 0).
569 | first_order = np.array([0, *np.cumsum(first_order)]).reshape((-1, 1)[flip])
570 |
571 | # Subtract the first order effect.
572 | ale -= first_order
573 |
574 | # Compute the ALE at the bin centres.
575 | ale = (
576 | reduce(
577 | add,
578 | (
579 | ale[i : ale.shape[0] - 1 + i, j : ale.shape[1] - 1 + j]
580 | for i, j in list(product(*(range(2),) * 2))
581 | ),
582 | )
583 | / 4
584 | )
585 |
586 | # Centre the ALE by subtracting its expectation value.
587 | ale -= np.sum(samples_grid * ale) / train_set.shape[0]
588 |
589 | # Mark the originally missing points as such to enable later interpretation.
590 | ale.mask = missing_bin_mask
591 | return ale, quantiles_list
592 |
593 |
594 | def _first_order_ale_cat(
595 | predictor, train_set, feature, features_classes, feature_encoder=None
596 | ):
597 | """Compute the first-order ALE function on single categorical feature data.
598 |
599 | Parameters
600 | ----------
601 | predictor : callable
602 | Prediction function.
603 | train_set : pandas.core.frame.DataFrame
604 | Training set on which model was trained.
605 | feature : str
606 | Feature name.
607 | features_classes : iterable or str
608 | Feature's classes.
609 | feature_encoder : callable or iterable
610 | Encoder that was used to encode categorical feature. If features_classes is
611 | not None, this parameter is skipped.
612 |
613 | """
614 | num_cat = len(features_classes)
615 | ale = np.zeros(num_cat) # Final ALE function.
616 |
617 | for i in range(num_cat):
618 | subset = train_set[train_set[feature] == features_classes[i]]
619 |
620 | # Without any observation, local effect on split area is null.
621 | if len(subset) != 0:
622 | z_low = subset.copy()
623 | z_up = subset.copy()
624 | # The main ALE idea that compute prediction difference between same data
625 | # except feature's one.
626 | z_low[feature] = quantiles[i - 1]
627 | z_up[feature] = quantiles[i]
628 | ale[i] += (predictor(z_up) - predictor(z_low)).sum() / subset.shape[0]
629 |
630 | # The accumulated effect.
631 | ale = ale.cumsum()
632 | # Now we have to center ALE function in order to obtain null expectation for ALE
633 | # function.
634 | ale -= ale.mean()
635 | return ale
636 |
637 |
638 | def ale_plot(
639 | model,
640 | train_set,
641 | features,
642 | bins=10,
643 | monte_carlo=False,
644 | predictor=None,
645 | features_classes=None,
646 | monte_carlo_rep=50,
647 | monte_carlo_ratio=0.1,
648 | rugplot_lim=1000,
649 | ):
650 | """Plots ALE function of specified features based on training set.
651 |
652 | Parameters
653 | ----------
654 | model : object
655 | An object that implements a 'predict' method. If None, a `predictor` function
656 | must be supplied which will be used instead of `model.predict`.
657 | train_set : pandas.core.frame.DataFrame
658 | Training set on which model was trained.
659 | features : [2-iterable of] column label
660 | One or two features for which to plot the ALE plot.
661 | bins : [2-iterable of] int, optional
662 | Number of bins used to split feature's space. 2 integers can only be given
663 | when 2 features are supplied in order to compute a different number of
664 | quantiles for each feature.
665 | monte_carlo : boolean, optional
666 | Compute and plot Monte-Carlo samples.
667 | predictor : callable
668 | Custom prediction function. See `model`.
669 | features_classes : iterable of str, optional
670 | If features is first-order and a categorical variable, plot ALE according to
671 | discrete aspect of data.
672 | monte_carlo_rep : int
673 | Number of Monte-Carlo replicas.
674 | monte_carlo_ratio : float
675 | Proportion of randomly selected samples from dataset for each Monte-Carlo
676 | replica.
677 | rugplot_lim : int, optional
678 | If `train_set` has more rows than `rugplot_lim`, no rug plot will be plotted.
679 | Set to None to always plot rug plots. Set to 0 to always plot rug plots.
680 |
681 | Raises
682 | ------
683 | ValueError
684 | If both `model` and `predictor` are None.
685 | ValueError
686 | If `len(features)` not in {1, 2}.
687 | ValueError
688 | If multiple bins were given for 1 feature.
689 | NotImplementedError
690 | If `features_classes` is not None.
691 |
692 | """
693 | if model is None and predictor is None:
694 | raise ValueError("If 'model' is None, 'predictor' must be supplied.")
695 |
696 | if features_classes is not None:
697 | raise NotImplementedError("'features_classes' is not implemented yet.")
698 |
699 | fig, ax = plt.subplots()
700 |
701 | features = _parse_features(features)
702 |
703 | if len(features) == 1:
704 | if not isinstance(bins, (int, np.integer)):
705 | raise ValueError("1 feature was given, but 'bins' was not an integer.")
706 |
707 | if features_classes is None:
708 | # Continuous data.
709 |
710 | if monte_carlo:
711 | mc_replicates = np.asarray(
712 | [
713 | [
714 | np.random.choice(range(train_set.shape[0]))
715 | for _ in range(int(monte_carlo_ratio * train_set.shape[0]))
716 | ]
717 | for _ in range(monte_carlo_rep)
718 | ]
719 | )
720 | for k, rep in enumerate(mc_replicates):
721 | train_set_rep = train_set.iloc[rep, :]
722 | # Make this recursive?
723 | if features_classes is None:
724 | # The same quantiles cannot be reused here as this could cause
725 | # some bins to be empty or contain disproportionate numbers of
726 | # samples.
727 | mc_ale, mc_quantiles = _first_order_ale_quant(
728 | model.predict if predictor is None else predictor,
729 | train_set_rep,
730 | features[0],
731 | bins,
732 | )
733 | _first_order_quant_plot(
734 | ax, mc_quantiles, mc_ale, color="#1f77b4", alpha=0.06
735 | )
736 |
737 | ale, quantiles = _first_order_ale_quant(
738 | model.predict if predictor is None else predictor,
739 | train_set,
740 | features[0],
741 | bins,
742 | )
743 | _ax_labels(ax, "Feature '{}'".format(features[0]), "")
744 | _ax_title(
745 | ax,
746 | "First-order ALE of feature '{0}'".format(features[0]),
747 | "Bins : {0} - Monte-Carlo : {1}".format(
748 | len(quantiles) - 1,
749 | mc_replicates.shape[0] if monte_carlo else "False",
750 | ),
751 | )
752 | ax.grid(True, linestyle="-", alpha=0.4)
753 | if rugplot_lim is None or train_set.shape[0] <= rugplot_lim:
754 | sns.rugplot(train_set[features[0]], ax=ax, alpha=0.2)
755 | _first_order_quant_plot(ax, quantiles, ale, color="black")
756 | _ax_quantiles(ax, quantiles)
757 |
758 | elif len(features) == 2:
759 | if features_classes is None:
760 | # Continuous data.
761 | ale, quantiles_list = _second_order_ale_quant(
762 | model.predict if predictor is None else predictor,
763 | train_set,
764 | features,
765 | bins,
766 | )
767 | _second_order_quant_plot(fig, ax, quantiles_list, ale)
768 | _ax_labels(
769 | ax,
770 | "Feature '{}'".format(features[0]),
771 | "Feature '{}'".format(features[1]),
772 | )
773 | for twin, quantiles in zip(("x", "y"), quantiles_list):
774 | _ax_quantiles(ax, quantiles, twin=twin)
775 | _ax_title(
776 | ax,
777 | "Second-order ALE of features '{0}' and '{1}'".format(
778 | features[0], features[1]
779 | ),
780 | "Bins : {0}x{1}".format(*[len(quant) - 1 for quant in quantiles_list]),
781 | )
782 | else:
783 | raise ValueError(
784 | "'{n_feat}' 'features' were given, but only up to 2 are supported.".format(
785 | n_feat=len(features)
786 | )
787 | )
788 | plt.show()
789 | return ax
790 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blent-ai/ALEPython/286350ab674980a32270db2a0b5ccca1380312a7/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_ale_calculation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import numpy as np
4 | import pandas as pd
5 |
6 | from alepython.ale import _first_order_ale_quant, _get_centres, _second_order_ale_quant
7 |
8 | from .utils import interaction_predictor, linear_predictor
9 |
10 |
11 | def test_linear():
12 | """We expect both X['a'] and X['b'] to have a linear relationship.
13 |
14 | There should be no second order interaction.
15 |
16 | """
17 |
18 | def linear(x, a, b):
19 | return a * x + b
20 |
21 | np.random.seed(1)
22 |
23 | N = int(1e5)
24 | X = pd.DataFrame({"a": np.random.random(N), "b": np.random.random(N)})
25 |
26 | # Test that the first order relationships are linear.
27 | for column in X.columns:
28 | ale, quantiles = _first_order_ale_quant(linear_predictor, X, column, 21)
29 | centres = _get_centres(quantiles)
30 | p, V = np.polyfit(centres, ale, 1, cov=True)
31 | assert np.all(np.isclose(p, [1, -0.5], atol=1e-3))
32 | assert np.all(np.isclose(np.sqrt(np.diag(V)), 0))
33 |
34 | # Test that a second order relationship does not exist.
35 | ale_second_order, quantiles_list = _second_order_ale_quant(
36 | linear_predictor, X, X.columns, 21
37 | )
38 | assert np.all(np.isclose(ale_second_order, 0))
39 |
40 |
41 | def test_interaction():
42 | """Ensure that the method picks up a trivial interaction term."""
43 | np.random.seed(1)
44 |
45 | N = int(1e6)
46 | X = pd.DataFrame({"a": np.random.random(N), "b": np.random.random(N)})
47 | ale, quantiles_list = _second_order_ale_quant(
48 | interaction_predictor, X, X.columns, 61
49 | )
50 |
51 | # XXX: There seems to be a small deviation proportional to the first axis ('a')
52 | # that is preventing this from being closer to 0.
53 | assert np.all(np.abs(ale[:, :30] - ale[:, 31:][::-1]) < 1e-2)
54 |
--------------------------------------------------------------------------------
/tests/test_figure_creation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from contextlib import contextmanager
3 | from string import ascii_lowercase
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import pandas as pd
8 | import pytest
9 |
10 | from alepython import ale_plot
11 |
12 | from .utils import SimpleModel, simple_predictor
13 |
14 |
15 | @contextmanager
16 | def assert_n_created_figures(n=1):
17 | """Assert that a given number of figures are created."""
18 | initial_fignums = plt.get_fignums()
19 | yield # Do not catch exceptions (ie. no try except).
20 | new_fignums = set(plt.get_fignums()) - set(initial_fignums)
21 | n_new = len(new_fignums)
22 | assert n_new == n, "Expected '{n}' figure(s), got '{n_new}'.".format(
23 | n=n, n_new=n_new
24 | )
25 |
26 |
27 | @pytest.mark.parametrize(
28 | "features,columns",
29 | (("a", ("a", "b")), (("a",), ("a", "b", "c")), (("a", "b"), ("a", "b", "c"))),
30 | )
31 | def test_model(features, columns):
32 | """Given a model with a predict method, a plot should be created."""
33 | plt.ion() # Prevent plt.show() from blocking.
34 | np.random.seed(1)
35 | train_set = pd.DataFrame(np.random.random((100, len(columns))), columns=columns)
36 | with assert_n_created_figures():
37 | ale_plot(SimpleModel(), train_set, features)
38 | # Clean up the created figure.
39 | plt.close()
40 |
41 |
42 | @pytest.mark.parametrize(
43 | "features,columns",
44 | (("a", ("a", "b")), (("a",), ("a", "b", "c")), (("a", "b"), ("a", "b", "c"))),
45 | )
46 | def test_predictor(features, columns):
47 | """Given a predictor function, a plot should be created."""
48 | plt.ion() # Prevent plt.show() from blocking.
49 | np.random.seed(1)
50 | train_set = pd.DataFrame(np.random.random((100, len(columns))), columns=columns)
51 | with assert_n_created_figures():
52 | ale_plot(
53 | model=None,
54 | train_set=train_set,
55 | features=features,
56 | predictor=simple_predictor,
57 | )
58 | # Clean up the created figure.
59 | plt.close()
60 |
61 |
62 | @pytest.mark.parametrize(
63 | "features,columns", ((("a",), ("a", "b", "c")), (("a", "b"), ("a", "b", "c")))
64 | )
65 | def test_monte_carlo(features, columns):
66 | plt.ion() # Prevent plt.show() from blocking.
67 | np.random.seed(1)
68 | train_set = pd.DataFrame(np.random.random((100, len(columns))), columns=columns)
69 | with assert_n_created_figures():
70 | ale_plot(
71 | model=None,
72 | train_set=train_set,
73 | features=features,
74 | predictor=simple_predictor,
75 | monte_carlo=True,
76 | )
77 | # Clean up the created figure.
78 | plt.close()
79 |
80 |
81 | def test_df_column_features():
82 | """Test the handling of the `features` argument.
83 |
84 | No matter the type of the `features` iterable, `ale_plot` should be able to select
85 | the right columns.
86 |
87 | """
88 | plt.ion() # Prevent plt.show() from blocking.
89 | n_col = 3
90 | np.random.seed(1)
91 | train_set = pd.DataFrame(
92 | np.random.random((100, n_col)), columns=list(ascii_lowercase[:n_col])
93 | )
94 | with assert_n_created_figures():
95 | ale_plot(SimpleModel(), train_set, train_set.columns[:1])
96 | # Clean up the created figure.
97 | plt.close()
98 |
--------------------------------------------------------------------------------
/tests/test_parameters.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | import pytest
5 |
6 | from alepython import ale_plot
7 | from alepython.ale import (
8 | _ax_quantiles,
9 | _check_two_ints,
10 | _get_quantiles,
11 | _second_order_ale_quant,
12 | )
13 |
14 | from .utils import SimpleModel
15 |
16 |
17 | def test_two_ints():
18 | with pytest.raises(ValueError, match=r"'3' values.*"):
19 | _check_two_ints((1, 2, 3))
20 | with pytest.raises(ValueError, match=r".*Got type(s) '{}'.*"):
21 | _check_two_ints(("1", "2"))
22 |
23 |
24 | def test_quantiles():
25 | with pytest.raises(ValueError, match=r".*type ''.*"):
26 | _get_quantiles(pd.DataFrame({"a": [1, 2, 3]}), "a", "1")
27 |
28 |
29 | def test_second_order_ale_quant():
30 | with pytest.raises(ValueError, match=r".*contained '1' features.*"):
31 | _second_order_ale_quant(lambda x: None, pd.DataFrame({"a": [1, 2, 3]}), "a", 1)
32 |
33 |
34 | def test_ale_plot():
35 | """Test that proper errors are raised."""
36 | with pytest.raises(ValueError, match=r".*'model'.*'predictor'.*"):
37 | ale_plot(model=None, train_set=pd.DataFrame([1]), features=[0])
38 |
39 | with pytest.raises(ValueError, match=r"'3' 'features'.*"):
40 | ale_plot(
41 | model=SimpleModel(), train_set=pd.DataFrame([1]), features=list(range(3))
42 | )
43 |
44 | with pytest.raises(ValueError, match=r"'0' 'features'.*"):
45 | ale_plot(model=SimpleModel(), train_set=pd.DataFrame([1]), features=[])
46 |
47 | with pytest.raises(
48 | NotImplementedError, match="'features_classes' is not implemented yet."
49 | ):
50 | ale_plot(
51 | model=SimpleModel(),
52 | train_set=pd.DataFrame([1]),
53 | features=[0],
54 | features_classes=["a"],
55 | )
56 |
57 | with pytest.raises(ValueError, match=r"1 feature.*but 'bins' was not an integer."):
58 | ale_plot(
59 | model=SimpleModel(), train_set=pd.DataFrame([1]), features=[0], bins=1.0,
60 | )
61 |
62 |
63 | def test_ax_quantiles():
64 | fig, ax = plt.subplots()
65 | with pytest.raises(ValueError, match="'twin' should be one of 'x' or 'y'."):
66 | _ax_quantiles(ax, list(range(2)), "z")
67 | plt.close(fig)
68 |
69 | fig, ax = plt.subplots()
70 | _ax_quantiles(ax, list(range(2)), "x")
71 | plt.close(fig)
72 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | from .test_models import *
5 |
--------------------------------------------------------------------------------
/tests/utils/test_models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 |
5 | class SimpleModel:
6 | """A simple predictive model for testing purposes.
7 |
8 | Methods
9 | -------
10 | predict(X)
11 | Given input data `X`, predict response variable.
12 |
13 | """
14 |
15 | def predict(self, X):
16 | return np.mean(X, axis=1)
17 |
18 |
19 | def simple_predictor(X):
20 | return np.mean(X, axis=1)
21 |
22 |
23 | def linear_predictor(X):
24 | """A simple linear effect with features 'a' and 'b'."""
25 | return X["a"] + X["b"]
26 |
27 |
28 | def interaction_predictor(X):
29 | """Interaction changes sign at b = 0.5."""
30 | a = X["a"]
31 | b = X["b"]
32 |
33 | out = np.empty_like(a)
34 |
35 | mask = b <= 0.5
36 | out[mask] = a[mask] * b[mask]
37 | mask = ~mask
38 | out[mask] = -a[mask] * b[mask]
39 |
40 | return out
41 |
--------------------------------------------------------------------------------