├── .gitignore ├── .travis.yml ├── AUTHORS ├── LICENSE ├── README.md ├── examples └── regression_iris.ipynb ├── requirements.txt ├── resources ├── fo_ale_quant.png └── so_ale_quant.png ├── setup.cfg ├── setup.py ├── src └── alepython │ ├── __init__.py │ └── ale.py └── tests ├── __init__.py ├── test_ale_calculation.py ├── test_figure_creation.py ├── test_parameters.py └── utils ├── __init__.py └── test_models.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | .dmypy.json 114 | dmypy.json 115 | 116 | # Pyre type checker 117 | .pyre/ 118 | 119 | # Package version. 120 | _version.py 121 | 122 | # Pre-commit configuration file. 123 | .pre-commit-config.yaml 124 | 125 | # Git attributes. 126 | .gitattributes 127 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | - "3.6" 5 | - "3.7" 6 | - "3.8" 7 | before_install: 8 | - python --version 9 | - pip install -U pip 10 | install: pip install -e ".[test]" 11 | script: pytest 12 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Maxime Jumelle - maxime@blent.ai - (gh: @MaximeJumelle) 2 | Sanjif Rajaratnam (gh: @sanjifr3) 3 | Alexander Kuhn-Regnier (gh: @akuhnregnier) 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018- Maxime Jumelle 190 | Copyright 2020- ALEPython developers (see AUTHORS file) 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/MaximeJumelle/ALEPython.svg?branch=dev)](https://travis-ci.org/MaximeJumelle/ALEPython) 2 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black) 3 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | 5 | Python Accumulated Local Effects package. 6 | 7 | # Why ALE? 8 | 9 | Explaining model predictions is very common when you have to deploy a Machine Learning algorithm on a large scale. 10 | There are many methods that help us understand our model; one these uses Partial Dependency Plots (PDP), which have been widely used for years. 11 | 12 | However, they suffer from a stringent assumption: **features have to be uncorrelated**. 13 | In real world scenarios, features are often correlated, whether because some are directly computed from others, or because observed phenomena produce correlated distributions. 14 | 15 | Accumulated Local Effects (or ALE) plots first proposed by [_Apley and Zhu_ (2016)](#1) alleviate this issue reasonably by using actual conditional marginal distributions instead of considering each marginal distribution of features. 16 | This is more reliable when handling (even strongly) correlated variables. 17 | 18 | This package aims to provide useful and quick access to ALE plots, so that you can easily explain your model through predictions. 19 | 20 | For further details about model interpretability and ALE plots, see eg. [_Molnar_ (2020)](#2). 21 | 22 | # Install 23 | 24 | ALEPython is supported on Python >= 3.5. 25 | You can either install package via `pip`: 26 | 27 | ```sh 28 | pip install alepython 29 | ``` 30 | directly from source (including requirements): 31 | ```sh 32 | pip install git+https://github.com/MaximeJumelle/ALEPython.git@dev#egg=alepython 33 | ``` 34 | or after cloning (or forking) for development purposes, including test dependencies: 35 | ```sh 36 | git clone https://github.com/MaximeJumelle/ALEPython.git 37 | pip install -e "ALEPython[test]" 38 | ``` 39 | 40 | # Usage 41 | 42 | ```python 43 | from alepython import ale_plot 44 | # Plots ALE of feature 'cont' with Monte-Carlo replicas (default : 50). 45 | ale_plot(model, X_train, 'cont', monte_carlo=True) 46 | ``` 47 | 48 | # Highlights 49 | 50 | - First-order ALE plots of continuous features 51 | - Second-order ALE plots of continuous features 52 | 53 | # Gallery 54 | 55 | ## First-order ALE plots of continuous features 56 | 57 |
58 | 59 | --- 60 | 61 | ## Second-order ALE plots of continuous features 62 | 63 |
64 | 65 | # Work In Progress 66 | 67 | - First-order ALE plots of categorical features 68 | - Enhanced visualization of first-order plots 69 | - Second-order ALE plots of categorical features 70 | - Documentation and API reference 71 | - Jupyter Notebook examples 72 | - Upload to PyPi 73 | - Upload to conda-forge 74 | - Use of matplotlib styles or kwargs to allow overriding plotting appearance 75 | 76 | If you are interested in the project, I would be happy to collaborate with you since there are still quite a lot of improvements needed. 77 | 78 | ## References 79 | 80 | 81 | Apley, Daniel W., and Jingyu Zhu. 2016. Visualizing the Effects of Predictor Variables in Black Box Supervised Learning Models. . 82 | 83 | 84 | Molnar, Christoph. 2020. Interpretable Machine Learning. . 85 | -------------------------------------------------------------------------------- /examples/regression_iris.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "from loguru import logger\n", 12 | "from sklearn import datasets\n", 13 | "from sklearn.ensemble import RandomForestRegressor" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "## Loading the Dataset\n", 21 | "\n", 22 | "Define our response `y` and predictor, `X`, variables." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "iris = datasets.load_iris()\n", 32 | "dataset = pd.DataFrame(data=iris[\"data\"], columns=iris[\"feature_names\"])\n", 33 | "X = dataset.iloc[:, 1:]\n", 34 | "y = dataset.iloc[:, 0]\n", 35 | "dataset.head()" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Fit a Random Forest Regressor" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "model = RandomForestRegressor(n_estimators=20, bootstrap=True)\n", 52 | "model.fit(X, y)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## Create our ALE Plots\n", 60 | "\n", 61 | "### 1D Main Effect ALE Plot" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "import matplotlib as mpl\n", 71 | "\n", 72 | "from alepython import ale_plot\n", 73 | "\n", 74 | "mpl.rc(\"figure\", figsize=(9, 6))\n", 75 | "ale_plot(\n", 76 | " model,\n", 77 | " X,\n", 78 | " X.columns[:1],\n", 79 | " bins=20,\n", 80 | " monte_carlo=True,\n", 81 | " monte_carlo_rep=100,\n", 82 | " monte_carlo_ratio=0.6,\n", 83 | ")" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "### 2D Second-Order ALE Plot" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "mpl.rc(\"figure\", figsize=(9, 6))\n", 100 | "ale_plot(model, X, X.columns[:2], bins=10)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "mpl.rc(\"figure\", figsize=(9, 6))\n", 110 | "ale_plot(model, X, X.columns[1:], bins=10)" 111 | ] 112 | } 113 | ], 114 | "metadata": { 115 | "kernelspec": { 116 | "display_name": "Python 3", 117 | "language": "python", 118 | "name": "python3" 119 | } 120 | }, 121 | "nbformat": 4, 122 | "nbformat_minor": 4 123 | } 124 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | loguru>=0.4.1 2 | matplotlib>=2.2.3 3 | numpy>=1.15.4 4 | pandas>=0.22.0 5 | scipy>=1.1.0 6 | seaborn>=0.9.0 7 | -------------------------------------------------------------------------------- /resources/fo_ale_quant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blent-ai/ALEPython/286350ab674980a32270db2a0b5ccca1380312a7/resources/fo_ale_quant.png -------------------------------------------------------------------------------- /resources/so_ale_quant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blent-ai/ALEPython/286350ab674980a32270db2a0b5ccca1380312a7/resources/so_ale_quant.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE 3 | 4 | [tool:pytest] 5 | addopts = 6 | --doctest-modules 7 | --cov=alepython 8 | -r a 9 | -v 10 | markers = 11 | slow: mark a test as being very slow. 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from setuptools import find_packages, setup 4 | 5 | with open("README.md") as f: 6 | long_description = f.read() 7 | 8 | with open("requirements.txt", "r") as f: 9 | required = f.read().splitlines() 10 | 11 | setup( 12 | name="alepython", 13 | description="Python Accumulated Local Effects (ALE) package.", 14 | author="Maxime Jumelle", 15 | author_email="maxime@blent.ai", 16 | license="Apache 2", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="https://github.com/MaximeJumelle/alepython/", 20 | install_requires=required, 21 | extras_require={"test": ["pytest>=5.4", "pytest-cov>=2.8"]}, 22 | setup_requires=["setuptools-scm"], 23 | python_requires=">=3.5", 24 | use_scm_version=dict(write_to="src/alepython/_version.py"), 25 | keywords="alepython", 26 | package_dir={"": "src"}, 27 | packages=find_packages(where="src"), 28 | include_package_data=True, 29 | classifiers=[ 30 | "Programming Language :: Python :: 3", 31 | "License :: OSI Approched :: Apache 2", 32 | "Operating System :: OS Independent", 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /src/alepython/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from ._version import version as __version__ 3 | from .ale import * 4 | 5 | del _version 6 | -------------------------------------------------------------------------------- /src/alepython/ale.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ALE plotting for continuous or categorical features.""" 3 | from collections.abc import Iterable 4 | from functools import reduce 5 | from itertools import product 6 | from operator import add 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import pandas as pd 11 | import scipy 12 | import seaborn as sns 13 | from loguru import logger 14 | from matplotlib.patches import Rectangle 15 | from scipy.spatial import cKDTree 16 | 17 | logger.disable("alepython") 18 | 19 | 20 | __all__ = ("ale_plot",) 21 | 22 | 23 | def _parse_features(features): 24 | """Standardise representation of column labels. 25 | 26 | Args: 27 | features : object 28 | One or more column labels. 29 | 30 | Returns: 31 | features : array-like 32 | An array of input features. 33 | 34 | Examples 35 | -------- 36 | >>> _parse_features(1) 37 | array([1]) 38 | >>> _parse_features(('a', 'b')) 39 | array(['a', 'b'], dtype='>> _check_two_ints(1) 74 | (1, 1) 75 | >>> _check_two_ints((1, 2)) 76 | (1, 2) 77 | >>> _check_two_ints((1,)) 78 | (1, 1) 79 | 80 | """ 81 | if isinstance(values, (int, np.integer)): 82 | values = (values, values) 83 | elif len(values) == 1: 84 | values = (values[0], values[0]) 85 | elif len(values) != 2: 86 | raise ValueError( 87 | "'{}' values were given. Expected at most 2.".format(len(values)) 88 | ) 89 | 90 | if not all(isinstance(n_bin, (int, np.integer)) for n_bin in values): 91 | raise ValueError( 92 | "All values must be an integer. Got types '{}' instead.".format( 93 | {type(n_bin) for n_bin in values} 94 | ) 95 | ) 96 | return values 97 | 98 | 99 | def _get_centres(x): 100 | """Return bin centres from bin edges. 101 | 102 | Parameters 103 | ---------- 104 | x : array-like 105 | The first axis of `x` will be averaged. 106 | 107 | Returns 108 | ------- 109 | centres : array-like 110 | The centres of `x`, the shape of which is (N - 1, ...) for 111 | `x` with shape (N, ...). 112 | 113 | Examples 114 | -------- 115 | >>> import numpy as np 116 | >>> x = np.array([0, 1, 2, 3]) 117 | >>> _get_centres(x) 118 | array([0.5, 1.5, 2.5]) 119 | 120 | """ 121 | return (x[1:] + x[:-1]) / 2 122 | 123 | 124 | def _ax_title(ax, title, subtitle=""): 125 | """Add title to axis. 126 | 127 | Parameters 128 | ---------- 129 | ax : matplotlib.axes.Axes 130 | Axes object to add title to. 131 | title : str 132 | Axis title. 133 | subtitle : str, optional 134 | Sub-title for figure. Will appear one line below `title`. 135 | 136 | """ 137 | ax.set_title("\n".join((title, subtitle))) 138 | 139 | 140 | def _ax_labels(ax, xlabel=None, ylabel=None): 141 | """Add labels to axis. 142 | 143 | Parameters 144 | ---------- 145 | ax : matplotlib.axes.Axes 146 | Axes object to add labels to. 147 | xlabel : str, optional 148 | X axis label. 149 | ylabel : str, optional 150 | Y axis label. 151 | 152 | """ 153 | if xlabel is not None: 154 | ax.set_xlabel(xlabel) 155 | if ylabel is not None: 156 | ax.set_ylabel(ylabel) 157 | 158 | 159 | def _ax_quantiles(ax, quantiles, twin="x"): 160 | """Plot quantiles of a feature onto axis. 161 | 162 | Parameters 163 | ---------- 164 | ax : matplotlib.axes.Axes 165 | Axis to modify. 166 | quantiles : array-like 167 | Quantiles to plot. 168 | twin : {'x', 'y'}, optional 169 | Select the axis for which to plot quantiles. 170 | 171 | Raises 172 | ------ 173 | ValueError 174 | If `twin` is not one of 'x' or 'y'. 175 | 176 | """ 177 | if twin not in ("x", "y"): 178 | raise ValueError("'twin' should be one of 'x' or 'y'.") 179 | 180 | logger.debug("Quantiles: {}.", quantiles) 181 | 182 | # Duplicate the 'opposite' axis so we can define a distinct set of ticks for the 183 | # desired axis (`twin`). 184 | ax_mod = ax.twiny() if twin == "x" else ax.twinx() 185 | 186 | # Set the new axis' ticks for the desired axis. 187 | getattr(ax_mod, "set_{twin}ticks".format(twin=twin))(quantiles) 188 | # Set the corresponding tick labels. 189 | 190 | # Calculate tick label percentage values for each quantile (bin edge). 191 | percentages = ( 192 | 100 * np.arange(len(quantiles), dtype=np.float64) / (len(quantiles) - 1) 193 | ) 194 | 195 | # If there is a fractional part, add a decimal place to show (part of) it. 196 | fractional = (~np.isclose(percentages % 1, 0)).astype("int8") 197 | 198 | getattr(ax_mod, "set_{twin}ticklabels".format(twin=twin))( 199 | [ 200 | "{0:0.{1}f}%".format(percent, format_fraction) 201 | for percent, format_fraction in zip(percentages, fractional) 202 | ], 203 | color="#545454", 204 | fontsize=7, 205 | ) 206 | getattr(ax_mod, "set_{twin}lim".format(twin=twin))( 207 | getattr(ax, "get_{twin}lim".format(twin=twin))() 208 | ) 209 | 210 | 211 | def _first_order_quant_plot(ax, quantiles, ale, **kwargs): 212 | """First order ALE plot. 213 | 214 | Parameters 215 | ---------- 216 | ax : matplotlib.axes.Axes 217 | Axis to plot onto. 218 | quantiles : array-like 219 | ALE quantiles. 220 | ale : array-like 221 | ALE to plot. 222 | **kwargs : plot properties, optional 223 | Additional keyword parameters are passed to `ax.plot`. 224 | 225 | """ 226 | ax.plot(_get_centres(quantiles), ale, **kwargs) 227 | 228 | 229 | def _second_order_quant_plot( 230 | fig, ax, quantiles_list, ale, mark_empty=True, n_interp=50, **kwargs 231 | ): 232 | """Second order ALE plot. 233 | 234 | Parameters 235 | ---------- 236 | ax : matplotlib.axes.Axes 237 | Axis to plot onto. 238 | quantiles_list : array-like 239 | ALE quantiles for the first (`quantiles_list[0]`) and second 240 | (`quantiles_list[1]`) features. 241 | ale : masked array 242 | ALE to plot. Where `ale.mask` is True, this denotes bins where no samples were 243 | available. See `mark_empty`. 244 | mark_empty : bool, optional 245 | If True, plot rectangles over bins that did not contain any samples. 246 | n_interp : [2-iterable of] int, optional 247 | The number of interpolated samples generated from `ale` prior to contour 248 | plotting. Two integers may be given to specify different interpolation steps 249 | for the two features. 250 | **kwargs : contourf properties, optional 251 | Additional keyword parameters are passed to `ax.contourf`. 252 | 253 | Raises 254 | ------ 255 | ValueError 256 | If `n_interp` values were not integers. 257 | ValueError 258 | If more than 2 values were given for `n_interp`. 259 | 260 | """ 261 | centres_list = [_get_centres(quantiles) for quantiles in quantiles_list] 262 | n_x, n_y = _check_two_ints(n_interp) 263 | x = np.linspace(centres_list[0][0], centres_list[0][-1], n_x) 264 | y = np.linspace(centres_list[1][0], centres_list[1][-1], n_y) 265 | 266 | X, Y = np.meshgrid(x, y, indexing="xy") 267 | ale_interp = scipy.interpolate.interp2d(centres_list[0], centres_list[1], ale.T) 268 | CF = ax.contourf(X, Y, ale_interp(x, y), cmap="bwr", levels=30, alpha=0.7, **kwargs) 269 | 270 | if mark_empty and np.any(ale.mask): 271 | # Do not autoscale, so that boxes at the edges (contourf only plots the bin 272 | # centres, not their edges) don't enlarge the plot. 273 | plt.autoscale(False) 274 | # Add rectangles to indicate cells without samples. 275 | for i, j in zip(*np.where(ale.mask)): 276 | ax.add_patch( 277 | Rectangle( 278 | [quantiles_list[0][i], quantiles_list[1][j]], 279 | quantiles_list[0][i + 1] - quantiles_list[0][i], 280 | quantiles_list[1][j + 1] - quantiles_list[1][j], 281 | linewidth=1, 282 | edgecolor="k", 283 | facecolor="none", 284 | alpha=0.4, 285 | ) 286 | ) 287 | fig.colorbar(CF) 288 | 289 | 290 | def _get_quantiles(train_set, feature, bins): 291 | """Get quantiles from a feature in a dataset. 292 | 293 | Parameters 294 | ---------- 295 | train_set : pandas.core.frame.DataFrame 296 | Dataset containing feature `feature`. 297 | feature : column label 298 | Feature for which to calculate quantiles. 299 | bins : int 300 | The number of quantiles is calculated as `bins + 1`. 301 | 302 | Returns 303 | ------- 304 | quantiles : array-like 305 | Quantiles. 306 | bins : int 307 | Number of bins, `len(quantiles) - 1`. This may be lower than the original 308 | `bins` if identical quantiles were present. 309 | 310 | Raises 311 | ------ 312 | ValueError 313 | If `bins` is not an integer. 314 | 315 | Notes 316 | ----- 317 | When using this definition of quantiles in combination with a half open interval 318 | (lower quantile, upper quantile], care has to taken that the smallest observation 319 | is included in the first bin. This is handled transparently by `np.digitize`. 320 | 321 | """ 322 | if not isinstance(bins, (int, np.integer)): 323 | raise ValueError( 324 | "Expected integer 'bins', but got type '{}'.".format(type(bins)) 325 | ) 326 | quantiles = np.unique( 327 | np.quantile( 328 | train_set[feature], np.linspace(0, 1, bins + 1), interpolation="lower" 329 | ) 330 | ) 331 | bins = len(quantiles) - 1 332 | return quantiles, bins 333 | 334 | 335 | def _first_order_ale_quant(predictor, train_set, feature, bins): 336 | """Estimate the first-order ALE function for single continuous feature data. 337 | 338 | Parameters 339 | ---------- 340 | predictor : callable 341 | Prediction function. 342 | train_set : pandas.core.frame.DataFrame 343 | Training set on which the model was trained. 344 | feature : column label 345 | Feature name. A single column label. 346 | bins : int 347 | This defines the number of bins to compute. The effective number of bins may 348 | be less than this as only unique quantile values of train_set[feature] are 349 | used. 350 | 351 | Returns 352 | ------- 353 | ale : array-like 354 | The first order ALE. 355 | quantiles : array-like 356 | The quantiles used. 357 | 358 | """ 359 | quantiles, _ = _get_quantiles(train_set, feature, bins) 360 | logger.debug("Quantiles: {}.", quantiles) 361 | 362 | # Define the bins the feature samples fall into. Shift and clip to ensure we are 363 | # getting the index of the left bin edge and the smallest sample retains its index 364 | # of 0. 365 | indices = np.clip( 366 | np.digitize(train_set[feature], quantiles, right=True) - 1, 0, None 367 | ) 368 | 369 | # Assign the feature quantile values to two copied training datasets, one for each 370 | # bin edge. Then compute the difference between the corresponding predictions 371 | predictions = [] 372 | for offset in range(2): 373 | mod_train_set = train_set.copy() 374 | mod_train_set[feature] = quantiles[indices + offset] 375 | predictions.append(predictor(mod_train_set)) 376 | # The individual effects. 377 | effects = predictions[1] - predictions[0] 378 | 379 | # Average these differences within each bin. 380 | index_groupby = pd.DataFrame({"index": indices, "effects": effects}).groupby( 381 | "index" 382 | ) 383 | 384 | mean_effects = index_groupby.mean().to_numpy().flatten() 385 | 386 | ale = np.array([0, *np.cumsum(mean_effects)]) 387 | 388 | # The uncentred mean main effects at the bin centres. 389 | ale = _get_centres(ale) 390 | 391 | # Centre the effects by subtracting the mean (the mean of the individual 392 | # `effects`, which is equivalently calculated using `mean_effects` and the number 393 | # of samples in each bin). 394 | ale -= np.sum(ale * index_groupby.size() / train_set.shape[0]) 395 | return ale, quantiles 396 | 397 | 398 | def _second_order_ale_quant(predictor, train_set, features, bins): 399 | """Estimate the second-order ALE function for two continuous feature data. 400 | 401 | Parameters 402 | ---------- 403 | predictor : callable 404 | Prediction function. 405 | train_set : pandas.core.frame.DataFrame 406 | Training set on which the model was trained. 407 | features : 2-iterable of column label 408 | The two desired features, as two column labels. 409 | bins : [2-iterable of] int 410 | This defines the number of bins to compute. The effective number of bins may 411 | be less than this as only unique quantile values of train_set[feature] are 412 | used. If one integer is given, this is used for both features. 413 | 414 | Returns 415 | ------- 416 | ale : (M, N) masked array 417 | The second order ALE. Elements are masked where no data was available. 418 | quantiles : 2-tuple of array-like 419 | The quantiles used: first the quantiles for `features[0]` with shape (M + 1,), 420 | then for `features[1]` with shape (N + 1,). 421 | 422 | Raises 423 | ------ 424 | ValueError 425 | If `features` does not contain 2 features. 426 | ValueError 427 | If more than 2 bins are given. 428 | ValueError 429 | If bins are not integers. 430 | 431 | """ 432 | features = _parse_features(features) 433 | if len(features) != 2: 434 | raise ValueError( 435 | "'features' contained '{n_feat}' features. Expected 2.".format( 436 | n_feat=len(features) 437 | ) 438 | ) 439 | 440 | quantiles_list, bins_list = tuple( 441 | zip( 442 | *( 443 | _get_quantiles(train_set, feature, n_bin) 444 | for feature, n_bin in zip(features, _check_two_ints(bins)) 445 | ) 446 | ) 447 | ) 448 | logger.debug("Quantiles: {}.", quantiles_list) 449 | 450 | # Define the bins the feature samples fall into. Shift and clip to ensure we are 451 | # getting the index of the left bin edge and the smallest sample retains its index 452 | # of 0. 453 | indices_list = [ 454 | np.clip(np.digitize(train_set[feature], quantiles, right=True) - 1, 0, None) 455 | for feature, quantiles in zip(features, quantiles_list) 456 | ] 457 | 458 | # Invoke the predictor at the corners of the bins. Then compute the second order 459 | # difference between the predictions at the bin corners. 460 | predictions = {} 461 | for shifts in product(*(range(2),) * 2): 462 | mod_train_set = train_set.copy() 463 | for i in range(2): 464 | mod_train_set[features[i]] = quantiles_list[i][indices_list[i] + shifts[i]] 465 | predictions[shifts] = predictor(mod_train_set) 466 | # The individual effects. 467 | effects = (predictions[(1, 1)] - predictions[(1, 0)]) - ( 468 | predictions[(0, 1)] - predictions[(0, 0)] 469 | ) 470 | 471 | # Group the effects by their indices along both axes. 472 | index_groupby = pd.DataFrame( 473 | {"index_0": indices_list[0], "index_1": indices_list[1], "effects": effects} 474 | ).groupby(["index_0", "index_1"]) 475 | 476 | # Compute mean effects. 477 | mean_effects = index_groupby.mean() 478 | # Get the indices of the mean values. 479 | group_indices = mean_effects.index 480 | valid_grid_indices = tuple(zip(*group_indices)) 481 | # Extract only the data. 482 | mean_effects = mean_effects.to_numpy().flatten() 483 | 484 | # Get the number of samples in each bin. 485 | n_samples = index_groupby.size().to_numpy() 486 | 487 | # Create a 2D array of the number of samples in each bin. 488 | samples_grid = np.zeros(bins_list) 489 | samples_grid[valid_grid_indices] = n_samples 490 | 491 | ale = np.ma.MaskedArray( 492 | np.zeros((len(quantiles_list[0]), len(quantiles_list[1]))), 493 | mask=np.ones((len(quantiles_list[0]), len(quantiles_list[1]))), 494 | ) 495 | # Mark the first row/column as valid, since these are meant to contain 0s. 496 | ale.mask[0, :] = False 497 | ale.mask[:, 0] = False 498 | 499 | # Place the mean effects into the final array. 500 | # Since `ale` contains `len(quantiles)` rows/columns the first of which are 501 | # guaranteed to be valid (and filled with 0s), ignore the first row and column. 502 | ale[1:, 1:][valid_grid_indices] = mean_effects 503 | 504 | # Record where elements were missing. 505 | missing_bin_mask = ale.mask.copy()[1:, 1:] 506 | 507 | if np.any(missing_bin_mask): 508 | # Replace missing entries with their nearest neighbours. 509 | 510 | # Calculate the dense location matrices (for both features) of all bin centres. 511 | centres_list = np.meshgrid( 512 | *(_get_centres(quantiles) for quantiles in quantiles_list), indexing="ij" 513 | ) 514 | 515 | # Select only those bin centres which are valid (had observation). 516 | valid_indices_list = np.where(~missing_bin_mask) 517 | tree = cKDTree( 518 | np.hstack( 519 | tuple( 520 | centres[valid_indices_list][:, np.newaxis] 521 | for centres in centres_list 522 | ) 523 | ) 524 | ) 525 | 526 | row_indices = np.hstack( 527 | [inds.reshape(-1, 1) for inds in np.where(missing_bin_mask)] 528 | ) 529 | # Select both columns for each of the rows above. 530 | column_indices = np.hstack( 531 | ( 532 | np.zeros((row_indices.shape[0], 1), dtype=np.int8), 533 | np.ones((row_indices.shape[0], 1), dtype=np.int8), 534 | ) 535 | ) 536 | 537 | # Determine the indices of the points which are nearest to the empty bins. 538 | nearest_points = tree.query(tree.data[row_indices, column_indices])[1] 539 | 540 | nearest_indices = tuple( 541 | valid_indices[nearest_points] for valid_indices in valid_indices_list 542 | ) 543 | 544 | # Replace the invalid bin values with the nearest valid ones. 545 | ale[1:, 1:][missing_bin_mask] = ale[1:, 1:][nearest_indices] 546 | 547 | # Compute the cumulative sums. 548 | ale = np.cumsum(np.cumsum(ale, axis=0), axis=1) 549 | 550 | # Subtract first order effects along both axes separately. 551 | for i in range(2): 552 | # Depending on `i`, reverse the arguments to operate on the opposite axis. 553 | flip = slice(None, None, 1 - 2 * i) 554 | 555 | # Undo the cumulative sum along the axis. 556 | first_order = ale[(slice(1, None), ...)[flip]] - ale[(slice(-1), ...)[flip]] 557 | # Average the diffs across the other axis. 558 | first_order = ( 559 | first_order[(..., slice(1, None))[flip]] 560 | + first_order[(..., slice(-1))[flip]] 561 | ) / 2 562 | # Weight by the number of samples in each bin. 563 | first_order *= samples_grid 564 | # Take the sum along the axis. 565 | first_order = np.sum(first_order, axis=1 - i) 566 | # Normalise by the number of samples in the bins along the axis. 567 | first_order /= np.sum(samples_grid, axis=1 - i) 568 | # The final result is the cumulative sum (with an additional 0). 569 | first_order = np.array([0, *np.cumsum(first_order)]).reshape((-1, 1)[flip]) 570 | 571 | # Subtract the first order effect. 572 | ale -= first_order 573 | 574 | # Compute the ALE at the bin centres. 575 | ale = ( 576 | reduce( 577 | add, 578 | ( 579 | ale[i : ale.shape[0] - 1 + i, j : ale.shape[1] - 1 + j] 580 | for i, j in list(product(*(range(2),) * 2)) 581 | ), 582 | ) 583 | / 4 584 | ) 585 | 586 | # Centre the ALE by subtracting its expectation value. 587 | ale -= np.sum(samples_grid * ale) / train_set.shape[0] 588 | 589 | # Mark the originally missing points as such to enable later interpretation. 590 | ale.mask = missing_bin_mask 591 | return ale, quantiles_list 592 | 593 | 594 | def _first_order_ale_cat( 595 | predictor, train_set, feature, features_classes, feature_encoder=None 596 | ): 597 | """Compute the first-order ALE function on single categorical feature data. 598 | 599 | Parameters 600 | ---------- 601 | predictor : callable 602 | Prediction function. 603 | train_set : pandas.core.frame.DataFrame 604 | Training set on which model was trained. 605 | feature : str 606 | Feature name. 607 | features_classes : iterable or str 608 | Feature's classes. 609 | feature_encoder : callable or iterable 610 | Encoder that was used to encode categorical feature. If features_classes is 611 | not None, this parameter is skipped. 612 | 613 | """ 614 | num_cat = len(features_classes) 615 | ale = np.zeros(num_cat) # Final ALE function. 616 | 617 | for i in range(num_cat): 618 | subset = train_set[train_set[feature] == features_classes[i]] 619 | 620 | # Without any observation, local effect on split area is null. 621 | if len(subset) != 0: 622 | z_low = subset.copy() 623 | z_up = subset.copy() 624 | # The main ALE idea that compute prediction difference between same data 625 | # except feature's one. 626 | z_low[feature] = quantiles[i - 1] 627 | z_up[feature] = quantiles[i] 628 | ale[i] += (predictor(z_up) - predictor(z_low)).sum() / subset.shape[0] 629 | 630 | # The accumulated effect. 631 | ale = ale.cumsum() 632 | # Now we have to center ALE function in order to obtain null expectation for ALE 633 | # function. 634 | ale -= ale.mean() 635 | return ale 636 | 637 | 638 | def ale_plot( 639 | model, 640 | train_set, 641 | features, 642 | bins=10, 643 | monte_carlo=False, 644 | predictor=None, 645 | features_classes=None, 646 | monte_carlo_rep=50, 647 | monte_carlo_ratio=0.1, 648 | rugplot_lim=1000, 649 | ): 650 | """Plots ALE function of specified features based on training set. 651 | 652 | Parameters 653 | ---------- 654 | model : object 655 | An object that implements a 'predict' method. If None, a `predictor` function 656 | must be supplied which will be used instead of `model.predict`. 657 | train_set : pandas.core.frame.DataFrame 658 | Training set on which model was trained. 659 | features : [2-iterable of] column label 660 | One or two features for which to plot the ALE plot. 661 | bins : [2-iterable of] int, optional 662 | Number of bins used to split feature's space. 2 integers can only be given 663 | when 2 features are supplied in order to compute a different number of 664 | quantiles for each feature. 665 | monte_carlo : boolean, optional 666 | Compute and plot Monte-Carlo samples. 667 | predictor : callable 668 | Custom prediction function. See `model`. 669 | features_classes : iterable of str, optional 670 | If features is first-order and a categorical variable, plot ALE according to 671 | discrete aspect of data. 672 | monte_carlo_rep : int 673 | Number of Monte-Carlo replicas. 674 | monte_carlo_ratio : float 675 | Proportion of randomly selected samples from dataset for each Monte-Carlo 676 | replica. 677 | rugplot_lim : int, optional 678 | If `train_set` has more rows than `rugplot_lim`, no rug plot will be plotted. 679 | Set to None to always plot rug plots. Set to 0 to always plot rug plots. 680 | 681 | Raises 682 | ------ 683 | ValueError 684 | If both `model` and `predictor` are None. 685 | ValueError 686 | If `len(features)` not in {1, 2}. 687 | ValueError 688 | If multiple bins were given for 1 feature. 689 | NotImplementedError 690 | If `features_classes` is not None. 691 | 692 | """ 693 | if model is None and predictor is None: 694 | raise ValueError("If 'model' is None, 'predictor' must be supplied.") 695 | 696 | if features_classes is not None: 697 | raise NotImplementedError("'features_classes' is not implemented yet.") 698 | 699 | fig, ax = plt.subplots() 700 | 701 | features = _parse_features(features) 702 | 703 | if len(features) == 1: 704 | if not isinstance(bins, (int, np.integer)): 705 | raise ValueError("1 feature was given, but 'bins' was not an integer.") 706 | 707 | if features_classes is None: 708 | # Continuous data. 709 | 710 | if monte_carlo: 711 | mc_replicates = np.asarray( 712 | [ 713 | [ 714 | np.random.choice(range(train_set.shape[0])) 715 | for _ in range(int(monte_carlo_ratio * train_set.shape[0])) 716 | ] 717 | for _ in range(monte_carlo_rep) 718 | ] 719 | ) 720 | for k, rep in enumerate(mc_replicates): 721 | train_set_rep = train_set.iloc[rep, :] 722 | # Make this recursive? 723 | if features_classes is None: 724 | # The same quantiles cannot be reused here as this could cause 725 | # some bins to be empty or contain disproportionate numbers of 726 | # samples. 727 | mc_ale, mc_quantiles = _first_order_ale_quant( 728 | model.predict if predictor is None else predictor, 729 | train_set_rep, 730 | features[0], 731 | bins, 732 | ) 733 | _first_order_quant_plot( 734 | ax, mc_quantiles, mc_ale, color="#1f77b4", alpha=0.06 735 | ) 736 | 737 | ale, quantiles = _first_order_ale_quant( 738 | model.predict if predictor is None else predictor, 739 | train_set, 740 | features[0], 741 | bins, 742 | ) 743 | _ax_labels(ax, "Feature '{}'".format(features[0]), "") 744 | _ax_title( 745 | ax, 746 | "First-order ALE of feature '{0}'".format(features[0]), 747 | "Bins : {0} - Monte-Carlo : {1}".format( 748 | len(quantiles) - 1, 749 | mc_replicates.shape[0] if monte_carlo else "False", 750 | ), 751 | ) 752 | ax.grid(True, linestyle="-", alpha=0.4) 753 | if rugplot_lim is None or train_set.shape[0] <= rugplot_lim: 754 | sns.rugplot(train_set[features[0]], ax=ax, alpha=0.2) 755 | _first_order_quant_plot(ax, quantiles, ale, color="black") 756 | _ax_quantiles(ax, quantiles) 757 | 758 | elif len(features) == 2: 759 | if features_classes is None: 760 | # Continuous data. 761 | ale, quantiles_list = _second_order_ale_quant( 762 | model.predict if predictor is None else predictor, 763 | train_set, 764 | features, 765 | bins, 766 | ) 767 | _second_order_quant_plot(fig, ax, quantiles_list, ale) 768 | _ax_labels( 769 | ax, 770 | "Feature '{}'".format(features[0]), 771 | "Feature '{}'".format(features[1]), 772 | ) 773 | for twin, quantiles in zip(("x", "y"), quantiles_list): 774 | _ax_quantiles(ax, quantiles, twin=twin) 775 | _ax_title( 776 | ax, 777 | "Second-order ALE of features '{0}' and '{1}'".format( 778 | features[0], features[1] 779 | ), 780 | "Bins : {0}x{1}".format(*[len(quant) - 1 for quant in quantiles_list]), 781 | ) 782 | else: 783 | raise ValueError( 784 | "'{n_feat}' 'features' were given, but only up to 2 are supported.".format( 785 | n_feat=len(features) 786 | ) 787 | ) 788 | plt.show() 789 | return ax 790 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blent-ai/ALEPython/286350ab674980a32270db2a0b5ccca1380312a7/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_ale_calculation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from alepython.ale import _first_order_ale_quant, _get_centres, _second_order_ale_quant 7 | 8 | from .utils import interaction_predictor, linear_predictor 9 | 10 | 11 | def test_linear(): 12 | """We expect both X['a'] and X['b'] to have a linear relationship. 13 | 14 | There should be no second order interaction. 15 | 16 | """ 17 | 18 | def linear(x, a, b): 19 | return a * x + b 20 | 21 | np.random.seed(1) 22 | 23 | N = int(1e5) 24 | X = pd.DataFrame({"a": np.random.random(N), "b": np.random.random(N)}) 25 | 26 | # Test that the first order relationships are linear. 27 | for column in X.columns: 28 | ale, quantiles = _first_order_ale_quant(linear_predictor, X, column, 21) 29 | centres = _get_centres(quantiles) 30 | p, V = np.polyfit(centres, ale, 1, cov=True) 31 | assert np.all(np.isclose(p, [1, -0.5], atol=1e-3)) 32 | assert np.all(np.isclose(np.sqrt(np.diag(V)), 0)) 33 | 34 | # Test that a second order relationship does not exist. 35 | ale_second_order, quantiles_list = _second_order_ale_quant( 36 | linear_predictor, X, X.columns, 21 37 | ) 38 | assert np.all(np.isclose(ale_second_order, 0)) 39 | 40 | 41 | def test_interaction(): 42 | """Ensure that the method picks up a trivial interaction term.""" 43 | np.random.seed(1) 44 | 45 | N = int(1e6) 46 | X = pd.DataFrame({"a": np.random.random(N), "b": np.random.random(N)}) 47 | ale, quantiles_list = _second_order_ale_quant( 48 | interaction_predictor, X, X.columns, 61 49 | ) 50 | 51 | # XXX: There seems to be a small deviation proportional to the first axis ('a') 52 | # that is preventing this from being closer to 0. 53 | assert np.all(np.abs(ale[:, :30] - ale[:, 31:][::-1]) < 1e-2) 54 | -------------------------------------------------------------------------------- /tests/test_figure_creation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from contextlib import contextmanager 3 | from string import ascii_lowercase 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import pytest 9 | 10 | from alepython import ale_plot 11 | 12 | from .utils import SimpleModel, simple_predictor 13 | 14 | 15 | @contextmanager 16 | def assert_n_created_figures(n=1): 17 | """Assert that a given number of figures are created.""" 18 | initial_fignums = plt.get_fignums() 19 | yield # Do not catch exceptions (ie. no try except). 20 | new_fignums = set(plt.get_fignums()) - set(initial_fignums) 21 | n_new = len(new_fignums) 22 | assert n_new == n, "Expected '{n}' figure(s), got '{n_new}'.".format( 23 | n=n, n_new=n_new 24 | ) 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "features,columns", 29 | (("a", ("a", "b")), (("a",), ("a", "b", "c")), (("a", "b"), ("a", "b", "c"))), 30 | ) 31 | def test_model(features, columns): 32 | """Given a model with a predict method, a plot should be created.""" 33 | plt.ion() # Prevent plt.show() from blocking. 34 | np.random.seed(1) 35 | train_set = pd.DataFrame(np.random.random((100, len(columns))), columns=columns) 36 | with assert_n_created_figures(): 37 | ale_plot(SimpleModel(), train_set, features) 38 | # Clean up the created figure. 39 | plt.close() 40 | 41 | 42 | @pytest.mark.parametrize( 43 | "features,columns", 44 | (("a", ("a", "b")), (("a",), ("a", "b", "c")), (("a", "b"), ("a", "b", "c"))), 45 | ) 46 | def test_predictor(features, columns): 47 | """Given a predictor function, a plot should be created.""" 48 | plt.ion() # Prevent plt.show() from blocking. 49 | np.random.seed(1) 50 | train_set = pd.DataFrame(np.random.random((100, len(columns))), columns=columns) 51 | with assert_n_created_figures(): 52 | ale_plot( 53 | model=None, 54 | train_set=train_set, 55 | features=features, 56 | predictor=simple_predictor, 57 | ) 58 | # Clean up the created figure. 59 | plt.close() 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "features,columns", ((("a",), ("a", "b", "c")), (("a", "b"), ("a", "b", "c"))) 64 | ) 65 | def test_monte_carlo(features, columns): 66 | plt.ion() # Prevent plt.show() from blocking. 67 | np.random.seed(1) 68 | train_set = pd.DataFrame(np.random.random((100, len(columns))), columns=columns) 69 | with assert_n_created_figures(): 70 | ale_plot( 71 | model=None, 72 | train_set=train_set, 73 | features=features, 74 | predictor=simple_predictor, 75 | monte_carlo=True, 76 | ) 77 | # Clean up the created figure. 78 | plt.close() 79 | 80 | 81 | def test_df_column_features(): 82 | """Test the handling of the `features` argument. 83 | 84 | No matter the type of the `features` iterable, `ale_plot` should be able to select 85 | the right columns. 86 | 87 | """ 88 | plt.ion() # Prevent plt.show() from blocking. 89 | n_col = 3 90 | np.random.seed(1) 91 | train_set = pd.DataFrame( 92 | np.random.random((100, n_col)), columns=list(ascii_lowercase[:n_col]) 93 | ) 94 | with assert_n_created_figures(): 95 | ale_plot(SimpleModel(), train_set, train_set.columns[:1]) 96 | # Clean up the created figure. 97 | plt.close() 98 | -------------------------------------------------------------------------------- /tests/test_parameters.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | import pytest 5 | 6 | from alepython import ale_plot 7 | from alepython.ale import ( 8 | _ax_quantiles, 9 | _check_two_ints, 10 | _get_quantiles, 11 | _second_order_ale_quant, 12 | ) 13 | 14 | from .utils import SimpleModel 15 | 16 | 17 | def test_two_ints(): 18 | with pytest.raises(ValueError, match=r"'3' values.*"): 19 | _check_two_ints((1, 2, 3)) 20 | with pytest.raises(ValueError, match=r".*Got type(s) '{}'.*"): 21 | _check_two_ints(("1", "2")) 22 | 23 | 24 | def test_quantiles(): 25 | with pytest.raises(ValueError, match=r".*type ''.*"): 26 | _get_quantiles(pd.DataFrame({"a": [1, 2, 3]}), "a", "1") 27 | 28 | 29 | def test_second_order_ale_quant(): 30 | with pytest.raises(ValueError, match=r".*contained '1' features.*"): 31 | _second_order_ale_quant(lambda x: None, pd.DataFrame({"a": [1, 2, 3]}), "a", 1) 32 | 33 | 34 | def test_ale_plot(): 35 | """Test that proper errors are raised.""" 36 | with pytest.raises(ValueError, match=r".*'model'.*'predictor'.*"): 37 | ale_plot(model=None, train_set=pd.DataFrame([1]), features=[0]) 38 | 39 | with pytest.raises(ValueError, match=r"'3' 'features'.*"): 40 | ale_plot( 41 | model=SimpleModel(), train_set=pd.DataFrame([1]), features=list(range(3)) 42 | ) 43 | 44 | with pytest.raises(ValueError, match=r"'0' 'features'.*"): 45 | ale_plot(model=SimpleModel(), train_set=pd.DataFrame([1]), features=[]) 46 | 47 | with pytest.raises( 48 | NotImplementedError, match="'features_classes' is not implemented yet." 49 | ): 50 | ale_plot( 51 | model=SimpleModel(), 52 | train_set=pd.DataFrame([1]), 53 | features=[0], 54 | features_classes=["a"], 55 | ) 56 | 57 | with pytest.raises(ValueError, match=r"1 feature.*but 'bins' was not an integer."): 58 | ale_plot( 59 | model=SimpleModel(), train_set=pd.DataFrame([1]), features=[0], bins=1.0, 60 | ) 61 | 62 | 63 | def test_ax_quantiles(): 64 | fig, ax = plt.subplots() 65 | with pytest.raises(ValueError, match="'twin' should be one of 'x' or 'y'."): 66 | _ax_quantiles(ax, list(range(2)), "z") 67 | plt.close(fig) 68 | 69 | fig, ax = plt.subplots() 70 | _ax_quantiles(ax, list(range(2)), "x") 71 | plt.close(fig) 72 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | from .test_models import * 5 | -------------------------------------------------------------------------------- /tests/utils/test_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | class SimpleModel: 6 | """A simple predictive model for testing purposes. 7 | 8 | Methods 9 | ------- 10 | predict(X) 11 | Given input data `X`, predict response variable. 12 | 13 | """ 14 | 15 | def predict(self, X): 16 | return np.mean(X, axis=1) 17 | 18 | 19 | def simple_predictor(X): 20 | return np.mean(X, axis=1) 21 | 22 | 23 | def linear_predictor(X): 24 | """A simple linear effect with features 'a' and 'b'.""" 25 | return X["a"] + X["b"] 26 | 27 | 28 | def interaction_predictor(X): 29 | """Interaction changes sign at b = 0.5.""" 30 | a = X["a"] 31 | b = X["b"] 32 | 33 | out = np.empty_like(a) 34 | 35 | mask = b <= 0.5 36 | out[mask] = a[mask] * b[mask] 37 | mask = ~mask 38 | out[mask] = -a[mask] * b[mask] 39 | 40 | return out 41 | --------------------------------------------------------------------------------